0

I'm trying to write a filter_words function in pandas_udf

Here are the functions I am using:

   @udf_annotator(returnType=ArrayType(StructType([StructField("position", IntegerType(), True),
                                           StructField("tokens", StringType(), True)])))
    def position_words(tokens):
        position = [(int(i), token) for i, token in enumerate(tokens)]
        return position
    
    @pandas_udf(returnType=ArrayType(StructType([StructField("position", IntegerType(), True),
                                                    StructField("word", StringType(), True)])))
    def filter_words(lst2, lang2):
        def filter_word2(lst, lang):
            filtered_tokens = []
            for pos, word in lst:
                if word is None: continue
                if len(word) == 0: continue
                text = re.sub(
                    r"((https?|ftps?|file)?:\/\/)?(?:[\w\d!#$&'()*\+,:;=?@[\]\-_.~]|(?:%[0-9a-fA-F][0-9a-fA-F]))+" +
                    "\\.([\\w\\d]{2,6})(\\/(?:[\\w\\d!#$&'()*\\+,:;=?@[\\]\\-_.~]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)*",
                    "", word)
                text = re.sub(r"[@#]\w+", "", text)
                text = re.sub(r"'", " ", text)
                word_filtered = re.findall(r"""(?:\w\.)+\w\.?|\w{2,}-?\w*-*\w*""", text)
                word_filtered = " ".join(word_filtered)
                filtered_tokens.append((pos, word_filtered))
            return filtered_tokens
        all_founded_result = [filter_word2(lst, lang) for lst, lang in zip(lst2, lang2)]
        return pd.Series(all_founded_result)

Here I create an example of a dataframe on which I call functions

import random
langs = ['eng', 'rus', 'tuk', 'jpn', 'arb', 'fin', 'fra', 'cmn']

def random_text(length):
    return ''.join(random.choice('sdfsdfg jkhkhkj jh kh') for _ in range(length))

df = pd.DataFrame({'text': [random_text(10) for _ in range(100000)],
                       'lang': [random.choice(langs) for _ in range(100000)]})
sdf = spark.createDataFrame(df).withColumn('tokens', F.split(F.col('text')))\
  .withColumn("position", position_words(F.col("tokens")))\
  .withColumn("position_filt", filter_words(F.col("position"), F.col("lang")))

but unfortunately I get an error:

pyarrow.lib.ArrowInvalid: Could not convert 'position' with type str: tried to convert to int32

I would like to keep the filter_words function as pandas_udf

Rory
  • 471
  • 2
  • 11

1 Answers1

1

The error you're encountering is due to the fact that you're passing a column (F.col("position")) to the filter_words function, which expects a pandas DataFrame or Series. The pandas_udf decorator expects the UDF to be compatible with pandas operations, but passing a Spark column breaks that compatibility. To resolve this issue, you can convert the Spark DataFrame column to a pandas Series before passing it to the filter_words function. Here's an updated version of your code:

python
import random
import re
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import ArrayType, IntegerType, StringType, StructType, StructField
import pandas as pd

langs = ['eng', 'rus', 'tuk', 'jpn', 'arb', 'fin', 'fra', 'cmn']

def random_text(length):
    return ''.join(random.choice('sdfsdfg jkhkhkj jh kh') for _ in range(length))

@pandas_udf(returnType=ArrayType(StructType([StructField("position", IntegerType(), True),
                                            StructField("tokens", StringType(), True)])))
def position_words(tokens):
    position = [(int(i), token) for i, token in enumerate(tokens)]
    return pd.Series(position)

def filter_word2(lst, lang):
    filtered_tokens = []
    for pos, word in lst:
        if word is None: continue
        if len(word) == 0: continue
        text = re.sub(
            r"((https?|ftps?|file)?:\/\/)?(?:[\w\d!#$&'()*\+,:;=?@[\]\-_.~]|(?:%[0-9a-fA-F][0-9a-fA-F]))+" +
            "\\.([\\w\\d]{2,6})(\\/(?:[\\w\\d!#$&'()*\\+,:;=?@[\\]\\-_.~]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)*",
            "", word)
        text = re.sub(r"[@#]\w+", "", text)
        text = re.sub(r"'", " ", text)
        word_filtered = re.findall(r"""(?:\w\.)+\w\.?|\w{2,}-?\w*-*\w*""", text)
        word_filtered = " ".join(word_filtered)
        filtered_tokens.append((pos, word_filtered))
    return filtered_tokens

@pandas_udf(returnType=ArrayType(StructType([StructField("position", IntegerType(), True),
                                            StructField("word", StringType(), True)])))
def filter_words(lst2, lang2):
    all_founded_result = [filter_word2(lst, lang) for lst, lang in zip(lst2, lang2)]
    return pd.Series(all_founded_result)

df = pd.DataFrame({'text': [random_text(10) for _ in range(100000)],
                   'lang': [random.choice(langs) for _ in range(100000)]})

sdf = spark.createDataFrame(df).withColumn('tokens', F.split(F.col('text'))) \
    .withColumn("position", position_words(F.col("tokens")))

# Convert the 'position' column to a pandas Series
sdf = sdf.toPandas()
sdf['position_filt'] = filter_words(sdf['position'], sdf['lang'])
sdf = spark.createDataFrame(sdf)

# Output the resulting dataframe
sdf.show()```

In the updated code, I removed the @pandas_udf decorator from the position_words function and defined the filter_word2 function