Search code examples
pythonnumpyruntime-error

Strange behaviour of numpy.where


Consider the code

np.where(1<=0, np.sqrt(-1), 0)

When running this, I get

RuntimeWarning: invalid value encountered in sqrt
array(0.)

Why is python even running (somewhere) the first branch? The above code might look stupid but I investigated after I received the error for what I have in my program:

np.where(x<= 1, np.sqrt(1-x**2), 1)

The x array contains positive elements (at least if my other code is right). It wasn't obvious if I had a mistake in my x array because usually I use torch.where, and the same warning is not thrown.


Solution

  • There are no branches (i.e. if statements or conditional expressions) in your code. Arguments are fully evaluated when passed as a part of a function call. Perhaps, seeing an equivalent piece of code will make this more clear:

    _temp0 = 1 <= 0
    _temp1 = np.sqrt(-1)
    _temp2 = 0
    np.where(_temp0, _temp1, _temp2)
    

    Python has a strict and very straightforward evaluation strategy. It sounds like you were expecting a non-strict evaluation strategy like call by name.