-1

I am trying to use pytorch in mlflow. Currently, predict method allows data that is only of pd.DataFrame or np.ndarray type, is there a way to override this and write a custom predict method without writing a completely new loader_module?

The source code for the predict method can be found here -> pytorch

Progman
  • 16,827
  • 6
  • 33
  • 48
  • Welcome to Stack Overflow. Please take the [tour] to learn how Stack Overflow works and read [ask] on how to improve the quality of your question. Then [edit] your question to include your source code as a [mcve], which can be tested by others. Also add your code to the question itself, not on an external site. – Progman Jul 23 '21 at 12:00

1 Answers1

1

You cannot override the defined loader_module, instead you can use mlflow custom model definition for defining a custom pytorch functionalities within the model:

# Define the model class
import mlflow.pyfunc
class CustomPytorchWrapper(mlflow.pyfunc.PythonModel):

    def load_context(self, context):
        # define model loading method 

    def predict(self, context, model_input):
        # define your custom method to predict
tRex002
  • 41
  • 2