1

I have a time series classification task in which I should output a classification of 3 classes for every time stamp t.

All data is labeled per frame.

In the data set are more than 3 classes [which are also imbalanced].

My net should see all samples sequentially, because it uses that for historical information.
Thus, I can't just eliminate all irrelevant class samples at preprocessing time.

In case of a prediction on a frame which is labeled differently than those 3 classes, I don't care about the result.


How to do this correctly in Pytorch?

Gulzar
  • 23,452
  • 27
  • 113
  • 201

1 Answers1

1

Following from this discussion, which was not google searchable, there are two options, both are options of the CrossEntropyLoss:

Option 1

If there is only one class to ignore, use ignore_index=class_index when instantiating the loss.

Option 2

If there are more classes, use weight=weights, with weights.shape==n_classes and torch.sum(weights[ignored_classes]) == 0

Gulzar
  • 23,452
  • 27
  • 113
  • 201