8

Newby here... I loaded TF dataset as follows:

dataset = tf.data.TFRecordDataset(files)
dataset.map(extract_fn)

The dataset contains a "string column" with some values and I want to "one-hot" encode them. I could do that in the extract_fn record by record if I had indices and depth (I only have a String value as of now). However, is there a TF function that could do that for me? i.e.

  • Count the number of distinct values
  • Map each value to an index
  • Create a one-hot encoded column for that
Innat
  • 16,113
  • 6
  • 53
  • 101
Marsellus Wallace
  • 17,991
  • 25
  • 90
  • 154

1 Answers1

0

I think this does what you want:

import tensorflow as tf
def one_hot_any(a):
    # Save original shape
    s = tf.shape(a)
    # Find unique values
    values, idx = tf.unique(tf.reshape(a, [-1]))
    # One-hot encoding
    n = tf.size(values)
    a_1h_flat = tf.one_hot(idx, n)
    # Reshape to original shape
    a_1h = tf.reshape(a_1h_flat, tf.concat([s, [n]], axis=0))
    return a_1h, values

# Test
x = tf.constant([['a', 'b'], ['a', 'd'], ['c', 'd'], ['b', 'd']])
x_1h, x_vals = one_hot_any(x)
with tf.Session() as sess:
    print(*sess.run([x_1h, x_vals]), sep='\n')

Output:

[[[1. 0. 0. 0.]
  [0. 1. 0. 0.]]

 [[1. 0. 0. 0.]
  [0. 0. 1. 0.]]

 [[0. 0. 0. 1.]
  [0. 0. 1. 0.]]

 [[0. 1. 0. 0.]
  [0. 0. 1. 0.]]]
[b'a' b'b' b'd' b'c']

The problem, though, is that different inputs will produce inconsistent outputs, with different value orders or even different one-hot depth, so I'm not sure if it is really useful.

Innat
  • 16,113
  • 6
  • 53
  • 101
jdehesa
  • 58,456
  • 7
  • 77
  • 121