28

I'm playing with the Dataset API in Tensorflow v1.3. It's great. It is possible to map a dataset with a function as described here. I am interested to know how can I pass a function which has an additional argument, for example arg1:

def _parse_function(example_proto, arg1):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

Of course,

dataset = dataset.map(_parse_function)

will not work since there is no way to pass in arg1.

mikkola
  • 3,376
  • 1
  • 19
  • 41
AmirHJ
  • 827
  • 1
  • 11
  • 21
  • Just an idea, maybe we can fake this by passing in a python class with arg1 as its class member, as well as a defined __call__method. – Yao Zhang Sep 18 '17 at 06:39
  • What kind of argument is `arg1`? if it is a regular Python variable (not TensorFlow), you can just define your `_parse_function` function within another function in which `arg1` is known, and you don't have to pass it in anymore. – CNugteren Nov 21 '17 at 13:00

3 Answers3

44

Here is an example using a lambda expression to wrap the function to which we want to pass an argument:

import tensorflow as tf
def fun(x, arg):
    return x * arg

my_arg = tf.constant(2, dtype=tf.int64)
ds = tf.data.Dataset.range(5)
ds = ds.map(lambda x: fun(x, my_arg))

In the above, the signature of the function provided to map must match the contents of our dataset. So we have to write our lambda expression to match that. Here it is simple, as there is only one element contained in the dataset, the x that contains elements in the range from 0 to 4.

If necessary, you can pass in an arbitrary number of external arguments from outside the dataset: ds = ds.map(lambda x: my_other_fun(x, arg1, arg2, arg3), and so on.

To verify that the above works, we can observe that the mapping indeed multiplies each dataset element by two:

iterator = ds.make_initializable_iterator()
next_x = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)

    while True:
      try:
        print(sess.run(next_x))
      except tf.errors.OutOfRangeError:
        break

The output:

0
2
4
6
8
mikkola
  • 3,376
  • 1
  • 19
  • 41
9

You can also use a Partial function instead to wrap your parameter :

def _parse_function(arg1, example_proto):
  features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),
              "label": tf.FixedLenFeature((), tf.int32, default_value=0)}
  parsed_features = tf.parse_single_example(example_proto, features)
  return parsed_features["image"], parsed_features["label"]

The parameters order of your function is changed in order to fit the partiality, then you can wrap your function with a parameter value like following :

from functools import partial

arg1 = ...
dataset = dataset.map(partial(_parse_function, arg1))
Faylixe
  • 478
  • 6
  • 12
1

Another solution is to use a class wrapper. In the following code, I passed the parameter shape to the parse function.

class MyDataSets:

    def __init__(self, shape):
        self.shape = shape

    def parse_sample(self.sample):
        features = { ... }
        f = tf.parse_example([example], features=features)

        image_raw = tf.decode_raw(f['image_raw'], tf.uint8)
        image = image.reshape(image_raw, self.shape)

        label = tf.cast(f['label'], tf.int32)

        return image, label

    def init(self):
        ds = tf.data.TFRecordDataSets(...)
        ds = ds.map(self.parse_sample)
        ...
        return ds.make_initializable_iterator()
redice li
  • 106
  • 2