-1

I'm following an article on medium called Machine Learning: Improving Classification accuracy on MNIST using Data Augmentation and running into issues with the code in the article. The code is as follows:

from sklearn.datasets import fetch_openml
from scipy.ndimage.interpolation import shift
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import numpy as np


# Fetching MNIST Dataset
mnist = fetch_openml('mnist_784', version=1)

# Get the data and target
X, y = mnist["data"], mnist["target"]

# Split the train and test set
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

# Method to shift the image by given dimension
def shift_image(image, dx, dy):
    image = image.reshape((28, 28))
    shifted_image = shift(image, [dy, dx], cval=0, mode="constant")
    return shifted_image.reshape([-1])


# Creating Augmented Dataset
X_train_augmented = [image for image in X_train]
y_train_augmented = [image for image in y_train]

for dx, dy in ((1,0), (-1,0), (0,1), (0,-1)):
     for image, label in zip(X_train, y_train):
             X_train_augmented.append(shift_image(image, dx, dy))
             y_train_augmented.append(label)


# Shuffle the dataset
shuffle_idx = np.random.permutation(len(X_train_augmented))
X_train_augmented = np.array(X_train_augmented)[shuffle_idx]
y_train_augmented = np.array(y_train_augmented)[shuffle_idx]


# Training on augmented dataset
rf_clf_for_augmented = RandomForestClassifier(random_state=42)
rf_clf_for_augmented.fit(X_train_augmented, y_train_augmented)

# Evaluating the model
y_pred_after_augmented = rf_clf_for_augmented.predict(X_test)
score = accuracy_score(y_test, y_pred_after_augmented)
print("Accuracy score after training on augmented dataset", score)

when I run this code I get this error:

AttributeError: 'str' object has no attribute 'reshape'

Process finished with exit code 1

def shift_image(image, dx, dy):
  image = image.reshape((28, 28))

Why does this happen? What's the problem, and how can it be resolved?

starball
  • 20,030
  • 7
  • 43
  • 238
user2027502
  • 149
  • 2
  • 4
  • 2
    Okay, and what is your *question* about this error message? Did you *read* the error message? Did you *understand* it? – Karl Knechtel Apr 21 '21 at 12:24
  • its ambiguous. line 18/19 in the code URL above – user2027502 Apr 21 '21 at 14:03
  • I don't understand. What do you find ambiguous about it? It says ` 'str' object has no attribute 'reshape'`. Is there more than one `'str' object` you think it could be talking about? More than one `attribute 'reshape'`? More than one way to `have no` such attribute? – Karl Knechtel Apr 22 '21 at 08:20
  • the line is ambigous. the inputs come from the MNIST dataset. when you debug and hover it shows an image. when this code hits it what converted it to a STR? – user2027502 Apr 22 '21 at 18:14

2 Answers2

0

Please can you add a code block where you are using it.

Also, I believe you are passing the wrong parameters while calling this function, by mistake I think you are passing a string as argument

Dhruv Agarwal
  • 558
  • 6
  • 15
0

Try:

for dx, dy in ((1,0), (-1,0), (0,1), (0,-1)):
    for image, label in zip(X_train.values, y_train):
        X_train_augmented.append(shift_image(image, dx, dy))
        y_train_augmented.append(label)

Instead of:

for dx, dy in ((1,0), (-1,0), (0,1), (0,-1)):
    for image, label in zip(X_train, y_train):
        X_train_augmented.append(shift_image(image, dx, dy))
        y_train_augmented.append(label)
wjandrea
  • 28,235
  • 9
  • 60
  • 81
Deepanshu
  • 13
  • 4