Consider the following dataframe:
df = pl.DataFrame({
"letters": ["A", "B", "C", "D", "E", "F", "G", "H"],
"values": ["aa", "bb", "cc", "dd", "ee", "ff", "gg", "hh"]
})
print(df)
shape: (8, 2)
┌─────────┬────────┐
│ letters ┆ values │
│ --- ┆ --- │
│ str ┆ str │
╞═════════╪════════╡
│ A ┆ aa │
│ B ┆ bb │
│ C ┆ cc │
│ D ┆ dd │
│ E ┆ ee │
│ F ┆ ff │
│ G ┆ gg │
│ H ┆ hh │
└─────────┴────────┘
How do I take a window of size +/- N around any row that satisfies a given condition? For example, the condition is pl.col("letters").contains("D|F")
and N = 2
. Then, the output should be:
┌─────────┬────────────────────────────────┐
│ letters ┆ output │
│ --- ┆ --- │
│ str ┆ list[str] │
╞═════════╪════════════════════════════════╡
│ D ┆ ["bb", "cc", "dd", "ee", "ff"] │
│ F ┆ ["dd", "ee", "ff", "gg", "hh"] │
└─────────┴────────────────────────────────┘
Note that the windows are overlapping in this case (the F
window also contains dd
and the D
windows also contains ff
). Also, note that N = 2 for the sake of simplicity here but, in reality, it'll be larger (~10 - 20). And the dataset is relatively large so I'd like to do this as efficiently as possible without exploding memory usage.
EDIT: To make the ask more explicit, here's the query in DuckDB's SQL syntax that gives the right answer (and I'd like to know how to translate it to Polars):
df_table = df.to_arrow()
con = duckdb.connect()
query = """
SELECT
letters,
list(values) OVER (
ROWS BETWEEN 2 PRECEDING
AND 2 FOLLOWING
) as combined
FROM df_table
QUALIFY letters in ('D', 'F')
"""
print(pl.from_arrow(con.execute(query).arrow()))
shape: (2, 2)
┌─────────┬────────────────────────┐
│ letters ┆ combined │
│ --- ┆ --- │
│ str ┆ list[str] │
╞═════════╪════════════════════════╡
│ D ┆ ["bb", "cc", ... "ff"] │
│ F ┆ ["dd", "ee", ... "hh"] │
└─────────┴────────────────────────┘
I ran the suggested solutions in a Jupyter notebook on one of Amazon's ml.c5.xlarge
machines. While the notebook was running, I also kept htop
open in a terminal to observe CPU and memory use. The dataset had 12M+ rows.
I ran both solutions via both the eager and lazy APIs. For good measure, I also tried using a simple Python for loop to extract the slices after identifying the rows of interest and also DuckDB.
Polars had really robust performance and judicious memory use (with the @jqurious' method) because of the clever, no-copy implementation of .shift()
. Surprisingly, a well-thought out Python for loop did just as well. DuckDB had performed rather poorly in both speed and memory use.
Neither Polars nor DuckDB uses more than one core for the operation. Not sure if that's due to a lack of optimization or if this problem is just amenable to parallelization. I suppose we're only filtering over one column and then taking slices of that same column so there's not much multiple threads can do.
method | cpu use | memory use | time |
---|---|---|---|
ΩΠΟΚΕΚΡΥΜΜΕΝΟΣ | single core | explosion | |
jqurious | single core | 2.53G to 2.53G | 4.63 s |
(smart) for loop | single core | 2.53G to 2.58G | 4.91 s |
DuckDB | single core | 1.62G to 6.13G | 38.6 s |
preceding = 2
following = 2
look_around = [pl.col("body").shift(-i)
for i in range(-preceding, following + 1)]
(
df
.with_columns(
pl.when(pl.col('body').str.contains(regex))
.then(pl.concat_list(look_around))
.alias('combined')
)
.filter(pl.col('combined').is_not_null())
)
Unfortunately, on my rather large dataset, this solution caused the memory use to explode and the kernel to crash with both the eager and lazy APIs.
preceding = 2
following = 2
look_around = [
pl.col("body").shift(-i).alias(f"lag_{i}") for i in range(-preceding, following + 1)
]
(
df
.with_columns(
look_around
)
.filter(pl.col("body").str.contains(regex))
.select(
pl.col("body"),
pl.concat_list([f"lag_{i}" for i in range(-2, 3)]).alias("output")
)
)
eager:
lazy:
preceding = 2
following = 2
output = []
indices = df.with_row_index().select(
pl.col("index").filter(pl.col("body").str.contains(regex))
)["index"]
for idx, x in enumerate(indices):
offset = max(0, x - preceding)
length = preceding + following + 1
output.append(df["body"].slice(offset, length))
Note that I first converted the df
to an Arrow.Table
before running the query so DuckDB could directly act on it. Also, I'm not sure if the conversion of the result back to Arrow takes up a huge amount of computation and is unfair to it.
preceding = 2
following = 2
query = f"""
SELECT
body,
list(body) OVER (
ROWS BETWEEN {preceding} PRECEDING
AND {following} FOLLOWING
) as combined
FROM df_table
QUALIFY regexp_matches(body, '{regex}')
"""
result = con.execute(query).arrow()
With DuckDB, my first attempt to run the computation crashed. I had to retry by reading to an Arrow Table directly without using Polars (this saved about 1GB of memory) to give DuckDB more memory to use.
first try:
second try:
A modification of Use the rolling function of polars to get a list of all values in the rolling windows
(
df.with_columns(
pl.col("values").shift(i).alias(f"lag_{i}") for i in range(-2, 3)
)
.filter(pl.col("letters").str.contains("D|F"))
.select(
pl.col("letters"),
pl.concat_list(reversed([f"lag_{i}" for i in range(-2, 3)]))
.alias("output")
)
)
shape: (2, 2)
┌─────────┬────────────────────────────────┐
│ letters | output │
│ --- | --- │
│ str | list[str] │
╞═════════╪════════════════════════════════╡
│ D | ["bb", "cc", "dd", "ee", "ff"] │
│ F | ["dd", "ee", "ff", "gg", "hh"] │
└─────────┴────────────────────────────────┘