0

I am using a pre-trained BERT model to classify ASR-generated transcript segments and am currently using Optuna to identify the optimal hyperparameters. I wish to modify this code to use cross-validation whilst using Optuna to find the best hyperparameters and evaluate the "best" BERT model on a hold-out test dataset. I need to ensure that all segments belonging to a particular Partcipant_ID are grouped together in the same fold, to prevent data leakage. I am not entirely sure how to modify the code below to achieve this. Can anyone advise how to go about this?

class Label(Enum):
    PC = 0
    BT = 1

class Dataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer):
        self.participantIDs = df['Participant_ID']
        self.labels = [Label[label] for label in df['Diagnosis']]
        self.texts = [tokenizer.encode_plus(text, padding='max_length', max_length=64, truncation=True, return_tensors="pt") for text in df['Segment_Words']]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        batch_texts = self.texts[idx]
        batch_y = self.labels[idx].value 
        return batch_texts['input_ids'].squeeze(), batch_texts['attention_mask'].squeeze(), batch_y

class BertClassifier(nn.Module):
    def __init__(self, dropout=0.1):
        super(BertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 2)  # Assuming 2 classes

    def forward(self, input_ids=None, attention_mask=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

def train_epoch(model, dataloader, criterion, optimizer, device, mode):
    total_loss = 0
    total_correct = 0
    total_samples = 0

    model.train() if mode == 'train' else model.eval()

    with torch.set_grad_enabled(mode == 'train'):
        for input_ids, attention_mask, labels in tqdm(dataloader):
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            
            if mode == 'train':
                loss.backward()
                optimizer.step()

            total_loss += loss.item()
            predictions = outputs.argmax(dim=1)
            total_correct += (predictions == labels).sum().item()
            total_samples += len(labels)

    avg_loss = total_loss / len(dataloader)
    avg_accuracy = total_correct / total_samples

    return avg_loss, avg_accuracy

def train(model, train_data, val_data, tokenizer, learning_rate, epochs, batch_size):
    train_dataset = Dataset(train_data, tokenizer)
    val_dataset = Dataset(val_data, tokenizer)

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=learning_rate)

    if use_cuda:
        model = model.to(device)
        criterion = criterion.to(device)

    for epoch_num in range(epochs):
        train_loss, train_accuracy = train_epoch(model, train_dataloader, criterion, optimizer, device, 'train')
        val_loss, val_accuracy = train_epoch(model, val_dataloader, criterion, optimizer, device, 'val')

        print(f'Epochs: {epoch_num + 1} | Train Loss: {train_loss:.4f} | Train Accuracy: {train_accuracy:.4f}')
        print(f'Epochs: {epoch_num + 1} | Val Loss: {val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}')

def evaluate(model, test_data, tokenizer, return_misclassified=False):
    test_dataset = Dataset(test_data, tokenizer)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=8)
    
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    
    misclassified = [] if return_misclassified else None
    total_correct = 0
    total_samples = 0

    model.eval()  # Set model to evaluation mode

    with torch.no_grad():
        for batch_index, batch in enumerate(test_dataloader):
            input_ids, attention_mask, test_labels = [item.to(device) for item in batch]
            output = model(input_ids, attention_mask)

            predictions = output.argmax(dim=1)
            correct = (predictions == test_labels).sum().item()
            total_correct += correct
            total_samples += len(test_labels)

            if return_misclassified:
                incorrect_indices = (predictions != test_labels).nonzero().flatten()

                for idx_in_batch in incorrect_indices:
                    global_idx = batch_index * test_dataloader.batch_size + idx_in_batch.item()
                    participant_id = test_data.iloc[global_idx]['Participant_ID']
                    segment_number = idx_in_batch.item() + 1
                    misclassified.append(f"{participant_id}_{segment_number}")

    accuracy = total_correct / total_samples
    print(f'Test Accuracy: {accuracy:.4f}')

    if return_misclassified:
        accuracy_with_misclassified = 1 - (len(misclassified) / len(test_data))
        return misclassified, accuracy_with_misclassified
    else:
        return accuracy

def objective(trial, train_data, val_data, test_data, tokenizer):

    learning_rate = trial.suggest_float('learning_rate', 1e-6, 1e-3, log=True)
    dropout = trial.suggest_float('dropout', 0.1, 0.9)
    batch_size = trial.suggest_int('batch_size', 4, 64)
    epochs = trial.suggest_int('epochs', 5, 30)

    model = BertClassifier(dropout=dropout)
    train(model, train_data, val_data, tokenizer, learning_rate, epochs, batch_size)
    return evaluate(model, test_data, tokenizer)  

def perform_group_shuffle_split(data_frame, test_size, val_size):
    gss = GroupShuffleSplit(n_splits=1, test_size=test_size)
    idx1, idx2 = next(gss.split(data_frame, groups=data_frame.Participant_ID))
    train_data_df, test_data_df = data_frame.iloc[idx1], data_frame.iloc[idx2]

    gss2 = GroupShuffleSplit(n_splits=1, test_size=val_size / (1 - test_size))
    idx3, idx4 = next(gss2.split(train_data_df, groups=train_data_df.Participant_ID))
    train_data_df, val_data_df = train_data_df.iloc[idx3], train_data_df.iloc[idx4]

    return train_data_df, val_data_df, test_data_df

def generate_confusion_matrix(model, test_dataloader, device):
    true_labels = np.array([])
    predicted_labels = np.array([])

    model.eval()
    with torch.no_grad():
        for batch in test_dataloader:
            input_ids, attention_mask, test_labels = [item.to(device) for item in batch]
            outputs = model(input_ids, attention_mask)
            predictions = outputs.argmax(dim=1)

            true_labels = np.concatenate((true_labels, test_labels.cpu().numpy()))
            predicted_labels = np.concatenate((predicted_labels, predictions.cpu().numpy()))

    conf_matrix = confusion_matrix(true_labels, predicted_labels)
    return conf_matrix

def plot_confusion_matrix(conf_matrix, class_labels):
    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_labels, yticklabels=class_labels)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

def main():
    
    original_stdout = sys.stdout
    with open('/usr/not-backed-up/el18saj/data/trial_output3.txt', 'w') as f:
        sys.stdout = f

        transcript_data = '/usr/not-backed-up/el18saj/data/Transcripts/WhisperX_Transcripts/SegmentLevelTranscripts/NoPauses/3_segments/Initial_Prompt_True/Zoom/WhisperX_InitialPromptTrue_ZoomID_3_seconds.csv'
        print(transcript_data)   
        print("Learning_Rate_Range: 1e-6 to 1e-3, Dropout_Range: 0.1-0.9, Batch_Size_Range: 4-64, Batch_Size: 5-30")
        transcript_df = pd.read_csv(transcript_data)
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

        df_train, df_val, df_test = perform_group_shuffle_split(transcript_df, test_size=0.2, val_size=0.125)

        print("Participant IDs in training dataset:", ', '.join(map(str, df_train['Participant_ID'].unique())))
        print("Participant IDs in validation dataset:", ', '.join(map(str, df_val['Participant_ID'].unique())))
        print("Participant IDs in test dataset:", ', '.join(map(str, df_test['Participant_ID'].unique())))

        study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=42))
        study.optimize(lambda trial: objective(trial, df_train, df_val, df_test, tokenizer), n_trials=50)
            
        # Print results for each trial
        for trial_number, trial in enumerate(study.trials, start=1):
            trial_params = trial.params
            trial_value = trial.value
            print(f"Trial {trial_number} finished with value: {trial_value} and parameters: {trial_params}")

        # Print the parameters for the best trial
        print("Parameters for the Best Trial:")
        for key, value in study.best_trial.params.items():
            print(key, value)
        print("Best Objective Value:", study.best_value)  

         # Train a model using the best trial's parameters
        best_params = study.best_params
        best_model = BertClassifier(dropout=best_params['dropout'])
        train(best_model, df_train, df_val, tokenizer, best_params['learning_rate'], best_params['epochs'], best_params['batch_size'])
        
        # Save the trained model
        model_filename = '/usr/not-backed-up/el18saj/Code/SavedModels/BERT_model_1.pth'  # Change this path to where you want to save the model
        torch.save(best_model.state_dict(), model_filename)

        # Evaluate the best model on the test dataset
        test_dataset = Dataset(df_test, tokenizer)
        test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=8)
        use_cuda = torch.cuda.is_available()
        device = torch.device("cuda" if use_cuda else "cpu")
        if use_cuda:
            best_model = best_model.to(device)

        misclassified_ids, test_accuracy = evaluate(best_model, df_test, tokenizer, return_misclassified=True)
        print(f"Test Accuracy using Best Model: {test_accuracy:.4f}")
        print(misclassified_ids)

        # Generate confusion matrix
        conf_matrix = generate_confusion_matrix(best_model, test_dataloader, device)

        # Print the confusion matrix
        print("Confusion Matrix:")
        for row in conf_matrix:
            print(' '.join([str(value) for value in row]))
        
        sys.stdout = original_stdout

if __name__ == "__main__":
    main()




   
csStudent2102
  • 75
  • 1
  • 7

0 Answers0