0

I'm porting PyTorch code to Flashlight code. What is an Arrayfire or Flashlight function equivalent for squeeze and unsqueeze in Pytorch?

processed_query = self.query_layer(query.unsqueeze(1))

energies = energies.squeeze(-1)

How to convert this to Arrayfire code? (or, flashlight?)

minty99
  • 327
  • 1
  • 2
  • 9

1 Answers1

1

You can do this using the af::moddims function:

array a = randu(10, 1, 10, 10);
squeezed_a = moddims(a, 10, 10, 10);
Umar Arshad
  • 970
  • 1
  • 9
  • 22