I have a Spark DataFrame with two columns containing PySpark SparseVectors.
And I have to calculate a pairwise cosine similarity between them.
import pandas as pd
from pyspark.sql.types import *
from pyspark.sql import functions as f
from pyspark.ml.feature import Tokenizer, HashingTF, IDF
phrases = ['I have a few texts here',
'This is a text number one',
'This is a text number two']
schema = StructType(fields=[
StructField('col1', IntegerType()),
StructField('col2', StringType())
])
df = sqlContext.createDataFrame(pd.DataFrame({'col1': range(1, len(phrases) + 1), 'col2': phrases}), schema=schema)
tokenizer = Tokenizer(inputCol='col2', outputCol='words')
df2 = tokenizer.transform(df)
tf_hasher = HashingTF(numFeatures=15, binary=True, inputCol=tokenizer.getOutputCol(), outputCol='tf_vector')
df3 = tf_hasher.transform(df2)
idf = IDF(inputCol=tf_hasher.getOutputCol(), outputCol='tf_idf_vector')
idf_model = idf.fit(df3)
df4 = idf_model.transform(df3)
left = df4[['col1', 'col2', 'tf_idf_vector']].withColumnRenamed('col1', 'col1_1')\
.withColumnRenamed('col2', 'col2_1')\
.withColumnRenamed('tf_idf_vector', 'tf_idf_vector_1')
right = df4[['col1', 'col2', 'tf_idf_vector']].withColumnRenamed('col1', 'col1_2')\
.withColumnRenamed('col2', 'col2_2')\
.withColumnRenamed('tf_idf_vector', 'tf_idf_vector_2')
df5 = left.crossJoin(right)
def cosine_similarity(u, v):
return v.dot(u) / (v.norm(2) * u.norm(2))
cosine_similarity_udf = f.udf(cosine_similarity, FloatType())
df6 = df5.withColumn('cosine_similarity', cosine_similarity_udf(df5.tf_idf_vector_1, df5.tf_idf_vector_2))
df6.show(10)
I expect the following result (like, if I calculated in Pandas):
Py4JJavaError: An error occurred while calling o871.collectToPython.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 1078 in stage 88.0 failed 4 times, most recent failure: Lost task 1078.3 in stage 88.0 (TID 21038, executor 23): net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype)
Spark version is 2.4.4. How can I solve this? Thanks.