0

I have a tf.data.Dataset called train_dataset such that -

a, b = next(iter(train_dataset))
print(a.shape, b.shape)

output - (100, 128, 128, 2) (100, 128, 128, 1)

That is, my batch size is 100, input images are 128*128 sized complex-valued images such that first channel represents the real part of the input image and second channel represents the imaginary part. And the target images are grayscale, i.e., single channel images.

For pre-processing the dataset, I need to find the mean image of the batch and subtract it from each image in the batch (Note in my input images, each image has two channels and is of shape (128,128,2)).

I tried tf.data.Dataset.reduce() as follows to find the mean of input images -

train_mean = train_dataset.reduce(0., tf.math.add)/ tf.cast(train_dataset.cardinality(),tf.float32)

But it is giving the following error -

InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [100,128,128,2] != values[1].shape = [100,128,128,1]
     [[{{node add/y}}]] [Op:ReduceDataset]

Any help in debugging the error or any other suggestion to compute dataset mean will be highly appreciated.


On the suggestion of @HakanAkgun in comments, I tried using np.mean() in a py_function (followed by another py_function for linear scaling to (0,1) range), but the py_functions don't seem to be affecting the dataset values at all. I expected the second print statement to give a different min value and the third print statement to print a 0 min value, but the min value does not seem to be changed at all.

def remove_mean(image, target):
    image_mean = np.mean(image, axis=0)
    target_mean = np.mean(target, axis=0)
    image = image - image_mean
    target = target - target_mean
    return image, target

def linear_scaling(image, target):
    image_min = np.ndarray.min(image, axis=(1,2), keepdims=True)
    image_max = np.ndarray.max(image, axis=(1,2), keepdims=True)
    image = (image-image_min)/(image_max-image_min)

    target_min = np.ndarray.min(target, axis=(1,2), keepdims=True)
    target_max = np.ndarray.max(target, axis=(1,2), keepdims=True)
    target = (target-target_min)/(target_max-target_min)
    return image, target

a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))

train_dataset.map(lambda item1, item2: tuple(tf.py_function(remove_mean, [item1, item2], [tf.float32, tf.float32])))
test_dataset.map(lambda item1, item2: tuple(tf.py_function(remove_mean, [item1, item2], [tf.float32, tf.float32])))

a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))

train_dataset.map(lambda item1, item2: tuple(tf.py_function(linear_scaling, [item1, item2], [tf.float32])))
test_dataset.map(lambda item1, item2: tuple(tf.py_function(linear_scaling, [item1, item2], [tf.float32])))

a, b = next(iter(train_dataset))
print(tf.math.reduce_min(a[0,:,:,:]))


Output - 

tf.Tensor(-0.00040511801, shape=(), dtype=float32)
tf.Tensor(-0.00040511801, shape=(), dtype=float32)
tf.Tensor(-0.00040511801, shape=(), dtype=float32)
psj
  • 356
  • 3
  • 18
  • As I understand a represents your images but I couldn't get what b represents. In order to find the mean over the batch have you tried to use np.mean(Image_batch,axis=0)? – Hakan Akgün Jul 29 '21 at 17:43
  • b represents the target (i.e., ground truth) images. I haven't tried np.mean(Image_batch,axis=0), I suppose I should try to use py_function computing np.mean, is that correct? – psj Jul 30 '21 at 00:53
  • @HakanAkgün I have added the code with np.mean, but the values in the dataset don't seem to be affected by these py_functions (for removing mean and linear scaling) as the max and min values don't change after the transformation. Could you clarify the way for applying np.mean? – psj Jul 30 '21 at 04:03
  • Since you have arrays np.mean(arr,axis=n) helps you to find mean over a certain axis and it returns an array with shape batch_size, height, width. then you can subtract it. – Hakan Akgün Jul 30 '21 at 10:01
  • Yes, I have used it in the same way in the py_function remove_mean. But even after applying the function to the dataset, the values don't seem to have changed. – psj Jul 30 '21 at 10:08
  • 1
    As I see you are applying the map function but you are not reassigning the result of your map function back into your variable. Can you try to reassign it by doing train_dataset=train_dataset.map( ...function...) rather than just train_dataset.map( ...function...) – Hakan Akgün Jul 30 '21 at 12:14
  • @HakanAkgün Thanks, that resolved the issue! – psj Jul 30 '21 at 15:11

0 Answers0