0

So I am trying to write a RNN for the dataset from Quickdraw. My problem is, that the data, which is in ndjson, has multiple lists inside of it.

So one row would look like this:

{"word":"The Eiffel Tower","countrycode":"GB","timestamp":"2017-03-11 14:47:44.05242 UTC","recognized":true,"key_id":"5027286841556992","drawing":[[[0,22,37,64,255],[218,220,227,228,211]],[[76,95,135,141,150,159,166,180,186,201],[220,138,31,0,63,79,117,150,191,224]],[[94,104,111,119,127,141,143,142,180,191],[212,167,149,80,59,41,30,134,202,232]],[[109,127,137,147,150,162,172,185],[122,120,104,97,99,124,128,128]],[[75,130,158],[162,159,150]]]}

Now my Problem is, I only need the word and drawing data, which I managed, but I cant seem to figure out, how to convert the data to tensors to use. The error looks like this:

TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.

I tried to find a solution, but I couldn't find anything that worked. Here is part of my code:

files = glob.glob("data/*.ndjson")

for file in files:
    print (file) # iterate over the list of files   
    with open(file, "r") as fin: 
        data = ndjson.load(fin)
df = pd.DataFrame(data)
X = df['drawing'].to_numpy()
y = df['word'].to_numpy()  # assuming that the target column is named 'word'
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)

# Convert the data to PyTorch tensors and create DataLoaders for the training and validation sets
train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
val_dataset = torch.utils.data.TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=True)
#help

I get the error in the line from train_dataset

kmkurn
  • 611
  • 1
  • 13
SanSiro8
  • 1
  • 2

0 Answers0