0

I'm running this code and I got an error witht the fit function

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
lda = LinearDiscriminantAnalysis(shrinkage='auto')
lda.fit(np.random.rand(3,2),np.random.randint((1,1,1)))

Here is the error :

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-34-ec552dd1faa1> in <module>
      1 lda = LinearDiscriminantAnalysis(shrinkage='auto')
----> 2 lda.fit(np.random.rand(3,2),np.random.randint((1,1,1)))
      3 LinearDiscriminantAnalysis()

~/anaconda3/lib/python3.8/site-packages/sklearn/discriminant_analysis.py in fit(self, X, y)
    581         if self.solver == "svd":
    582             if self.shrinkage is not None:
--> 583                 raise NotImplementedError("shrinkage not supported")
    584             if self.covariance_estimator is not None:
    585                 raise ValueError(

NotImplementedError: shrinkage not supported

How to fix it? (got the same error upgrading scikit learn, and also on google collab)

desertnaut
  • 57,590
  • 26
  • 140
  • 166

1 Answers1

1

shrinkage is not supported with svd solver. You can use this parameter with other solvers such as eigen or lsqr as follows:

LinearDiscriminantAnalysis(solver='lsqr',shrinkage='auto').fit(X_train, y_train)
Antoine Dubuis
  • 4,974
  • 1
  • 15
  • 29