Search code examples
pythonmeanpython-xarrayrolling-computation

Rolling mean continuing into next year (xarray)


I have an xarray with the dimensions (year: 5, lat: 90, lon: 180, month: 12). I can now calculate the rolling mean across 3 month using my_xarray = my_xarray.rolling(month=3).mean()

The problem is that the rolling function does not continue into the next year after december of the previous year (ie the plots for Jan and Feb of each year are blank as it starts the rolling window anew for each year).

Can I somehow specify that it should jump into the next year (and month) column when it reaches the end of the month column?

Hope it is understandable what I want to achieve. Thanks for any help!

Edit: If it helps those are the results when I use

  1. print(my_xarray.dims) <xarray.DataArray (year: 5, lat: 90, lon: 180, month: 12)>

  2. print(my_xarray) before taking the rolling mean:

          -9.87300873e-02, -2.58998200e-03, -1.67404532e-01],
         [ 5.95971942e-04, -2.02189982e-01, -3.97106633e-03, ...,
          -9.64657962e-02, -3.48943099e-03, -1.64729238e-01],
         [ 3.09602171e-03, -2.09298491e-01, -1.11376867e-02, ...,
          -9.64361429e-02, -3.36983800e-03, -1.62733972e-01],
         ...,
         [-6.85611367e-03, -1.94556922e-01,  4.57027294e-02, ...,
          -8.56379271e-02, -4.38956916e-03, -1.74577653e-01],
         [-4.64860350e-03, -2.00546771e-01,  3.28682028e-02, ...,
          -8.63482431e-02, -5.57301566e-03, -1.73252046e-01],
         [-4.17149812e-03, -2.02498823e-01,  2.37097144e-02, ...,
          -8.98122042e-02, -4.10436466e-03, -1.72041461e-01]],

        [[-6.76314309e-02, -5.28460778e-02,  1.12987854e-01, ...,
          -1.75108999e-01,  1.14214182e-01, -9.38383192e-02],
         [-3.71367447e-02, -1.19695403e-02,  6.92197084e-02, ...,
          -1.66514024e-01,  1.31363243e-01, -1.02556169e-01],
         [-5.75000793e-03, -1.72003862e-02,  5.47835231e-02, ...,
          -1.55288070e-01,  1.24138020e-01, -1.03031531e-01],
...
          -2.58931130e-01,  8.03834945e-02, -1.80395544e-01],
         [ 3.55556488e-01, -7.68683434e-01,  3.21449339e-03, ...,
          -2.84671545e-01,  5.23177236e-02, -1.65052935e-01],
         [ 3.99193943e-01, -7.59860992e-01,  5.04764691e-02, ...,
          -2.98249483e-01,  3.26042697e-02, -1.58649802e-01]],

        [[ 3.25531572e-01, -4.28714514e-01, -1.47960767e-01, ...,
          -1.24289311e-01, -3.02775592e-01, -3.59893829e-01],
         [ 3.32164109e-01, -4.26804453e-01, -1.53042451e-01, ...,
          -1.20779485e-01, -3.07494372e-01, -3.57666224e-01],
         [ 3.45293462e-01, -4.26565051e-01, -1.55301645e-01, ...,
          -1.20180212e-01, -3.11209410e-01, -3.45913649e-01],
         ...,
         [ 2.99354017e-01, -4.30373788e-01, -1.71406969e-01, ...,
          -1.09746858e-01, -2.76240230e-01, -3.72962207e-01],
         [ 3.06181461e-01, -4.35510933e-01, -1.72495663e-01, ...,
          -1.13980271e-01, -2.79644579e-01, -3.66411239e-01],
         [ 3.18018258e-01, -4.34309036e-01, -1.64760321e-01, ...,
          -1.23182893e-01, -2.91709840e-01, -3.65398616e-01]]]],
      dtype=float32)
Coordinates:
  * lon      (lon) float64 0.0 2.0 4.0 6.0 8.0 ... 350.0 352.0 354.0 356.0 358.0
  * lat      (lat) float64 -89.0 -87.0 -85.0 -83.0 -81.0 ... 83.0 85.0 87.0 89.0
    height   float64 2.0
  * month    (month) int64 1 2 3 4 5 6 7 8 9 10 11 12
  * year     (year) int64 2020 2021 2022 2023 2024
('year', 'lat', 'lon', 'month')
  1. And after taking the rolling mean: my_xarray = my_xarray.rolling(month=3).mean() print(my_xarray)
<xarray.DataArray (year: 5, lat: 90, lon: 180, month: 12)>
array([[[[            nan,             nan, -6.64931387e-02, ...,
          -9.65834657e-02, -4.84402974e-02, -8.95748734e-02],
         [            nan,             nan, -6.85216933e-02, ...,
          -9.58202779e-02, -4.96433278e-02, -8.82281562e-02],
         [            nan,             nan, -7.24467238e-02, ...,
          -9.80513891e-02, -5.37225107e-02, -8.75133177e-02],
         ...,
         [            nan,             nan, -5.19034366e-02, ...,
          -9.29711560e-02, -3.84746144e-02, -8.82017215e-02],
         [            nan,             nan, -5.74423869e-02, ...,
          -9.49127277e-02, -4.14346159e-02, -8.83911053e-02],
         [            nan,             nan, -6.09868666e-02, ...,
          -9.67354774e-02, -4.46880311e-02, -8.86526704e-02]],

        [[            nan,             nan, -2.49655296e-03, ...,
          -3.19432567e-02, -3.28139116e-02, -5.15777121e-02],
         [            nan,             nan,  6.70447449e-03, ...,
          -2.96478843e-02, -2.62145599e-02, -4.59023168e-02],
         [            nan,             nan,  1.06110424e-02, ...,
          -2.02979098e-02, -2.67094250e-02, -4.47271963e-02],
...
         [            nan,             nan, -1.55030757e-01, ...,
          -9.92223521e-02, -8.67839058e-02, -1.19647721e-01],
         [            nan,             nan, -1.36637489e-01, ...,
          -1.22766892e-01, -1.13554617e-01, -1.32468919e-01],
         [            nan,             nan, -1.03396863e-01, ...,
          -1.32896582e-01, -1.27950917e-01, -1.41431669e-01]],

        [[            nan,             nan, -8.37145646e-02, ...,
          -6.00561102e-02, -1.46990995e-01, -2.62319565e-01],
         [            nan,             nan, -8.25609316e-02, ...,
          -5.84986111e-02, -1.46998684e-01, -2.61980017e-01],
         [            nan,             nan, -7.88577447e-02, ...,
          -5.79771499e-02, -1.48239036e-01, -2.59101093e-01],
         ...,
         [            nan,             nan, -1.00808918e-01, ...,
          -5.09810448e-02, -1.30277574e-01, -2.52983093e-01],
         [            nan,             nan, -1.00608379e-01, ...,
          -5.37393292e-02, -1.33528948e-01, -2.53345370e-01],
         [            nan,             nan, -9.36836998e-02, ...,
          -5.75257987e-02, -1.41069442e-01, -2.60097106e-01]]]])
Coordinates:
  * lon      (lon) float64 0.0 2.0 4.0 6.0 8.0 ... 350.0 352.0 354.0 356.0 358.0
  * lat      (lat) float64 -89.0 -87.0 -85.0 -83.0 -81.0 ... 83.0 85.0 87.0 89.0
    height   float64 2.0
  * month    (month) int64 1 2 3 4 5 6 7 8 9 10 11 12
  * year     (year) int64 2020 2021 2022 2023 2024

Solution

  • Ok so the rolling function "restarts" because the month dimension corresponds to different rows, one row per year.

    One way to do what you want could be the following. I created some dummy data similar to yours like this:

    import numpy as np
    import pandas as pd
    import xarray as xr
    
    da = xr.DataArray(
        np.random.random(size=(2,12)),
        dims=("year","month"),
        coords={"month":np.linspace(1, 12, num=12).astype(int),
                "year":[2000,2001]
               },
    
    )
    print(da)
    

    then I used the stack method to create a new dimension where year and month are combined and on that dimension I applied the rolling window:

    my_xarray = da.stack(z=("year", "month")).rolling(z=3).mean()
    print(my_xarray)
    

    it seems to give what you want:

    xarray.DataArrayz: 24
    array([       nan,        nan, 0.60642737, 0.67814489, 0.44616648,
           0.45587241, 0.36101104, 0.33491579, 0.39246105, 0.42972596,
           0.54526778, 0.55617721, 0.46796958, 0.46491759, 0.44476617,
           0.47922742, 0.58516182, 0.55660812, 0.4536117 , 0.33743334,
           0.27727016, 0.3451959 , 0.49314071, 0.63349366])
    Coordinates:
    z
    (z)
    MultiIndex
    (year, month)
    array([(2000, 1), (2000, 2), (2000, 3), (2000, 4), (2000, 5), (2000, 6),
           (2000, 7), (2000, 8), (2000, 9), (2000, 10), (2000, 11), (2000, 12),
           (2001, 1), (2001, 2), (2001, 3), (2001, 4), (2001, 5), (2001, 6),
           (2001, 7), (2001, 8), (2001, 9), (2001, 10), (2001, 11), (2001, 12)],
          dtype=object)
    year
    (z)
    int64
    2000 2000 2000 ... 2001 2001 2001
    array([2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000, 2000,
           2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001, 2001])
    month
    (z)
    int64
    1 2 3 4 5 6 7 ... 6 7 8 9 10 11 12
    array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12,  1,  2,  3,  4,  5,  6,
            7,  8,  9, 10, 11, 12])