0

One gets an already built tensorflow dataset object (tf.data.Dataset) named data.

Is there a way to know if the function repeat/batch/shuffle was called on this object, by inspecting data ? (and possibly get other informations like the argument of batch and repeat)

(I assume eager execution)

edit 1: seems line the str method carries some information. Looking into that.

edit 2: the attribute output_shapes give information on the batch size and shapes.

jeandut
  • 2,471
  • 4
  • 29
  • 56

1 Answers1

0

The only solution I could think of is getting into tensorflow code. gen_dataset_ops.py is generated during building from source, so it could only be found locally.

Another file is dataset_ops.py, it's available in the link below. You just insert print statement before relevant function's return. For example shuffle function from dataset_ops.py:

def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
"""Randomly shuffles the elements of this dataset.
...
print('Dataset shuffled') #inserted print here
return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration)

Dataset object is wrapped into DatasetV1Adapter, so you can't know anything about it advance. The only difference in eager mode is that it supports explicit iteration, but it'll be extremely inefficient to do smth like

array = np.random.rand(10)
dataset = tf.data.Dataset.from_tensor_slices(array)
if len([i for i in dataset]) != array.shape[0]:
    print('repeated')

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/data/ops/dataset_ops.py

Sharky
  • 4,473
  • 2
  • 19
  • 27