1

I'm starting my first machine learning code with python. But, I encountered an error while developing the confusion matrix for my multiclass model.

#Defining the model 

model = Sequential()

model.add(Dense(32,input_shape=(22,),activation='tanh'))
model.add(Dense(16,activation='tanh'))
model.add(Dense(6,activation='tanh'))
model.add(Dense(5,activation='softmax'))

model.compile(Adam(lr=0.004),'sparse_categorical_crossentropy',metrics=['accuracy'])

#fitting the model and predicting 

model.fit(X_train,Y_train,epochs=1)

Y_pred = model.predict(X_test)

Y_pred = Y_pred.astype(int)

Y_test_class = np.argmax(Y_test, axis=0)
Y_pred_class = np.argmax(Y_pred, axis=0)

#Accuracy of the predicted values

print(metrics.classification_report(Y_test_class,Y_pred_class))
print(metrics.confusion_matrix(Y_test_class,Y_pred_class))

I'm getting this error:

TypeError: Singleton array 3045 cannot be considered a valid collection.

Test data details X_test[:5]

[['0' '1' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0'
  '0' '1' '0' '0']
 ['1' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0'
  '0' '0' '0' '0']
 ['1' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0'
  '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0'
  '0' '0' '0' '0']
 ['0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0' '0'
  '0' '1' '1' '0']]

Y_test[:5]

['1' '2' '2' '2' '2']

The shape of

Y_test_class ==> ()
Y_pred_class ==> (5,)

Asma Mekki
  • 47
  • 1
  • 10

1 Answers1

1

Have you looked at:

Multilabel-indicator is not supported for confusion matrix

Depending on if you're using OHE or a vector of labels, this might help.

scottbaker
  • 153
  • 1
  • 6
  • Thx for the link! It helps. I think the problem is that `Y_test_class` is empty while the shape of `Y_test` is `(32863,)`. Have you any idea how can I solve that PLS? – Asma Mekki Mar 11 '20 at 03:30
  • I'm not familiar with your data, but check that ```astype(int)``` isn't messing it up. Otherwise, I'd assume that ```np.argmax``` will output the same result potentially without the ```astype(int)``` part – scottbaker Mar 11 '20 at 03:46
  • That's true! but even when I remove that line the error persists. I added some details about my test data. I hope that can help! – Asma Mekki Mar 11 '20 at 04:12