I have the following pyspark dataframe that contains two fields, ID and QUARTER:
pandas_df = pd.DataFrame({"ID":[1, 2, 3,4, 5, 3,5,6,3,7,2,6,8,9,1,7,5,1,10],"QUARTER":[1, 1, 1, 1, 1,2,2,2,3,3,3,3,3,4,4,5,5,5,5]})
spark_df = spark.createDataFrame(pandas_df)
spark_df.createOrReplaceTempView('spark_df')
and I have th following liste that contains the number of entries I want from each of the 5 quarter
numbers=[2,1,3,1,2]
I want to select each time from each quarter a number of rows equals to the number indicated in the list 'numbers'. I should respect that the ID
should be unique at the end. It means if i selected an ID in a certain quarter, I should not reselect it again in an other quarter.
For that I used the following pyspark code:
quart=1 # the first quarter
liste_unique=[] # an empty list that will contains the unique Id values to compare with
for i in range(0,len(numbers)):
tmp=spark_df.where(spark_df.QUARTER==quart)# select only rows with the chosed quarter
tmp=tmp.where(tmp.ID.isin(liste_unique)==False)# the selected id were not selected before
w = Window().partitionBy(lit('col_count0')).orderBy(lit('col_count0'))#dummy column
df_final=tmp.withColumn("row_num", row_number().over(w)).filter(col("row_num").between(1,numbers[i])) # number of rows needed from the 'numbers list'
df_final=df_final.drop(col("row_num")) # drop the row num column
liste_tempo=df_final.select(['ID']).rdd.map(lambda x : x[0]).collect() # transform the selected id into list
liste_unique.extend(liste_tempo) # extend the list of unique id each time we select new rows from a quarter
df0=df0.union(df_final) # union the empty list each time with the selected data in each quarter
quart=quart+1 #increment the quarter
df0 is simply an empty list at the begining. It will contains all the data at the end, it can be declared as follow
spark = SparkSession.builder.appName('Empty_Dataframe').getOrCreate()
# Create an empty schema
columns = StructType([StructField('ID',
StringType(), True),
StructField('QUARTER',
StringType(), True)
])
df0 = spark.createDataFrame(data = [],
schema = columns)
The code works fine without errors, except that I can find duplicate ID at different quarter which is not correct. Also, a weird behavior is When I tried to count the number of unique ID in the df0 dataframe ( in a new different cell)
print(df0.select('ID').distinct().count())
It gives at each execution a different value even if the dataframe is not touched with any other process ( it is more clear with a larger dataset than the example). I can not understand this behavior,I tried to delete the cache or the temporary variables using unpersist(True)
, but nothing change. I suspect that the Union
function is wrongly used but I did not found any alternative in pyspark.