1

I have a SQL query like below:

select col4, col5 from TableA where col1 = 'x'
intersect
select col4, col5 from TableA where col1 = 'y'
intersect
select col4, col5 from TableA where col1 = 'z'

How can I convert this SQL to PySpark equivalent? I can create 3 DF and then do intersect like below:

df1 ==> select col4, col5 from TableA where col1 = 'x'
df2 ==> select col4, col5 from TableA where col1 = 'y'
df3 ==> select col4, col5 from TableA where col1 = 'z'

df_result = df1.intersect(df2)
df_result = df_result.intersect(df3)

But I feel that's not good approach to follow if I had more intersect queries.

Also, let's say [x,y,z] is dynamic, means it can be like [x,y,z,a,b,.....]

Any suggestion?

ZygD
  • 22,092
  • 39
  • 79
  • 102
Temp Expt
  • 305
  • 1
  • 4
  • 17
  • As it has tag of Apache-spark-sql so you can use your same sql query in Spark SQL – Anjaneya Tripathi Jun 11 '22 at 07:33
  • I want to use PySpark only, else the query will be very big. This is just simple SQL I've provided here. Just for single value changes like `x` or `y` or `z`, entire SQL query has to be appeded with intersect. – Temp Expt Jun 11 '22 at 07:52

1 Answers1

1

If you wanted to do several consecutive intersect, there's reduce available. Put all your dfs in one list and you will do intersect consecutively:

from functools import reduce
dfs = [df1, df2,...]
df = reduce(lambda a, b: a.intersect(b), dfs)

But it would be inefficient in your case.


Since all the data comes from the same dataframe, I would suggest a rework. Instead of dividing df and then rejoining using intersect, do an aggregation and filtering.

Script (Spark 3.1):

vals = ['x', 'y', 'z']
arr = F.array([F.lit(v) for v in vals])
df = df.groupBy('col4', 'col5').agg(F.collect_set('col1').alias('set'))
df = df.filter(F.forall(arr, lambda x: F.array_contains('set', x)))
df = df.drop('set')

Test:

from pyspark.sql import functions as F
df = spark.createDataFrame(
    [(1, 11, 'y'),
     (1, 11, 'y'),
     (1, 11, 'x'),
     (2, 22, 'x'),
     (1, 11, 'z'),
     (4, 44, 'z'),
     (1, 11, 'M')],
    ['col4', 'col5', 'col1'])

vals = ['x', 'y', 'z']
arr = F.array([F.lit(v) for v in vals])
df = df.groupBy('col4', 'col5').agg(F.collect_set('col1').alias('set'))
df = df.filter(F.forall(arr, lambda x: F.array_contains('set', x)))
df = df.drop('set')

df.show()
# +----+----+
# |col4|col5|
# +----+----+
# |   1|  11|
# +----+----+
ZygD
  • 22,092
  • 39
  • 79
  • 102
  • 1
    thank you for sharing. Yes, `reduce` isn't right approach for me, as I don't want to create multiple df. I opted for 2nd one. Thank you. – Temp Expt Jun 13 '22 at 14:20