0

I already have a network which can do the binary classification. Now I want to extend it to a multi-label classification network. I have already modified the model (use 'sigmoid' as the output layer activation function and use 'binary_crossentropy' loss). However, I want to know how to modify the DataGenerator function to make it can deliver multi-label outputs?

For example, the original DataGenerator function is:

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples'
        # Initialization
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            X[i, ] = np.load(train_dir + ID)

            # Store class
            y[i] = self.labels[ID]

        return X, to_categorical(y, num_classes=self.n_classes)

For the binary classification, the labels[ID] is an integer. Now if I want to do the multi-label classification, the labels[ID] is a sequence like [0 1 0 0 1 0] (suppose I have 6 labels). I want to know how can I pass this sequence via DataGnerator() function?

Thanks a lot!

1 Answers1

0

You just need to change the last layer's activation to softmax for muticlass classification and the loss to categorical_crossentropy. the data generator looks fine to me, though. only the labels themselves will change, instead of only 1 and 0 you would have more labels than that and the to_categorical function will convert them to one hot encoded outputs. for multi label classification, if y has multiple labels like so [1, 3] then you need to add one thing to your generator return X, to_categorical(y, num_classes=self.n_classes).sum(axis=1)

Mahmoud Youssef
  • 778
  • 7
  • 16
  • Thanks very much. However, actually I was trying to do the multi-label classification instead of multi-class classification. Therefore, maybe sigmoid activation function is more useful than softmax activation function. I changed the dimension of 'y' from 1 to 2 and it can work now. – Renee1002 Jun 10 '20 at 11:09
  • I've edited my answer to allow for mutli label encoding – Mahmoud Youssef Jun 10 '20 at 11:23