Search code examples
pythonnumpyselectlazy-evaluation

Numpy select lazy version


Consider the code below

 >>> x = np.array([0, 0, 1, 1])
 >>> np.select([x==0, True], [x+1, 1/x])
 array([ 1.,  1.,  1.,  1.])

It suffers from two problems.

Firstly, it is not lazy. It eagerly evaluates both x+1 and 1/x even though some of the evaluated values are not required in the final result.

Secondly, numpy issues a warning every time the code is run

RuntimeWarning: divide by zero encountered in true_divide

which is somewhat related to the previous point, because it is trying to evaluate 1/x even when that is not required in the final answer.

Is there a version of select that is lazy and does not suffer from the problems above?


Solution

  • You can avoid evaluation by explicitly masking your output vector for the two cases:

    y = x.copy()
    mask = (x == 0)  # parentheses only necessary for readability
    y[mask] = x[mask] + 1
    y[~mask] = 1 / x[~mask]
    

    The above is what I strongly recommend doing, so you should only keep reading for a pointlessly contrived solution that actually tackles the "lazy evaluation" part of the question. I don't recommend using the below snippet in practice! You have been warned.

    I could finally achieve actual lazy evaluation, although it's a bit messy and poses an unnecessary complication of the situation (well, at least this situation; I can imagine there are situations where this might come handy). In the spirit of "anything worth doing is worth overdoing":

    xfun1 = [lambda xval=xval: xval + 1 for xval in x]
    xfun2 = [lambda xval=xval: 1 / xval for xval in x]
    [fun() for fun in np.select([x == 0, True], [xfun1, xfun2])]
    

    The idea is to protect the values of 1/x from being evaluated by hiding them behind lambda definitions. The auxiliary arrays xfun1 and xfun2 define a dummy lambda for each value of x; the first returning x+1, the second returning 1/x. These, however, are not evaluated until you call an element as xfun2[2]().

    So we use the select call to choose elements from the two function arrays, but then we obtain a list of functions. In order to get numeric return values, we need to use a list comprehension to evaluate each lambda.