Search code examples
pythonrelu

ValueError with ReLU function in python


I declared ReLU function like this:

def relu(x):
    return (x if x > 0 else 0)

and an ValueError has occured and its traceback message is

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

But if I change ReLU function with numpy, it works:

def relu_np(x):
    return np.maximum(0, x)

Why this function(relu(x)) doesn't work? I cannot understand it...

================================

Used code:

>>> x = np.arange(-5.0, 5.0, 0.1)
>>> y = relu(x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "filename", line, in relu
    return (x if x > 0 else 0)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Solution

  • Keep in mind that x > 0 is an array of booleans, a mask if you like:

    array([False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False,
           False, False, False, False, False, False, False, False, False,
           False, False, False, False, False,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True,  True,  True,  True,  True,  True,  True,  True,  True,
            True])
    

    So it does not make sense to do if x>0, since x contains several elements, which can be True or False. This is the source of your error.

    Your second implementation of numpy is good ! Another implementation (maybe more clear?) might be:

    def relu(x):
      return x * (x > 0)
    

    In this implementation, we do an elementwise multiplication of x, which is a range of values along the x axis, by 0 if the element of x is below 0, and 1 if the element is above.