1

My dataset's __getitem__ function returns a torch.stft() M x N x D tensor with N being the audio input series with have variable length. Each item is read inside the __getitem__ function. I would like to have batches concatenated along the second dimension (N). So that by iterating the dataloader I would get data shaped as: M x (N x batch_size) x D. Is there a possible solution to this problem?

nickyfot
  • 1,932
  • 17
  • 25
matlio
  • 13
  • 4

1 Answers1

1

You can do this with a custom collate function, passed to the DataLoader:

import torch
from torch.utils.data import DataLoader

M = 20
D = 12
N = 30
a = torch.rand((M,N,D))
b = torch.rand((M,N,D))

def my_collate(batch):
    c = torch.stack(batch, dim=1)
    return c.permute(0, 2, 1, 3)

c = my_collate([a,b]) # output shape  MxNxBxD-> torch.Size([20, 30, 2, 12])

And then to pass to the DataLoader:

loader = DataLoader(dataset=datasetObject, batch_size=1, collate_fn=my_collate)
nickyfot
  • 1,932
  • 17
  • 25