Search code examples
numbapython-polars

Grouping by resetting cumulative sum


I have several files in a dataframe with their file_size. I want to group the files in groups under a file_size_threshold (in my example, the file_size_threshold is 3). As an example of the expected behavour: this is my input dataframe:

import polars as pl

data = {
    "bucket": ["bucket1"] * 8,
    "full_path": ["file1.txt","file2.txt","file3.txt","file4.txt","file5.txt","file6.txt","file7.txt","file8.txt"],
    "ETag": ["11c9d17dc657860c447e655fa79172f","21c9d17dc657860c447e655fa79172g","31c9d17dc657860c447e655fa79172f","41c9d17dc657860c447e655fa79172f","51c9d17dc657860c447e655fa79172f","61c9d17dc657860c447e655fa79172f","71c9d17dc657860c447e655fa79172f","81c9d17dc657860c447e655fa79172f"],
    "file_size": [1, 2, 2,2, 1, 3, 1, 2]
}
print(pl.DataFrame(data))
#shape: (8, 4)
#┌─────────┬───────────┬─────────────────────────────────┬───────────┐
#│ bucket  ┆ full_path ┆ ETag                            ┆ file_size │
#│ ---     ┆ ---       ┆ ---                             ┆ ---       │
#│ str     ┆ str       ┆ str                             ┆ i64       │
#╞═════════╪═══════════╪═════════════════════════════════╪═══════════╡
#│ bucket1 ┆ file1.txt ┆ 11c9d17dc657860c447e655fa79172f ┆ 1         │
#│ bucket1 ┆ file2.txt ┆ 21c9d17dc657860c447e655fa79172g ┆ 2         │
#│ bucket1 ┆ file3.txt ┆ 31c9d17dc657860c447e655fa79172f ┆ 2         │
#│ bucket1 ┆ file4.txt ┆ 41c9d17dc657860c447e655fa79172f ┆ 2         │
#│ bucket1 ┆ file5.txt ┆ 51c9d17dc657860c447e655fa79172f ┆ 1         │
#│ bucket1 ┆ file6.txt ┆ 61c9d17dc657860c447e655fa79172f ┆ 3         │
#│ bucket1 ┆ file7.txt ┆ 71c9d17dc657860c447e655fa79172f ┆ 1         │
#│ bucket1 ┆ file8.txt ┆ 81c9d17dc657860c447e655fa79172f ┆ 2         │
#└─────────┴───────────┴─────────────────────────────────┴───────────┘

And this is expected result (for this particular case, where file_size_threshold=3) :

shape: (8, 5)
┌─────────┬───────────┬─────────────────────────────────┬───────────┬──────────────┐
│ bucket  ┆ full_path ┆ ETag                            ┆ file_size ┆ group_number │
│ ---     ┆ ---       ┆ ---                             ┆ ---       ┆ ---          │
│ str     ┆ str       ┆ str                             ┆ i64       ┆ i64          │
╞═════════╪═══════════╪═════════════════════════════════╪═══════════╪══════════════╡
│ bucket1 ┆ file1.txt ┆ 11c9d17dc657860c447e655fa79172f ┆ 1         ┆ 1            │
│ bucket1 ┆ file2.txt ┆ 21c9d17dc657860c447e655fa79172g ┆ 2         ┆ 1            │
│ bucket1 ┆ file3.txt ┆ 31c9d17dc657860c447e655fa79172f ┆ 2         ┆ 2            │
│ bucket1 ┆ file4.txt ┆ 41c9d17dc657860c447e655fa79172f ┆ 2         ┆ 3            │
│ bucket1 ┆ file5.txt ┆ 51c9d17dc657860c447e655fa79172f ┆ 1         ┆ 3            │
│ bucket1 ┆ file6.txt ┆ 61c9d17dc657860c447e655fa79172f ┆ 3         ┆ 4            │
│ bucket1 ┆ file7.txt ┆ 71c9d17dc657860c447e655fa79172f ┆ 1         ┆ 5            │
│ bucket1 ┆ file8.txt ┆ 81c9d17dc657860c447e655fa79172f ┆ 2         ┆ 5            │
└─────────┴───────────┴─────────────────────────────────┴───────────┴──────────────┘

So, file1 and 2 are group number 1, file3 is group number 2, file4 and 5 is group number 3, file6 is group number 4 and file 7 and 8 is group number 5.

Clarification note: file2 and file3 are in a different group as file_size_threshold=3. File size of file1 is 1 so it is part of the 1st group. Second row cum file_size is 1+2=3, minor or equal than threshold then in the same group. When it comes to file3, you will have 1+2+2, that is bigger than 3 therefore file2 and file3 needs to be in an different group.

Any idea how to solve it using Polars?

I tried to use cum_sum() function and mod(), but I could not find a way to solve the problem.


Solution

  • I don't think this works with cum_sum and mod or floordiv because there's not an obvious way to make it restart when you create a new group.

    One way to overcome this is to make a compiled ufunc with numba's guvectorize decorator.

    import numba as nb
    
    @nb.guvectorize([(nb.int64[:], nb.int64, nb.int64[:])], '(n),()->(n)', nopython=True)
    def make_groups(file_sizes, threshold, results):
        cur_group_id = 1
        cur_group_size = 0
        for i in range(file_sizes.shape[0]):
            cur_group_size+=file_sizes[i]
            if cur_group_size > threshold:
                cur_group_size = file_sizes[i]
                cur_group_id+=1
            results[i]=cur_group_id
    

    then it's just

    df.with_columns(group=make_groups(pl.col('file_size'),3))
    shape: (8, 5)
    ┌─────────┬───────────┬─────────────────────────────────┬───────────┬───────┐
    │ bucket  ┆ full_path ┆ ETag                            ┆ file_size ┆ group │
    │ ---     ┆ ---       ┆ ---                             ┆ ---       ┆ ---   │
    │ str     ┆ str       ┆ str                             ┆ i64       ┆ i64   │
    ╞═════════╪═══════════╪═════════════════════════════════╪═══════════╪═══════╡
    │ bucket1 ┆ file1.txt ┆ 11c9d17dc657860c447e655fa79172f ┆ 1         ┆ 1     │
    │ bucket1 ┆ file2.txt ┆ 21c9d17dc657860c447e655fa79172g ┆ 2         ┆ 1     │
    │ bucket1 ┆ file3.txt ┆ 31c9d17dc657860c447e655fa79172f ┆ 2         ┆ 2     │
    │ bucket1 ┆ file4.txt ┆ 41c9d17dc657860c447e655fa79172f ┆ 2         ┆ 3     │
    │ bucket1 ┆ file5.txt ┆ 51c9d17dc657860c447e655fa79172f ┆ 1         ┆ 3     │
    │ bucket1 ┆ file6.txt ┆ 61c9d17dc657860c447e655fa79172f ┆ 3         ┆ 4     │
    │ bucket1 ┆ file7.txt ┆ 71c9d17dc657860c447e655fa79172f ┆ 1         ┆ 5     │
    │ bucket1 ┆ file8.txt ┆ 81c9d17dc657860c447e655fa79172f ┆ 2         ┆ 5     │
    └─────────┴───────────┴─────────────────────────────────┴───────────┴───────┘
    

    You can make an expression wrapper and register a namespace but I like monkey patching it to pl.Expr so you can call it from the column

    pl.Expr.make_groups = lambda self, threshold: (
        make_groups(self, threshold)
    )
    
    ##now you can do
    df.with_columns(group=pl.col('file_size').make_groups(3))