1

I have a 2D RaggedTensor consisting of indices I want from each row of a full tensor, e.g.:

[
    [0,4],
    [1,2,3],
    [5]
]

into

[
    [200, 305, 400, 20, 20, 105],
    [200, 315, 401, 20, 20, 167],
    [200, 7, 402, 20, 20, 105],
]

gives

[
    [200,20],
    [315,401,20],
    [105]
]

How can I achieve this in the most efficient way (preferably only with tf functions)? I believe that things like gather_nd are able to take RaggedTensors but I cannot figure out how it works.

AAC
  • 563
  • 1
  • 6
  • 18

1 Answers1

1

You can use tf.gather, with the batch_dims keyword argument:

>>> tf.gather(tensor,indices,batch_dims=1)
<tf.RaggedTensor [[200, 20], [315, 401, 20], [105]]>
Lescurel
  • 10,749
  • 16
  • 39