3

I currently have the following code:

def _join_intent_types(df):
  mappings = {
    'PastNews': 'ContextualInformation',
    'ContinuingNews': 'News',
    'KnownAlready': 'OriginalEvent',
    'SignificantEventChange': 'NewSubEvent',
  }
  return df.withColumn('Categories', posexplode('Categories').alias('i', 'val'))\
           .when(col('val').isin(mappings), mappings[col('i')])\
           .otherwise(col('val'))

But I'm not sure if my syntax is right. What I'm trying to do is operate on a column of lists such as:

['EmergingThreats', 'Factoid', 'KnownAlready']

and replace strings within that Array with the mappings in the dictionary provided, i.e.

['EmergingThreats', 'Factoid', 'OriginalEvent']

I understand this is possible with a UDF but I was worried how this would impact performance and scalability.

A sample of the original table:

+------------------+-----------------------------------------------------------+
|postID            |Categories                                                 |
+------------------+-----------------------------------------------------------+
|266269932671606786|[EmergingThreats, Factoid, KnownAlready]                   |
|266804609954234369|[Donations, ServiceAvailable, ContinuingNews]              |
|266250638852243457|[EmergingThreats, Factoid, ContinuingNews]                 |
|266381928989589505|[EmergingThreats, MultimediaShare, Factoid, ContinuingNews]|
|266223346520297472|[EmergingThreats, Factoid, KnownAlready]                   |
+------------------+-----------------------------------------------------------+

I'd like the code to replace strings in those arrays with their new mappings, provided they exist in the dictionary. If not, leave them as they are:

+------------------+-------------------------------------------------+          
|postID            |Categories                                       |
+------------------+-------------------------------------------------+
|266269932671606786|[EmergingThreats, Factoid, OriginalEvent]        |
|266804609954234369|[Donations, ServiceAvailable, News]              |
|266250638852243457|[EmergingThreats, Factoid, News]                 |
|266381928989589505|[EmergingThreats, MultimediaShare, Factoid, News]|
|266223346520297472|[EmergingThreats, Factoid, OriginalEvent]        |
+------------------+-------------------------------------------------+
ZygD
  • 22,092
  • 39
  • 79
  • 102
apgsov
  • 794
  • 1
  • 8
  • 30

4 Answers4

6

Using explode + collect_list is expensive. This is untested, but should work for Spark 2.4+:

from pyspark.sql.functions import expr

for k, v in mappings.items()
    df = df.withColumn(
        'Categories', 
        expr('transform(sequence(0,size(Categories)-1), x -> replace(Categories[x], {k}, {v}))'.format(k=k, v=v))
    )

You can also convert the mappings into CASE/WHEN statement and then apply it to the SparkSQL transform function:

sql_epxr = "transform(Categories, x -> CASE x {} ELSE x END)".format(" ".join("WHEN '{}' THEN '{}'".format(k,v) for k,v in mappings.items()))
# this yields the following SQL expression:
# transform(Categories, x -> 
#   CASE x 
#     WHEN 'PastNews' THEN 'ContextualInformation' 
#     WHEN 'ContinuingNews' THEN 'News' 
#     WHEN 'KnownAlready' THEN 'OriginalEvent' 
#     WHEN 'SignificantEventChange' THEN 'NewSubEvent' 
#     ELSE x 
#   END
# )

df.withColumn('Categories', expr(sql_epxr)).show(truncate=False)    

For older versions of spark, a udf may be preferred.

jxc
  • 13,553
  • 4
  • 16
  • 34
pault
  • 41,343
  • 15
  • 107
  • 149
1

You can explode The Categories column, then na.replace with the dictionary followed by groupby and aggregate as arrays using collect_list:

import pyspark.sql.functions as F

out = (df.select(F.col("postID"),F.explode("Categories").alias("Categories"))
         .na.replace(mappings).groupby("postID")
        .agg(F.collect_list("Categories").alias("Categories")))

out.show(truncate=False)

+------------------+-------------------------------------------------+
|postID            |Categories                                       |
+------------------+-------------------------------------------------+
|266269932671606786|[EmergingThreats, Factoid, OriginalEvent]        |
|266250638852243457|[EmergingThreats, Factoid, News]                 |
|266381928989589505|[EmergingThreats, MultimediaShare, Factoid, News]|
|266804609954234369|[Donations, ServiceAvailable, News]              |
|266223346520297472|[EmergingThreats, Factoid, OriginalEvent]        |
+------------------+-------------------------------------------------+

UPDATE:

As discussed in comments , you can consider using an udf considering performance:

def fun(x):
    return [mappings.get(i,i) for i in x]
myudf = F.udf(fun)
df.withColumn("Categories",myudf(F.col("Categories"))).show(truncate=False)

+------------------+-------------------------------------------------+
|postID            |Categories                                       |
+------------------+-------------------------------------------------+
|266269932671606786|[EmergingThreats, Factoid, OriginalEvent]        |
|266804609954234369|[Donations, ServiceAvailable, News]              |
|266250638852243457|[EmergingThreats, Factoid, News]                 |
|266381928989589505|[EmergingThreats, MultimediaShare, Factoid, News]|
|266223346520297472|[EmergingThreats, Factoid, OriginalEvent]        |
+------------------+-------------------------------------------------+
anky
  • 74,114
  • 11
  • 41
  • 70
  • 1
    Why `.na.replace` rather than just `.replace` ? – Sreeram TP Apr 17 '20 at 11:11
  • 1
    @SreeramTP you're right , `replace` will work here as well. Thanks – anky Apr 17 '20 at 11:16
  • 1
    [Considering high cost of explode + collect_list idiom, using `transform` for spark 2.4+ or a `udf` otherwise is almost exclusively preferred, despite its intrinsic cost.](https://stackoverflow.com/a/53486896/5858851) – pault Apr 17 '20 at 14:19
  • @pault okay , thank you , will keep that noted. Thanks for linking me these answers – anky Apr 17 '20 at 14:22
  • @pault Got it , struggling with the execution part may have to give this more time. however , do you think even a simple `for loop udf` is performant than explode+agg ? Asking for my future reference – anky Apr 17 '20 at 14:54
  • 1
    Generally speaking, yes but I suppose it depends on the size of your data and the distribution. – pault Apr 17 '20 at 15:13
1

You can do this with a series of steps,

import pandas as pd
from pyspark.sql.functions as F
from itertools import chain

df = pd.DataFrame()
df['postID'] = [266269932671606786, 266804609954234369, 266250638852243457]
df['Categories']= [
  ['EmergingThreats', 'Factoid', 'KnownAlready'],
  ['Donations', 'ServiceAvailable', 'ContinuingNews'],
  ['EmergingThreats', 'Factoid', 'ContinuingNews'] 
]

sdf = sc.createDataFrame(df)

mappings = {
    'PastNews': 'ContextualInformation',
    'ContinuingNews': 'News',
    'KnownAlready': 'OriginalEvent',
    'SignificantEventChange': 'NewSubEvent',
    'Donations': 'x'
  }

mapping_expr = F.create_map([F.lit(x) for x in chain(*mappings.items())])

sdf.select(F.col("postID"), F.explode("Categories").alias("Categories")) \
            .withColumn("Categories", F.coalesce(mapping_expr.getItem(F.col("Categories")), F.col('Categories'))) \
            .groupBy('postID').agg(F.collect_list('Categories').alias('Categories')) \
            .show(truncate=False)


+------------------+-----------------------------------------+
|postID            |Categories                               |
+------------------+-----------------------------------------+
|266250638852243457|[EmergingThreats, Factoid, News]         |
|266804609954234369|[x, ServiceAvailable, News]              |
|266269932671606786|[EmergingThreats, Factoid, OriginalEvent]|
+------------------+-----------------------------------------+
Sreeram TP
  • 11,346
  • 7
  • 54
  • 108
0

Spark 3.1+

First, create_map from the dict, then transform and most importantly, coalesce.

map_col = F.create_map([F.lit(x) for i in mappings.items() for x in i])
df = df.withColumn('col', F.transform('col', lambda x: F.coalesce(map_col[x], x)))

Full example:

from pyspark.sql import functions as F
mappings = {
    'PastNews': 'ContextualInformation',
    'ContinuingNews': 'News',
    'KnownAlready': 'OriginalEvent',
    'SignificantEventChange': 'NewSubEvent',
}
df = spark.createDataFrame(
    [('266269932671606786', ['EmergingThreats', 'Factoid', 'KnownAlready']),
     ('266804609954234369', ['Donations', 'ServiceAvailable', 'ContinuingNews']),
     ('266250638852243457', ['EmergingThreats', 'Factoid', 'ContinuingNews']),
     ('266381928989589505', ['EmergingThreats', 'MultimediaShare', 'Factoid', 'ContinuingNews']),
     ('266223346520297472', ['EmergingThreats', 'Factoid', 'KnownAlready'])],
    ['postID', 'Categories'])

map_col = F.create_map([F.lit(x) for i in mappings.items() for x in i])
df = df.withColumn("Categories", F.transform('Categories', lambda x: F.coalesce(map_col[x], x)))

df.show(truncate=0)
# +------------------+-------------------------------------------------+
# |postID            |Categories                                       |
# +------------------+-------------------------------------------------+
# |266269932671606786|[EmergingThreats, Factoid, OriginalEvent]        |
# |266804609954234369|[Donations, ServiceAvailable, News]              |
# |266250638852243457|[EmergingThreats, Factoid, News]                 |
# |266381928989589505|[EmergingThreats, MultimediaShare, Factoid, News]|
# |266223346520297472|[EmergingThreats, Factoid, OriginalEvent]        |
# +------------------+-------------------------------------------------+
ZygD
  • 22,092
  • 39
  • 79
  • 102