2

I'm using PyTorch Lightning and PyTorch for LSTM classification. Whenever I train the model, this error shows:

File "<string>", line 1, in <module>
  File "C:\Users\hisha\anaconda3\envs\FYP\lib\multiprocessing\spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "C:\Users\hisha\anaconda3\envs\FYP\lib\multiprocessing\spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'GaitDataset' on <module '__main__' (built-in)>

I'm implementing time series classification using sequences. The original model and data module code is below:

class GaitDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence, label = self.sequences[idx]
        return dict(
            sequence=torch.Tensor(sequence.to_numpy()),
            label=torch.tensor(label).long()
        )

class GaitDataModule(pl.LightningDataModule):
    def __init__(self, train_sequences, test_sequences, batch_size):
        super().__init__()
        self.train_sequences = train_sequences
        self.test_sequences = test_sequences
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = GaitDataset(self.train_sequences)
        self.test_dataset = GaitDataset(self.test_sequences)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=cpu_count()
        )

    def val_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=cpu_count()
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=2
        )


N_EPOCHS = 250
BATCH_SIZE = 64

data_module = GaitDataModule(train_sequences, test_sequences, BATCH_SIZE)


class SequenceModel(nn.Module):
    def __init__(self, n_features, n_classes, n_hidden=256, n_layers=3):
        super().__init__()

        self.lstm = nn.LSTM(
            input_size=n_features,
            hidden_size=n_hidden,
            num_layers=n_layers,
            batch_first=True,
            dropout=0.5
        )
        self.classifier = nn.Linear(n_hidden, n_classes)

    def forward(self, x):
        self.lstm.flatten_parameters()
        _, (hidden, _) = self.lstm(x)

        out = hidden[-1]
        return self.classifier(out)


class GaitPredictor(pl.LightningModule):
    def __init__(self, n_features: int, n_classes: int):
        super().__init__()
        self.model = SequenceModel(n_features, n_classes)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x, labels=None):
        output = self.model(x)
        loss = 0
        if labels is not None:
            loss = self.criterion(output, labels)
        return loss, output

    def training_step(self, batch, batch_idx):
        sequences = batch['sequence']
        labels = batch['label']
        loss, outputs = self(sequences, labels)
        predictions = torch.argmax(outputs, dim=1)
        step_accuracy = accuracy(predictions, labels)

        self.log('train_loss', loss, prog_bar=True, logger=True)
        self.log('train_accuracy', step_accuracy, prog_bar=True, logger=True)

        return {'loss': loss, 'accuracy': step_accuracy}

    def validation_step(self, batch, batch_idx):
        sequences = batch['sequence']
        labels = batch['label']
        loss, outputs = self(sequences, labels)
        predictions = torch.argmax(outputs, dim=1)
        step_accuracy = accuracy(predictions, labels)

        self.log('val_loss', loss, prog_bar=True, logger=True)
        self.log('val_accuracy', step_accuracy, prog_bar=True, logger=True)

        return {'loss': loss, 'accuracy': step_accuracy}

    def test_step(self, batch, batch_idx):
        sequences = batch['sequence']
        labels = batch['label']
        loss, outputs = self(sequences, labels)
        predictions = torch.argmax(outputs, dim=1)
        step_accuracy = accuracy(predictions, labels)

        self.log('test_loss', loss, prog_bar=True, logger=True)
        self.log('test_accuracy', step_accuracy, prog_bar=True, logger=True)

        return {'loss': loss, 'accuracy': step_accuracy}

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.0001)


model = GaitPredictor(
    n_features=len(FeatureColumns),
    n_classes=len(label_encoder.classes_)
)

checkpoint_callback = ModelCheckpoint(
    dirpath='checkpoints',
    filename='best-checkpoint',
    save_top_k=1,
    verbose=True,
    monitor='val_loss',
    mode='min'
)

logger = TensorBoardLogger('lightning_logs', name='Gait')

trainer = pl.Trainer(
    checkpoint_callback=checkpoint_callback,
    logger=logger,
    max_epochs=N_EPOCHS,
    gpus=1,
    progress_bar_refresh_rate=30
)

trainer.fit(model, data_module)

How to fix this error? I am using GaitDataset class which describes gait sequences for training, but it seems that Python is unable to import the class properly.

I followed this tutorial from YouTube: https://youtu.be/PCgrgHgy26c

Environment:

  • python 3.9.7
  • pytorch 1.11.0
  • pytorch-lightning 1.5.10

Conda environment configured in Pycharm

  • Please trim your code to make it easier to find your problem. Follow these guidelines to create a [minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example). – itprorh66 May 23 '22 at 15:16

1 Answers1

1

This is due to lightning trying to load the data from the file. It cannot determine the 'GaitDataset'. I guess this is actually a bug in lightning since it should use the fully qualified name to resolve it.

Funny enough you can import it in your main file and it will be found. Add

from your.package.here import GaitDataset

It gets funny if you have other wrapper logic happening.

Edit: See https://stackoverflow.com/a/68279928/1615430 for the cause of this error

RookieGuy
  • 517
  • 7
  • 18