I know that usually the batch dimension is axis zero, and I imagine this has a reason: The underlying memory for each item in the batch is contiguous.
My model calls a function that becomes simpler if I have another dimension in the first axis, so that I can use x[k]
instead of x[:, k]
.
Results from arithmetic operations seems to keep the same memory layout
x = torch.ones(2,3,4).transpose(0,1)
y = torch.ones_like(x)
u = (x + 1)
v = (x + y)
print(x.stride(), u.stride(), v.stride())
When I create additional variables I am creating them with torch.zeros
and then transposing, so that the largest stride goes to the axis 1, as well.
e.g.
a,b,c = torch.zeros(
(3, x.shape[1], ADDITIONAL_DIM, x.shape[0]) + x.shape[2:]
).transpose(1,2)
Will create three tensors with the same batch size x.shape[1]
.
In terms of memory locality it would make any difference to have
a,b,c = torch.zeros(
(x.shape[1], 3, ADDITIONAL_DIM, x.shape[0]) + x.shape[2:]
).permute(1,2,0, ...)
instead.
Should I care about this at all?