I have a torch
tensor of the following form:
a = torch.tensor([[[2,3],
[3,4],
[4,5]],
[[3,6],
[6,2],
[2,-1]],
[[float('nan'), 1],
[2,3],
[3,2]])
I would like to return another tensor with nan
removed, but also all the entries along the same dimension removed. So expect
a_clean = torch.tensor([[[3,4],
[4,5]],
[[6,2],
[2,-1]],
[[2,3],
[3,2]])
Any ideas on how to achieve this?