4

I am a fresh starter with PyTorch. Strangely I cannot find anything related to this, although it seems rather simple.

I want to structure my batch with specific examples, like all examples per batch having the same label, or just fill the batch with examples of just 2 classes.

How would I do that? For me, it seems the right place within the data loader and not in the dataset? As the data loader is responsible for the batches and not the dataset?

Is there a simple minimal example?

Ivan
  • 34,531
  • 8
  • 55
  • 100
  • You might need to look into custom samplers, they're basically an intermediate layer between the data loader and the dataset, which is where that kind of logic seems to fit in. – Ivan Feb 05 '21 at 15:16
  • I forgot to say i use a iterable dataset. –  Feb 05 '21 at 15:21
  • Thanks. I think this is wha I need. Do you know if by default the dataloader uses both the sampler and a btachsampler? So, in rpinciple I have to customize both and not just a sampler. Anyway, why at all do I need a batchsampler? The sampler alone does not yield batches? –  Feb 05 '21 at 15:54

1 Answers1

17

TLDR;

  1. Default DataLoader only uses a sampler, not a batch sampler.

  2. You can define a sampler, plus a batch sampler, a batch sampler will override the sampler.

  3. The sampler only yields the sequence of dataset elements, not the actual batches (this is handled by the data loader, depending on batch_size).


To answer your initial question: Working with a sampler on an iterable dataset doesn't seem to be possible cf. Github issue (still open). Also, read the following note on pytorch/dataloader.py.


Samplers (for map-style datasets):

That aside, if you are switching to a map-style dataset, here are some details on how samplers and batch samplers work. You have access to a dataset's underlying data using indices, just like you would with a list (since torch.utils.data.Dataset implements __getitem__). In other words, your dataset elements are all dataset[i], for i in [0, len(dataset) - 1].

Here is a toy dataset:

class DS(Dataset):
    def __getitem__(self, index):
        return index
        
    def __len__(self):
        return 10

In a general use case you would just give torch.utils.data.DataLoader the arguments batch_size and shuffle. By default, shuffle is set to false, which means it will use torch.utils.data.SequentialSampler. Else (if shuffle is true) torch.utils.data.RandomSampler will be used. The sampler defines how the data loader accesses the dataset (in which order it accesses it).

The above dataset (DS) has 10 elements. The indices are 0, 1, 2, 3, 4, 5, 6, 7, 8, and 9. They map to elements 0, 10, 20, 30, 40, 50, 60, 70, 80, and 90. So with a batch size of 2:

  • SequentialSampler: DataLoader(ds, batch_size=2) (implictly shuffle=False), identical to DataLoader(ds, batch_size=2, sampler=SequentialSampler(ds)). The dataloader will deliver tensor([0, 10]), tensor([20, 30]), tensor([40, 50]), tensor([60, 70]), and tensor([80, 90]).

  • RandomSampler: DataLoader(ds, batch_size=2, shuffle=True), identical to DataLoader(ds, batch_size=2, sampler=RandomSampler(ds)). The dataloader will sample randomly each time you iterate through it. For instance: tensor([50, 40]), tensor([90, 80]), tensor([0, 60]), tensor([10, 20]), and tensor([30, 70]). But the sequence will be different if you iterate through the dataloader a second time!


Batch sampler

Providing batch_sampler will override batch_size, shuffle, sampler, and drop_last altogether. It is meant to define exactly the batch elements and their content. For instance:

>>> DataLoader(ds, batch_sampler=[[1,2,3], [6,5,4], [7,8], [0,9]])` 

Will yield tensor([10, 20, 30]), tensor([60, 50, 40]), tensor([70, 80]), and tensor([ 0, 90]).


Batch sampling on the class

Let's say I just want to have 2 elements (different or not) of each class in my batch and have to exclude more examples of each class. So ensuring that not 3 examples are inside of the batch.

Let's say you have a dataset with four classes. Here is how I would do it. First, keep track of dataset indices for each class.

class DS(Dataset):
    def __init__(self, data):
        super(DS, self).__init__()
        self.data = data

        self.indices = [[] for _ in range(4)]
        for i, x in enumerate(data):
            if x > 0 and x % 2: self.indices[0].append(i)
            if x > 0 and not x % 2: self.indices[1].append(i)
            if x < 0 and x % 2: self.indices[2].append(i)
            if x < 0 and not x % 2: self.indices[3].append(i)

    def classes(self):
        return self.indices

    def __getitem__(self, index):
        return self.data[index]

For example:

>>> ds = DS([1, 6, 7, -5, 10, -6, 8, 6, 1, -3, 9, -21, -13, 11, -2, -4, -21, 4])

Will give:

>>> ds.classes()
[[0, 2, 8, 10, 13], [1, 4, 6, 7, 17], [3, 9, 11, 12, 16], [5, 14, 15]]

Then for the batch sampler, the easiest way is to create a list of class indices that are available, and have as many class index as there are dataset element.

In the dataset defined above, we have 5 items from class 0, 5 from class 1, 5 from class 2, and 3 from class 3. Therefore we want to construct [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3]. We will shuffle it. Then, from this list and the dataset classes content (ds.classes()) we will be able to construct the batches.

class Sampler():
    def __init__(self, classes):
        self.classes = classes

    def __iter__(self):
        classes = copy.deepcopy(self.classes)
        indices = flatten([[i for _ in range(len(klass))] for i, klass in enumerate(classes)])
        random.shuffle(indices)
        grouped = zip(*[iter(indices)]*2)

        res = []
        for a, b in grouped:
            res.append((classes[a].pop(), classes[b].pop()))
        return iter(res)

Note - deep copying the list is required since we're popping elements from it.

A possible output of this sampler would be:

[(15, 14), (16, 17), (7, 12), (11, 6), (13, 10), (5, 4), (9, 8), (2, 0), (3, 1)]

At this point we can simply use torch.data.utils.DataLoader:

>>> dl = DataLoader(ds, batch_sampler=sampler(ds.classes()))

Which could yield something like:

[tensor([ 4, -4]), tensor([-21,  11]), tensor([-13,   6]), tensor([9, 1]), tensor([  8, -21]), tensor([-3, 10]), tensor([ 6, -2]), tensor([-5,  7]), tensor([-6,  1])]

An easier approach

Here is another - easier - approach that will not guarantee to return all elements from the dataset, on average it will...

For each batch, first sample class_per_batch classes, then sample batch_size elements from these selected classes (by first sampling a class from that class subset, then sampling from a data point from that class).

class Sampler():
    def __init__(self, classes, class_per_batch, batch_size):
        self.classes = classes
        self.n_batches = sum([len(x) for x in classes]) // batch_size
        self.class_per_batch = class_per_batch
        self.batch_size = batch_size

    def __iter__(self):
        classes = random.sample(range(len(self.classes)), self.class_per_batch)
        
        batches = []
        for _ in range(self.n_batches):
            batch = []
            for i in range(self.batch_size):
                klass = random.choice(classes)
                batch.append(random.choice(self.classes[klass]))
            batches.append(batch)
        return iter(batches)

You can try it this way:

>>> s = Sampler(ds.classes(), class_per_batch=2, batch_size=4)
>>> list(s)
[[16, 0, 0, 9], [10, 8, 11, 2], [16, 9, 16, 8], [2, 9, 2, 3]]

>>> dl = DataLoader(ds, batch_sampler=s)
>>> list(iter(dl))
[tensor([ -5,  -6, -21, -13]), tensor([ -4,  -4, -13, -13]), tensor([ -3, -21,  -2,  -5]), tensor([-3, -5, -4, -6])]
Ivan
  • 34,531
  • 8
  • 55
  • 100
  • Thanks! I get a feeling! I do not yet know if I should use a simple Sampler or a BatchSampler for my case: Let's say I just want to have 2 elements (different or not) of each class in my batch and have to exclude more examples of each classs. So ensuring that not 3 examples are inside of the batch... –  Feb 05 '21 at 18:05
  • I have updated my answer, hope this helps. – Ivan Feb 05 '21 at 19:49
  • 1
    Sorry, but where and what is the flatten() function. The flatten from torch does not work... –  Feb 08 '21 at 06:42
  • it woul dbe nice if you could explain your batch sampler example at the end. What do you wnat to achive with this xexamples? Like why do I get pairs out of the batches? –  Feb 08 '21 at 06:45
  • The objective was to sample data points from <=2 classes at the time (*at a batch*). Here the example I gave is with `batch_size=2` and a maximum number of different classes per batch `= 2`. – Ivan Feb 08 '21 at 07:23
  • Is this now a simple sampler or batch sampler? That is confusing...Where is the batch size going to actually in the sampler then? I mean how does the sampler now knows when to stop filling the batches? That is needed if I want to take care of specific number of elements in my batch, right? –  Feb 08 '21 at 07:29
  • It's a batch sampler: `DataLoader(ds, batch_sampler=sampler(ds.classes()))`. It generates the whole datapoint index sequence and provides it as an iterator to the dataloader. The general idea is this: sample from the classes first (draw as many index classes as you need, in my example above I took `2` (*cf.* `grouped = zip(*[iter(indices)]*2)`). Then you can draw as many elements (*i.e.* `batch_size`) from that pool of classes. Here I took one sample from each of the two classes. – Ivan Feb 08 '21 at 07:37
  • How can I check what the sampler is doing? If I do: for sample_batched in enumerate(dl): print(i_batch, sample_batched) –  Feb 08 '21 at 07:40
  • I cannot see anything –  Feb 08 '21 at 07:40
  • I thought this 2 is my batch_size? Why do you want tuples? How would I incorporate the batch_size in this for loop? –  Feb 08 '21 at 07:42
  • This line `classes[a].pop(), classes[b].pop()` takes one sample from each class. If you want to extend this implementation for an arbitrary `batch_size`, you should sample multiple elements from those groups of index classes. – Ivan Feb 08 '21 at 07:52
  • So this 2 in the zip() is actually the batchsize? Can I make this as an argument like in the dataloader? –  Feb 08 '21 at 08:05
  • Sorry, but if i do: data = next(data_iterator) then I get a single batch and this has 2 elemtns, so this 2 is my batch size? Or how ould I else specify this? –  Feb 08 '21 at 08:11
  • Ok, that is really magic I do not understand. Why do i get 2 elemtns per tuple in res, if I append classes[a] etc. which are list up lenght greater than 2...That is now really strange –  Feb 08 '21 at 08:23
  • Thanks. I think this zip() and pop is only for your showcase, right? My actuall real need is then: Iterate trhough each class and pick (in the same or random order) 1 example and put it into the batch. The next bacthes should then pick other and not the same elements such that all elemtns ar unqique in the batches. If I have more batches then I start again and pick also duplicates... –  Feb 08 '21 at 08:54
  • Yes, my initial implementation (the one with `zip`...) made sure elements did not appear twice in a single epoch. While the latter implementation does not. This is why I recommend you make an implementation from the former version for arbitrary `batch_size` and `class_per_batch` parameters. – Ivan Feb 08 '21 at 09:17
  • Thanks, this look like waht I want and can be easily customized, like not using sample,and exclude multiple example for each class. So, a sampler needs to yield already the bacthes as an iterator? –  Feb 08 '21 at 09:34
  • Could you please rephrase that? I'm not sure I understand what you mean. – Ivan Feb 08 '21 at 09:44
  • In practice I have unqual number of elements per class and want to avoid the same class apperaing twice per batch as much as possible. So for 30 classes the batch does not be larger than 30. I think for me it is no problem if a sample appears multiple time but across different batches (if the class has fewer examples). Bceause for a class with fewer examples, I have to reuse some elments to gets the size of batch_size... –  Feb 08 '21 at 09:47
  • So in a sampler I need to define the bacthes for all data points? So I have to know the number of bacthes in total beforehand? I cannot somehow generates single batches on the fly? –  Feb 08 '21 at 11:57
  • The only you can use a batch sampler (or sampler for that matter) is if you use a map-style dataset (which has a length) and **not** an iterable dataset (*i.e.* its length is unknown) – Ivan Feb 08 '21 at 11:59
  • What do you say about my 2nd last comment about different number of samples per class? –  Feb 08 '21 at 12:31
  • If you don't want to have more than one sample from the same class in each batch, then just simply draw `batch_size` from the list of classes, then from each sampled class, sample a single sample from that class. As a result, you will get `batch_size` elements per batch coming from `batch_size` distinct classes. – Ivan Feb 08 '21 at 12:55
  • Yes, if I have always batch_size number of classes. For more classes than batch_size I think I can handle this in training.You just ignore some classes in this certain bacth for training your model. But I canno use batch size larger than the number of classes... –  Feb 08 '21 at 13:18
  • So, now I get an error because there is no __len__ What should this len method return?. Number of datapoint or number of batches? –  Feb 09 '21 at 12:06
  • Assuming you are referring to the `__len__` function inside your dataset, then it refers to the total number of data points. – Ivan Feb 09 '21 at 13:04
  • No, the dsampler has also a len method. But I think it is just the number of bacthes –  Feb 09 '21 at 13:17