9

I have been training a unet model for multiclass semantic segmentation in python using Tensorflow and Tensorflow Datasets.

I've noticed that one of my classes seems to be underrepresented in training. After doing some research, I found out about sample weights and thought it might be a good solution to my problem, but I've been having trouble deciphering the documentation on how to use it or find examples of it being used.

Could someone help explain how sample weights come into play with datasets for training or point me to an example where it is being implemented? Or even what type of input the model.fit function is expecting would be helpful.

desertnaut
  • 57,590
  • 26
  • 140
  • 166
jtheck314
  • 107
  • 1
  • 1
  • 2

1 Answers1

13

From the documentation of tf.keras model.fit():

sample_weight

[...] This argument is not supported when x is a dataset, generator, or keras.utils.Sequence instance, instead provide the sample_weights as the third element of x.

What is meant by that? This is demonstrated for the Dataset case in one of the official documentation turorials:

sample_weight = np.ones(shape=(len(y_train),))
sample_weight[y_train == 5] = 2.0

# Create a Dataset that includes sample weights
# (3rd element in the return tuple).
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train, sample_weight))

# Shuffle and slice the dataset.
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

model = get_compiled_model()
model.fit(train_dataset, epochs=1)

See the link for a full-fledged example.

desertnaut
  • 57,590
  • 26
  • 140
  • 166
  • 1
    from the link, I found this infomation below helpful. A "sample weights" array is an array of numbers that specify how much weight each sample in a batch should have in computing the total loss. It is commonly used in imbalanced classification problems (the idea being to give more weight to rarely-seen classes). – pakira79 Apr 16 '22 at 22:17
  • There is already a `class_weight` parameter on the `fit` method for weighing classes for classification problems, I think this is more useful for weighing actual samples, for example giving more weights to more recent samples – Maro Aug 22 '22 at 21:34
  • `class_weight` is interesting for a fairly narrow range of problems. That is, if one only has one instance of each class in the target, say 10 classes needing 10 weights. But for time-series models, it is common to have each class computed independently over time within the target. Thus, with 100 time bins and 10 classes, one needs 1000 weights, not just the 10 weights that the `class_weights` allows for. To my mind, `sample_weights` are much more general and more powerful. But it depends on the requirements for the model. – Hephaestus Apr 24 '23 at 00:17