0

Input_pyspark_dataframe:

id   name  collection  student.1.price  student.2.price  student.3.price
111  aaa      1           100              999               232
222  bbb      2           200              888               656
333  ccc      1           300              777               454
444  ddd      1           400              666               787

output_pyspark_dataframe

id   name  collection    price  
111  aaa      1           100           
222  bbb      2           888            
333  ccc      1           300             
444  ddd      3           787       

we can find the correct price of each id by using value present in the collection column

Question

using pyspark, How i can find the correct price of each id by dynamically framing column name student.{collection}.price ?

please let me know.

siva
  • 549
  • 7
  • 25

1 Answers1

0

A bit complete but you can do this way.

The fields will give you the field names of the struct field, student. You should give this manually and eventually get 1, 2, 3.

The first line then make an array of the columns student.{i}.price for i = range(1, 4). Similarly, the second line make an array of the literals {i}.

Now, zip this two array into one array such as

[('1', col('student.1.price')), ...]

and explode the array then it becomes:

('1', col('student.1.price'))
('2', col('student.2.price'))
('3', col('student.3.price'))

Since the arrays_zip give you an array of struct, the above result is struct type. Get each value by using struct key as the column, that is the index and price.

Finally, you can compare the collection and index (this is actually the field name of the student struct column).

import pyspark.sql.functions as f

fields = [field.name for field in next(field for field in df.schema.fields if field.name == 'student').dataType.fields]

df.withColumn('array', f.array(*map(lambda x: 'student.' + x + '.price', fields))) \
  .withColumn('index', f.array(*map(lambda x: f.lit(x), fields))) \
  .withColumn('zip', f.arrays_zip('index', 'array')) \
  .withColumn('zip', f.explode('zip')) \
  .withColumn('index', f.col('zip.index')) \
  .withColumn('price', f.col('zip.array')) \
  .filter('collection = index') \
  .select('id', 'name', 'collection', 'price') \
  .show(10, False)

+---+----+----------+-----+
|id |name|collection|price|
+---+----+----------+-----+
|111|aaa |1         |100  |
|222|bbb |2         |888  |
|333|ccc |1         |300  |
|444|ddd |3         |787  |
+---+----+----------+-----+
Lamanus
  • 12,898
  • 4
  • 21
  • 47