Search code examples
pythondataframepython-polars

Sample from each group in polars dataframe?


I'm looking for a function along the lines of

df.group_by('column').agg(sample(10))

so that I can take ten or so randomly-selected elements from each group.

This is specifically so I can read in a LazyFrame and work with a small sample of each group as opposed to the entire dataframe.

Update:

One approximate solution is:

df = lf.group_by('column').agg(
        pl.all().sample(.001)
    )
df = df.explode(df.columns[1:])

Update 2

That approximate solution is just the same as sampling the whole dataframe and doing a groupby after. No good.


Solution

  • Let start with some dummy data:

    n = 100
    seed = 0
    
    df = pl.DataFrame({
        "groups": (pl.int_range(n, eager=True) % 5).shuffle(seed=seed),
        "values": pl.int_range(n, eager=True).shuffle(seed=seed)
    })
    
    shape: (100, 2)
    ┌────────┬────────┐
    │ groups ┆ values │
    │ ---    ┆ ---    │
    │ i64    ┆ i64    │
    ╞════════╪════════╡
    │ 0      ┆ 55     │
    │ 0      ┆ 40     │
    │ 2      ┆ 57     │
    │ 4      ┆ 99     │
    │ 4      ┆ 4      │
    │ …      ┆ …      │
    │ 0      ┆ 90     │
    │ 2      ┆ 87     │
    │ 1      ┆ 96     │
    │ 3      ┆ 43     │
    │ 4      ┆ 44     │
    └────────┴────────┘
    

    This gives us 100 / 5, is 5 groups of 20 elements. Let's verify that:

    df.group_by("groups").agg(pl.len())
    
    shape: (5, 2)
    ┌────────┬─────┐
    │ groups ┆ len │
    │ ---    ┆ --- │
    │ i64    ┆ u32 │
    ╞════════╪═════╡
    │ 0      ┆ 20  │
    │ 4      ┆ 20  │
    │ 2      ┆ 20  │
    │ 3      ┆ 20  │
    │ 1      ┆ 20  │
    └────────┴─────┘
    

    Sample our data

    Now we are going to use a window function to take a sample of our data.

    df.filter(
        pl.int_range(pl.len()).shuffle().over("groups") < 10
    )
    
    shape: (50, 2)
    ┌────────┬────────┐
    │ groups ┆ values │
    │ ---    ┆ ---    │
    │ i64    ┆ i64    │
    ╞════════╪════════╡
    │ 0      ┆ 55     │
    │ 2      ┆ 57     │
    │ 4      ┆ 99     │
    │ 4      ┆ 4      │
    │ 1      ┆ 81     │
    │ …      ┆ …      │
    │ 2      ┆ 22     │
    │ 1      ┆ 76     │
    │ 3      ┆ 98     │
    │ 0      ┆ 90     │
    │ 4      ┆ 44     │
    └────────┴────────┘
    

    For every group in over("group") the pl.int_range(pl.len()) expression creates an index row. We then shuffle that range so that we take a sample and not a slice. Then we only want to take the index values that are lower than 10. This creates a boolean mask that we can pass to the filter method.