My dataset's __getitem__
function returns a torch.stft()
M x N x D tensor with N being the audio input series with have variable length. Each item is read inside the __getitem__
function. I would like to have batches concatenated along the second dimension (N). So that by iterating the dataloader I would get data shaped as: M x (N x batch_size) x D.
Is there a possible solution to this problem?
Asked
Active
Viewed 1,718 times
1
1 Answers
1
You can do this with a custom collate function, passed to the DataLoader:
import torch
from torch.utils.data import DataLoader
M = 20
D = 12
N = 30
a = torch.rand((M,N,D))
b = torch.rand((M,N,D))
def my_collate(batch):
c = torch.stack(batch, dim=1)
return c.permute(0, 2, 1, 3)
c = my_collate([a,b]) # output shape MxNxBxD-> torch.Size([20, 30, 2, 12])
And then to pass to the DataLoader:
loader = DataLoader(dataset=datasetObject, batch_size=1, collate_fn=my_collate)

nickyfot
- 1,932
- 17
- 25