[Edit #1 after @mrry comment] I am using the (great & amazing) Dataset API along with tf.contrib.data.rejection_resample to set a specific distribution function to the input training pipeline.
Before adding the tf.contrib.data.rejection_resample to the input_fn I used the one shot Iterator. Alas, when starting to use the latter, I tried using the dataset.make_initializable_iterator() - This is because we are introducing to the pipeline stateful variables, and one is required to initialize the iterator AFTER all variables in the input pipeline are init. As @mrry wrote here.
I am passing the input_fn to an estimator and wrapped by an Experiment.
Problem is - where to hook the init of the iterator? If I try:
dataset = dataset.batch(batch_size)
if self.balance:
dataset = tf.contrib.data.rejection_resample(dataset, self.class_mapping_function, self.dist_target)
iterator = dataset.make_initializable_iterator()
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
else:
iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()
print (image_batch)
and the mapping function:
def class_mapping_function(self, feature, label):
"""
returns a a function to be used with dataset.map() to return class numeric ID
The function is mapping a nested structure of tensors (having shapes and types defined by dataset.output_shapes
and dataset.output_types) to a scalar tf.int32 tensor. Values should be in [0, num_classes).
"""
# For simplicity, trying to return the label itself as I assume its numeric...
return tf.cast(label, tf.int32) # <-- I guess this is the bug
the iterator does not receive the Tensor shape as it does with one shot iterator.
For Example. With One Shot iterator run, the iterator gets correct shape:
Tensor("train_input_fn/IteratorGetNext:0", shape=(?, 100, 100, 3), dtype=float32, device=/device:CPU:0)
But when using the initializable iterator, it is missing tensor shape info:
Tensor("train_input_fn/IteratorGetNext:0", shape=(?,), dtype=int32, device=/device:CPU:0)
Any help will be so appreciated!
[Edit #2 ]- following @mrry comment that it seems like another dataset] Perhaps the real issue here is not the init sequence of the iterator but the mapping function used by tf.contrib.data.rejection_resample that returns tf.int32. But then I wonder how the mapping function should be defined ? To keep the dataset shape as (?,100,100,3) for example...
[Edit #3]: From the implementation of rejection_resample
class_values_ds = dataset.map(class_func)
So it makes sense the class_func will take a dataset and return a dataset of tf.int32.