I am trying to load a GPT2 fine tuned model in flask initially. The model is being loaded during the init functions using:
app.modelgpt2 = torch.load('models/model_gpt2.pt', map_location=torch.device('cpu'))
app.modelgpt2tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
But while performing the prediction task as followed in the snippet below:
from flask import current_app
input_ids = current_app.modelgpt2tokenizer.encode("sample sentence here", return_tensors='pt')
sample_outputs = current_app.modelgpt2.generate(input_ids,
do_sample=True,
top_k=50,
min_length=30,
max_length=300,
top_p=0.95,
temperature=0.7,
num_return_sequences=1)
It throws the following error as mentioned in the question: AttributeError: 'GPT2Model' object has no attribute 'gradient_checkpointing'
The error trace is listed starting from the
model.generate
function: File "/venv/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context return func(*args, **kwargs)
File "/venv/lib/python3.8/site-packages/transformers/generation_utils.py", line 1017, in generate return self.sample(
File "/venv/lib/python3.8/site-packages/transformers/generation_utils.py", line 1531, in sample outputs = self(
File "/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs)
File "/venv/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1044, in forward transformer_outputs = self.transformer(
File "/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs)
File "/venv/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 861, in forward print(self.gradient_checkpointing)
File "/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1177, in getattr raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'GPT2Model' object has no attribute 'gradient_checkpointing'
Checked with modeling_gpt2.py
, by default self.gradient_checkpointing
is set False
in the constructor of the class.