0

TF version: 2.11

I try to train a simple 2input classifier with TFRecords tf.data pipeline

I do not manage to convert the tf.dense Tensor with containing only a scalar to a tf.onehot vector

    # get all recorddatasets abspath
    training_names= [record_path+'/'+rec for rec in os.listdir(record_path) if rec.startswith('train')]
    
    # load in tf dataset
    train_dataset = tf.data.TFRecordDataset(training_names[1])
    train_dataset = train_dataset.map(return_xy)

mapping function:

def return_xy(example_proto):

    #parse example
    sample= parse_function(example_proto)

    #decode image 1
    encoded_image1 = sample['image/encoded_1']
    decoded_image1 = decode_image(encoded_image1)

    #decode image 2
    encoded_image2 = sample['image/encoded_2']
    decoded_image2 = decode_image(encoded_image2)

    #decode label 
    print(f'image/object/class/'+level: {sample['image/object/class/'+level]}')

    class_label = tf.sparse.to_dense(sample['image/object/class/'+level])
    print(f'type of class label :{type(class_label)}')
    print(class_label)

    # conversion to onehot with depth 26 :: -> how can i extract only the value or convert directly to tf.onehot??
    label_onehot=tf.one_hot(class_label,26)


    #resizing image 
    input_left=tf.image.resize(decoded_image1,[416, 416])
    input_right=tf.image.resize(decoded_image2,[416, 416])
    return {'input_3res1':input_left, 'input_5res2':input_right} ,  label_onehot

output:

image/object/class/'+level: SparseTensor(indices=Tensor("ParseSingleExample/ParseExample/ParseExampleV2:14", shape=(None, 1), dtype=int64), values=Tensor("ParseSingleExample/ParseExample/ParseExampleV2:31", shape=(None,), dtype=int64), dense_shape=Tensor("ParseSingleExample/ParseExample/ParseExampleV2:48", shape=(1,), dtype=int64))

type of class label :<class 'tensorflow.python.framework.ops.Tensor'>
Tensor("SparseToDense:0", shape=(None,), dtype=int64)

However I am sure that the label is in this Tensor because when run it eagerly

raw_dataset = tf.data.TFRecordDataset([rec_file])
parsed_dataset = raw_dataset.map(parse_function) # only parsing

for sample in parsed_dataset:
    class_label=tf.sparse.to_dense(sample['image/object/class/label_level3'])[0]
    print(f'type of class label :{type(class_label)}')
    print(f'labels  from labelmap :{class_label}')

I get output:

type of class label :<class 'tensorflow.python.framework.ops.EagerTensor'>
labels  from labelmap :7

If I just chose a random number for the label and pass it to tf_one_hot(randint, 26) then the model begins to train (obviously nonsensical).

So the question is how can i convert the:

Tensor("SparseToDense:0", shape=(None,), dtype=int64)

to a

Tensor("one_hot:0", shape=(26,), dtype=float32)

What I tried so far

in the call data.map(parse_xy) i tried to just call .numpy() on the tf tensors but didnt work , this only works for eager tensors.

In my understanding i cannot use eager execution because everthing in the parse_xy function gets excecuted on the whole graph: ive already tried to enable eager execution -> failed

https://www.tensorflow.org/api_docs/python/tf/config/run_functions_eagerly
Note: This flag has no effect on functions passed into tf.data transformations as arguments.
 tf.data functions are never executed eagerly and are always executed as a compiled Tensorflow Graph.

ive also tried to use the tf_pyfunc but this only returns another tf.Tensor with an unknown shape

def get_onehot(tensor):
    class_label=tensor[0]
    return tf.one_hot(class_label,26)

and add the line in parse_xy:

    label_onehot=tf.py_function(func=get_onehot, inp=[class_label], Tout=tf.int64)

but there i always get an unknown shape which a cannot just alter with .set_shape()

1 Answers1

0

I was able to solve the issue by only using TensorFlow functions.

tf.gather allows to index a TensorFlow tensor:

class_label_gather = tf.sparse.to_dense(sample['image/object/class/'+level])
class_indices = tf.gather(tf.cast(class_label_gather,dtype=tf.int32),0)
label_onehot=tf.one_hot(class_indices,26)
Anton Menshov
  • 2,266
  • 14
  • 34
  • 55