Search code examples
pythonpython-polars

Polars: pandas equivalent of selecting column names from a list


I have two DataFrames in polars, one that describes the meta data, and one of the actual data (LazyFrames are used as the actual data is larger):

import polars as pl
df = pl.LazyFrame(
    {
        "ID": ["CX1", "CX2", "CX3"],
        "Sample1": [1, 1, 1],
        "Sample2": [2, 2, 2],
        "Sample3": [4, 4, 4],
    }
)

df_meta = pl.LazyFrame(
    {
        "sample": ["Sample1", "Sample2", "Sa,mple3", "Sample4"],
        "qc": ["pass", "pass", "fail", "pass"]
    }
)

I need to select the columns in df for samples that have passing qc using the information in df_meta. As you can see, df_meta has an additional sample, which of course we are not interested in as it's not part of our data.

In pandas, I'd do (not very elegant but does the job):

df.loc[:, df.columns.isin(df_meta.query("qc == 'pass'")["sample"])]

However I'm not sure about how doing this in polars. Reading through SO and the docs didn't give me a definite answer.

I've tried:

df.with_context(
   df_meta.filter(pl.col("qc") == "pass").select(pl.col("sample").alias("meta_ids"))
).with_columns(
    pl.all().is_in("meta_ids")
).collect()

Which however raises an exception:

InvalidOperationError: `is_in` cannot check for String values in Int64 data

I assume it's checking the content of the columns, but I'm interested in the column names.

I've also tried:

meta_ids = df_meta.filter(pl.col("qc") == "pass").get_column("sample")
df.select(pl.col(meta_ids))

but as expected, an exception is raised as there's one sample not accounted for in the first dataFrame:

ColumnNotFoundError: Sample4

What would be the correct way to do this?


Solution

  • Just to build upon https://stackoverflow.com/a/78676922/ - I find require_all=False rather cryptic.

    It is also possible to set intersect with the cs.all() selector:

    >>> meta_ids = df_meta.filter(pl.col("qc") == "pass")["sample"]
    >>> cs.all() & cs.by_name(meta_ids)
    (cs.all() & cs.by_name('Sample1', 'Sample2', 'Sample4'))
    
    df.select(cs.all() & cs.by_name(meta_ids))
    
    shape: (3, 2)
    ┌─────────┬─────────┐
    │ Sample1 ┆ Sample2 │
    │ ---     ┆ ---     │
    │ i64     ┆ i64     │
    ╞═════════╪═════════╡
    │ 1       ┆ 2       │
    │ 1       ┆ 2       │
    │ 1       ┆ 2       │
    └─────────┴─────────┘