28

I have a class imbalance problem and been experimenting with a weighted Random Forest using the implementation in scikit-learn (>= 0.16).

I have noticed that the implementation takes a class_weight parameter in the tree constructor and sample_weight parameter in the fit method to help solve class imbalance. Those two seem to be multiplied though to decide a final weight.

I have trouble understanding the following:

  • In what stages of the tree construction/training/prediction are those weights used? I have seen some papers for weighted trees, but I am not sure what scikit implements.
  • What exactly is the difference between class_weight and sample_weight?
cottontail
  • 10,268
  • 18
  • 50
  • 51
user36047
  • 441
  • 1
  • 4
  • 8

2 Answers2

21

RandomForests are built on Trees, which are very well documented. Check how Trees use the sample weighting:

  • User guide on decision trees - tells exactly what algorithm is used
  • Decision tree API - explains how sample_weight is used by trees (which for random forests, as you have determined, is the product of class_weight and sample_weight).

As for the difference between class_weight and sample_weight: much can be determined simply by the nature of their datatypes. sample_weight is 1D array of length n_samples, assigning an explicit weight to each example used for training. class_weight is either a dictionary of each class to a uniform weight for that class (e.g., {1:.9, 2:.5, 3:.01}), or is a string telling sklearn how to automatically determine this dictionary.

So the training weight for a given example is the product of it's explicitly named sample_weight (or 1 if sample_weight is not provided), and it's class_weight (or 1 if class_weight is not provided).

Andreus
  • 2,437
  • 14
  • 22
  • You are right about DT as the base classifier. I am also interested in how those weights are used during training (e.g. deciding impurity of a decision node etc.) and during prediction. – user36047 Jun 12 '15 at 16:25
  • Check out the documentation for `sample_weight` under the `fit()` method: [Decision tree API](http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier) – Andreus Jun 12 '15 at 16:29
  • Although it is not crystal clear, I think everything will be clearer when I digest the material in the links :) – user36047 Jun 12 '15 at 16:49
0

If we look at the source code, RandomForestClassifier is sub-classed from ForestClassifier class, which in turn is sub-classed from BaseForest class and the fit() method is actually defined the BaseForest class. As OP pointed out, the interaction between class_weight and sample_weight determine the sample weights used to fit each decision tree of the random forest.

If we inspect _validate_y_class_weight(), fit() and _parallel_build_trees() methods, we can understand the interaction between class_weight, sample_weight and bootstrap parameters better. In particular,

  • if class_weight is passed to the RandomForestClassifier() constructor but no sample_weight is passed to fit(), class_weight is used as the sample weight
  • if both sample_weight and class_weight are passed, then they are multiplied together to determine the final sample weights used to train each individual decision tree
  • if class_weight=None, then sample_weight determines the final sample weights (by default, if None, then samples are equally weighted).

The relevant part in the source code may be summarized as follows.

from sklearn.utils import compute_sample_weight

if class_weight == "balanced_subsample" and not bootstrap:
    expanded_class_weight = compute_sample_weight("balanced", y)
elif class_weight is not None and class_weight != "balanced_subsample" and bootstrap:
    expanded_class_weight = compute_sample_weight(class_weight, y)
else:
    expanded_class_weight = None

if expanded_class_weight is not None:
    if sample_weight is not None:
        sample_weight = sample_weight * expanded_class_weight
    else:
        sample_weight = expanded_class_weight

With bootstrap=True, observations are randomly selected for individual trees trained, which is done via the sample_weight argument of fit() whose relevant (abridged) code looks like the following.

if bootstrap:
    if sample_weight is None:
        sample_weight = np.ones((X.shape[0],), dtype=np.float64)

    indices = check_random_state(tree.random_state).randint(X.shape[0], n_samples_bootstrap)
    sample_counts = np.bincount(indices, minlength=X.shape[0])
    sample_weight *= sample_counts

    if class_weight == "balanced_subsample":
        sample_weight *= compute_sample_weight("balanced", y, indices=indices)
cottontail
  • 10,268
  • 18
  • 50
  • 51