We have tfrecord files where each tfrecord file contains a single example, but where the features contain a list of values. We are using tf.data.Dataset
in the following manner:
n_rows_per_record_file = 100
def parse_tfrecord_to_example(record_bytes):
col_map = {
"my_col": tf.io.FixedLenFeature(
shape=n_rows_per_record_file, dtype=tf.int64
)}
ds = (
tf.data.TFRecordDataset(file_paths)
.map(parse_tfrecord_to_example)
)
instead of using a fixed constant for n_rows_per_record_file
we would like to lookup the number of rows given the filepath.
Any ideas on how to achieve this ?
We tried using something like this
def get_shape(filepath):
return filepath, shapes[filepath]
ds = (
tf.data.list_files(file_paths)
.map(get_shape)
.map(
lambda f, shape: tf.data.TFRecordDataset(f).map(
lambda shape: parse_tfrecord_to_example(shape)
)
)
but this fails because tf.data doesn't eagerly evaluate the filepath until it needs to (i.e. it remains as a tf.Tensor)