I have a dataframe:
import pyspark.sql.functions as F
sdf1 = spark.createDataFrame(
[
(2022, 1, ["apple", "edible"]),
(2022, 1, ["edible", "fruit"]),
(2022, 1, ["orange", "sweet"]),
(2022, 4, ["flowering ", "plant"]),
(2022, 3, ["green", "kiwi"]),
(2022, 3, ["kiwi", "fruit"]),
(2022, 3, ["fruit", "popular"]),
(2022, 3, ["yellow", "lemon"]),
],
[
"year",
"id",
"bigram",
],
)
sdf1.show(truncate=False)
+----+---+-------------------+
|year|id |bigram |
+----+---+-------------------+
|2022|1 |[apple, edible] |
|2022|1 |[edible, fruit] |
|2022|1 |[orange, sweet] |
|2022|4 |[flowering , plant]|
|2022|3 |[green, kiwi] |
|2022|3 |[kiwi, fruit] |
|2022|3 |[fruit, popular] |
|2022|3 |[yellow, lemon] |
+----+---+-------------------+
And i wrote a function that returns bigrams with the same last words in n-grams.I apply this function separately to the column.
from networkx import DiGraph, dfs_labeled_edges
# Grouping
sdf = (
sdf1.groupby("year", "id")
.agg(F.collect_set("bigram").alias("collect_bigramm"))
.withColumn("size", F.size("collect_bigramm"))
)
data_collect = sdf.collect()
@udf(returnType=ArrayType(StringType()))
def myfunc(lst):
graph = DiGraph()
for row in data_collect:
if row["size"] > 1:
for i, lst1 in enumerate(lst):
while i < len(lst) - 1:
lst2 = lst[i + 1]
if lst1[0] == lst2[1]:
graph.add_edge(lst2[0], lst2[1])
graph.add_edge(lst1[0], lst1[1])
elif lst1[1] == lst2[0]:
graph.add_edge(lst1[0], lst1[1])
graph.add_edge(lst2[0], lst2[1])
i = i + 1
gen = dfs_labeled_edges(graph)
lst_tmp = []
lst_res = []
f = 0
for g in list(gen):
if (g[2] == "forward") and (g[0] != g[1]):
f = 1
lst_tmp.append(g[0])
lst_tmp.append(g[1])
if g[2] == "nontree":
continue
if g[2] == "reverse":
if f == 1:
lst_res.append(lst_tmp.copy())
f = 0
if g[0] in lst_tmp:
lst_tmp.remove(g[0])
if g[1] in lst_tmp:
lst_tmp.remove(g[1])
if lst_res != []:
lst_res = [
ii for n, ii in enumerate(lst_res[0]) if ii not in lst_res[0][:n]
]
if lst_res == []:
lst_res = None
return lst_res
sdf_new = sdf.withColumn("new_col", myfunc(F.col("collect_bigramm")))
sdf_new.show(truncate=False)
Output:
+----+---+-----------------------------------------------------------------+----+-----------------------------+
|year|id |collect_bigramm |size|new_col |
+----+---+-----------------------------------------------------------------+----+-----------------------------+
|2022|4 |[[flowering , plant]] |1 |null |
|2022|1 |[[edible, fruit], [orange, sweet], [apple, edible]] |3 |[apple, edible, fruit] |
|2022|3 |[[yellow, lemon], [green, kiwi], [kiwi, fruit], [fruit, popular]]|4 |[green, kiwi, fruit, popular]|
+----+---+-----------------------------------------------------------------+----+-----------------------------+
But now i want to use the pandas udf. I would like to first groupby and get the collect_bigramm
column in the function. And thus leave all the columns in the dataframe, but also add a new one, which is the lst_res
array in the function.
schema2 = StructType(
[
StructField("year", IntegerType(), True),
StructField("id", IntegerType(), True),
StructField("bigram", ArrayType(StringType(), True), True),
StructField("new_col", ArrayType(StringType(), True), True),
StructField("collect_bigramm", ArrayType(ArrayType(StringType(), True), True), True),
]
)
@pandas_udf(schema2, functionType=PandasUDFType.GROUPED_MAP)
def myfunc(df):
graph = DiGraph()
for index, row in df.iterrows():
# Instead of the variable lst, i need to insert a column sdf['collect_bigramm']
...
return df
sdf_new = sdf.groupby(["year", "id"]).apply(myfunc)