I'm developing a project in PyTorch which requires returning the weights (state_dict) of a pytorch model as a Flask endpoint response. To better explain it, the simplest code could be:
@app.endpoint('/send_weights', methods=['GET', 'POST'])
def send_weights():
model_weights = model.state_dict() # It is a dict[str, torch.tensor]
return model_weights
However it is not as simple because the torch.tensor is not JSON serializable, so, I've tried to convert them to a list (JSON serializable object) and it works:
@app.endpoint('/send_weights', methods=['GET', 'POST'])
def send_weights():
model_weights = model.state_dict() # It is a dict[str, torch.tensor]
model_weights = {k:v.tolist() for k,v in model_weights.items()}
return model_weights
However this process is very slow and it doesn't meet my requirements. I was trying to convert the tensors to bytes but the code gives the same problem, bytes is not JSON serializable. So, I'm thinking that the json response won't be the solution. I'm not an expert in Flask but I've read about the flask send_file method, however I not sure how to use it in this case (not even sure this would be a possible solution), I haven't got a mimetype for the dictionary.
Does anybody know a better way to do this?