0

I would like to sample at most n rows from each group in the data, where the grouping is defined by a single column. There are many answers for selecting the top n rows, but I dont't need order and am not sure whether ordering would not introduce unnecessary shuffling.

I have looked at

  • sampleBy(), but I don't need a fraction but a maximal absolute amount of rows.
  • Windows, but they always seem to imply ordering the values
  • groupBy, but was not able to construct something of the available aggregate functions.

Code example:

data = [('A',1), ('B',1), ('C',2)]
columns = ["field_1","field_2"]
df = spark.createDataFrame(data=data, schema = columns)

Where I would be looking for a pandas-like

df.groupby('field_2').head(1)

I would also be happy with a suitable SQL expression.

Otherwise if there is no better performance than using

Window.partitionBy(df['field_2']).orderBy('field_1')...

then I'd also be happy to know that. Thanks!

malumno
  • 103
  • 3
  • Does this answer your question? [Retrieve top n in each group of a DataFrame in pyspark](https://stackoverflow.com/questions/38397796/retrieve-top-n-in-each-group-of-a-dataframe-in-pyspark) If you need randomness, you can add `df.orderBy(F.rand())`, but be aware of the performance. – Emma Jun 13 '22 at 23:27
  • I don't explicitly need randomness. But the answer you are suggesting includes an orderBy(some_column). Is this the optimum, performance wise, even if I don't need any specific order within each group? – malumno Jun 14 '22 at 07:13
  • afaik, if you need top N (>1), you'll need window functions (`row_number`/`rank`) and both of those window functions require the `orderBy`. If you want top 1, then you don't need order (`groupBy('field_2').agg(F.first('field_1'))`). – Emma Jun 14 '22 at 16:34

1 Answers1

0

The below would work if a sort isn't required, and it uses RDD transformations.

For a dataframe like the following

sdf.show()

# +-----------+-------+--------+----+
# |bvdidnumber|dt_year|dt_rfrnc|goal|
# +-----------+-------+--------+----+
# |          1|   2020|  202006|   0|
# |          1|   2020|  202012|   1|
# |          1|   2020|  202012|   0|
# |          1|   2021|  202103|   0|
# |          1|   2021|  202106|   0|
# |          1|   2021|  202112|   1|
# |          2|   2020|  202006|   0|
# |          2|   2020|  202012|   0|
# |          2|   2020|  202012|   1|
# |          2|   2021|  202103|   0|
# |          2|   2021|  202106|   0|
# |          2|   2021|  202112|   1|
# +-----------+-------+--------+----+

I created a function that can be shipped to all executors, and then used with flatMapValues() in RDD transformation.

# best to ship this function to all executors for optimum performance
def get_n_from_group(group, num_recs):
    """
    get `N` number of sample records
    """
    res = []
    i = 0

    for rec in group:
        res.append(rec)
        i = i + 1

        if i == num_recs:
            break

    return res

rdd = sdf.rdd. \
    groupBy(lambda x: x.bvdidnumber). \
    flatMapValues(lambda k: get_n_from_group(k, 2))  # 2 records only

top_n_sdf = spark.createDataFrame(rdd.values(), schema=sdf.schema)

top_n_sdf.show()

# +-----------+-------+--------+----+
# |bvdidnumber|dt_year|dt_rfrnc|goal|
# +-----------+-------+--------+----+
# |          1|   2020|  202006|   0|
# |          1|   2020|  202012|   1|
# |          2|   2020|  202006|   0|
# |          2|   2020|  202012|   0|
# +-----------+-------+--------+----+
samkart
  • 6,007
  • 2
  • 14
  • 29