0

Currently I am trying write vectorized UDF function in snowpark

import pandas as pd
from snowflake.snowpark.functions import pandas_udf
from snowflake.snowpark.types import StringType

@pandas_udf(  
  name='EXERCISE_CO2_VS_TEMPERATURE.GLOBAL_TEMPERATURES.GET_REVIEW_CLASSIFICATION',
  session=new_session,
  is_permanent=True,
  replace=True,
  imports=[
      '@EXERCISE_CO2_VS_TEMPERATURE.GLOBAL_TEMPERATURES.ZERO_SHOT_CLASSIFICATION/bart-large-mnli.joblib'
  ],
  input_types=[StringType()],
  return_type=[StringType()],
  stage_location='@EXERCISE_CO2_VS_TEMPERATURE.GLOBAL_TEMPERATURES.ZERO_SHOT_CLASSIFICATION',
  packages=['cachetools==4.2.2', 'transformers==4.14.1']
)
def get_review_classification(sentences: pd.Series) -> pd.Series:
  # Classify using the available categories
  candidate_labels = ['customer support', 'product experience', 'account issues']
  classifier = read_model()

  # Apply the model
  predictions = []
  for sentence in sentences:
      result = classifier(sentence, candidate_labels)
      if 'scores' in result and 'labels' in result:
          category_idx = pd.Series(result['scores']).idxmax()
          predictions.append(result['labels'][category_idx])
      else:
          predictions.append(None)
  return pd.Series(predictions)

But currently facing TypeError in the get_review_classification function

Cell In[55], line 17
      1 from snowflake.snowpark.functions import pandas_udf
      2 from snowflake.snowpark.types import StringType
      4 @pandas_udf(  
      5     name='EXERCISE_CO2_VS_TEMPERATURE.GLOBAL_TEMPERATURES.GET_REVIEW_CLASSIFICATION',
      6     session=new_session,
      7     is_permanent=True,
      8     replace=True,
      9     imports=[
     10         '@EXERCISE_CO2_VS_TEMPERATURE.GLOBAL_TEMPERATURES.ZERO_SHOT_CLASSIFICATION/bart-large-mnli.joblib'
     11     ],
     12     input_types=[StringType()],
     13     return_type=[StringType()],
     14     stage_location='@EXERCISE_CO2_VS_TEMPERATURE.GLOBAL_TEMPERATURES.ZERO_SHOT_CLASSIFICATION',
     15     packages=['cachetools==4.2.2', 'transformers==4.14.1']
.
.
.
.
TypeError: invalid type 

Referring this medium article : https://medium.com/snowflake/deploying-pre-trained-llms-in-snowflake-75a0d07ef03d

Can anyone please advice on this error

markalex
  • 8,623
  • 2
  • 7
  • 32
Sam777
  • 15
  • 6

1 Answers1

0

You are using as input_types and return_type this:

input_types=[StringType()],
return_type=[StringType()],

while the article mentions you should use PandasSeriesType as input/return:

input_types=[PandasSeriesType(StringType())],
return_type=PandasSeriesType(StringType()),
Sergiu
  • 4,039
  • 1
  • 13
  • 21