I have a Flask app running on Google Cloud Run, which needs to download a large model (GPT-2 from huggingface). This takes a while to download, so I am trying to set up so that it only downloads on deployment and then just serves this up for subsequent visits. That is I have the following code in a script that is imported by my main flask app app.py:
import torch
# from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AutoTokenizer, AutoModelWithLMHead
# Disable gradient calculation - Useful for inference
torch.set_grad_enabled(False)
# Check if gpu or cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load tokenizer and model
try:
tokenizer = AutoTokenizer.from_pretrained("./gpt2-xl")
model = AutoModelWithLMHead.from_pretrained("./gpt2-xl")
except Exception as e:
print('no model found! Downloading....')
AutoTokenizer.from_pretrained('gpt2').save_pretrained('./gpt2-xl')
AutoModelWithLMHead.from_pretrained('gpt2').save_pretrained('./gpt2-xl')
tokenizer = AutoTokenizer.from_pretrained("./gpt2-xl")
model = AutoModelWithLMHead.from_pretrained("./gpt2-xl")
model = model.to(device)
This basically tries to load the the downloaded model, and if that fails it downloads a new copy of the model. I have autoscaling set to a minimum of 1 which I thought would mean something would always be running and therefore the downloaded file would persist even after activity. But it keeps having to redownload the model which freezes up the app when some people try to use it. I am trying to recreate something like this app https://text-generator-gpt2-app-6q7gvhilqq-lz.a.run.app/ which does not appear to have the same load time issue . In the flask app itself I have the following:
@app.route('/')
@cross_origin()
def index():
prompt = wp[random.randint(0, len(wp)-1)]
res = generate(prompt, size=75)
generated = res.split(prompt)[-1] + '\n \n...TO BE CONTINUED'
#generated = prompt
return flask.render_template('main.html', prompt = prompt, output = generated)
if __name__ == "__main__":
app.run(host='0.0.0.0',
debug=True,
port=PORT)
But it seems to redownload the models every few hours...how can I avoid having the app re-downloading the models and the app freezing for those who want to try it?