One way to do it is to use udf
to make a sampling column. This column will have a random number multiplied by your desired weight. Then we sort by the sampling column, and take the top N.
Consider the following illustrative example:
Create Dummy Data
import numpy as np
import string
import pyspark.sql.functions as f
index = range(100)
weights = [i%26 for i in index]
labels = [string.ascii_uppercase[w] for w in weights]
df = sqlCtx.createDataFrame(
zip(index, labels, weights),
('index', 'label', 'weight')
)
df.show(n=5)
#+-----+-----+------+
#|index|label|weight|
#+-----+-----+------+
#| 0| A| 0|
#| 1| B| 1|
#| 2| C| 2|
#| 3| D| 3|
#| 4| E| 4|
#+-----+-----+------+
#only showing top 5 rows
Add Sampling Column
In this example, we want to sample the DataFrame using the column weight
as the weight. We define a udf
using numpy.random.random()
to generate uniform random numbers and multiply by the weight. Then we use sort()
on this column and use limit()
to get the desired number of samples.
N = 10 # the number of samples
def get_sample_value(x):
return np.random.random() * x
get_sample_value_udf = f.udf(get_sample_value, FloatType())
df_sample = df.withColumn('sampleVal', get_sample_value_udf(f.col('weight')))\
.sort('sampleVal', ascending=False)\
.select('index', 'label', 'weight')\
.limit(N)
Result
As expected, the DataFrame df_sample
has 10 rows, and it's contents tend to have letters near the end of the alphabet (higher weights).
df_sample.count()
#10
df_sample.show()
#+-----+-----+------+
#|index|label|weight|
#+-----+-----+------+
#| 23| X| 23|
#| 73| V| 21|
#| 46| U| 20|
#| 25| Z| 25|
#| 19| T| 19|
#| 96| S| 18|
#| 75| X| 23|
#| 48| W| 22|
#| 51| Z| 25|
#| 69| R| 17|
#+-----+-----+------+