1

I have a torch tensor of the following form:

a = torch.tensor([[[2,3],
                   [3,4],
                   [4,5]],
                  [[3,6],
                   [6,2],
                   [2,-1]],
                  [[float('nan'), 1],
                   [2,3], 
                   [3,2]])

I would like to return another tensor with nan removed, but also all the entries along the same dimension removed. So expect

 a_clean =    torch.tensor([[[3,4],
                       [4,5]],
                      [[6,2],
                       [2,-1]],
                      [[2,3], 
                       [3,2]])

Any ideas on how to achieve this?

Brian61354270
  • 8,690
  • 4
  • 21
  • 43
konstant
  • 685
  • 1
  • 7
  • 19

1 Answers1

0

This can be accomplished using Tensor.isnan, Tensor.any, and some creative indexing:

>>> a[:, ~a.isnan().any(dim=2).any(dim=0), :]
tensor([[[ 3.,  4.],
         [ 4.,  5.]],
        [[ 6.,  2.],
         [ 2., -1.]],
        [[ 2.,  3.],
         [ 3.,  2.]]])

Note that you're trying to remove entries across dimension 1, so the indexing takes place in dimension 1. Reducing the result of isnan() across every dimension except dimension 1 tells us which indices in dimension 1 contain NaN values. Putting these together gives the expression above.

Brian61354270
  • 8,690
  • 4
  • 21
  • 43