Search code examples
python-3.xslicepython-polars

Why slice expression don't get correct indexes in polars DataFrame?


I have a polars dataframe which looks like this:

shape: (2_655_541, 4)
┌────────────┬────────────┬─────────────────┬─────────────────────┐
│ streamflow ┆ sm_surface ┆ basin_id        ┆ time                │
│ ---        ┆ ---        ┆ ---             ┆ ---                 │
│ f32        ┆ f32        ┆ str             ┆ datetime[μs]        │
╞════════════╪════════════╪═════════════════╪═════════════════════╡
│ null       ┆ null       ┆ camels_01022500 ┆ 2015-01-01 03:00:00 │
│ null       ┆ null       ┆ camels_01022500 ┆ 2015-01-01 06:00:00 │
│ null       ┆ null       ┆ camels_01022500 ┆ 2015-01-01 09:00:00 │
│ null       ┆ null       ┆ camels_01022500 ┆ 2015-01-01 12:00:00 │
│ null       ┆ null       ┆ camels_01022500 ┆ 2015-01-01 15:00:00 │
│ …          ┆ …          ┆ …               ┆ …                   │
│ 0.718293   ┆ 0.40595    ┆ HML_LOBO3       ┆ 2016-12-30 18:00:00 │
│ null       ┆ 0.40601    ┆ HML_LOBO3       ┆ 2016-12-30 21:00:00 │
│ null       ┆ 0.406075   ┆ HML_LOBO3       ┆ 2016-12-31 00:00:00 │
│ null       ┆ 0.406177   ┆ HML_LOBO3       ┆ 2016-12-31 03:00:00 │
│ null       ┆ 0.406333   ┆ HML_LOBO3       ┆ 2016-12-31 06:00:00 │
└────────────┴────────────┴─────────────────┴─────────────────────┘

Now I want to slice data for every basins, so I run code below:

df1 = (valid_ds.y_origin.group_by('basin_id', maintain_order=True).agg(pl.all().slice(0, 2865)).explode(pl.exclude('basin_id')))

Result is this:

shape: (2_604_285, 4)
┌─────────────────┬────────────┬────────────┬─────────────────────┐
│ basin_id        ┆ streamflow ┆ sm_surface ┆ time                │
│ ---             ┆ ---        ┆ ---        ┆ ---                 │
│ str             ┆ f32        ┆ f32        ┆ datetime[μs]        │
╞═════════════════╪════════════╪════════════╪═════════════════════╡
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-01 03:00:00 │
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-01 06:00:00 │
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-01 09:00:00 │
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-01 12:00:00 │
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-01 15:00:00 │
│ …               ┆ …          ┆ …          ┆ …                   │
│ HML_LOBO3       ┆ 0.898755   ┆ 0.424079   ┆ 2016-12-23 15:00:00 │
│ HML_LOBO3       ┆ 0.88542    ┆ 0.419914   ┆ 2016-12-23 18:00:00 │
│ HML_LOBO3       ┆ 0.868826   ┆ 0.417434   ┆ 2016-12-23 21:00:00 │
│ HML_LOBO3       ┆ 0.855195   ┆ 0.416104   ┆ 2016-12-24 00:00:00 │
│ HML_LOBO3       ┆ 0.848972   ┆ 0.415531   ┆ 2016-12-24 03:00:00 │
└─────────────────┴────────────┴────────────┴─────────────────────┘

However when I change slice from (0, 2865) to (1, 2865), output became this:

shape: (2_604_285, 4)
┌─────────────────┬────────────┬────────────┬─────────────────────┐
│ basin_id        ┆ streamflow ┆ sm_surface ┆ time                │
│ ---             ┆ ---        ┆ ---        ┆ ---                 │
│ str             ┆ f32        ┆ f32        ┆ datetime[μs]        │
╞═════════════════╪════════════╪════════════╪═════════════════════╡
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-01 06:00:00 │
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-01 09:00:00 │
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-01 12:00:00 │
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-01 15:00:00 │
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-01 18:00:00 │
│ …               ┆ …          ┆ …          ┆ …                   │
│ HML_LOBO3       ┆ 0.88542    ┆ 0.419914   ┆ 2016-12-23 18:00:00 │
│ HML_LOBO3       ┆ 0.868826   ┆ 0.417434   ┆ 2016-12-23 21:00:00 │
│ HML_LOBO3       ┆ 0.855195   ┆ 0.416104   ┆ 2016-12-24 00:00:00 │
│ HML_LOBO3       ┆ 0.848972   ┆ 0.415531   ┆ 2016-12-24 03:00:00 │
│ HML_LOBO3       ┆ 0.838897   ┆ 0.41831    ┆ 2016-12-24 06:00:00 │
└─────────────────┴────────────┴────────────┴─────────────────────┘

You can see the first time has changed, but total length of dataframe has not.

And when I use slice(100, 2865), length of result became below:

shape: (2_564_641, 4)
┌─────────────────┬────────────┬────────────┬─────────────────────┐
│ basin_id        ┆ streamflow ┆ sm_surface ┆ time                │
│ ---             ┆ ---        ┆ ---        ┆ ---                 │
│ str             ┆ f32        ┆ f32        ┆ datetime[μs]        │
╞═════════════════╪════════════╪════════════╪═════════════════════╡
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-13 15:00:00 │
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-13 18:00:00 │
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-13 21:00:00 │
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-14 00:00:00 │
│ camels_01022500 ┆ null       ┆ null       ┆ 2015-01-14 03:00:00 │
│ …               ┆ …          ┆ …          ┆ …                   │
│ HML_LOBO3       ┆ 0.718293   ┆ 0.40595    ┆ 2016-12-30 18:00:00 │
│ HML_LOBO3       ┆ null       ┆ 0.40601    ┆ 2016-12-30 21:00:00 │
│ HML_LOBO3       ┆ null       ┆ 0.406075   ┆ 2016-12-31 00:00:00 │
│ HML_LOBO3       ┆ null       ┆ 0.406177   ┆ 2016-12-31 03:00:00 │
│ HML_LOBO3       ┆ null       ┆ 0.406333   ┆ 2016-12-31 06:00:00 │
└─────────────────┴────────────┴────────────┴─────────────────────┘
len(df3['basin_id'].unique()) = 909, 2564641 // 909 = 2821

It's obvious that 2865-2821 is not 100.

So what happened to slice() expression and how to solve it?


Solution

  • I guess the confusion is coming from the fact that the second argument to pl.slice is the length, not the last index. So, to get the slice from element #100 till the end, you need something like

    pl.slice(100, pl.len()-100)