I am getting a ValueError: shape mismatch: objects cannot be broadcast to single shape.
The error occurs when I run the following code:
plt.bar(range(1, 14), pca.explained_variance_ratio_, alpha=0.5,
... align='center')
Traceback (most recent call last):
File "<stdin>", line 2, in <module>
File "/me/anaconda3/envs/new36/lib/python3.6/site-packages/matplotlib/pyplot.py", line 2648, in bar
ret = ax.bar(*args, **kwargs)
File "/me/anaconda3/envs/new36/lib/python3.6/site-packages/matplotlib/__init__.py", line 1717, in inner
return func(ax, *args, **kwargs)
File "/me/anaconda3/envs/new36/lib/python3.6/site-packages/matplotlib/axes/_axes.py", line 2019, in bar
np.atleast_1d(x), height, width, y, linewidth)
File "/me/anaconda3/envs/new36/lib/python3.6/site-packages/numpy/lib/stride_tricks.py", line 249, in broadcast_arrays
shape = _broadcast_shape(*args)
File "/me/anaconda3/envs/new36/lib/python3.6/site-packages/numpy/lib/stride_tricks.py", line 184, in _broadcast_shape
b = np.broadcast(*args[:32])
ValueError: shape mismatch: objects cannot be broadcast to a single shape
This is the full code used to replicate the error:
from sklearn.datasets.samples_generator import make_blobs
from pandas import DataFrame
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
# generate a classification dataset
X, y = make_blobs(n_samples=1000, centers=3, n_features=10, random_state=1,
cluster_std=3)
X_train, X_test, y_train, y_test = \
train_test_split(X, y, test_size=0.2,
stratify=y,
random_state=0)
sc = StandardScaler()
X_train_std = sc.fit_transform(X_train)
X_test_std = sc.transform(X_test)
pca = PCA()
X_train_pca = pca.fit_transform(X_train_std)
pca.explained_variance_ratio_
import matplotlib.pyplot as plt
import numpy as np
# Error occurs here
plt.bar(range(1, 14), pca.explained_variance_ratio_, alpha=0.5,
align='center')
plt.step(range(1, 14), np.cumsum(pca.explained_variance_ratio_), where='mid')
plt.ylabel('Explained variance ratio')
plt.xlabel('Principal components')
plt.show()