I have a tf.data.Dataset called train_dataset such that -
a, b = next(iter(train_dataset))
print(a.shape, b.shape)
output - (100, 128, 128, 2) (100, 128, 128, 1)
That is, my batch size is 100, input images are 128*128 sized complex-valued images such that first channel represents the real part of the input image and second channel represents the imaginary part. And the target images are grayscale, i.e., single channel images.
For pre-processing the dataset, I need to find the mean image of the batch and subtract it from each image in the batch (Note in my input images, each image has two channels and is of shape (128,128,2)).
I tried tf.data.Dataset.reduce() as follows to find the mean of input images -
train_mean = train_dataset.reduce(0., tf.math.add)/ tf.cast(train_dataset.cardinality(),tf.float32)
But it is giving the following error -
InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [100,128,128,2] != values[1].shape = [100,128,128,1]
[[{{node add/y}}]] [Op:ReduceDataset]
Any help in debugging the error or any other suggestion to compute dataset mean will be highly appreciated.
On the suggestion of @HakanAkgun in comments, I tried using np.mean() in a py_function (followed by another py_function for linear scaling to (0,1) range), but the py_functions don't seem to be affecting the dataset values at all. I expected the second print statement to give a different min value and the third print statement to print a 0 min value, but the min value does not seem to be changed at all.
def remove_mean(image, target):
image_mean = np.mean(image, axis=0)
target_mean = np.mean(target, axis=0)
image = image - image_mean
target = target - target_mean
return image, target
def linear_scaling(image, target):
image_min = np.ndarray.min(image, axis=(1,2), keepdims=True)
image_max = np.ndarray.max(image, axis=(1,2), keepdims=True)
image = (image-image_min)/(image_max-image_min)
target_min = np.ndarray.min(target, axis=(1,2), keepdims=True)
target_max = np.ndarray.max(target, axis=(1,2), keepdims=True)
target = (target-target_min)/(target_max-target_min)
return image, target
a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))
train_dataset.map(lambda item1, item2: tuple(tf.py_function(remove_mean, [item1, item2], [tf.float32, tf.float32])))
test_dataset.map(lambda item1, item2: tuple(tf.py_function(remove_mean, [item1, item2], [tf.float32, tf.float32])))
a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))
train_dataset.map(lambda item1, item2: tuple(tf.py_function(linear_scaling, [item1, item2], [tf.float32])))
test_dataset.map(lambda item1, item2: tuple(tf.py_function(linear_scaling, [item1, item2], [tf.float32])))
a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))
Output -
tf.Tensor(-0.00040511801, shape=(), dtype=float32)
tf.Tensor(-0.00040511801, shape=(), dtype=float32)
tf.Tensor(-0.00040511801, shape=(), dtype=float32)