Unable to infer a Spark ML Pipeline model built using Custom Transformers/Estimators. I had some custom requirements to transform a raw data. Those custom operations were not in the pyspark.ml
module. In order to facilitate these operations, I created a custom transformer which was extending the Estimators, HasInput, HasOutput, MLWritable and MLReadable classes i.e.,
from pyspark.ml.pipeline import Transformer, Estimator
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
I was able to tune it using hyperOpt and train-evaluate on the whole data.I also logged the model within MLflow and also using pyspark.ml.PipelineModel.write
.
However, when I tried to load the pipeline model for inferring, it was failing due the custom stage __init__
method.
I am not able to understand why on loading the model the constructor method is called if the values are already fitted within the object.
Here's the initial part of the custom transformer:
class CategoricalCleaner(Estimator, MLWritable, MLReadable, HasInputCol, HasOutputCol):
"""
Custom transformer to encode binary and handle high cardinality categorical columns. High cardinality columns are handled by selecting top n frequent values for fraudulent transactions
and replacing all other values with "Other_colname".
"""
def __init__(self, inputCols=None, outputCols=None, top_n=10):
"""
Args:
inputCols (List[str]): List of input column names to encode.
outputCols (List[str]): List of output column names for encoded columns. Default is to use the input column names with "_freq_enc" appended.
top_n (int): Number of top values to select for each input column. Default is 10.
"""
super(CategoricalCleaner, self).__init__()
self.inputCols = inputCols
self.outputCols = outputCols or [col + '_freq_enc' for col in inputCols]
self.top_n = top_n
# Dictionary to store top n frequent values for each input column
self.freq_values_ = {}
def _fit(self, dataset):
"""
Fit the transformer on the input dataset to determine top n frequent values for each input column.
Args:
dataset (pyspark.sql.DataFrame): Input dataset containing the columns to be encoded.
Returns:
self
"""
# Loop over input columns and compute top n frequent values for fraud transactions
for inputCol, outputCol in zip(self.inputCols, self.outputCols):
# Select top n frequent values for fraud transactions
top_values = (dataset
.filter(dataset.is_fraud == 1)
.groupBy(inputCol)
.agg(F.count('*').alias('count'))
.orderBy(F.desc('count'))
.limit(self.top_n)
.select(inputCol)
.rdd.flatMap(lambda x: x)
.collect()
)
self.freq_values_[inputCol] = top_values
return self
-----------------------------------------TRUNCATED-----------------------------------------
Here's the screenshot of the error I'm facing:
**If there's anyone who has worked on this kind of development. Please help! It would be great if someone can share some working examples to do that.
Thank You **