Currently I am trying write vectorized UDF function in snowpark
import pandas as pd
from snowflake.snowpark.functions import pandas_udf
from snowflake.snowpark.types import StringType
@pandas_udf(
name='EXERCISE_CO2_VS_TEMPERATURE.GLOBAL_TEMPERATURES.GET_REVIEW_CLASSIFICATION',
session=new_session,
is_permanent=True,
replace=True,
imports=[
'@EXERCISE_CO2_VS_TEMPERATURE.GLOBAL_TEMPERATURES.ZERO_SHOT_CLASSIFICATION/bart-large-mnli.joblib'
],
input_types=[StringType()],
return_type=[StringType()],
stage_location='@EXERCISE_CO2_VS_TEMPERATURE.GLOBAL_TEMPERATURES.ZERO_SHOT_CLASSIFICATION',
packages=['cachetools==4.2.2', 'transformers==4.14.1']
)
def get_review_classification(sentences: pd.Series) -> pd.Series:
# Classify using the available categories
candidate_labels = ['customer support', 'product experience', 'account issues']
classifier = read_model()
# Apply the model
predictions = []
for sentence in sentences:
result = classifier(sentence, candidate_labels)
if 'scores' in result and 'labels' in result:
category_idx = pd.Series(result['scores']).idxmax()
predictions.append(result['labels'][category_idx])
else:
predictions.append(None)
return pd.Series(predictions)
But currently facing TypeError in the get_review_classification function
Cell In[55], line 17
1 from snowflake.snowpark.functions import pandas_udf
2 from snowflake.snowpark.types import StringType
4 @pandas_udf(
5 name='EXERCISE_CO2_VS_TEMPERATURE.GLOBAL_TEMPERATURES.GET_REVIEW_CLASSIFICATION',
6 session=new_session,
7 is_permanent=True,
8 replace=True,
9 imports=[
10 '@EXERCISE_CO2_VS_TEMPERATURE.GLOBAL_TEMPERATURES.ZERO_SHOT_CLASSIFICATION/bart-large-mnli.joblib'
11 ],
12 input_types=[StringType()],
13 return_type=[StringType()],
14 stage_location='@EXERCISE_CO2_VS_TEMPERATURE.GLOBAL_TEMPERATURES.ZERO_SHOT_CLASSIFICATION',
15 packages=['cachetools==4.2.2', 'transformers==4.14.1']
.
.
.
.
TypeError: invalid type
Referring this medium article : https://medium.com/snowflake/deploying-pre-trained-llms-in-snowflake-75a0d07ef03d
Can anyone please advice on this error