2

I was wondering what is the best way to perform a batched slice (torch/numpy)? I know for constant slicing indices it is possible to perform this operation

batch_size = 2
data = torch.zeros((batch_size,1, 256, 256))
x_start = 10
x_stop = 20

y_start = 10
y_stop = 20
data[torch.arange(batch_size), :, y_start:y_stop, x_start:x_stop] = 1

But the question is how do I handle the case if the start and stop values are different? eg.

batch_size = 2
data = torch.zeros((batch_size,1, 256, 256))
x_start = [10, 5]
x_stop  = [20, 30]

y_start = [10, 5]
y_stop = [20, 30]
data[torch.arange(batch_size), :, y_start:y_stop, x_start:x_stop] = 1 # crashes

I guess I could do it in a for loop, but I was wondering if there is a more pythonic way to do it.

ndrwnaguib
  • 5,623
  • 3
  • 28
  • 51
Alexus
  • 1,282
  • 1
  • 12
  • 20
  • 2
    In your second example the differences between `x_start` / `x_stop` and `y_start` / `y_stop` values are not constant, what would you expect the shape of the sliced array to be? In any case, I'm not sure I understand what you want, I thought you wanted to have different start/end indices for each batch item, but then the size of the lists should match `batch_size`. – jdehesa Oct 09 '20 at 14:41
  • sorry, there was a typo in my demo code, batch_size should be equal to the length of x_start. Basically, I want to set for sample in batch different squares to 1 – Alexus Oct 09 '20 at 15:39
  • 2
    Ah, I see what you mean now. If the batch size is not big, I would probably consider doing it in a loop. The thing is, to do it "in one go", you will essentially have to generate arrays with the indices of all the positions you want to modify (which will take some time/space) and then do "advanced indexing" to actually set the values (more expensive than simple slicing). I'm talking about the NumPy case, I don't have experience with PyTorch but I wouldn't expect it to be too different. – jdehesa Oct 09 '20 at 16:50

0 Answers0