I have a 3D x-array that computes the cumulative sum for specific time periods and I'd like to detect which time periods meet a certain condition (and set to 1) and those which do not meet this condition (set to zero). I'll explain using the code below:
import pandas as pd
import xarray as xr
import numpy as np
# Create demo x-array
data = np.random.rand(20, 5, 5)
times = pd.date_range('2000-01-01', periods=20)
lats = np.arange(10, 0, -2)
lons = np.arange(0, 10, 2)
data = xr.DataArray(data, coords=[times, lats, lons], dims=['time', 'lat', 'lon'])
data.values[6:12] = 0 # Ensure some values are set to zero so that the cumsum can reset between valid time steps
data.values[18:] = 0
# This creates an xarray whereby the cumsum is calculated but resets each time a zero value is found
cumulative = data.cumsum(dim='time')-data.cumsum(dim='time').where(data.values == 0).ffill(dim='time').fillna(0)
print(cumulative[:,0,0])
>>> <xarray.DataArray (time: 20)>
array([0.13395 , 0.961934, 1.025337, 1.252985, 1.358501, 1.425393, 0. ,
0. , 0. , 0. , 0. , 0. , 0.366988, 0.896463,
1.728956, 2.000537, 2.316263, 2.922798, 0. , 0. ])
Coordinates:
* time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-20
lat int64 10
lon int64 0
The print statement shows that the cumulative sum resets each time a zero is encountered on the time dimension. I need a solution to identify, which of the two periods exceeds a value of 2 and convert to a binary array to confirm where the conditions are met.
So my expected output would be (for this specific example):
<xarray.DataArray (time: 20)>
array([0. , 0. , 0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. , 1. , 1. ,
1. , 1. , 1. , 1. , 0. , 0. ])
Solved this using some masking and the backfill functionality:
# make something to put results in
out = xr.full_like(cumulative, fill_value=0.0)
# find the points which have met the criteria
out.values[cumulative.values > 3] = 1
# fill the other valid sections over 0, with nans so we can fill them
out.values[(cumulative.values>0) & (cumulative.values<3)] = np.nan
# backfill it, so the ones that have not reached 2 are filled with 0
# and the ones that have are filled with 1
out_ds = out.bfill(dim='time').fillna(1)
print ('Cumulative array:')
print (cumulative.values[:,0,0])
print (' ')
print ('Binary array')
print (out_ds.values[:,0,0])