0

We are investigating transitioning our ML pipelines from a set of manual steps into a TFX pipeline. I do however have some questions for which I would like to have some additional insights.

We typically perform the following steps (for an image classification task):

  1. Load image data and meta-data
  2. Filter out ‘bad’ data based on meta-data
  3. Determine image based statistics (classic image processing in Python):
    1. Image level characteristics
    2. Image region characteristics (region is determined based on a fine-tuned EfficientDet model)
  4. Filter out ‘bad’ data based on image statistics
  5. Generate TFRecords from this image and meta-data
  6. Oversample certain TFRecords for class balancing (using tf.data)
  7. Train an image classifier

Now, I’m trying to map this onto the typical example TFX pipeline.

This however raises a number of questions:

  1. I see two options:

    • ExampleGen uses a CSV file containing pointers to the image to be loaded and the meta-data to be loaded (above step ‘1’). However:

      • If this CSV file contains a path to an image file, can ExampleGen then load the image data and add this to its output?
      • Is the output of ExampleGen a streaming output, or a dump of all example data?
    • ExampleGen has TFRecords as input (output of above step ‘5’)

      -> This implies that we would still need to implement steps 1-5 outside of TFX… Which would decrease the value for TFX for us…

    Could you please advice what would be the best way forward?

  2. Can StatisticsGen also generate statistics on a per-example base (for example some image (region) characteristics based on classic image processing)? Or should this be implemented in ExampleGen? Or…?

  3. Can the calculated statistics be cached using the metadata store? If yes, is there an example of this available?

    Calculating image based characteristics using classic image processing is slow. If new data becomes available, triggering the TFX input component to be executed, ideally already calculated statistics should be loaded from the cache.

  4. Is it correct that ExampleValidator may reject some examples (e.g. missing data, outliers, …)?

  5. How can class balancing at the network input side (not via the loss function) be achieved in this setup (normally we do this by oversampling our TFRecords using tf.data)? If this is done at the ExampleGen level, then the ExampleValidator may still reject some examples potentially unbalancing the data again. This may not seem like a big issue for large data ML tasks, but it becomes crucial for small data ML tasks (as typically is the case in a healthcare setting). So I would expect a TFX component for this before the Transform component, but this block should then have access to all data, not in a streaming way (see my earlier question on ExampleGen output)…

Thank you for your insights.

1 Answers1

0

I'll try to address most questions with my experience with tfx.

  1. I have a dataflow job that I run to pre-process my images, labels, features, etc and turn all that into tfrecords. That lives outside of tfx and is ran only when there is data refreshes.

You can do the same, here is a very simple code snippet that i use to resize all my images and create simple features.

try:
      image = tf.io.decode_jpeg(image_string)
      image = tf.image.resize(image,[image_resize_size,image_resize_size])
      image = tf.image.convert_image_dtype(image/255.0, dtype=tf.uint8)
      image_shape = image.shape

      image = tf.io.encode_jpeg(image,quality=100)

      feature = {
        'height': _int64_feature(image_shape[0]),
        'width' : _int64_feature(image_shape[1]),
        'depth' : _int64_feature(image_shape[2]),
        'label' : _int64_feature(labels_to_int(element[2].decode())),
        'image_raw' : _bytes_feature(image.numpy())
      }

      tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
    except:
      print('image could not be decoded')
      return None
  1. Once I have the data in tfrecord format, I use the ImportExampleGen component to load the data into my tfx pipeline. This is followed by StatisticsGen which will compute statistics on the features.

When running all of this in the cloud, it is using dataflow under the covers in batch mode.

  1. Your metadata store only caches pipeline metadata, but your data is cached in a gcs bucket and metadata store knows it. So when you re-run your pipeline, if you have caching set to True, your ImportExampleGen, StatisticsGen, SchemaGen, Transform will not be reran if the data hasn't changed. This has huge benefits in time and costs.

  2. ExampleValidator outputs an artifact to let you know what data anomalies are in your data. I created a custom component that intakes the examplevalidator artifact and if my data doesn't meet certain criteria, I kill the pipeline by throwing an error in this component. I wish there was a component that can just stop the pipeline, but I haven't found one, so my work around was to throw an error which stops the pipeline from progressing further.

  3. Usually when I create a tfx pipeline, it is done to automate the machine learning process. At this point we have already done class balancing, feature selection, etc. Since that falls more under the pre-processing stage.

I guess you could technically create a custom component that takes a StatisticsGen artifact, parses it and tries to do some class balancing and creates a new dataset with balances classes. But honestly, I think is better to do it at the preprocessing stage.

Juan Acevedo
  • 1,768
  • 2
  • 20
  • 39