0

Also posted the question at https://github.com/tensorflow/transform/issues/261

I am using tft in TFX and needs to transform string list class labels into multi-hot indicators inside preprocesing_fn. Essentially:

vocab = tft.vocabulary(inputs['label'])
outputs['label'] = tf.cast(
    tf.sparse.to_indicator(
       tft.apply_vocabulary(inputs['label'], vocab),
       vocab_size=VOCAB_SIZE,
    ),
    "int64",
)

I am trying to get VOCAB_SIZE from the result of vocab, but couldn't find a way to satisfy the deferred execution and known shapes. The closest I got below wouldn't pass the saved model export as the shape for label is unknown.

def _make_table_initializer(filename_tensor):
    return tf.lookup.TextFileInitializer(
        filename=filename_tensor,
        key_dtype=tf.string,
        key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
        value_dtype=tf.int64,
        value_index=tf.lookup.TextFileIndex.LINE_NUMBER,
    )

def _vocab_size(deferred_vocab_filename_tensor):
    initializer = _make_table_initializer(deferred_vocab_filename_tensor)
    table = tf.lookup.StaticHashTable(initializer, default_value=-1)
    table_size = table.size()
    return table_size

deferred_vocab_and_filename = tft.vocabulary(inputs['label'])
vocab_applied = tft.apply_vocabulary(inputs['label'], deferred_vocab_and_filename)
vocab_size = _vocab_size(deferred_vocab_and_filename)
outputs['label'] = tf.cast(
    tf.sparse.to_indicator(vocab_applied, vocab_size=vocab_size),
    "int64",
)

Got

ValueError: Feature label (Tensor("Identity_3:0", shape=(None, None), dtype=int64)) had invalid shape (None, None) for FixedLenFeature: apart from the batch dimension, all dimensions must have known size [while running 'Analyze/CreateSavedModel[tf_v2_only]/CreateSavedModel']

Any idea how to achieve this?

ynait
  • 1

1 Answers1

0

As per this comment in the github issue, You can use tft.experimental.get_vocabulary_size_by_name (link) to achieve the same.

halfer
  • 19,824
  • 17
  • 99
  • 186
  • 1
    I also found this works the same. https://www.tensorflow.org/tfx/transform/api_docs/python/tft/get_num_buckets_for_transformed_feature – ynait Feb 25 '22 at 18:09