0

I would like to make a sampler for my dataloader. I have 12 unique classes in my dataset and it is really important that there is no more than one element of each class in each batch. It also doesn't matter how big the batch size is as long as this requirement is fulfilled. I've tried the weighted random sampler, but it still gives double elements in 40% of cases (with batch size = 4). This is what I have for the weighted sampler but I don't know where to go from here:

def get_targets(dataset):
    """
    Get all labels in dataset
    """
    targets = []
    for i in range(len(dataset)):
        sample = dataset[i]
        targets.append(sample['patient'])
    return targets

def class_weights(target):
    """
    Get class weights
    """
    unique_patients = np.unique(np.array(target))
    n_patients = len(unique_patients)
    print("Number of unique patients...", n_patients)
    patient_weights = {}
    for patient in unique_patients:
        sample_count = 0
        for n in range(0, len(target)):
            if target[n] == patient:
                sample_count +=1 
        patient_weights[patient] = 1/sample_count
            
    return patient_weights

def make_sampler(dataset):
    """
    Make weighted sampler
    """
    targets = get_targets(dataset)
    weight = class_weights(targets)
    samples_weight = np.array([weight[t] for t in targets])
    samples_weight = torch.from_numpy(samples_weight)
    sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))
    return sampler

0 Answers0