I have a train_test_split
function in Polars that can handle an eager DataFrame. I wish to write an equivalent function that can take a LazyFrame as input and return two LazyFrames without evaluating them.
My function is as follows. It shuffles all rows, and then splits it using row-indexing based on the height of the full frame.
def train_test_split(
df: pl.DataFrame, train_fraction: float = 0.75
) -> tuple[pl.DataFrame, pl.DataFrame]:
"""Split polars dataframe into two sets.
Args:
df (pl.DataFrame): Dataframe to split
train_fraction (float, optional): Fraction that goes to train. Defaults to 0.75.
Returns:
Tuple[pl.DataFrame, pl.DataFrame]: Tuple of train and test dataframes
"""
df = df.with_columns(pl.all().shuffle(seed=1))
split_index = int(train_fraction * df.height)
df_train = df[:split_index]
df_test = df[split_index:]
return df_train, df_test
df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [4, 3, 2, 1]})
train, test = train_test_split(df)
# this is what the above looks like:
train = pl.DataFrame({'a': [2, 3, 4], 'b': [3, 2, 1]})
test = pl.DataFrame({'a': [1], 'b': [4]})
Lazyframes, however, have unknown height, so we have to do this another way. I have two ideas, but run into issues with both:
- Use
df.sample(frac=train_fraction, with_replacement=False, shuffle=False)
. This way I could get the train part, but wouldn't be able to get the test part. - Add a "random" column, where each row gets assigned a random value between 0 and 1. Then I can filter on values below my train_fraction and above train_fraction, and assign these to my train and test datasets respectively. But since I don't know the length of the dataframe beforehand, and (afaik) Polars doesn't have a native way of creating such a column, I would need to
.apply
an equivalent ofnp.random.uniform
on each row, which would be very time consuming. - Add a
.with_row_count()
and filter on rows larger than some fraction of the total, but here I also need the height, and creating the row count might be expensive.
Finally, I might be going about this the wrong way: I could count the total number of rows beforehand, but I don't know how expensive this is considered.
Here's a big dataframe to test on (takes ~1 sec) to run my function eagerly:
N = 50_000_000
df_big = pl.DataFrame(
[
pl.arange(0, N, eager=True),
pl.arange(0, N, eager=True),
pl.arange(0, N, eager=True),
pl.arange(0, N, eager=True),
pl.arange(0, N, eager=True),
],
schema=["a", "b", "c", "d", "e"],
)