2

I am trying to use fine-tune TransformerXL for language modeling.

from transformers import TransfoXLTokenizer, TransfoXLModel

tokenizer = TransfoXLTokenizer.from_pretrained("transfo-xl-wt103")
model = TransfoXLModel.from_pretrained("transfo-xl-wt103")

I have a .csv file with only one column ('text') that contains paragraphs.

from sklearn.model_selection import train_test_split

df_train, df_test = train_test_split(df, test_size=0.2)
df_train.to_csv('train.csv',index=False)
df_test.to_csv('test.csv',index=False)
from datasets import load_dataset
dataset = load_dataset("csv", data_files={"train": "train.csv", "test": "test.csv"})
encoded_dataset = dataset.map(lambda t: tokenizer(t['text'],  truncation=True, padding='max_length'), batched=True, load_from_cache_file=False)
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False, mlm_probability=0.15
)
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=500,
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    data_collator=data_collator,
)
trainer.train()

The last line fails with error:

forward() got an unexpected keyword argument 'labels'

I have no idea where the argument 'labels' comes from. The encoded dataset has the following shape:

DatasetDict({
    train: Dataset({
        features: ['text', 'input_ids'],
        num_rows: 18
    })
    test: Dataset({
        features: ['text', 'input_ids'],
        num_rows: 5
    })
})
elenata24
  • 21
  • 2
  • 1
    The labels are created from your collate function `DataCollatorForLanguageModeling` and are the targets of your training. The class `TransfoXLModel` doesn't accept `labels`, maybe you can try `TransfoXLLMHeadModel`. – cronoik Mar 13 '23 at 19:56

0 Answers0