I have a tf.data.Dataset
that I read from a tfrecords
file like so:
import tensorflow as tf
# given an existing record_file
raw_dataset = tf.data.TFRecordDataset(record_file)
example_description = {
"height": tf.io.FixedLenFeature([], tf.int64),
"width": tf.io.FixedLenFeature([], tf.int64),
"channels": tf.io.FixedLenFeature([], tf.int64),
"image": tf.io.FixedLenFeature([], tf.string),
}
dataset = raw_dataset.map(
lambda example: tf.io.parse_single_example(example, example_description)
)
Next, I combine the features into a single image like so:
dataset = dataset.map(_extract_image_from_sample)
# and
def _extract_image_from_sample(sample):
height = tf.cast(sample["height"], tf.int32) # always 1038
width = tf.cast(sample["width"], tf.int32) # always 1366
depth = tf.cast(sample["channels"], tf.int32) # always 3
shape = [height, width, depth]
image = sample["image"]
image = decode_tf_image(image)
image = tf.reshape(image, shape)
return image
At this point, any image in the dataset has shape (None, None, None)
(which surprises me, because I reshape them).
I believe this to be the cause of error, when I try to augment the dataset using tf.keras.preprocessing.image.ImageDataGenerator
:
augmented_dataset = dataset.map(random_image_augmentation)
# and
image_data_generator = tf.keras.preprocessing.image.ImageDataGenerator(
rotation_range=45,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=5.0,
zoom_range=[0.9, 1.2],
fill_mode="reflect",
horizontal_flip=True,
vertical_flip=True,
)
def random_image_augmentation(image: tf.Tensor) -> tf.Tensor:
transform = image_data_generator.get_random_transform(img_shape=image.shape)
image = image_data_generator.apply_transform(image, transform)
return image
This results in the error message:
TypeError: in user code:
# ...
C:\Users\[PATH_TO_ENVIRONMENT]\lib\site-packages\keras_preprocessing\image\image_data_generator.py:778 get_random_transform *
tx *= img_shape[img_row_axis]
TypeError: unsupported operand type(s) for *=: 'float' and 'NoneType'
However, if I don't use graph mode, but eager mode, this works like a charm:
it = iter(dataset)
for i in range(3):
image = it.next()
image = random_image_augmentation(image.numpy())
This leads me to the conclusion that the main error is the missing shape information after reading in the dataset. But I don't know how to define it more explicitly than I already do. Any ideas?