0

I am trying to read 7 bands in TFRecords to create a dataset using the following function. Am I doing this correctly, I ask because my CNN is performing very poorly and I want to be sure:

# Define the feature columns for parsing the TFRecord data
bands = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7']
patch_dimensions_flat = [40, 40]
image_columns = [
  tf.io.FixedLenFeature(shape=patch_dimensions_flat, dtype=tf.float32) 
  for k in bands
]

image_columns += [tf.io.FixedLenFeature(shape=patch_dimensions_flat, dtype=tf.int64)]

# Create a dictionary mapping feature names to feature columns
image_features_dict = dict(zip(bands, image_columns))

    # Define a parsing function to parse a single example from a TFRecord
    def parse_image(example_proto):
        try:
            parsed_features = tf.io.parse_single_example(example_proto, image_features_dict)
            # Reshape the input tensor to 4D
            image = tf.reshape(parsed_features['B1'], (patch_dimensions_flat[0], patch_dimensions_flat[1], 1))
            for band in bands[1:]:
                image = tf.concat([image, tf.reshape(parsed_features[band], (patch_dimensions_flat[0], patch_dimensions_flat[1], 1))], axis=-1)
            # Normalize the image
            image = (image - tf.reduce_min(image)) / (tf.reduce_max(image) - tf.reduce_min(image))
    
            # Apply data augmentation to the image
            image = tf.image.random_flip_left_right(image)
            image = tf.image.random_flip_up_down(image)
            image = tf.image.rot90(image, k=tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)) # randomly rotate the image by 0, 90, 180 or 270 degrees
            image = tf.convert_to_tensor(image, dtype=tf.float32)
    
            return image
        except tf.errors.InvalidArgumentError as e:
            print(f"Error parsing example: {e}")
            return None
riskiem
  • 187
  • 2
  • 7

0 Answers0