I am trying to read 7 bands in TFRecords to create a dataset using the following function. Am I doing this correctly, I ask because my CNN is performing very poorly and I want to be sure:
# Define the feature columns for parsing the TFRecord data
bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7']
patch_dimensions_flat = [40, 40]
image_columns = [
tf.io.FixedLenFeature(shape=patch_dimensions_flat, dtype=tf.float32)
for k in bands
]
image_columns += [tf.io.FixedLenFeature(shape=patch_dimensions_flat, dtype=tf.int64)]
# Create a dictionary mapping feature names to feature columns
image_features_dict = dict(zip(bands, image_columns))
# Define a parsing function to parse a single example from a TFRecord
def parse_image(example_proto):
try:
parsed_features = tf.io.parse_single_example(example_proto, image_features_dict)
# Reshape the input tensor to 4D
image = tf.reshape(parsed_features['B1'], (patch_dimensions_flat[0], patch_dimensions_flat[1], 1))
for band in bands[1:]:
image = tf.concat([image, tf.reshape(parsed_features[band], (patch_dimensions_flat[0], patch_dimensions_flat[1], 1))], axis=-1)
# Normalize the image
image = (image - tf.reduce_min(image)) / (tf.reduce_max(image) - tf.reduce_min(image))
# Apply data augmentation to the image
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
image = tf.image.rot90(image, k=tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)) # randomly rotate the image by 0, 90, 180 or 270 degrees
image = tf.convert_to_tensor(image, dtype=tf.float32)
return image
except tf.errors.InvalidArgumentError as e:
print(f"Error parsing example: {e}")
return None