pyspark code using pandas udf functions , works fine with df.limit(20).collect() & write to csv for 20 records. But when i try write 100 records to csv it fails with java.io.EOFException error. Same code works fine with regular udf functions (not pandas udf). Not clear on how to troubleshoot , looking for guidance
pandas udf to find names using flairnlp model
#Define Flair Function - pandas udf
@pandas_udf("array<string>")
def pu_find_names_flair(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
for textSr in iterator:
sentences = []
for text in textSr:
sentences.append(Sentence(text))
tagger.predict(sentences)
Flairresult = []
for i, sentence in enumerate(sentences):
Flairresultentity = []
# iterate over entities and print each
for entity in sentence.get_spans('ner'):
if entity.get_label("ner").value in ["PER", "LOC", "ORG"]:
Flairresultentity.append((entity.text))
Flairresult.append(Flairresultentity)
yield pd.Series(Flairresult)
Another pandas udf to find & replace strings passed in text
#Mask pii in Call transcripts - pandas udf
@pandas_udf("string")
def pu_mask_all_pii(iterator: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]:
"""
Pandas UDF to remove PII from text.
:param: Iterators for two df columns
:return: return iterator
"""
for text, pii_list in iterator:
result = []
for text, tokens in zip(text, pii_list):
tokens = sorted(tokens,key=len, reverse=True)
pat = '|'.join(map(re.escape, tokens))
text = re.sub(pat, lambda g: 'X' * len(g.group()), str(text), flags=re.IGNORECASE)
result.append(text)
yield pd.Series(result)
pyspark dataframe code
exclude_pii_words = ['@@@@@@@@','#####']
dfs= dfs.limit(100)
dfs=dfs.withColumn("pii_flairner",pu_find_names_flair(col("CALL_TRANSCRIPT")))
dfs = dfs.withColumn('pii_allmethods', array_except(array_distinct(concat(dfs["pii_bertner"],
dfs["pii_regex_spacy"], dfs["pii_flairner"])),array(*map(lit,exclude_pii_words))))
dfs = dfs.withColumn('FULL_TRANSCRIPT', pu_mask_all_pii(col("CALL_TRANSCRIPT"),
col("pii_allmethods")))
dfs = dfs.withColumn('pii_flairner', concat_ws(",",col("pii_flairner")))
dfs = dfs.withColumn('pii_regex_spacy', concat_ws(",",col("pii_regex_spacy")))
dfs = dfs.withColumn('pii_bertner', concat_ws(",",col("pii_bertner")))
dfs = dfs.withColumn('pii_allmethods', concat_ws(",",col("pii_allmethods")))
dfs.limit(20).select('FULL_TRANSCRIPT').collect() # Works fine
dfs.coalesce(1).write.option("header", "true").mode("overwrite").csv("<filename>") #Errors
Exception
23/03/27 00:44:19 WARN TaskSetManager: Lost task 4.0 in stage 4.0 (TID 14) (10.192.99.70 executor 1): org.apache.spark.SparkException: Python worker exited unexpectedly (crashed)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator$$anonfun$1.applyOrElse(PythonRunner.scala:678)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator$$anonfun$1.applyOrElse(PythonRunner.scala:667)
at scala.runtime.AbstractPartialFunction.apply(AbstractPartialFunction.scala:38)
at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:107)
at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:50)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:595)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:489)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:757)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
at org.apache.spark.ContextAwareIterator.hasNext(ContextAwareIterator.scala:39)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
at scala.collection.Iterator$GroupedIterator.fill(Iterator.scala:1209)
at scala.collection.Iterator$GroupedIterator.hasNext(Iterator.scala:1215)
at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
at scala.collection.Iterator.foreach(Iterator.scala:941)
at scala.collection.Iterator.foreach$(Iterator.scala:941)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1429)
at org.apache.spark.api.python.PythonRDD$.writeIteratorToStream(PythonRDD.scala:442)
at org.apache.spark.sql.execution.python.PythonUDFRunner$$anon$1.writeIteratorToStream(PythonUDFRunner.scala:53)
at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:521)
at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:2241)
at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:313)
Caused by: java.io.EOFException
at java.io.DataInputStream.readInt(DataInputStream.java:392)
at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:88)
... 23 more
pyarrow versions
pyarrow==11.0.0
numpy==1.23.3
pandas==1.4.4