0

I'm training a Linear Regression Neural Network on embeddings of protein data. The problem I'm facing is that for the training and validations loss scores I'm getting decent results but once I try the testing dataset (which I don't have access to) I get results which are a lot worse.

Here's what I'm doing:

wt_emb = torch.load("train/train_wt.pt")
mut_emb = torch.load("train/train_mut.pt")
df = pd.read_csv("train/train.csv")

from sklearn.model_selection import train_test_split

# Creating a combined dataset
combined_data = list(zip(wt_emb, mut_emb, df['ddg' if 'ddg' in df else 'ID']))

# Splitting the data
train_data, val_data = train_test_split(combined_data, test_size=0.2)  # 80% training, 20% validation

# Separating the embeddings and the targets
wt_emb_train, mut_emb_train, df_train = zip(*train_data)
wt_emb_val, mut_emb_val, df_val = zip(*val_data)

# Converting back to the original types
wt_emb_train = torch.stack(wt_emb_train)
mut_emb_train = torch.stack(mut_emb_train)
df_train = pd.DataFrame(df_train, columns=['ddg'])

wt_emb_val = torch.stack(wt_emb_val)
mut_emb_val = torch.stack(mut_emb_val)
df_val = pd.DataFrame(df_val, columns=['ddg'])

# Building the dataset class
class EmbeddingDataset(torch.utils.data.Dataset):
  def __init__(self,mut_pt, wt_pt, data_df):
    self.pt_mut = mut_pt
    self.pt_wt = wt_pt
    self.df = data_df
  
  def __len__(self):
      return self.pt_mut.shape[0]

  def __getitem__(self, index):
    o1=self.pt_mut[index,:]
    o2=self.pt_wt[index,:]
    if "ddg" in self.df:
      df_out=torch.Tensor([self.df["ddg"][index]])
    else:
      df_out=torch.Tensor([self.df["ID"][index]])
    return  self.pt_mut[index,:],self.pt_wt[index,:],df_out 

# creating training dataset and dataloader
train_dataset = EmbeddingDataset(wt_emb_train, mut_emb_train, df_train)
val_dataset = EmbeddingDataset(wt_emb_val, mut_emb_val, df_val)
# preparing a dataloader for the training
train_dataloader = torch.utils.data.dataloader.DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        num_workers=2,
    )
# preparing a dataloader for the validation
val_dataloader = torch.utils.data.dataloader.DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2,
)


import torch.nn.functional as F

HIDDEN_UNITS_POS_CONTACT = 32
DROPOUT_PROB = 0.5

class StabilityModel(torch.nn.Module):
    def __init__(self):
        super(StabilityModel, self).__init__()
        self.fc1 = torch.nn.Linear(1280 * 2, HIDDEN_UNITS_POS_CONTACT)
        self.dropout1 = torch.nn.Dropout(DROPOUT_PROB)
        self.fc2 = torch.nn.Linear(HIDDEN_UNITS_POS_CONTACT, 1)

    def forward(self, x, y):
        outputs_pos_concat = torch.cat((x, y), 1)
        fc1_outputs = F.relu(self.fc1(outputs_pos_concat))
        fc1_outputs = self.dropout1(fc1_outputs)
        logits = self.fc2(fc1_outputs)
        return logits

# Example of training script
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model =  StabilityModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Define a learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
criterion = torch.nn.MSELoss()
num_epochs = 10
best_loss = float('inf')
for epoch in range(num_epochs):
    epoch_loss = 0
    model.train()
    for batch_idx, (data_mut, data_wt, target) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch+1}"):
        data_mut = data_mut.to(device)
        data_wt = data_wt.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        output = model(data_mut, data_wt)
        loss = torch.sqrt(criterion(output, target))
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    epoch_loss /= len(train_dataloader)
    print(f"Epoch {epoch+1} loss: {epoch_loss}")

    # Save the model if it has the best validation loss so far
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), "best_model.pt")
print("Training completed.")

# Load the best model for prediction
best_model = StabilityModel().to(device)
best_model.load_state_dict(torch.load("best_model.pt"))
best_model.eval()

# Perform prediction using the best model on the validation dataset
predictions = []
targets = []
with torch.no_grad():
    for batch_idx, (data_mut, data_wt, target) in tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc="Validation"):
        data_mut = data_mut.to(device)
        data_wt = data_wt.to(device)

        output = best_model(data_mut, data_wt)
        predictions.extend(output.cpu().numpy())
        targets.extend(target.numpy())

print("Prediction completed.")

import numpy as np
from sklearn.metrics import mean_squared_error

# Convert predictions and targets to NumPy arrays
predictions = np.array(predictions)
targets = np.array(targets)

# Calculate root mean squared error (RMSE)
rmse = np.sqrt(mean_squared_error(targets, predictions))
print("Validation RMSE:", rmse)

# You can perform further analysis or evaluation metrics on the predictions and targets
# For example, you can calculate the mean absolute error (MAE)
mae = np.mean(np.abs(targets - predictions))
print("Validation MAE:", mae)

# load embedding tensors & traing csv
wt_test_emb = torch.load("test/test_wt.pt")
mut_test_emb = torch.load("test/test_mut.pt")
df_test = pd.read_csv("test/test.csv")

# creating testing dataset and loading the embedding
test_dataset = EmbeddingDataset(wt_test_emb,mut_test_emb,df_test)
# preparing a dataloader for the testing
test_dataloader = torch.utils.data.dataloader.DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=2,
    )

df_result = pd.DataFrame()
with torch.no_grad():
  for batch_idx, (data_mut,data_wt , target) in tqdm(enumerate(test_dataloader)):
    x1 = data_wt.to(device)
    x2 = data_mut.to(device)
    id = target.to(device)
    # make prediction
    y_pred = best_model(x1,x2)
    df_result = pd.concat([df_result, pd.DataFrame({"ID":id.squeeze().cpu().numpy().astype(int) , "ddg" : y_pred.squeeze().cpu().numpy()})])

The problem I seem to be facing is that the model is overfitting the training data or otherwise I have no idea why I'm getting these bad results.

I would get something like RMSE: 0.85 for training and 0.86 for validation but I would get 1.70 or 1.66 for the testing set.

Any idea why this is happening ?

Additional point: while tweaking I somehow got 1.53 which was the best I got but I couldn't reproduce the results ever since.

Moe_blg
  • 71
  • 4

0 Answers0