I'm currently working on implementing a pytorch version of this tensorflow project. They published a paper which can be found here.
Dataset formats and where to find them The datasets available are .tfrecord files, which can be downloaded with a script they provide that gets them from a storage.google-apis site. I was able to download the files (if you don't wanna run their script you can find an example of the dataset files in this drive link. there are .npz versions of SOME of the dataset here and I'd like to have all of them in that format for workflow purposes). the .npz versions of the dataset can be loaded as dictionaries with keys "simulation_trajectory_0", "simulation_trajectory_1"..."simulation_trajectory_999". each one of these keys contains an array of shape [320,x,2] where x varies from around 100 to 2000. this makes sense, since they are simulations of trajectories of x particles, in 2 dimensions, for 320 timesteps. My objective My objective is to read these tfrecords, and ideally convert them into .npz files, which I then load in numpy and convert to torch tensors. I'm content even with reading them with tensorflow and converting them directly into torch tensors, but since i'm working on a group project in pytorch, I'd like to handle the data preprocessing on my own and not force my colleagues to use tensorflow too.
The problems, and what I tried
I am very new to tensorflow, so I was not able to read the files properly. From what I was able to gather, a tfrecord file is a serialization of large datasets, made to be read through streaming methods, and not usually handled all at once. I read many questions on stackoverflow and read the TF documentation and it seems like I need to learn the features of my .tf file, create a parsing function and give the file + the parsing function to tf.data.TFRecordDataset
.
I was able to extract the features from my .tfrecord file using this code:
import tensorflow as tf
def list_record_features(tfrecords_path):
# Dict of extracted feature information
features = {}
# Iterate records
for rec in tf.data.TFRecordDataset([str(tfrecords_path)]):
# Get record bytes
example_bytes = rec.numpy()
# Parse example protobuf message
example = tf.train.Example()
example.ParseFromString(example_bytes)
# Iterate example features
for key, value in example.features.feature.items():
# Kind of data in the feature
kind = value.WhichOneof('kind')
# Size of data in the feature
size = len(getattr(value, kind).value)
# Check if feature was seen before
if key in features:
# Check if values match, use None otherwise
kind2, size2 = features[key]
if kind != kind2:
kind = None
if size != size2:
size = None
# Save feature data
features[key] = (kind, size)
return features
found in this question
features = list_record_features("valid.tfrecord")
features
gives me this output: {'key': ('int64_list', 1), 'particle_type': ('bytes_list', 1)}
And now if I write a decode function and give it to TFRecordDataset
def decode_fn(record_bytes):
return tf.io.parse_single_example(
# Data
record_bytes,
# Schema
{"x": tf.io.FixedLenFeature([], dtype=tf.train.Int64List),
"y": tf.io.FixedLenFeature([], dtype=tf.train.BytesList)} #these two types give error
)
for batch in tf.data.TFRecordDataset("valid.tfrecord").map(decode_fn):
print("x = {x:.4f}, y = {y:.4f}".format(**batch))
I don't understand what to put after dtype=
, and I don't know how I would then go about getting the data into a proper format that I can save to .npz.
the only thing remotely similar to what need is this question, but that just gives me this sort of table when i run it key particle_type position step_context 0 0 ... [b'\xe6\x1fB?\xa1\x12\xa8>\x19\xd6>?\xdcr\xa9>... [b'\x00\x00\xc0\x7f', , , , , ... 1 1 ... [b'\x1d\x8a>?\x0b\xb8s>F\xd3>?\xa0T\x80>\x8f\x... [b'\x00\x00\xc0\x7f', , , , , ... 2 2 ... [b'\x02s\xb7>\x11y\x8e>f\xdf\xb0>J\xd9\x8f>\xa... [b'\x00\x00\xc0\x7f', , , , , ... 3 3 ... [b']\x02\xfa>\xc9\x97\x93>\xc2n\xf3>\x02\xf8\x... [b'\x00\x00\xc0\x7f', , , , , ... 4 4 ... [b'\xd3q:?\x80\xb6\x98>\x05(7?\xb8\x16\x9a>\xa... [b'\x00\x00\xc0\x7f', , , , , ... ...
I'm sorry for the long question but this is my first ask and I wanted to make sure there was all the info needed. Any help or ideas are much appreciated.