4

I have registered a scikit learn model on my MLflow Tracking server, and I am loading it with sklearn.load_model(model_uri).

Now, I would like to access the signature of the model so I can get a list of the model's required inputs/features so I can retrieve them from my feature store by name. I can't seem to find any utility or method in the mlflow API or the MLFlowClient API that will let me access a signature or inputs/outputs attribute, even though I can see a list of inputs and outputs under each version of the model in the UI.

I know that I can find the input sample and the model configuration in the model's artifacts, but that would require me actually downloading the artifacts and loading them manually in my script. I don't need to avoid that, but I am surprised that I can't just return the signature as a dictionary the same way I can return a run's parameters or metrics.

Mike
  • 444
  • 1
  • 8
  • 19

2 Answers2

4

The way to access the model's signature without downloading the MLModel file is under the loaded model. And then you'll access the model's attributes, such as its signature or even other Pyfunc-defined methods.

import mlflow

model = mlflow.pyfunc.load_model("runs:/<run_id>/model")
print(model._model_meta._signature)
  • Thank you, I'll test this when I get a chance and accept your answer! – Mike Mar 02 '22 at 18:20
  • Does mlflow signature ensures same column order for model training and can be used for production? – Shubh Apr 11 '22 at 11:25
  • 1
    @Shubh I haven't tested that, never got around to using mlflow in my project. However, since column order is so important for all downstream tasks after retrieving the signature I would be surprised if the signature doesn't preserve it. My concern would be that python dictionaries don't preserver order, so I supposed a bug that messes up the order could easily creep in to mlflow's source code. – Mike Sep 13 '22 at 17:48
1

A more canonical way (not referring to specific runs and model type like pyfunc) is to use mlflow.models.get_model_info

client = MlflowClient(mlflow.get_tracking_uri()

model_uri = client.get_model_version_download_uri('toy-model','10')
model_info = mlflow.models.get_model_info(model_uri)
model_info._signature_dict
Maciej Skorski
  • 2,303
  • 6
  • 14