2

I'm trying to fine-tune my bert-based QA model(PyTorch) with Tpu v3-8 provided by Kaggle. In the validation process I used a ParallelLoader to make predictions on 8 cores at the same time. But after that I don't know what should I do to gather all the results back from each core(and in the correct order corresponding to dataset), in order to calculate the overall EM & F1 score. Can anybody help? Code:

def _run():
    MAX_LEN = 192 # maximum text length in the batch (cannot have too high due to memory constraints)
    BATCH_SIZE = 16 # batch size (cannot have too high due to memory constraints)
    EPOCHS = 2 # number of epochs
    train_sampler = torch.utils.data.distributed.DistributedSampler(
          tokenized_datasets['train'],
          num_replicas=xm.xrt_world_size(), # tell PyTorch how many devices (TPU cores) we are using for training
          rank=xm.get_ordinal(), # tell PyTorch which device (core) we are on currently
          shuffle=True
    )
    
    train_data_loader = torch.utils.data.DataLoader(
        tokenized_datasets['train'],
        batch_size=BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=0,
    )
    
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
          tokenized_datasets['validation'],
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=False
    )
    
    valid_data_loader = torch.utils.data.DataLoader(
        tokenized_datasets['validation'],
        batch_size=BATCH_SIZE,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=0
    )
    
    device = xm.xla_device() # device (single TPU core)
    model = model.to(device) # put model onto the TPU core
    xm.master_print('done loading model')
    xm.master_print(xm.xrt_world_size(),'as size')
    
    lr = 0.5e-5 * xm.xrt_world_size()
    optimizer = AdamW(model.parameters(), lr=lr) # define our optimizer
    
    for epoch in range(EPOCHS):
        gc.collect() 
        # use ParallelLoader (provided by PyTorch XLA) for TPU-core-specific dataloading:
        para_loader = pl.ParallelLoader(train_data_loader, [device]) 
        xm.master_print('parallel loader created... training now')
        gc.collect()
        
        call training loop:
        train_loop_fn(para_loader.per_device_loader(device), model, optimizer, device, scheduler=None)
        del para_loader
        model.eval()
        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        gc.collect()

        model.eval()
        # call evaluation loop
        print("call evaluation loop")
        start_logits, end_logits = eval_loop_fn(para_loader.per_device_loader(device), model, device)
佩特微
  • 21
  • 1

0 Answers0