Search code examples
numpymissing-datamasked-array

How can I efficiently "stretch" present values in an array over absent ones


Where 'absent' can mean either nan or np.masked, whichever is easiest to implement this with.

For instance:

>>> from numpy import nan
>>> do_it([1, nan, nan, 2, nan, 3, nan, nan, 4, 3, nan, 2, nan])
array([1, 1, 1, 2, 2, 3, 3, 3, 4, 3, 3, 2, 2])
# each nan is replaced with the first non-nan value before it
>>> do_it([nan, nan, 2, nan])
array([nan, nan, 2, 2])
# don't care too much about the outcome here, but this seems sensible

I can see how you'd do this with a for loop:

def do_it(a):
    res = []
    last_val = nan
    for item in a:
        if not np.isnan(item):
            last_val = item
        res.append(last_val)
    return np.asarray(res)

Is there a faster way to vectorize it?


Solution

  • Working from @Benjamin's deleted solution, everything is great if you work with indices

    def do_it(data, valid=None, axis=0):
        # normalize the inputs to match the question examples
        data = np.asarray(data)
        if valid is None:
            valid = ~np.isnan(data)
    
        # flat array of the data values
        data_flat = data.ravel()
    
        # array of indices such that data_flat[indices] == data
        indices = np.arange(data.size).reshape(data.shape)
    
        # thanks to benjamin here
        stretched_indices = np.maximum.accumulate(valid*indices, axis=axis)
        return data_flat[stretched_indices]
    

    Comparing solution runtime:

    >>> import numpy as np
    >>> data = np.random.rand(10000)
    
    >>> %timeit do_it_question(data)
    10000 loops, best of 3: 17.3 ms per loop
    >>> %timeit do_it_mine(data)
    10000 loops, best of 3: 179 µs per loop
    >>> %timeit do_it_user(data)
    10000 loops, best of 3: 182 µs per loop
    
    # with lots of nans
    >>> data[data > 0.25] = np.nan
    
    >>> %timeit do_it_question(data)
    10000 loops, best of 3: 18.9 ms per loop
    >>> %timeit do_it_mine(data)
    10000 loops, best of 3: 177 µs per loop
    >>> %timeit do_it_user(data)
    10000 loops, best of 3: 231 µs per loop
    

    So both this and @user2357112's solution blow the solution in the question out of the water, but this has the slight edge over @user2357112 when there are high numbers of nans