Search code examples
python-polars

Splitting a lazyframe into two frames by fraction of rows to make a train-test split


I have a train_test_split function in Polars that can handle an eager DataFrame. I wish to write an equivalent function that can take a LazyFrame as input and return two LazyFrames without evaluating them.

My function is as follows. It shuffles all rows, and then splits it using row-indexing based on the height of the full frame.

def train_test_split(
    df: pl.DataFrame, train_fraction: float = 0.75
) -> tuple[pl.DataFrame, pl.DataFrame]:
    """Split polars dataframe into two sets.
    Args:
        df (pl.DataFrame): Dataframe to split
        train_fraction (float, optional): Fraction that goes to train. Defaults to 0.75.
    Returns:
        Tuple[pl.DataFrame, pl.DataFrame]: Tuple of train and test dataframes
    """
    df = df.with_columns(pl.all().shuffle(seed=1))
    split_index = int(train_fraction * df.height)
    df_train = df[:split_index]
    df_test = df[split_index:]
    return df_train, df_test


df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [4, 3, 2, 1]})
train, test = train_test_split(df)

# this is what the above looks like:
train = pl.DataFrame({'a': [2, 3, 4], 'b': [3, 2, 1]})
test = pl.DataFrame({'a': [1], 'b': [4]})

Lazyframes, however, have unknown height, so we have to do this another way. I have two ideas, but run into issues with both:

  1. Use df.sample(frac=train_fraction, with_replacement=False, shuffle=False). This way I could get the train part, but wouldn't be able to get the test part.
  2. Add a "random" column, where each row gets assigned a random value between 0 and 1. Then I can filter on values below my train_fraction and above train_fraction, and assign these to my train and test datasets respectively. But since I don't know the length of the dataframe beforehand, and (afaik) Polars doesn't have a native way of creating such a column, I would need to .map_elements an equivalent of np.random.uniform on each row, which would be very time consuming.
  3. Add a .with_row_index() and filter on rows larger than some fraction of the total, but here I also need the height, and creating the row count might be expensive.

Finally, I might be going about this the wrong way: I could count the total number of rows beforehand, but I don't know how expensive this is considered.

Here's a big dataframe to test on (takes ~1 sec) to run my function eagerly:

N = 50_000_000
df_big = pl.DataFrame(
    [
        pl.int_range(N, eager=True),
        pl.int_range(N, eager=True),
        pl.int_range(N, eager=True),
        pl.int_range(N, eager=True),
        pl.int_range(N, eager=True),
    ],
    schema=["a", "b", "c", "d", "e"],
)


Solution

  • Here is one way in lazy mode to do it with Polars with_row_index:

    def train_test_split_lazy(
        df: pl.DataFrame, train_fraction: float = 0.75
    ) -> tuple[pl.DataFrame, pl.DataFrame]:
        """Split polars dataframe into two sets.
        Args:
            df (pl.DataFrame): Dataframe to split
            train_fraction (float, optional): Fraction that goes to train. Defaults to 0.75.
        Returns:
            Tuple[pl.DataFrame, pl.DataFrame]: Tuple of train and test dataframes
        """
        df = df.with_columns(pl.all().shuffle(seed=1)).with_row_index()
        df_train = df.filter(pl.col("index") < pl.col("index").max() * train_fraction)
        df_test = df.filter(pl.col("index") >= pl.col("index").max() * train_fraction)
        return df_train, df_test
    

    Then:

    df_big = pl.DataFrame(
        [
            pl.int_range(N, eager=True),
            pl.int_range(N, eager=True),
            pl.int_range(N, eager=True),
            pl.int_range(N, eager=True),
            pl.int_range(N, eager=True),
        ],
        schema=["a", "b", "c", "d", "e"],
    ).lazy()
    train, test = train_test_split_lazy(df_big)
    
    print(train.collect())
    print(test.collect())
    
    shape: (37_500_000, 6)
    ┌──────────┬──────────┬──────────┬──────────┬──────────┬──────────┐
    │ row_nr   ┆ a        ┆ b        ┆ c        ┆ d        ┆ e        │
    │ ---      ┆ ---      ┆ ---      ┆ ---      ┆ ---      ┆ ---      │
    │ u32      ┆ i64      ┆ i64      ┆ i64      ┆ i64      ┆ i64      │
    ╞══════════╪══════════╪══════════╪══════════╪══════════╪══════════╡
    │ 0        ┆ 27454110 ┆ 27454110 ┆ 27454110 ┆ 27454110 ┆ 27454110 │
    │ 1        ┆ 2309916  ┆ 2309916  ┆ 2309916  ┆ 2309916  ┆ 2309916  │
    │ 2        ┆ 15065100 ┆ 15065100 ┆ 15065100 ┆ 15065100 ┆ 15065100 │
    │ 3        ┆ 12766444 ┆ 12766444 ┆ 12766444 ┆ 12766444 ┆ 12766444 │
    │ …        ┆ …        ┆ …        ┆ …        ┆ …        ┆ …        │
    │ 37499996 ┆ 40732880 ┆ 40732880 ┆ 40732880 ┆ 40732880 ┆ 40732880 │
    │ 37499997 ┆ 32447037 ┆ 32447037 ┆ 32447037 ┆ 32447037 ┆ 32447037 │
    │ 37499998 ┆ 41754221 ┆ 41754221 ┆ 41754221 ┆ 41754221 ┆ 41754221 │
    │ 37499999 ┆ 7019133  ┆ 7019133  ┆ 7019133  ┆ 7019133  ┆ 7019133  │
    └──────────┴──────────┴──────────┴──────────┴──────────┴──────────┘
    shape: (12_500_000, 6)
    ┌──────────┬──────────┬──────────┬──────────┬──────────┬──────────┐
    │ row_nr   ┆ a        ┆ b        ┆ c        ┆ d        ┆ e        │
    │ ---      ┆ ---      ┆ ---      ┆ ---      ┆ ---      ┆ ---      │
    │ u32      ┆ i64      ┆ i64      ┆ i64      ┆ i64      ┆ i64      │
    ╞══════════╪══════════╪══════════╪══════════╪══════════╪══════════╡
    │ 37500000 ┆ 29107559 ┆ 29107559 ┆ 29107559 ┆ 29107559 ┆ 29107559 │
    │ 37500001 ┆ 26750366 ┆ 26750366 ┆ 26750366 ┆ 26750366 ┆ 26750366 │
    │ 37500002 ┆ 17450938 ┆ 17450938 ┆ 17450938 ┆ 17450938 ┆ 17450938 │
    │ 37500003 ┆ 30333846 ┆ 30333846 ┆ 30333846 ┆ 30333846 ┆ 30333846 │
    │ …        ┆ …        ┆ …        ┆ …        ┆ …        ┆ …        │
    │ 49999996 ┆ 17167194 ┆ 17167194 ┆ 17167194 ┆ 17167194 ┆ 17167194 │
    │ 49999997 ┆ 9092583  ┆ 9092583  ┆ 9092583  ┆ 9092583  ┆ 9092583  │
    │ 49999998 ┆ 1929693  ┆ 1929693  ┆ 1929693  ┆ 1929693  ┆ 1929693  │
    │ 49999999 ┆ 35668469 ┆ 35668469 ┆ 35668469 ┆ 35668469 ┆ 35668469 │
    

    On my machine, I get this output in 0.455 seconds on average for 100 runs.

    If I cheat and replace df.height with 50_000_000 in your version of train_test_split, and then run it lazy mode, I get the same output in 0.446 seconds on average for 100 runs, which is equivalent in terms of performance.