0

I'm developing a project in PyTorch which requires returning the weights (state_dict) of a pytorch model as a Flask endpoint response. To better explain it, the simplest code could be:

@app.endpoint('/send_weights', methods=['GET', 'POST'])
def send_weights():
    model_weights = model.state_dict() # It is a dict[str, torch.tensor]
    return model_weights

However it is not as simple because the torch.tensor is not JSON serializable, so, I've tried to convert them to a list (JSON serializable object) and it works:

@app.endpoint('/send_weights', methods=['GET', 'POST'])
def send_weights():
    model_weights = model.state_dict() # It is a dict[str, torch.tensor]
    model_weights = {k:v.tolist() for k,v in model_weights.items()}
    return model_weights

However this process is very slow and it doesn't meet my requirements. I was trying to convert the tensors to bytes but the code gives the same problem, bytes is not JSON serializable. So, I'm thinking that the json response won't be the solution. I'm not an expert in Flask but I've read about the flask send_file method, however I not sure how to use it in this case (not even sure this would be a possible solution), I haven't got a mimetype for the dictionary.

Does anybody know a better way to do this?

amlarraz
  • 60
  • 6
  • Can you clarify the goal of your project? Sending weights as a response like this for any reasonable size model will be slow and can produce fairly large files if you convert to json. How will the weights be used and what's the workflow? – dbish Aug 07 '22 at 19:35
  • One option I've used for larger files is to just write to S3 or another file store and send the user the link in a more async way. – dbish Aug 07 '22 at 20:02
  • The proyect is a kind of federated training so, many nodes will sharing nodes between them a lot of times. Because that I need some way to send weights fast and confiable. I've done that using http post (sending weights as files) but it doesn't meet my requirements, I need to return the weights. Any thoughts? – amlarraz Aug 08 '22 at 17:29

1 Answers1

0

I've just found a solution. It is based on return the weights as a file using the send_file method from flask with the weights saved on a binary using torch.save and the mimetype 'application/octet-stream' as appears in this question. The final endopint code will be:

@app.endpoint('/send_weights', methods=['GET', 'POST']) 
def send_weights():
    model_weights = model.state_dict()
    to_send = io.BytesIO()
    torch.save(model_weights, to_send, _use_new_zipfile_serialization=False)
    to_send.seek(0)
    return send_file(to_send, mimetype='application/octet-stream')

And, to load it in the other side:

weights = torch.load(io.BytesIO(response.content))

Hope it will help somebody.

amlarraz
  • 60
  • 6