Search code examples
python-polars

polars - take a window of N rows surrounding a row fulfilling a condition


Consider the following dataframe:

df = pl.DataFrame({
    "letters": ["A", "B", "C", "D", "E", "F", "G", "H"],
    "values": ["aa", "bb", "cc", "dd", "ee", "ff", "gg", "hh"]
})

print(df)
shape: (8, 2)
┌─────────┬────────┐
│ letters ┆ values │
│ ---     ┆ ---    │
│ str     ┆ str    │
╞═════════╪════════╡
│ A       ┆ aa     │
│ B       ┆ bb     │
│ C       ┆ cc     │
│ D       ┆ dd     │
│ E       ┆ ee     │
│ F       ┆ ff     │
│ G       ┆ gg     │
│ H       ┆ hh     │
└─────────┴────────┘

How do I take a window of size +/- N around any row that satisfies a given condition? For example, the condition is pl.col("letters").contains("D|F") and N = 2. Then, the output should be:

┌─────────┬────────────────────────────────┐
│ letters ┆ output                         │
│ ---     ┆ ---                            │
│ str     ┆ list[str]                      │
╞═════════╪════════════════════════════════╡
│ D       ┆ ["bb", "cc", "dd", "ee", "ff"] │
│ F       ┆ ["dd", "ee", "ff", "gg", "hh"] │
└─────────┴────────────────────────────────┘

Note that the windows are overlapping in this case (the F window also contains dd and the D windows also contains ff). Also, note that N = 2 for the sake of simplicity here but, in reality, it'll be larger (~10 - 20). And the dataset is relatively large so I'd like to do this as efficiently as possible without exploding memory usage.


EDIT: To make the ask more explicit, here's the query in DuckDB's SQL syntax that gives the right answer (and I'd like to know how to translate it to Polars):

df_table = df.to_arrow()
con = duckdb.connect()
query = """
SELECT
    letters,
    list(values) OVER (
        ROWS BETWEEN 2 PRECEDING
                 AND 2 FOLLOWING
    ) as combined
FROM df_table
QUALIFY letters in ('D', 'F')
"""
print(pl.from_arrow(con.execute(query).arrow()))

shape: (2, 2)
┌─────────┬────────────────────────┐
│ letters ┆ combined               │
│ ---     ┆ ---                    │
│ str     ┆ list[str]              │
╞═════════╪════════════════════════╡
│ D       ┆ ["bb", "cc", ... "ff"] │
│ F       ┆ ["dd", "ee", ... "hh"] │
└─────────┴────────────────────────┘

Benchmarks of suggested solutions

I ran the suggested solutions in a Jupyter notebook on one of Amazon's ml.c5.xlarge machines. While the notebook was running, I also kept htop open in a terminal to observe CPU and memory use. The dataset had 12M+ rows.

I ran both solutions via both the eager and lazy APIs. For good measure, I also tried using a simple Python for loop to extract the slices after identifying the rows of interest and also DuckDB.

Summary Table

Polars had really robust performance and judicious memory use (with the @jqurious' method) because of the clever, no-copy implementation of .shift() . Surprisingly, a well-thought out Python for loop did just as well. DuckDB had performed rather poorly in both speed and memory use.

Neither Polars nor DuckDB uses more than one core for the operation. Not sure if that's due to a lack of optimization or if this problem is just amenable to parallelization. I suppose we're only filtering over one column and then taking slices of that same column so there's not much multiple threads can do.

method cpu use memory use time
ΩΠΟΚΕΚΡΥΜΜΕΝΟΣ single core explosion
jqurious single core 2.53G to 2.53G 4.63 s
(smart) for loop single core 2.53G to 2.58G 4.91 s
DuckDB single core 1.62G to 6.13G 38.6 s
  • cpu use shows if multiple cores were taxes during the operation
  • memory use shows how much memory was being used before the operation and the maximum memory use during the operation.

@ΩΠΟΚΕΚΡΥΜΜΕΝΟΣ's solution:

preceding = 2
following = 2

look_around = [pl.col("body").shift(-i)
               for i in range(-preceding, following + 1)]

(
    df
    .with_columns(
        pl.when(pl.col('body').str.contains(regex))
        .then(pl.concat_list(look_around))
        .alias('combined')
    )
    .filter(pl.col('combined').is_not_null())
)

Unfortunately, on my rather large dataset, this solution caused the memory use to explode and the kernel to crash with both the eager and lazy APIs.

@jqurious' solution

preceding = 2
following = 2

look_around = [
    pl.col("body").shift(-i).alias(f"lag_{i}") for i in range(-preceding, following + 1)
]

(
   df
    .with_columns(
      look_around
    )
    .filter(pl.col("body").str.contains(regex))
    .select(
      pl.col("body"),
      pl.concat_list([f"lag_{i}" for i in range(-2, 3)]).alias("output")
    )
)
  • eager:

    • cpu use: single-core
    • memory use: 2.53G -> 2.53G
    • time: 4.63 s ± 6.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • lazy:

    • cpu use: single-core
    • memory use: 2.53G -> 2.53G
    • time: 4.63 s ± 3.85 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

(Smart) Python for loop

preceding = 2
following = 2

output = []

indices = df.with_row_index().select(
    pl.col("index").filter(pl.col("body").str.contains(regex))
)["index"]

for idx, x in enumerate(indices):
    offset = max(0, x - preceding)
    length = preceding + following + 1
    output.append(df["body"].slice(offset, length))
  • cpu use: single-core
  • memory use: 2.53G -> 2.58G
  • time: 4.91 s ± 24.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

DuckDB

Note that I first converted the df to an Arrow.Table before running the query so DuckDB could directly act on it. Also, I'm not sure if the conversion of the result back to Arrow takes up a huge amount of computation and is unfair to it.

preceding = 2
following = 2

query = f"""
SELECT
    body,
    list(body) OVER (
        ROWS BETWEEN {preceding} PRECEDING
                 AND {following} FOLLOWING
    ) as combined
FROM df_table
QUALIFY regexp_matches(body, '{regex}')
"""

result = con.execute(query).arrow()

With DuckDB, my first attempt to run the computation crashed. I had to retry by reading to an Arrow Table directly without using Polars (this saved about 1GB of memory) to give DuckDB more memory to use.

  • first try:

    • cpu: single-core
    • memory: 2.53G -> 6.93G -> crash!
    • time: NA
  • second try:

    • cpu: single-core
    • memory: 1.62G -> 6.13G
    • time: 38.6 s ± 311 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Solution

  • A modification of Use the rolling function of polars to get a list of all values in the rolling windows

    (
       df.with_columns(
          pl.col("values").shift(i).alias(f"lag_{i}") for i in range(-2, 3)
       )
       .filter(pl.col("letters").str.contains("D|F"))
       .select(
          pl.col("letters"),
          pl.concat_list(reversed([f"lag_{i}" for i in range(-2, 3)]))
            .alias("output")
       )
    )
    
    shape: (2, 2)
    ┌─────────┬────────────────────────────────┐
    │ letters | output                         │
    │ ---     | ---                            │
    │ str     | list[str]                      │
    ╞═════════╪════════════════════════════════╡
    │ D       | ["bb", "cc", "dd", "ee", "ff"] │
    │ F       | ["dd", "ee", "ff", "gg", "hh"] │
    └─────────┴────────────────────────────────┘