0

I have an A100 (Colab Pro) with 40GB GPU memory and want to fine-tune an LLM utilizing the GPU's full capacity.

When I increase per_device_train_batch_size argument in Trainer's TrainingArguments to anything other than 1, I receive an error:

RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

per_device_train_batch_size=1 works perfectly fine. But GPU memory utilization is about 7.4GB out of 40GB during fine-tuning with batch size of 1, so I do not believe that this is an OOM issue. When set to 2 or anything above, it breaks. I also tried setting auto_find_batch_size=True, still breaks with the same error.

  1. Is this the correct way of trying to maximize GPU utilization?
  2. Why am I receiving the error?

My code is located on Colab here. This is the gist of my code:

Loading model with BitsAndBytes

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_type=torch.bfloat16
)
model = LlamaForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map={'':0},
    quantization_config=bnb_config
)
tokenizer = LlamaTokenizer.from_pretrained(MODEL_NAME)

Checkpointing, Kbit training, Lora

lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj"],
    lora_dropout=0.05,
    bias='none',
    task_type='CAUSAL_LM'
)
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

Trainer

trainer = transformers.Trainer(
    model=model,
    train_dataset=data['train'],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=16,
        auto_find_batch_size=True,
        gradient_accumulation_steps=4,
        warmup_steps=30,
        num_train_epochs=1,
        learning_rate=2e-4,
        fp16=True,
        output_dir='./output',
        logging_steps=1,
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False

0 Answers0