0

I am trying to do inference with a GPT2-like model on a large dataset (26k samples). To speed it up I would like to do it in batches, but trying this it goes in Cuda OOM after some batches. The fact that it goes out only after some batches sounds strange to me, because I suppose the memory use should be more or less constant in different batches. This is my code:

tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

sentences = ["<START_TOK>" + s + "<END_TOK>" + tokenizer.eos_token for s in sentences]

inputs = tokenizer(sentences, return_tensors="pt", padding=True, max_length=1024, truncation=True)

device = torch.device("cuda:0")
inputs = inputs.to(device)
model = model.to(device)
model.eval()
res = []
with torch.no_grad():
    output_sequences = model.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=1024,
            pad_token_id=tokenizer.eos_token_id,
            no_repeat_ngram_size=2,
            do_sample=True,
            top_k=100,
            top_p=0.9,
            temperature=0.85
        )
     output_sequences = output_sequences.cpu() #not really sure this is useful, just tried, but the problem remained
     for i in range(len(sentences)):
         res.append(tokenizer.decode(output_sequences[i]))
model.train()
return res

What could be the problem?

talonmies
  • 70,661
  • 34
  • 192
  • 269
Vitto
  • 361
  • 3
  • 17
  • check the max seq len of each batch, and see if your gpu is crashing at the longest – KonstantinosKokos May 03 '21 at 08:21
  • @KonstantinosKokos Yes, checking the longest sentence in each batch, when it crashes it's always at a batch which contains the longest sentence so far (how long it is depends on batch size), but the point is that the tokenizer should make all them fixed to 1024, and this is what should enter in the GPU memory, am I wrong? – Vitto May 03 '21 at 12:13
  • I *assume* the tokenizer pads to the max sequence length, rather than 1024. – KonstantinosKokos May 03 '21 at 15:26

0 Answers0