0

still new to pyspark but I have a pyspark dataframe that I am trying to manipulate. The data consists of users logging into a device and I'm creating sessions for each user. I want to reset the id for all the array of structs to 0 and increment the id by 1. This is what I have and here's what the schema looks like:

|-- user: string (nullable = true)
 |-- logins: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- id: integer (nullable = true) #ids are non unique
 |    |    |-- start_time: timestamp (nullable = true)
           |-- end_time: timestamp (nullable = true)
def id_fix():
    return lambda x: transform(x, lambda item: struct(lit(0)).alias("id"))

df.withColumn("corrected", transform(col("logins"), id()))

I can't even set the id's to 0 without getting an data type mismatch array as the "paramenter require 1 "ARRAY" type however "namedlambdavariable()" is of "STRUCT" type"

Thank you!

meepmepp
  • 21
  • 2

1 Answers1

0

You can use spark inbuilt functions posexplode() for this case.

  • posexplode(col("logins")) -> explode the logins array.

  • groupBy() -> on user to create new array with position of the ids

Example:

from pyspark.sql.functions import *
json = """{"user":"u1","logins":[{"id":1,"start_time":12345689, "end_time":99999},{"id":2,"start_time":89, "end_time":99999}]}"""
#sample data
df = spark.read.json(sc.parallelize([json]), multiLine=True)

df.select(col("user"), posexplode(col("logins"))).\
  select("user","pos","col.*").\
    groupBy("user").agg(to_json(collect_list(struct(col("pos").alias("id"),col("start_time"),col("end_time")))).alias("logins")).\
  show(10,False)
#+----+-------------------------------------------------------------------------------------------+
#|user|logins                                                                                     |
#+----+-------------------------------------------------------------------------------------------+
#|u1  |[{"id":0,"start_time":12345689,"end_time":99999},{"id":1,"start_time":89,"end_time":99999}]|
#+----+-------------------------------------------------------------------------------------------+

#return struct type
df.select(col("user"), posexplode(col("logins"))).\
  select("user","pos","col.*").\
    groupBy("user").agg(collect_list(struct(col("pos").alias("id"),col("start_time"),col("end_time"))).alias("logins")).\
  show(10,False)
#+----+--------------------------------------+
#|user|logins                                |
#+----+--------------------------------------+
#|u1  |[{0, 12345689, 99999}, {1, 89, 99999}]|
#+----+--------------------------------------+
notNull
  • 30,258
  • 4
  • 35
  • 50