7

I was looking at different ways that one can do custom Tensorflow datasets, and I was used to looking at PyTorch's datasets, but when I went to look at Tensorflow's datasets, I saw this example:

class ArtificialDataset(tf.data.Dataset):
  def _generator(num_samples):
    # Opening the file
    time.sleep(0.03)

    for sample_idx in range(num_samples):
      # Reading data (line, record) from the file
      time.sleep(0.015)

      yield (sample_idx,)

  def __new__(cls, num_samples=3):
    return tf.data.Dataset.from_generator(
        cls._generator,
        output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
        args=(num_samples,)
        )

But two questions came up:

  1. This looks like all it does is that when the object is instantiated, the __new__ method just calls the tf.data.Dataset.from_generator static method. So why not just call it? Why is there a point of even subclassing tf.data.Dataset? Are there any methods that are even used from tf.data.Dataset?
  2. Would there be a way to do it like a data generator, where one fills out an __iter__ method while inheriting from tf.data.Dataset? Idk, something like
class MyDataLoader(tf.data.Dataset):
  def __init__(self, path, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.data = pd.read_csv(path)

  def __iter__(self):
    for datum in self.data.iterrows():
      yield datum

Thank you all very much!

Nicolas Shu
  • 327
  • 2
  • 8

1 Answers1

2

Question 1

That example is just encapsulating a dataset with a generator in a class. It is inheriting from tf.data.Dataset because from_generator() returns a tf.data.Dataset -based object. However, no methods of tf.data.Dataset are used as seen in the example. Thus, answer to question 1: yes, it can be called straight without using the class.

Question 2

Yes. It can be done like that.

Another, similar way would be to use tf.keras.utils.Sequence like here.

Eljas Hyyrynen
  • 233
  • 1
  • 11
  • 1
    Up-voted for suggestion to go the `tf.keras.utils.Sequence` route. Just like PyTorch, inherit from an existing class then define `__init__`, `__len__`, and `__getitem__` – cdahms Jan 02 '23 at 19:38