Search code examples
pythonnumpypython-xarraycumsum

Convert cumsum() output to binary array in xarray


I have a 3D x-array that computes the cumulative sum for specific time periods and I'd like to detect which time periods meet a certain condition (and set to 1) and those which do not meet this condition (set to zero). I'll explain using the code below:

import pandas as pd
import xarray as xr
import numpy as np

# Create demo x-array
data = np.random.rand(20, 5, 5)
times = pd.date_range('2000-01-01', periods=20)
lats = np.arange(10, 0, -2)
lons = np.arange(0, 10, 2)
data = xr.DataArray(data, coords=[times, lats, lons], dims=['time', 'lat', 'lon'])
data.values[6:12] = 0 # Ensure some values are set to zero so that the cumsum can reset between valid time steps
data.values[18:] = 0

# This creates an xarray whereby the cumsum is calculated but resets each time a zero value is found
cumulative = data.cumsum(dim='time')-data.cumsum(dim='time').where(data.values == 0).ffill(dim='time').fillna(0)

print(cumulative[:,0,0])

>>> <xarray.DataArray (time: 20)>
array([0.13395 , 0.961934, 1.025337, 1.252985, 1.358501, 1.425393, 0.      ,
       0.      , 0.      , 0.      , 0.      , 0.      , 0.366988, 0.896463,
       1.728956, 2.000537, 2.316263, 2.922798, 0.      , 0.      ])
Coordinates:
  * time     (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-20
    lat      int64 10
    lon      int64 0

The print statement shows that the cumulative sum resets each time a zero is encountered on the time dimension. I need a solution to identify, which of the two periods exceeds a value of 2 and convert to a binary array to confirm where the conditions are met.

So my expected output would be (for this specific example):

<xarray.DataArray (time: 20)>
array([0.      , 0.      , 0.      , 0.      , 0.      , 0.     , 0.     ,
       0.      , 0.      , 0.      , 0.      , 0.      , 1.     , 1.     ,
       1.      , 1.      , 1.      , 1.      , 0.      , 0.     ])

Solution

  • Solved this using some masking and the backfill functionality:

    # make something to put results in
    out = xr.full_like(cumulative, fill_value=0.0)
    
    # find the points which have met the criteria
    out.values[cumulative.values > 3] = 1
    # fill the other valid sections over 0, with nans so we can fill them
    out.values[(cumulative.values>0) & (cumulative.values<3)] = np.nan
    
    # backfill it, so the ones that have not reached 2 are filled with 0
    # and the ones that have are filled with 1
    out_ds = out.bfill(dim='time').fillna(1)
    
    print ('Cumulative array:')
    print (cumulative.values[:,0,0])
    print (' ')
    print ('Binary array')
    print (out_ds.values[:,0,0])