-2

This is the code that i am using to print the original unreduced picture of 100 mnist data but is is constantly giving me an error. Even after trying a lot I could not find the solution. Request for suggestion

 from sklearn.datasets import fetch_openml
   mnist = fetch_openml('mnist_784')
   X = mnist["data"]
   y = mnist["target"]
   X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000],y[60000:]
   pca = PCA()
   pca.fit(X_train)
   cumsum = np.cumsum(pca.explained_variance_ratio_)
   d = np.argmax(cumsum >= 0.90) + 1

   #Setup a figure 8 inches by 8 inches
   fig = plt.figure(figsize=(8,8))
   fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
    for i in range(100):
        ax = fig.add_subplot(10, 10, i+1, xticks=[], yticks=[])
        ax.imshow(X_train[i].reshape(28,28), cmap=plt.cm.bone, interpolation='nearest')
        plt.show()
desertnaut
  • 57,590
  • 26
  • 140
  • 166
  • "*giving me an error*" is not helpful for possible respondents. What error? Please see how to create a [mcve] (and also fix your code indentation). – desertnaut May 26 '21 at 08:12

2 Answers2

1
import matplotlib.pyplot as plt

fig,ax = plt.subplots(5, 10)

for i in range(10):
    for j in range(10):
      ax[i,j].imshow(X_train[(10*i)+j].reshape(8, 8), cmap='binary')
Anurag Dhadse
  • 1,722
  • 1
  • 13
  • 26
0

It's just where you've got your plot show statement still within the loop. Just move it out of the loop and it'll show fine. Give the below a go;

from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt

mnist = fetch_openml('mnist_784')
X = mnist["data"]
y = mnist["target"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, 
 random_state=44)
pca = PCA()
pca.fit(X_train)

fig = plt.figure(figsize=(8,8))
fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
for i in range(100):
    ax = fig.add_subplot(10, 10, i+1, xticks=[], yticks=[])
    ax.imshow(X_train[i].reshape(28,28), cmap=plt.cm.bone, interpolation='nearest')

plt.show()
DataDude
  • 1
  • 2