0

I'm having trouble understanding how the TensorFlow data API (tensorflow.data.Dataset) works. My input is a list of lists of integers that I want to batch, pad and concatenate. E.g my data looks like this

data = [[1, 2, 3, 4, 5, 6, 7],
        [1, 2, 3, 4],
        [1]]

with batch size 3 it should become:

[[[1, 2, 3], [4, 5, 6], [7, 0, 0]],
 [[1, 2, 3], [4, 0, 0]],
 [[1, 0, 0]]]

and finally:

[[1, 2, 3], [4, 5, 6], [7, 0, 0],
 [1, 2, 3], [4, 0, 0], [1, 0, 0]]
Björn Lindqvist
  • 19,221
  • 20
  • 87
  • 122
  • What is your question? – Lukasz Tracewski Jun 06 '20 at 11:47
  • @LukaszTracewski How do you transform the data from the input format to the output format using the `tensorflow.data` API. – Björn Lindqvist Jun 06 '20 at 15:29
  • Sorry, somehow I did not get it the first time. Is this how you get the data in, as such lists? I am asking since likely you will have to process the data outside TF to get the desired shape (rectangular data). – Lukasz Tracewski Jun 06 '20 at 16:16
  • @LukaszTracewski The data is in a more complicated format but to simplify the question I've described it as lists of lists. I hope the tools in `tensorflow.data` should be enough to transform it into the structure I want. – Björn Lindqvist Jun 06 '20 at 16:23
  • What you have there can be converted to a `RaggedTensor`, which is not supported by `padded_batch`. IMO you're out of luck. – Lukasz Tracewski Jun 07 '20 at 04:47

1 Answers1

0

It wasn't easy, but I finally got it to work:

def batch_each(x):
    return Dataset.from_tensor_slices(x).batch(3)
data = [[1, 2, 3, 4, 5, 6, 7],
        [1, 2, 3, 4],
        [1]]
rt = tf.ragged.constant(data)
ds = Dataset \
    .from_tensor_slices(rt) \
    .flat_map(batch_each) \
    .padded_batch(1, padded_shapes = (3,)) \
    .unbatch()
for e in ds:
    print(e)
Björn Lindqvist
  • 19,221
  • 20
  • 87
  • 122