I'm using PyTorch Lightning and PyTorch for LSTM classification. Whenever I train the model, this error shows:
File "<string>", line 1, in <module>
File "C:\Users\hisha\anaconda3\envs\FYP\lib\multiprocessing\spawn.py", line 116, in spawn_main
exitcode = _main(fd, parent_sentinel)
File "C:\Users\hisha\anaconda3\envs\FYP\lib\multiprocessing\spawn.py", line 126, in _main
self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'GaitDataset' on <module '__main__' (built-in)>
I'm implementing time series classification using sequences. The original model and data module code is below:
class GaitDataset(Dataset):
def __init__(self, sequences):
self.sequences = sequences
def __len__(self):
return len(self.sequences)
def __getitem__(self, idx):
sequence, label = self.sequences[idx]
return dict(
sequence=torch.Tensor(sequence.to_numpy()),
label=torch.tensor(label).long()
)
class GaitDataModule(pl.LightningDataModule):
def __init__(self, train_sequences, test_sequences, batch_size):
super().__init__()
self.train_sequences = train_sequences
self.test_sequences = test_sequences
self.batch_size = batch_size
def setup(self, stage=None):
self.train_dataset = GaitDataset(self.train_sequences)
self.test_dataset = GaitDataset(self.test_sequences)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=cpu_count()
)
def val_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=cpu_count()
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=2
)
N_EPOCHS = 250
BATCH_SIZE = 64
data_module = GaitDataModule(train_sequences, test_sequences, BATCH_SIZE)
class SequenceModel(nn.Module):
def __init__(self, n_features, n_classes, n_hidden=256, n_layers=3):
super().__init__()
self.lstm = nn.LSTM(
input_size=n_features,
hidden_size=n_hidden,
num_layers=n_layers,
batch_first=True,
dropout=0.5
)
self.classifier = nn.Linear(n_hidden, n_classes)
def forward(self, x):
self.lstm.flatten_parameters()
_, (hidden, _) = self.lstm(x)
out = hidden[-1]
return self.classifier(out)
class GaitPredictor(pl.LightningModule):
def __init__(self, n_features: int, n_classes: int):
super().__init__()
self.model = SequenceModel(n_features, n_classes)
self.criterion = nn.CrossEntropyLoss()
def forward(self, x, labels=None):
output = self.model(x)
loss = 0
if labels is not None:
loss = self.criterion(output, labels)
return loss, output
def training_step(self, batch, batch_idx):
sequences = batch['sequence']
labels = batch['label']
loss, outputs = self(sequences, labels)
predictions = torch.argmax(outputs, dim=1)
step_accuracy = accuracy(predictions, labels)
self.log('train_loss', loss, prog_bar=True, logger=True)
self.log('train_accuracy', step_accuracy, prog_bar=True, logger=True)
return {'loss': loss, 'accuracy': step_accuracy}
def validation_step(self, batch, batch_idx):
sequences = batch['sequence']
labels = batch['label']
loss, outputs = self(sequences, labels)
predictions = torch.argmax(outputs, dim=1)
step_accuracy = accuracy(predictions, labels)
self.log('val_loss', loss, prog_bar=True, logger=True)
self.log('val_accuracy', step_accuracy, prog_bar=True, logger=True)
return {'loss': loss, 'accuracy': step_accuracy}
def test_step(self, batch, batch_idx):
sequences = batch['sequence']
labels = batch['label']
loss, outputs = self(sequences, labels)
predictions = torch.argmax(outputs, dim=1)
step_accuracy = accuracy(predictions, labels)
self.log('test_loss', loss, prog_bar=True, logger=True)
self.log('test_accuracy', step_accuracy, prog_bar=True, logger=True)
return {'loss': loss, 'accuracy': step_accuracy}
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=0.0001)
model = GaitPredictor(
n_features=len(FeatureColumns),
n_classes=len(label_encoder.classes_)
)
checkpoint_callback = ModelCheckpoint(
dirpath='checkpoints',
filename='best-checkpoint',
save_top_k=1,
verbose=True,
monitor='val_loss',
mode='min'
)
logger = TensorBoardLogger('lightning_logs', name='Gait')
trainer = pl.Trainer(
checkpoint_callback=checkpoint_callback,
logger=logger,
max_epochs=N_EPOCHS,
gpus=1,
progress_bar_refresh_rate=30
)
trainer.fit(model, data_module)
How to fix this error? I am using GaitDataset class which describes gait sequences for training, but it seems that Python is unable to import the class properly.
I followed this tutorial from YouTube: https://youtu.be/PCgrgHgy26c
Environment:
- python 3.9.7
- pytorch 1.11.0
- pytorch-lightning 1.5.10
Conda environment configured in Pycharm