Hi is there any method for apply trasnformation for certain batch?
It means, I want apply trasnformation for just last batch in every epochs.
What I tried is here
import torch
class test(torch.utils.data.Dataset):
def __init__(self):
self.source = [i for i in range(10)]
def __len__(self):
return len(self.source)
def __getitem__(self, idx):
print(idx)
return self.source[idx]
ds = test()
dl = torch.utils.data.DataLoader(dataset = ds, batch_size = 3,
shuffle = False, num_workers = 5)
for i in dl:
print(i)
because I thought that if I could get idx number, it would be possible to apply for certain batchs.
However If using num_workers outputs are
0
1
2
3
964
57
8
tensor([0, 1, 2])
tensor([3, 4, 5])
tensor([6, 7, 8])
tensor([9])
which are not I thought
without num_worker
0
1
2
tensor([0, 1, 2])
3
4
5
tensor([3, 4, 5])
6
7
8
tensor([6, 7, 8])
9
tensor([9])
So the question is
- Why idx works so with num_workers?
- How can I apply trasnform for certain batchs (or certain idx)?