Search code examples
pythonnumpydivide-by-zerocmath

Avoiding "complex division by zero" error in Numpy (related to cmath)


I would like to avoid ZeroDivisionError: complex division by zero in my calculation, getting nan at that exception.

Let me lay out my question with a simple example:

from numpy import *
from cmath import *

def f(x) :
    x = float_(x) 
    return 1./(1.-x)

def g(x) :
    x = float_(x)
    return sqrt(-1.)/(1.-x)

f(1.)   # This gives 'inf'.
g(1.)   # This gives 'ZeroDivisionError: complex division by zero'.

It is my intention to get g(1.) = nan, or at least anything but an error that interrupts the calculation. First question: how can I do it?

Importantly, I would not like to modify the code inside the functions (for example, inserting conditions for exceptions, as is done in this answer), but rather keep it in its current form (or even deleting the x = float_(x) line if possible, as I mention below). The reason is that I am working with a long code with dozens of functions: I would like them all to avoid ZeroDivisionError without the need of making lots of changes.

I was forced to insert x = float_(x) to avoid a ZeroDivisionError in f(1.). Second question: would there be a way of suppressing this line but still get f(1.) = inf without modifying at all the code defining f?


EDIT:

I have realized that using cmath (from cmath import *) is responsible for the error. Without it, I get g(1.) = nan, which is what I want. However, I need it in my code. So now the first question turns into the following: how can I avoid "complex division by zero" when using cmath?


EDIT 2:

After reading the answers, I have made some changes and I simplify the question, getting closer to the point:

import numpy as np                            
import cmath as cm                            

def g(x) :                                    
    x = np.float_(x)                         
    return cm.sqrt(x+1.)/(x-1.)    # I want 'g' to be defined in R-{1}, 
                                   # so I have to use 'cm.sqrt'. 

print 'g(1.) =', g(1.)             # This gives 'ZeroDivisionError: 
                                   # complex division by zero'.

Question: how can I avoid the ZeroDivisionError not modifying the code of my function g?


Solution

  • I still don't understand why you need to use cmath. Type-cast x into np.complex_ when you expect complex output, and then use np.sqrt.

    import numpy as np
    
    def f(x):
        x = np.float_(x)
        return 1. / (1. - x)
    
    def g(x):
        x = np.complex_(x)
        return np.sqrt(x + 1.) / (x - 1.)
    

    This yields:

    >>> f(1.)
    /usr/local/bin/ipython3:3: RuntimeWarning: divide by zero encountered in double_scalars
      # -*- coding: utf-8 -*-
    Out[131]: inf
    
    >>> g(-3.)
    Out[132]: -0.35355339059327379j
    
    >>> g(1.)
    /usr/local/bin/ipython3:3: RuntimeWarning: divide by zero encountered in cdouble_scalars
      # -*- coding: utf-8 -*-
    /usr/local/bin/ipython3:3: RuntimeWarning: invalid value encountered in cdouble_scalars
      # -*- coding: utf-8 -*-
    Out[133]: (inf+nan*j)
    

    The drawback being, of course, that the function g will always end up giving you complex output, which could later cause problems if you feed its result back into f, for instance, because that now type-casts to float, and so on... Maybe you should just type-cast to complex everywhere. But this will depend on what you need to achieve on some bigger scale.

    EDIT

    It turns out that there's a way to get g to return complex only if it needs to be. Use numpy.emath.

    import numpy as np
    
    def f(x):
        x = np.float_(x)
        return 1. / (1. - x)
    
    def g(x):
        x = np.float_(x)
        return np.emath.sqrt(x + 1.) / (x - 1.)
    

    This now gives what you would expect, converting to complex only when necessary.

    >>> f(1)
    /usr/local/bin/ipython3:8: RuntimeWarning: divide by zero encountered in double_scalars
    Out[1]: inf
    
    >>> g(1)
    /usr/local/bin/ipython3:12: RuntimeWarning: divide by zero encountered in double_scalars
    Out[2]: inf
    
    >>> g(-3)
    Out[3]: -0.35355339059327379j