1

I have a pypark dataframe in the following way:

+---+----+----+
| id|col1|col2|
+---+----+----+
|  1|   1|   3|
|  2| NaN|   4|
|  3|   3|   5|
+---+----+----+

I would like to sum col1 and col2 so that the result looks like this:

+---+----+----+---+
| id|col1|col2|sum|
+---+----+----+---+
|  1|   1|   3|  4|
|  2| NaN|   4|  4|
|  3|   3|   5|  8|
+---+----+----+---+

Here's what I have tried:

import pandas as pd

test = pd.DataFrame({
    'id': [1, 2, 3],
    'col1': [1, None, 3],
    'col2': [3, 4, 5]
})
test = spark.createDataFrame(test)
test.withColumn('sum', F.col('col1') + F.col('col2')).show()

This code returns:

+---+----+----+---+
| id|col1|col2|sum|
+---+----+----+---+
|  1|   1|   3|  4|
|  2| NaN|   4|NaN| # <-- I want a 4 here, not this NaN
|  3|   3|   5|  8|
+---+----+----+---+

Can anyone help me with this?

mck
  • 40,932
  • 13
  • 35
  • 50
itscarlayall
  • 128
  • 1
  • 14

2 Answers2

1

Use F.nanvl to replace NaN with a given value (0 here):

import pyspark.sql.functions as F

result = test.withColumn('sum', F.nanvl(F.col('col1'), F.lit(0)) + F.col('col2'))

For your comment:

result = test.withColumn('sum', 
    F.when(
        F.isnan(F.col('col1')) & F.isnan(F.col('col2')), 
        F.lit(float('nan'))
    ).otherwise(
        F.nanvl(F.col('col1'), F.lit(0)) + F.nanvl(F.col('col2'), F.lit(0))
    )
)
mck
  • 40,932
  • 13
  • 35
  • 50
0

In case someone has more than two columns, here is a more general solution. NB, that this is for nan, the null value is not the same

import pyspark.sql.functions as F

def with_notnan_sum(df, cols=None):
    if cols is None:
        cols = df.columns
    df = df.withColumn("_sum", sum(F.when(F.isnan(c), 0).otherwise(df[c]) for c in cols))
    df = df.withColumn("_count", sum(F.when(F.isnan(c), 0).otherwise(1) for c in cols))
    df = df.withColumn("notnan_sum", F.when(F.col("_count") > 0, F.col("_sum")))
    df = df.drop("_sum", "_count")
    return df

test = spark.createDataFrame(
    [
        (1, 1., 3),
        (2, float('nan'), 4),
        (3, 3., 5),
    ], 
    ('id', 'col1', 'col2'),
)
test.show()
test = with_notnan_sum(test, ['col1', 'col2'])
test.show()
+---+----+----+
| id|col1|col2|
+---+----+----+
|  1| 1.0|   3|
|  2| NaN|   4|
|  3| 3.0|   5|
+---+----+----+

+---+----+----+----------+
| id|col1|col2|notnan_sum|
+---+----+----+----------+
|  1| 1.0|   3|       4.0|
|  2| NaN|   4|       4.0|
|  3| 3.0|   5|       8.0|
+---+----+----+----------+
savfod
  • 549
  • 6
  • 9