Search code examples
pandasdataframepython-polarsrust-polars

How to create a new column based on the common start word between two series in a Polars DataFrame?


I have a Polars DataFrame consisting of two series, 'foo' and 'bar', which contain lists of integers. I want to create a new column that assigns a value of 1 if the start word (first element) of the 'foo' series is equal to the start word of the 'bar' series, and 0 otherwise. I'm using Polars, which seems a bit more complex than pandas.

Here is the example DataFrame I'm working with:

dff = pl.DataFrame({"foo": [[1, 3, 3, 3], [3, 5, 3, 4], [4, 7, 5, 3]], "bar": [[3, 345, 3, 4], [3, 4, 334, 2], [4, 52, 4, 2]]})

The shape of the DataFrame is:

shape: (3, 2)
┌─────────────┬───────────────┐
│ foo         ┆ bar           │
│ ---         ┆ ---           │
│ list[i64]   ┆ list[i64]     │
╞═════════════╪═══════════════╡
│ [1, 3, … 3] ┆ [3, 345, … 4] │
│ [3, 5, … 4] ┆ [3, 4, … 2]   │
│ [4, 7, … 3] ┆ [4, 52, … 2]  │
└─────────────┴───────────────┘

I would like to create a new column named 'common start' that reflects whether the start word of 'foo' matches the start word of 'bar'. The expected output for the 'common start' column in this case would be: [0, 1, 1].

shape: (3, 3)
┌─────────────┬───────────────┬──────────────┐
│ foo         ┆ bar           ┆ common_start │
│ ---         ┆ ---           ┆ ---          │
│ list[i64]   ┆ list[i64]     ┆ i64          │
╞═════════════╪═══════════════╪══════════════╡
│ [1, 3, … 3] ┆ [3, 345, … 4] ┆ 0            │
│ [3, 5, … 4] ┆ [3, 4, … 2]   ┆ 1            │
│ [4, 7, … 3] ┆ [4, 52, … 2]  ┆ 1            │
└─────────────┴───────────────┴──────────────┘

How can I achieve this? Any guidance or examples using Polars would be greatly appreciated.


Solution

  • The .list namespace contains the functionality for polars lists.

    You can use .get(0) or .first() to access the first item.

    df.with_columns(common_start =
       pl.col('foo').list.first() == pl.col('bar').list.first()
    )
    
    shape: (3, 3)
    ┌─────────────┬───────────────┬──────────────┐
    │ foo         ┆ bar           ┆ common_start │
    │ ---         ┆ ---           ┆ ---          │
    │ list[i64]   ┆ list[i64]     ┆ bool         │
    ╞═════════════╪═══════════════╪══════════════╡
    │ [1, 3, … 3] ┆ [3, 345, … 4] ┆ false        │
    │ [3, 5, … 4] ┆ [3, 4, … 2]   ┆ true         │
    │ [4, 7, … 3] ┆ [4, 52, … 2]  ┆ true         │
    └─────────────┴───────────────┴──────────────┘
    

    A common way to generate 0/1 is to cast a bool to an int:

    df.with_columns(common_start =
       (pl.col('foo').list.first() == pl.col('bar').list.first()).cast(int)
    )
    
    shape: (3, 3)
    ┌─────────────┬───────────────┬──────────────┐
    │ foo         ┆ bar           ┆ common_start │
    │ ---         ┆ ---           ┆ ---          │
    │ list[i64]   ┆ list[i64]     ┆ i64          │
    ╞═════════════╪═══════════════╪══════════════╡
    │ [1, 3, … 3] ┆ [3, 345, … 4] ┆ 0            │
    │ [3, 5, … 4] ┆ [3, 4, … 2]   ┆ 1            │
    │ [4, 7, … 3] ┆ [4, 52, … 2]  ┆ 1            │
    └─────────────┴───────────────┴──────────────┘
    

    You can also use when/then/otherwise to specify specific values:

    df.with_columns(common_start = 
       pl.when(pl.col('foo').list.first() == pl.col('bar').list.first())
         .then(1)
         .otherwise(0)
    )