1

I have got the following function where I struggle with:

def load_trained_bert(
    num_classes: int, path_to_model: Union[Path, str]
) -> Tuple[BertForSequenceClassification, device]:
    """Returns a bert model and device from four required model files of a folder

    Parameters
    ----------
    num_classes: int
        Number of output layers in the bert model
    path_to_model: Union[Path, str]
        Folder where the four required models files are

    Returns
    -------
      Tuple[BertForSequenceClassification, device]
        BERT model in evaluation mode and device

    """

    # Set device
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # Initialize BERT
    model = BertForSequenceClassification.from_pretrained(
        path_to_model,
        num_classes=num_classes,
        output_attentions=False,
        output_hidden_states=False,
    )

    # Load fine tuned model weights
    weight_file = get_weight_file(path_to_model)
    path_to_weights = os.path.join(path_to_model, weight_file)

    model.load_state_dict(torch.load(path_to_weights, map_location=torch.device("cpu")))

    # Send model to device
    model.to(device)
    # Set model to inference mode
    model.eval()

    return model, device

I am in general not sure how to fest this function, but I thought it would be a good idea just to check the parameters I call the function with:

class LoadModelTest(TestCase):
    @patch("abox.util.model_conversion.get_weight_file", return_value="test.model")
    def test_load_trained_bert(self, get_weight_file):
        BertForSequenceClassification.from_pretrained = Mock()
        load_trained_bert(num_classes=16, path_to_model="./model")
        BertForSequenceClassification.from_pretrained.assert_called_with(
            "./model",
            num_classes=16,
            output_attentions=False,
            output_hidden_states=False,
        )

This results in the following error:

FileNotFoundError: [Errno 2] No such file or directory: './model\\test.model'

Now it´s getting difficult... I have no idea what to do with the following snippet:

model.load_state_dict(torch.load(path_to_weights, map_location=torch.device("cpu")))

Can anyone help me here?

Data Mastery
  • 1,555
  • 4
  • 18
  • 60

0 Answers0