1

I have the following code which creates a new column based on combinations of columns in my dataframe, minus duplicates:

import itertools as it
import pandas as pd 

df = pd.DataFrame({
  'a': [3,4,5,6,3], 
  'b': [5,7,1,0,5], 
  'c': [3,4,2,1,3], 
  'd': [2,0,1,5,9]
})

orig_cols = df.columns 
for r in range(2, df.shape[1] + 1):
    for cols in it.combinations(orig_cols, r):
        df["_".join(cols)] = df.loc[:, cols].sum(axis=1)

df

enter image description here

I need to generate the same results using Pyspark through a UDF. What would be the equivalent code in Pyspark?

Shubham Sharma
  • 68,127
  • 6
  • 24
  • 53
jack homareau
  • 319
  • 1
  • 8

2 Answers2

1

There's no need to use UDF. Let us use native spark functions:

from itertools import combinations

sums = [
    sum(map(F.col, c)).alias('_'.join(c)) 
    for r in range(2, len(df.columns) + 1) 
    for c in combinations(df.columns,   r)
]

df = df.select('*', *sums)

df.show()

+---+---+---+---+---+---+---+---+---+---+-----+-----+-----+-----+-------+
|  a|  b|  c|  d|a_b|a_c|a_d|b_c|b_d|c_d|a_b_c|a_b_d|a_c_d|b_c_d|a_b_c_d|
+---+---+---+---+---+---+---+---+---+---+-----+-----+-----+-----+-------+
|  3|  5|  3|  2|  8|  6|  5|  8|  7|  5|   11|   10|    8|   10|     13|
|  4|  7|  4|  0| 11|  8|  4| 11|  7|  4|   15|   11|    8|   11|     15|
|  5|  1|  2|  1|  6|  7|  6|  3|  2|  3|    8|    7|    8|    4|      9|
|  6|  0|  1|  5|  6|  7| 11|  1|  5|  6|    7|   11|   12|    6|     12|
|  3|  5|  3|  9|  8|  6| 12|  8| 14| 12|   11|   17|   15|   17|     20|
+---+---+---+---+---+---+---+---+---+---+-----+-----+-----+-----+-------+
Shubham Sharma
  • 68,127
  • 6
  • 24
  • 53
0

This is the PySpark version of your code. You can modify the UDF function and its registration as you want.

data = [
    (3,4,5,6), 
    (5,7,1,0), 
    (3,4,2,1), 
    (2,0,1,5)
]

col_names = ['a', 'b', 'c', 'd']
df = spark.createDataFrame(data, col_names)

combination_cols = []
for r in range(2, df.count()):
    combination_cols += it.combinations(col_names, r)

def do_something(*args):
    result = 0
    for x in args:
        result += x
    return result

do_something_udf = udf(do_something, IntegerType())

output_df = df
for tup in combination_cols:
    output_df = output_df.withColumn("_".join(tup), do_something_udf(*tup))

output_df.show()

Output:

+---+---+---+---+---+---+---+---+---+---+-----+-----+-----+-----+
|  a|  b|  c|  d|a_b|a_c|a_d|b_c|b_d|c_d|a_b_c|a_b_d|a_c_d|b_c_d|
+---+---+---+---+---+---+---+---+---+---+-----+-----+-----+-----+
|  3|  4|  5|  6|  7|  8|  9|  9| 10| 11|   12|   13|   14|   15|
|  5|  7|  1|  0| 12|  6|  5|  8|  7|  1|   13|   12|    6|    8|
|  3|  4|  2|  1|  7|  5|  4|  6|  5|  3|    9|    8|    6|    7|
|  2|  0|  1|  5|  2|  3|  7|  1|  5|  6|    3|    7|    8|    6|
+---+---+---+---+---+---+---+---+---+---+-----+-----+-----+-----+