Search code examples
listpython-polars

Polars list nested-eval, how can I achieve something like Spark's transform


Lets say I have a column named structures which is a list of structs containing "a" and "b" inside of each. I want to avoid using explodes and such as my data structure is quite complex (and quite big), I want to work as much as possible with the lists alone.


    # Create a list of lists of structures
    list_of_structs = [
        [{"a":1, "b": 2}, {"a":3, "b": 4}, {"a":5, "b": 4}]
    ]

    # Create a DataFrame from the list of lists
    df = pl.DataFrame({"structures": list_of_structs})

My expected result would be a list of lists grouped by B

    [[(a1, b2)], [[(a3,b4)], (a5, b4)]] 

I want to perform some kind of "group all elements by B" (and then agg using concat_list), in spark the code (see how I reference X and Y) looks like this:


    arrays_grouped = F.array_distinct(
            F.transform(
                F.col("structures"),
                lambda x: (
                    F.filter(F.col("structures"),
                        lambda y: x.field("b") == y.field("b")
                    )
                ),
            )
        )

However in polars I can only see eval operator. https://docs.pola.rs/py-polars/html/reference/expressions/api/polars.Expr.list.eval.html#polars.Expr.list.eval. However I can't reference anything apart from pl.element() from outside the eval, so I'm quite stuck.

Do I need to implement this in plugins or there's a way with the provided API? This is my existing approach which does not work (probably influenced because I have been working quite long time with Spark functions).

    df = df.with_columns(
            #Gets all unique "b"
            pl.col("structures").list.eval(
                pl.element().struct.field("b")
            #Tries to filter structures all unique "b"
            ).list.unique().list.eval(
                pl.struct(
                    base:= pl.element(), 
                    df.get_column("structures").list.eval(
                        #Idk what this base value is, but its not filtering, if I replace it by hardcoded-4 it does filter 4s correctly
                        pl.element().filter(pl.element().struct.field("b") == base)
                    )
                )
            )
    )

I think something might be doable with gather, but I'm struggling to even start using it, as I can't find the indices to gather (similar issue to eval)


Solution

  • list_of_structs = [
        [{"a":1, "b": 2}, {"a":3, "b": 4}, {"a":5, "b": 4}],
        [{"a":2, "b": 5}, {"a":3, "b": 5}, {"a":4, "b": 3}]
    ]
            
    df = pl.DataFrame({"structures": list_of_structs})
    
    (df.select("structures")
       .with_row_index()
       .explode("structures")
       .unnest("structures")
       .group_by("index", "b", maintain_order=True) 
       .agg(
          pl.struct("a", "b")
       )
       .drop("b")
       .group_by("index", maintain_order=True)
       .all()
    )
    
    shape: (2, 2)
    ┌───────┬───────────────────────────┐
    │ index ┆ a                         │
    │ ---   ┆ ---                       │
    │ u32   ┆ list[list[struct[2]]]     │
    ╞═══════╪═══════════════════════════╡
    │ 0     ┆ [[{1,2}], [{3,4}, {5,4}]] │
    │ 1     ┆ [[{2,5}, {3,5}], [{4,3}]] │
    └───────┴───────────────────────────┘
    

    Performance

    If I increase the row size as an example

    df_big = df.sample(100_000, with_replacement=True)
    

    And repeat the above approach:

    • Elapsed time: 0.45117 seconds

    For comparison, if we perform the first list.eval step in your example:

    df_big.with_columns(unique =
       pl.col("structures").list.eval(pl.element().struct["b"])
         .list.unique()
    )
    
    • Elapsed time: 1.99926 seconds

    Incidentally, it seems list.unique() is the major slowdown here, which can be avoided:

    df_big.with_columns(unique =
       pl.col("structures").list.eval(pl.element().struct["b"].unique())
    )
    
    • Elapsed time: 0.23198 seconds

    Which is an improvement on the list.unique() version, but even then it is just a single piece of the functionality.

    Adding further steps would likely end up much slower than the explode/group_by approach.