I try to translate a dataframe from English to Persian. I use a pertrained langauge model for this purpose, but it's slow, how can I speed it up?
model_size = "base"
model_name = f"persiannlp/mt5-{model_size}-parsinlu-translation_en_fa"
tokenizer = MT5Tokenizer.from_pretrained(model_name)
model = MT5ForConditionalGeneration.from_pretrained(model_name)
with open(fname, 'w') as tsvfile:
writer = csv.writer(tsvfile, delimiter='\t')
#writer.writerow(["prefix", "input_text", "target_text"])
writer.writerow(["input_text", "target_text"])
for i, row in tqdm(df.iterrows(), total=maxlen):
prompt = row["input_text"]
target = row["target_text"]
if trans:
input_ids = tokenizer.encode(prompt, return_tensors="pt")
res = model.generate(input_ids)
prompt = tokenizer.batch_decode(res, skip_special_tokens=True)
input_ids = tokenizer.encode(target, return_tensors="pt")
res = model.generate(input_ids)
target = tokenizer.batch_decode(res, skip_special_tokens=True)
writer.writerow([prompt, target])
print("saved in ", fname)