Search code examples
pythonpython-3.xnumpypython-polars

How to convert Polars dataframe to numpy array which has certain dims?


I have a Polars DataFrame with 300 basins, each basin having 100,000 time records, and each time record consisting of 40 variables, totaling 30 million rows and 40 variables. How can I reconstruct it into a NumPy array with the shape (300, 100,000, 40) without disrupting the corresponding data indices?

Example:

shape: (10, 7)
┌──────────────┬─────────────┬─────────────┬─────────────┬─────────────┬─────────────┬─────────────┐
│ HQprecipitat ┆ IRprecipita ┆ precipitati ┆ precipitati ┆ randomError ┆ basin_id    ┆ time        │
│ ion          ┆ tion        ┆ onCal       ┆ onUncal     ┆ ---         ┆ ---         ┆ ---         │
│ ---          ┆ ---         ┆ ---         ┆ ---         ┆ f32         ┆ str         ┆ datetime[μs │
│ f32          ┆ f32         ┆ f32         ┆ f32         ┆             ┆             ┆ ]           │
╞══════════════╪═════════════╪═════════════╪═════════════╪═════════════╪═════════════╪═════════════╡
│ null         ┆ null        ┆ null        ┆ null        ┆ null        ┆ anhui_62909 ┆ 1980-01-01  │
│              ┆             ┆             ┆             ┆             ┆ 400         ┆ 09:00:00    │
│ null         ┆ null        ┆ null        ┆ null        ┆ null        ┆ anhui_62909 ┆ 1980-01-01  │
│              ┆             ┆             ┆             ┆             ┆ 400         ┆ 12:00:00    │
│ null         ┆ null        ┆ null        ┆ null        ┆ null        ┆ anhui_62909 ┆ 1980-01-01  │
│              ┆             ┆             ┆             ┆             ┆ 400         ┆ 15:00:00    │
│ null         ┆ null        ┆ null        ┆ null        ┆ null        ┆ anhui_62909 ┆ 1980-01-01  │
│              ┆             ┆             ┆             ┆             ┆ 400         ┆ 18:00:00    │
│ null         ┆ null        ┆ null        ┆ null        ┆ null        ┆ anhui_62909 ┆ 1980-01-01  │
│              ┆             ┆             ┆             ┆             ┆ 400         ┆ 21:00:00    │
│ null         ┆ null        ┆ null        ┆ null        ┆ null        ┆ anhui_62909 ┆ 1980-01-02  │
│              ┆             ┆             ┆             ┆             ┆ 400         ┆ 00:00:00    │
│ null         ┆ null        ┆ null        ┆ null        ┆ null        ┆ anhui_62909 ┆ 1980-01-02  │
│              ┆             ┆             ┆             ┆             ┆ 400         ┆ 03:00:00    │
│ null         ┆ null        ┆ null        ┆ null        ┆ null        ┆ anhui_62909 ┆ 1980-01-02  │
│              ┆             ┆             ┆             ┆             ┆ 400         ┆ 06:00:00    │
│ null         ┆ null        ┆ null        ┆ null        ┆ null        ┆ anhui_62909 ┆ 1980-01-02  │
│              ┆             ┆             ┆             ┆             ┆ 400         ┆ 09:00:00    │
│ null         ┆ null        ┆ null        ┆ null        ┆ null        ┆ anhui_62909 ┆ 1980-01-02  │
│              ┆             ┆             ┆             ┆             ┆ 400         ┆ 12:00:00    │
└──────────────┴─────────────┴─────────────┴─────────────┴─────────────┴─────────────┴─────────────┘
# It should be reshaped to a numpy array which shape is (1, 10, 7)
# 1 means amount of basins, 10 means amount of times, 7 means width or amount of variables.

Solution

  • Now I use group_by and slice to complete this.

    station_len = len(self.x['basin_id'].unique())
    x_truncated = (self.x.group_by('basin_id', maintain_order=True).agg(pl.all().slice(0, len(self.x) // station_len)).
                           explode(pl.exclude("basin_id"))