I'm training a Linear Regression Neural Network on embeddings of protein data. The problem I'm facing is that for the training and validations loss scores I'm getting decent results but once I try the testing dataset (which I don't have access to) I get results which are a lot worse.
Here's what I'm doing:
wt_emb = torch.load("train/train_wt.pt")
mut_emb = torch.load("train/train_mut.pt")
df = pd.read_csv("train/train.csv")
from sklearn.model_selection import train_test_split
# Creating a combined dataset
combined_data = list(zip(wt_emb, mut_emb, df['ddg' if 'ddg' in df else 'ID']))
# Splitting the data
train_data, val_data = train_test_split(combined_data, test_size=0.2) # 80% training, 20% validation
# Separating the embeddings and the targets
wt_emb_train, mut_emb_train, df_train = zip(*train_data)
wt_emb_val, mut_emb_val, df_val = zip(*val_data)
# Converting back to the original types
wt_emb_train = torch.stack(wt_emb_train)
mut_emb_train = torch.stack(mut_emb_train)
df_train = pd.DataFrame(df_train, columns=['ddg'])
wt_emb_val = torch.stack(wt_emb_val)
mut_emb_val = torch.stack(mut_emb_val)
df_val = pd.DataFrame(df_val, columns=['ddg'])
# Building the dataset class
class EmbeddingDataset(torch.utils.data.Dataset):
def __init__(self,mut_pt, wt_pt, data_df):
self.pt_mut = mut_pt
self.pt_wt = wt_pt
self.df = data_df
def __len__(self):
return self.pt_mut.shape[0]
def __getitem__(self, index):
o1=self.pt_mut[index,:]
o2=self.pt_wt[index,:]
if "ddg" in self.df:
df_out=torch.Tensor([self.df["ddg"][index]])
else:
df_out=torch.Tensor([self.df["ID"][index]])
return self.pt_mut[index,:],self.pt_wt[index,:],df_out
# creating training dataset and dataloader
train_dataset = EmbeddingDataset(wt_emb_train, mut_emb_train, df_train)
val_dataset = EmbeddingDataset(wt_emb_val, mut_emb_val, df_val)
# preparing a dataloader for the training
train_dataloader = torch.utils.data.dataloader.DataLoader(
train_dataset,
batch_size=32,
shuffle=True,
num_workers=2,
)
# preparing a dataloader for the validation
val_dataloader = torch.utils.data.dataloader.DataLoader(
val_dataset,
batch_size=32,
shuffle=False,
num_workers=2,
)
import torch.nn.functional as F
HIDDEN_UNITS_POS_CONTACT = 32
DROPOUT_PROB = 0.5
class StabilityModel(torch.nn.Module):
def __init__(self):
super(StabilityModel, self).__init__()
self.fc1 = torch.nn.Linear(1280 * 2, HIDDEN_UNITS_POS_CONTACT)
self.dropout1 = torch.nn.Dropout(DROPOUT_PROB)
self.fc2 = torch.nn.Linear(HIDDEN_UNITS_POS_CONTACT, 1)
def forward(self, x, y):
outputs_pos_concat = torch.cat((x, y), 1)
fc1_outputs = F.relu(self.fc1(outputs_pos_concat))
fc1_outputs = self.dropout1(fc1_outputs)
logits = self.fc2(fc1_outputs)
return logits
# Example of training script
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = StabilityModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# Define a learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
criterion = torch.nn.MSELoss()
num_epochs = 10
best_loss = float('inf')
for epoch in range(num_epochs):
epoch_loss = 0
model.train()
for batch_idx, (data_mut, data_wt, target) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch+1}"):
data_mut = data_mut.to(device)
data_wt = data_wt.to(device)
target = target.to(device)
optimizer.zero_grad()
output = model(data_mut, data_wt)
loss = torch.sqrt(criterion(output, target))
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_loss /= len(train_dataloader)
print(f"Epoch {epoch+1} loss: {epoch_loss}")
# Save the model if it has the best validation loss so far
if epoch_loss < best_loss:
best_loss = epoch_loss
torch.save(model.state_dict(), "best_model.pt")
print("Training completed.")
# Load the best model for prediction
best_model = StabilityModel().to(device)
best_model.load_state_dict(torch.load("best_model.pt"))
best_model.eval()
# Perform prediction using the best model on the validation dataset
predictions = []
targets = []
with torch.no_grad():
for batch_idx, (data_mut, data_wt, target) in tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc="Validation"):
data_mut = data_mut.to(device)
data_wt = data_wt.to(device)
output = best_model(data_mut, data_wt)
predictions.extend(output.cpu().numpy())
targets.extend(target.numpy())
print("Prediction completed.")
import numpy as np
from sklearn.metrics import mean_squared_error
# Convert predictions and targets to NumPy arrays
predictions = np.array(predictions)
targets = np.array(targets)
# Calculate root mean squared error (RMSE)
rmse = np.sqrt(mean_squared_error(targets, predictions))
print("Validation RMSE:", rmse)
# You can perform further analysis or evaluation metrics on the predictions and targets
# For example, you can calculate the mean absolute error (MAE)
mae = np.mean(np.abs(targets - predictions))
print("Validation MAE:", mae)
# load embedding tensors & traing csv
wt_test_emb = torch.load("test/test_wt.pt")
mut_test_emb = torch.load("test/test_mut.pt")
df_test = pd.read_csv("test/test.csv")
# creating testing dataset and loading the embedding
test_dataset = EmbeddingDataset(wt_test_emb,mut_test_emb,df_test)
# preparing a dataloader for the testing
test_dataloader = torch.utils.data.dataloader.DataLoader(
test_dataset,
batch_size=32,
shuffle=False,
num_workers=2,
)
df_result = pd.DataFrame()
with torch.no_grad():
for batch_idx, (data_mut,data_wt , target) in tqdm(enumerate(test_dataloader)):
x1 = data_wt.to(device)
x2 = data_mut.to(device)
id = target.to(device)
# make prediction
y_pred = best_model(x1,x2)
df_result = pd.concat([df_result, pd.DataFrame({"ID":id.squeeze().cpu().numpy().astype(int) , "ddg" : y_pred.squeeze().cpu().numpy()})])
The problem I seem to be facing is that the model is overfitting the training data or otherwise I have no idea why I'm getting these bad results.
I would get something like RMSE: 0.85 for training and 0.86 for validation but I would get 1.70 or 1.66 for the testing set.
Any idea why this is happening ?
Additional point: while tweaking I somehow got 1.53 which was the best I got but I couldn't reproduce the results ever since.