I am using a pre-trained BERT model to classify ASR-generated transcript segments and am currently using Optuna to identify the optimal hyperparameters. I wish to modify this code to use cross-validation whilst using Optuna to find the best hyperparameters and evaluate the "best" BERT model on a hold-out test dataset. I need to ensure that all segments belonging to a particular Partcipant_ID are grouped together in the same fold, to prevent data leakage. I am not entirely sure how to modify the code below to achieve this. Can anyone advise how to go about this?
class Label(Enum):
PC = 0
BT = 1
class Dataset(torch.utils.data.Dataset):
def __init__(self, df, tokenizer):
self.participantIDs = df['Participant_ID']
self.labels = [Label[label] for label in df['Diagnosis']]
self.texts = [tokenizer.encode_plus(text, padding='max_length', max_length=64, truncation=True, return_tensors="pt") for text in df['Segment_Words']]
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
batch_texts = self.texts[idx]
batch_y = self.labels[idx].value
return batch_texts['input_ids'].squeeze(), batch_texts['attention_mask'].squeeze(), batch_y
class BertClassifier(nn.Module):
def __init__(self, dropout=0.1):
super(BertClassifier, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-cased')
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(self.bert.config.hidden_size, 2) # Assuming 2 classes
def forward(self, input_ids=None, attention_mask=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits
def train_epoch(model, dataloader, criterion, optimizer, device, mode):
total_loss = 0
total_correct = 0
total_samples = 0
model.train() if mode == 'train' else model.eval()
with torch.set_grad_enabled(mode == 'train'):
for input_ids, attention_mask, labels in tqdm(dataloader):
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(input_ids, attention_mask)
loss = criterion(outputs, labels)
if mode == 'train':
loss.backward()
optimizer.step()
total_loss += loss.item()
predictions = outputs.argmax(dim=1)
total_correct += (predictions == labels).sum().item()
total_samples += len(labels)
avg_loss = total_loss / len(dataloader)
avg_accuracy = total_correct / total_samples
return avg_loss, avg_accuracy
def train(model, train_data, val_data, tokenizer, learning_rate, epochs, batch_size):
train_dataset = Dataset(train_data, tokenizer)
val_dataset = Dataset(val_data, tokenizer)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate)
if use_cuda:
model = model.to(device)
criterion = criterion.to(device)
for epoch_num in range(epochs):
train_loss, train_accuracy = train_epoch(model, train_dataloader, criterion, optimizer, device, 'train')
val_loss, val_accuracy = train_epoch(model, val_dataloader, criterion, optimizer, device, 'val')
print(f'Epochs: {epoch_num + 1} | Train Loss: {train_loss:.4f} | Train Accuracy: {train_accuracy:.4f}')
print(f'Epochs: {epoch_num + 1} | Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}')
def evaluate(model, test_data, tokenizer, return_misclassified=False):
test_dataset = Dataset(test_data, tokenizer)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=8)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
misclassified = [] if return_misclassified else None
total_correct = 0
total_samples = 0
model.eval() # Set model to evaluation mode
with torch.no_grad():
for batch_index, batch in enumerate(test_dataloader):
input_ids, attention_mask, test_labels = [item.to(device) for item in batch]
output = model(input_ids, attention_mask)
predictions = output.argmax(dim=1)
correct = (predictions == test_labels).sum().item()
total_correct += correct
total_samples += len(test_labels)
if return_misclassified:
incorrect_indices = (predictions != test_labels).nonzero().flatten()
for idx_in_batch in incorrect_indices:
global_idx = batch_index * test_dataloader.batch_size + idx_in_batch.item()
participant_id = test_data.iloc[global_idx]['Participant_ID']
segment_number = idx_in_batch.item() + 1
misclassified.append(f"{participant_id}_{segment_number}")
accuracy = total_correct / total_samples
print(f'Test Accuracy: {accuracy:.4f}')
if return_misclassified:
accuracy_with_misclassified = 1 - (len(misclassified) / len(test_data))
return misclassified, accuracy_with_misclassified
else:
return accuracy
def objective(trial, train_data, val_data, test_data, tokenizer):
learning_rate = trial.suggest_float('learning_rate', 1e-6, 1e-3, log=True)
dropout = trial.suggest_float('dropout', 0.1, 0.9)
batch_size = trial.suggest_int('batch_size', 4, 64)
epochs = trial.suggest_int('epochs', 5, 30)
model = BertClassifier(dropout=dropout)
train(model, train_data, val_data, tokenizer, learning_rate, epochs, batch_size)
return evaluate(model, test_data, tokenizer)
def perform_group_shuffle_split(data_frame, test_size, val_size):
gss = GroupShuffleSplit(n_splits=1, test_size=test_size)
idx1, idx2 = next(gss.split(data_frame, groups=data_frame.Participant_ID))
train_data_df, test_data_df = data_frame.iloc[idx1], data_frame.iloc[idx2]
gss2 = GroupShuffleSplit(n_splits=1, test_size=val_size / (1 - test_size))
idx3, idx4 = next(gss2.split(train_data_df, groups=train_data_df.Participant_ID))
train_data_df, val_data_df = train_data_df.iloc[idx3], train_data_df.iloc[idx4]
return train_data_df, val_data_df, test_data_df
def generate_confusion_matrix(model, test_dataloader, device):
true_labels = np.array([])
predicted_labels = np.array([])
model.eval()
with torch.no_grad():
for batch in test_dataloader:
input_ids, attention_mask, test_labels = [item.to(device) for item in batch]
outputs = model(input_ids, attention_mask)
predictions = outputs.argmax(dim=1)
true_labels = np.concatenate((true_labels, test_labels.cpu().numpy()))
predicted_labels = np.concatenate((predicted_labels, predictions.cpu().numpy()))
conf_matrix = confusion_matrix(true_labels, predicted_labels)
return conf_matrix
def plot_confusion_matrix(conf_matrix, class_labels):
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_labels, yticklabels=class_labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
def main():
original_stdout = sys.stdout
with open('/usr/not-backed-up/el18saj/data/trial_output3.txt', 'w') as f:
sys.stdout = f
transcript_data = '/usr/not-backed-up/el18saj/data/Transcripts/WhisperX_Transcripts/SegmentLevelTranscripts/NoPauses/3_segments/Initial_Prompt_True/Zoom/WhisperX_InitialPromptTrue_ZoomID_3_seconds.csv'
print(transcript_data)
print("Learning_Rate_Range: 1e-6 to 1e-3, Dropout_Range: 0.1-0.9, Batch_Size_Range: 4-64, Batch_Size: 5-30")
transcript_df = pd.read_csv(transcript_data)
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
df_train, df_val, df_test = perform_group_shuffle_split(transcript_df, test_size=0.2, val_size=0.125)
print("Participant IDs in training dataset:", ', '.join(map(str, df_train['Participant_ID'].unique())))
print("Participant IDs in validation dataset:", ', '.join(map(str, df_val['Participant_ID'].unique())))
print("Participant IDs in test dataset:", ', '.join(map(str, df_test['Participant_ID'].unique())))
study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=42))
study.optimize(lambda trial: objective(trial, df_train, df_val, df_test, tokenizer), n_trials=50)
# Print results for each trial
for trial_number, trial in enumerate(study.trials, start=1):
trial_params = trial.params
trial_value = trial.value
print(f"Trial {trial_number} finished with value: {trial_value} and parameters: {trial_params}")
# Print the parameters for the best trial
print("Parameters for the Best Trial:")
for key, value in study.best_trial.params.items():
print(key, value)
print("Best Objective Value:", study.best_value)
# Train a model using the best trial's parameters
best_params = study.best_params
best_model = BertClassifier(dropout=best_params['dropout'])
train(best_model, df_train, df_val, tokenizer, best_params['learning_rate'], best_params['epochs'], best_params['batch_size'])
# Save the trained model
model_filename = '/usr/not-backed-up/el18saj/Code/SavedModels/BERT_model_1.pth' # Change this path to where you want to save the model
torch.save(best_model.state_dict(), model_filename)
# Evaluate the best model on the test dataset
test_dataset = Dataset(df_test, tokenizer)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=8)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
if use_cuda:
best_model = best_model.to(device)
misclassified_ids, test_accuracy = evaluate(best_model, df_test, tokenizer, return_misclassified=True)
print(f"Test Accuracy using Best Model: {test_accuracy:.4f}")
print(misclassified_ids)
# Generate confusion matrix
conf_matrix = generate_confusion_matrix(best_model, test_dataloader, device)
# Print the confusion matrix
print("Confusion Matrix:")
for row in conf_matrix:
print(' '.join([str(value) for value in row]))
sys.stdout = original_stdout
if __name__ == "__main__":
main()