0

I'm attempting to utilize a generative model (Llama2) for a binary classification task and aim to obtain the positive score, which represents the confidence level for the positive label.

I tried to use compute_transition_scores but not sure how can I get the confidence between 0-1 correctly.

Here is my current code:

model = AutoModelForCausalLM.from_pretrained(
    peft_config.base_model_name_or_path,
    # quantization_config=bnb_config,
    torch_dtype='auto',
    device_map='auto',
    offload_folder="offload", offload_state_dict = True
)
pos_scores = []
input_ids = tokenizer(test_sample, return_tensors="pt").input_ids
tokens_for_summary = 1
output_tokens = input_ids.shape[1] + tokens_for_summary

outputs = model.generate(inputs=input_ids, do_sample=False, max_length=output_tokens, pad_token_id=tokenizer.eos_token_id, 
                         output_scores=True, return_dict_in_generate=True)
score = float(torch.exp(model.compute_transition_scores(outputs.sequences, outputs.scores)).cpu())
        
if pred_label == 1:
   pos_scores.append(score)
elif pred_label == 0:
   pos_scores.append(-1 * score) # reverse the sign of all samples for which the prediction was 0.

However, I'm obtaining high values. I've considered using the sigmoid function, but I'm not entirely certain if this is the correct approach.

How should I do that? Thank you!

Ofir
  • 590
  • 9
  • 19
  • Which model are you trying to use? What is the ultimate purpose of the binary classification? Is there a sample input and expected output you can give? Is there a reason to use CasualLM class instead of `from transformers import AutoConfig, AutoModelForSequenceClassification`? – alvas Aug 12 '23 at 10:45
  • @alvas I'm trying to do binary classification with Llama2. Yes, there is sample input and expected output and the generative model completes one token (1 for positive label, 0 for negative) – Ofir Aug 12 '23 at 10:47
  • 1
    How is `peft_config` initialized? and is there a reason you're not using `AutoModelForSequenceClassification`? – alvas Aug 12 '23 at 10:50
  • Also, can you post an example input for `test_sample`? – alvas Aug 12 '23 at 10:50

0 Answers0