I think the bug is in Transform
itself, as I am encountering an OOM issue passing it a preprocessing_fn
but do not encounter the same issue when passing the same preprocessing_fn
to the function below. Secondly, I face the OOM issue even when just having a training split. In the case where you are using Transform
, isn't the beam code essentially abstracted from you, so how can you override what it is doing? I see the only solution here to not use Transform
and use something like the below, the downside being that you lose all the MLMetadata
and pipeline
benefits, which are significant :(:
def transform_data(preprocessing_fn, source_path, source_tfrecord_prefix , experiment_path):
schema = get_schema_from_transform(source_path)
source_tfrecord_path = f"{source_path}/tfrecords/{source_tfrecord_prefix}*"
with beam.Pipeline() as pipeline:
with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
tfrecord_tfxio = tfxio.TFExampleRecord(
file_pattern=source_tfrecord_path , schema=schema)
raw_data = (
pipeline | 'ReadData' >> beam.io.ReadFromTFRecord(
file_pattern=source_tfrecord_path, coder=beam.coders.BytesCoder())
| 'DecodeTrainData' >> tfrecord_tfxio.BeamSource())
raw_dataset = (raw_data, tfrecord_tfxio.TensorAdapterConfig())
transformed_dataset, transform_fn = (
raw_dataset | tft_beam.AnalyzeAndTransformDataset(
preprocessing_fn, output_record_batches=True))
transformed_data, _ = transformed_dataset
transform_fn_output = os.path.join(experiment_path, 'transform_output')
tfrecord_file_path_prefix = os.path.join(experiment_path, 'tfrecords', experiment_path)
data_written = (
transformed_data | 'EncodeTrainData' >> beam.FlatMapTuple(
lambda batch, _: RecordBatchToExamples(batch)) | beam.io.WriteToTFRecord(
tfrecord_file_path_prefix, ))
_ = (transform_fn | 'WriteTransformFn' >> tft_beam.WriteTransformFn(transform_fn_output))
My code that generates the OOM is as follows:
from tfx import v1 as tfx
import tensorflow_data_validation as tfdv
TRAIN_DATA = 'train_smallest.csv'
LABEL_NAME = 'fare_amount'
BATCH_SIZE=256
ORIGINAL_TFRECORD_PATH='./identity_transform/tfrecords/'
from tfx.components import ImportExampleGen
from tfx.components import StatisticsGen
from tfx.components import SchemaGen
from tfx.components import Transform
from tfx.v1 import proto
from tfx.proto import example_gen_pb2
output = proto.Output(
split_config=example_gen_pb2.SplitConfig(splits=[
proto.SplitConfig.Split(name='train', hash_buckets=3),
proto.SplitConfig.Split(name='eval', hash_buckets=1)
]))
example_gen = ImportExampleGen(input_base=ORIGINAL_TFRECORD_PATH, output_config=output)
stats_options = tfdv.StatsOptions(label_feature=LABEL_NAME)
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'], stats_options=stats_options)
schema_gen = SchemaGen(
statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)
transform = Transform(examples = example_gen.outputs['examples'], schema = schema_gen.outputs['schema'], module_file='./preprocessing_fn.py')
components = [
example_gen,
statistics_gen,
schema_gen,
transform
]
pipeline_name='pipeline'
pipeline_root='./pipeline'
metadata_path='./metadata/pipeline/metadata.db'
pipeline = tfx.dsl.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
metadata_connection_config=tfx.orchestration.metadata
.sqlite_metadata_connection_config(metadata_path),
components=components,
enable_cache=False)
tfx.orchestration.LocalDagRunner().run(pipeline)
My preprocessing_fn
is:
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_transform as tft
NBUCKETS = 10
def preprocessing_fn(inputs):
"""
Preprocess input columns into transformed features. This is what goes
into tensorflow transform/apache beam.
"""
# Since we are modifying some features and leaving others unchanged, we
# start by setting `outputs` to a copy of `inputs.
transformed = inputs.copy()
del(transformed["key"])
transformed['passenger_count'] = tft.scale_to_0_1(
inputs['passenger_count'])
# cannot use the below in tft as managing those learned values need
# to be
# managed carefully
# normalizer = tf.keras.layers.Normalization(axis=None,
# name="passenger_count_normalizer")
# normalizer.adapt(inputs['passenger_count'])
# transformed['other_passenger_count'] = normalizer(
# inputs['passenger_count'])
for col in ['dropoff_longitude', 'dropoff_latitude']:
transformed[col] = tft.sparse_tensor_to_dense_with_shape(inputs[col], default_value=tft.mean(inputs[col]), shape=[None, 1]) #You can make this more robust by using the shape from the feature spec
for lon_col in ['pickup_longitude', 'dropoff_longitude']:
# transformed[lon_col] = scale_longitude(inputs[lon_col])
transformed[lon_col] = (transformed[lon_col] + 78) / 8.
for lat_col in ['pickup_latitude', 'dropoff_latitude']:
transformed[lat_col] = (transformed[lat_col] - 37) / 8.
position_difference = tf.square(
transformed["dropoff_longitude"] -
transformed["pickup_longitude"])
position_difference += tf.square(
transformed["dropoff_latitude"] -
transformed["pickup_latitude"])
transformed['euclidean'] = tf.sqrt(position_difference)
lat_lon_buckets = [
bin_edge / NBUCKETS
for bin_edge in range(0, NBUCKETS)]
transformed['bucketed_pickup_longitude'] = tft.apply_buckets(
transformed["pickup_longitude"],
bucket_boundaries=tf.constant([lat_lon_buckets]))
transformed["bucketed_pickup_latitude"] = tft.apply_buckets(
transformed['pickup_latitude'],
bucket_boundaries=tf.constant([lat_lon_buckets]))
transformed['bucketed_dropoff_longitude'] = tft.apply_buckets(
transformed["dropoff_longitude"],
bucket_boundaries=tf.constant([lat_lon_buckets]))
transformed['bucketed_dropoff_latitude'] = tft.apply_buckets(
transformed["dropoff_latitude"],
bucket_boundaries=tf.constant([lat_lon_buckets]))
# transformed["pickup_cross"]=tf.sparse.cross(
# inputs=[transformed['pickup_latitude_apply_buckets'],
# transformed['pickup_longitude_apply_buckets']])
hash_pickup_crossing_layer = tf.keras.layers.experimental.preprocessing.HashedCrossing(
output_mode='one_hot', num_bins=NBUCKETS**2, name='hash_pickup_crossing_layer')
transformed['pickup_location'] = hash_pickup_crossing_layer(
(transformed['bucketed_pickup_latitude'],
transformed['bucketed_pickup_longitude']))
hash_dropoff_crossing_layer = tf.keras.layers.experimental.preprocessing.HashedCrossing(
output_mode='one_hot', num_bins=NBUCKETS**2,
name='hash_dropoff_crossing_layer')
transformed['dropoff_location'] = hash_dropoff_crossing_layer(
(transformed['bucketed_dropoff_latitude'],
transformed['bucketed_dropoff_longitude']))
hash_pickup_crossing_layer_intermediary = tf.keras.layers.experimental.preprocessing.HashedCrossing(
output_mode='int', num_bins=NBUCKETS**2, )
hashed_pickup_intermediary = hash_pickup_crossing_layer_intermediary(
(transformed['bucketed_pickup_longitude'],
transformed['bucketed_pickup_latitude']))
hash_dropoff_crossing_layer_intermediary = tf.keras.layers.experimental.preprocessing.HashedCrossing(
output_mode='int', num_bins=NBUCKETS**2, )
hashed_dropoff_intermediary = hash_dropoff_crossing_layer_intermediary(
(transformed['bucketed_dropoff_longitude'],
transformed['bucketed_dropoff_latitude']))
hash_trip_crossing_layer = tf.keras.layers.experimental.preprocessing.HashedCrossing(
output_mode='one_hot', num_bins=NBUCKETS ** 3,
name="hash_trip_crossing_layer")
transformed['hashed_trip'] = hash_trip_crossing_layer(
(hashed_pickup_intermediary,
hashed_dropoff_intermediary))
seconds_since_1970 = tf.cast(
tfa.text.parse_time(
inputs["pickup_datetime"],
"%Y-%m-%d %H:%M:%S %Z",
output_unit="SECOND"),
tf.float32)
# seconds_since_1970 = fn_seconds_since_1970(inputs['pickup_datetime'])
seconds_since_1970 = tf.cast(seconds_since_1970, tf.float32)
hours_since_1970 = seconds_since_1970 / 3600.
hours_since_1970 = tf.floor(hours_since_1970)
hour_of_day_intermediary = hours_since_1970 % 24
transformed['hour_of_day'] = hour_of_day_intermediary
hour_of_day_intermediary = tf.cast(hour_of_day_intermediary, tf.int32)
days_since_1970 = seconds_since_1970 / (3600 * 24)
days_since_1970 = tf.floor(days_since_1970)
# January 1st 1970 was a Thursday
day_of_week_intermediary = (days_since_1970 + 4) % 7
transformed['day_of_week'] = day_of_week_intermediary
day_of_week_intermediary = tf.cast(day_of_week_intermediary, tf.int32)
hashed_crossing_layer = tf.keras.layers.experimental.preprocessing.HashedCrossing(
num_bins=24 * 7, output_mode="one_hot")
hashed_crossing_layer_intermediary = tf.keras.layers.experimental.preprocessing.HashedCrossing(
num_bins=24 * 7, output_mode="int", name='hashed_hour_of_day_of_week_layer')
transformed['hour_of_day_of_week'] = hashed_crossing_layer(
(hour_of_day_intermediary, day_of_week_intermediary))
hour_of_day_of_week_intermediary = hashed_crossing_layer_intermediary(
(hour_of_day_intermediary, day_of_week_intermediary))
hash_trip_crossing_layer_intermediary = tf.keras.layers.experimental.preprocessing.HashedCrossing(
output_mode='int', num_bins=NBUCKETS ** 3)
hashed_trip_intermediary = hash_trip_crossing_layer_intermediary(
(hashed_pickup_intermediary, hashed_dropoff_intermediary))
hash_trip_and_time_layer = tf.keras.layers.experimental.preprocessing.HashedCrossing(
output_mode='one_hot', num_bins=(
NBUCKETS ** 3) * 4, name='hash_trip_and_time_layer')
transformed['hashed_trip_and_time'] = hash_trip_and_time_layer(
(hashed_trip_intermediary, hour_of_day_of_week_intermediary))
return transformed
And the data that I'm using is as follows, this is before putting it into tfrecords by using a preprocessing_fn
that is essentially f(x) = x
:
key,fare_amount,pickup_datetime,pickup_longitude,pickup_latitude,dropoff_longitude,dropoff_latitude,passenger_count
2010-10-19 00:01:17.0000003,11.3,2010-10-19 00:01:17 UTC,-73.948724,40.777489,-73.949569,40.814049,1
2009-08-19 08:58:00.00000031,8.5,2009-08-19 08:58:00 UTC,-74.007743,40.724717,-74.006797,40.751253,1
2011-01-24 06:48:12.0000001,10.9,2011-01-24 06:48:12 UTC,-73.986678,40.742597,-73.955101,40.766174,1
2014-09-13 09:08:00.000000126,15.5,2014-09-13 09:08:00 UTC,-74.00325,40.7083,-73.975935,40.749007,1
2013-04-12 19:12:22.0000002,10,2013-04-12 19:12:22 UTC,-74.005318,40.728261,-73.981724,40.7293,1
2015-03-11 11:09:40.0000003,13,2015-03-11 11:09:40 UTC,-73.998809814453125,40.734573364257812,-73.989830017089844,40.756542205810547,1
2014-02-26 01:06:40.0000001,8,2014-02-26 01:06:40 UTC,-73.985821,40.763299,-74.003947,40.751722,1
2011-12-03 02:34:21.0000002,15.7,2011-12-03 02:34:21 UTC,-73.940638,40.840057,-73.98792,40.768815,2
2012-12-10 15:50:03.0000002,34,2012-12-10 15:50:03 UTC,-73.872871,40.774284,-73.995264,40.739349,1
2013-09-22 21:15:18.0000001,7.5,2013-09-22 21:15:18 UTC,-73.996565,40.718924,-74.007011,40.707672,2
2011-06-13 20:19:00.00000010,8.1,2011-06-13 20:19:00 UTC,-73.981587,40.747238,-73.956932,40.771512,5
2013-04-13 02:28:55.0000006,4,2013-04-13 02:28:55 UTC,0,0,0,0,1
2013-09-08 15:49:49.0000001,5.5,2013-09-08 15:49:49 UTC,-73.96077,40.775805,-73.970084,40.76252,1
2011-10-05 21:59:00.00000098,6.5,2011-10-05 21:59:00 UTC,-74.005052,40.70663,-74.012255,40.718838,1
2014-10-03 04:04:00.00000020,5.5,2014-10-03 04:04:00 UTC,-74.000032,40.732362,-73.99655,40.736532,5
2010-06-18 13:41:28.0000001,12.1,2010-06-18 13:41:28 UTC,-73.998732,40.717089,-73.975146,40.758038,3
2014-08-07 12:06:50.0000001,13,2014-08-07 12:06:50 UTC,-74.008268,40.722489,-73.982861,40.744874,1
2009-08-30 11:56:58.0000005,5.7,2009-08-30 11:56:58 UTC,-74.007583,40.748017,-73.994615,40.751834,1
2012-01-28 01:58:00.000000132,9.7,2012-01-28 01:58:00 UTC,-74.422832,40.84662,-74.49563,40.968122,1
2012-06-01 13:33:00.00000070,6.9,2012-06-01 13:33:00 UTC,-73.987377,40.743832,-73.998615,40.739855,1
2014-07-31 20:00:00.000000140,14.5,2014-07-31 20:00:00 UTC,-73.983415,40.75677,-73.95324,40.77148,1
2009-02-06 09:25:00.000000128,6.9,2009-02-06 09:25:00 UTC,-73.994222,40.727137,-73.987398,40.744425,1
2012-01-07 20:28:30.0000003,6.9,2012-01-07 20:28:30 UTC,-73.9889,40.7591,-74.0053,40.7401,0
2012-06-13 15:57:00.000000105,8.5,2012-06-13 15:57:00 UTC,-73.982535,40.761602,-73.980302,40.748475,1
2014-05-02 10:33:00.00000042,5,2014-05-02 10:33:00 UTC,-73.985352,40.74213,-73.991562,40.742512,1
2013-03-30 00:33:00.000000133,15,2013-03-30 00:33:00 UTC,-73.98616,40.757022,-74.009022,40.707857,1
2011-05-20 23:01:00.00000071,9.3,2011-05-20 23:01:00 UTC,-73.951177,40.77465,-73.972487,40.743393,1
2011-01-27 21:51:00.000000171,8.9,2011-01-27 21:51:00 UTC,-73.989867,40.756748,-73.972143,40.786588,3
2009-03-20 12:46:25.0000001,6.9,2009-03-20 12:46:25 UTC,-73.951526,40.770003,-73.970998,40.754989,1
2013-05-01 09:32:00.000000143,7,2013-05-01 09:32:00 UTC,-73.990302,40.756552,-73.982462,40.760242,5
2010-05-25 10:10:00.00000027,4.9,2010-05-25 10:10:00 UTC,-73.980722,40.779832,-73.971522,40.787518,1
2012-12-18 07:24:00.000000148,6,2012-12-18 07:24:00 UTC,-73.965952,40.776297,-73.950533,40.774467,4
2012-04-18 08:36:00.000000182,7.7,2012-04-18 08:36:00 UTC,-73.98358,40.766182,-73.97922,40.752992,1
2009-05-23 17:11:00.00000092,42.9,2009-05-23 17:11:00 UTC,-73.781909,40.64477,-73.991234,40.687251,2
2013-04-16 08:44:39.0000006,12,2013-04-16 08:44:39 UTC,-73.961365,40.760555,-73.9805,40.753709,1
2014-04-15 18:39:47.0000002,3.5,2014-04-15 18:39:47 UTC,-74.008172,40.737866,-74.007745,40.732653,1
2013-12-22 17:42:24.0000001,8,2013-12-22 17:42:24 UTC,-73.977505,40.742731,-73.980127,40.722385,2
2011-09-07 20:12:37.0000004,14.5,2011-09-07 20:12:37 UTC,-74.002225,40.715,-73.951187,40.728798,1
2014-12-10 21:57:46.0000005,6.5,2014-12-10 21:57:46 UTC,-73.97175,40.760287,0,0,1
2012-11-12 22:11:00.000000101,6.5,2012-11-12 22:11:00 UTC,-73.983615,40.760325,-73.998885,40.760012,1
2015-02-21 19:26:56.0000005,15,2015-02-21 19:26:56 UTC,-73.9959716796875,40.686809539794922,-74.009872436523438,40.718009948730469,1
2010-10-06 03:28:42.0000001,8.9,2010-10-06 03:28:42 UTC,-73.988341,40.72886,-73.984581,40.751519,1
2010-07-09 22:01:22.0000007,4.1,2010-07-09 22:01:22 UTC,-73.959272,40.771833,-73.962304,40.773124,1
2009-05-28 10:45:00.00000032,8.9,2009-05-28 10:45:00 UTC,-73.988872,40.753367,-73.972733,40.753327,1
2013-09-24 18:25:00.000000123,25,2013-09-24 18:25:00 UTC,-74.005197,40.751602,-73.980317,40.676607,1
2009-02-05 08:23:01.0000004,7.3,2009-02-05 08:23:01 UTC,-73.975468,40.759635,-73.991854,40.749352,1
2011-07-03 22:25:04.0000003,28.5,2011-07-03 22:25:04 UTC,-73.776755,40.64523,-73.951802,40.657781,1
2010-06-07 15:20:00.000000164,17.7,2010-06-07 15:20:00 UTC,-73.969625,40.758133,-74.012548,40.713983,1
2012-09-14 01:02:00.000000106,7,2012-09-14 01:02:00 UTC,-73.982777,40.744722,-73.984505,40.732127,1
2013-03-08 23:16:00.00000051,25.5,2013-03-08 23:16:00 UTC,-73.990822,40.734702,-73.945737,40.651117,1
2009-10-30 16:06:00.00000074,10.1,2009-10-30 16:06:00 UTC,-73.981008,40.768257,-73.968412,40.80144,2
2012-12-13 19:08:47.0000004,4.5,2012-12-13 19:08:47 UTC,-73.949347,40.78134,-73.956487,40.777995,1
2009-09-03 18:08:00.000000222,5.3,2009-09-03 18:08:00 UTC,-73.985702,40.753625,-73.989385,40.741143,5
2014-05-09 14:18:00.00000085,22.5,2014-05-09 14:18:00 UTC,-73.994697,40.765992,-74.012857,40.706807,1
2010-06-13 18:07:00.00000026,5.7,2010-06-13 18:07:00 UTC,-73.98811,40.774712,-73.98207,40.763285,1
2009-05-30 22:44:00.0000003,15.7,2009-05-30 22:44:00 UTC,-73.968862,40.791558,-73.971705,40.793732,2
2009-05-03 09:58:47.0000002,3.7,2009-05-03 09:58:47 UTC,-73.966445,40.804635,-73.969422,40.798094,1
2011-02-21 22:48:57.0000004,5.7,2011-02-21 22:48:57 UTC,-73.977624,40.752278,-73.994179,40.751649,1
2009-01-09 22:16:00.00000018,4.5,2009-01-09 22:16:00 UTC,-73.961832,40.763897,-73.96975,40.759523,1
2011-06-17 10:31:00.000000154,7.3,2011-06-17 10:31:00 UTC,-73.963242,40.771425,-73.968752,40.7537,1
2014-11-24 12:38:00.000000143,5,2014-11-24 12:38:00 UTC,-73.969662,40.753422,-73.975425,40.74545,1
2015-05-03 09:46:20.0000007,14.5,2015-05-03 09:46:20 UTC,-73.9213638305664,40.745254516601562,-73.85888671875,40.751045227050781,2
2010-04-05 18:49:28.0000001,8.9,2010-04-05 18:49:28 UTC,-73.960298,40.766187,-73.981875,40.727677,1
2013-12-17 20:12:00.00000030,6.5,2013-12-17 20:12:00 UTC,-73.956007,40.779077,-73.965855,40.765742,1
2010-03-12 21:29:00.000000235,9.3,2010-03-12 21:29:00 UTC,0,0,0,0,3
2011-11-30 17:09:00.000000158,14.1,2011-11-30 17:09:00 UTC,-73.985565,40.731717,-73.981347,40.77369,1
2011-12-18 22:21:00.000000162,12.1,2011-12-18 22:21:00 UTC,-73.995165,40.753117,-73.98827,40.725642,1
2010-09-02 14:39:00.000000201,4.1,2010-09-02 14:39:00 UTC,-73.978898,40.785037,-73.970818,40.789267,1
2014-08-22 16:00:00.00000098,6,2014-08-22 16:00:00 UTC,-73.97484,40.756925,-73.972712,40.761775,3
@tensorflow-support let's solve this issue and make Transform
and LocalDagRunner
usable so people can solve cool problems with it!