I imported CIFAR10 dataset via tensorflow_dataset.load()
.
This gives me <PrefetchDataset element_spec={'id': TensorSpec(shape=(), dtype=tf.string, name=None), 'image': TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), 'label': TensorSpec(shape=(), dtype=tf.int64, name=None)}>
This dataset has an id column. I want to remove this id. Becuase this id gives casuses an exception in jax. Why JAX throws an unfiltered stack trace? Guess I can convert it into panda dataframe but is their a better way?