Search code examples
pythonnumpycalculus

Cardano's formula not working with numpy?


--- using python 3 ---

Following the equations here, I tried to find all real roots of an arbitrary third-order-polynomial. Unfortunatelly, my implementation does not yield the correct result and I cannot find the error. Maybe you are able to spot it within a blink of an eye and tell me.

(As you notice, only the roots of the green curve are wrong.)

With best regards

import numpy as np
def find_cubic_roots(a,b,c,d):
    # with ax³ + bx² + cx + d = 0
    a,b,c,d = a+0j, b+0j, c+0j, d+0j
    all_ = (a != np.pi)

    Q = (3*a*c - b**2)/ (9*a**2)
    R = (9*a*b*c - 27*a**2*d - 2*b**3) / (54 * a**3)
    D = Q**3 + R**2
    S = (R + np.sqrt(D))**(1/3)
    T = (R - np.sqrt(D))**(1/3)

    result = np.zeros(tuple(list(a.shape) + [3])) + 0j
    result[all_,0] = - b / (3*a) + (S+T)
    result[all_,1] = - b / (3*a)  - (S+T) / 2 + 0.5j * np.sqrt(3) * (S - T)
    result[all_,2] = - b / (3*a)  - (S+T) / 2 -  0.5j * np.sqrt(3) * (S - T)

    return result

The example where you see it does not work:

import matplotlib.pyplot as plt
fig, ax = plt.subplots()
a = np.array([2.5])
b = np.array([-5])
c = np.array([0])

x = np.linspace(-2,3,100)
for i, d in enumerate([-8,0,8]):
    d = np.array(d)
    roots = find_cubic_roots(a,b,c,d)
    ax.plot(x, a*x**3 + b*x**2 + c*x + d, label = "a = %.3f, b = %.3f, c = %.3f, d = %.3f"%(a,b,c,d), color = colors[i])
    print(roots)
    ax.plot(x, x*0)
    ax.scatter(roots,roots*0,  s = 80)
ax.legend(loc = 0)
ax.set_xlim(-2,3)
plt.show()

Easy Example

Output:

[[ 2.50852567+0.j        -0.25426283+1.1004545j -0.25426283-1.1004545j]]
[[ 2.+0.j  0.+0.j  0.-0.j]]
[[ 1.51400399+1.46763129j  1.02750817-1.1867528j  -0.54151216-0.28087849j]]

Solution

  • Here is my stab at the solution. Your code fails for the case where R + np.sqrt(D) or R - np.sqrt(D) is negative. The reason is in this post. Basically if you do a**(1/3) where a is negative, numpy returns a complex number. However, we infact, want S and T to be real since cube root of a negative real number is simply a negative real number (let's ignore De Moivre's theorem for now and focus on the code and not the math). The way to work around it is to check if S is real, cast it to real and pass S to the function from scipy.special import cbrt. Similarly for T. Example code:

    import numpy as np
    import pdb
    import math
    from scipy.special import cbrt
    def find_cubic_roots(a,b,c,d, bp = False):
    
        a,b,c,d = a+0j, b+0j, c+0j, d+0j
        all_ = (a != np.pi)
    
        Q = (3*a*c - b**2)/ (9*a**2)
        R = (9*a*b*c - 27*a**2*d - 2*b**3) / (54 * a**3)
        D = Q**3 + R**2
        S = 0 #NEW CALCULATION FOR S STARTS HERE
        if np.isreal(R + np.sqrt(D)):
            S = cbrt(np.real(R + np.sqrt(D)))
        else:
            S = (R + np.sqrt(D))**(1/3)
        T = 0 #NEW CALCULATION FOR T STARTS HERE
        if np.isreal(R - np.sqrt(D)):
            T = cbrt(np.real(R - np.sqrt(D)))
        else:
            T = (R - np.sqrt(D))**(1/3)
    
        result = np.zeros(tuple(list(a.shape) + [3])) + 0j
        result[all_,0] = - b / (3*a) + (S+T)
        result[all_,1] = - b / (3*a)  - (S+T) / 2 + 0.5j * np.sqrt(3) * (S - T)
        result[all_,2] = - b / (3*a)  - (S+T) / 2 -  0.5j * np.sqrt(3) * (S - T)
        #if bp:
            #pdb.set_trace()
        return result
    
    
    import matplotlib.pyplot as plt
    fig, ax = plt.subplots()
    a = np.array([2.5])
    b = np.array([-5])
    c = np.array([0])
    x = np.linspace(-2,3,100)
    for i, d in enumerate([-8,0,8]):
        d = np.array(d)
        if d == 8:
            roots = find_cubic_roots(a,b,c,d, True)
        else:
            roots = find_cubic_roots(a,b,c,d)
    
        ax.plot(x, a*x**3 + b*x**2 + c*x + d, label = "a = %.3f, b = %.3f, c = %.3f, d = %.3f"%(a,b,c,d))
        print(roots)
        ax.plot(x, x*0)
        ax.scatter(roots,roots*0,  s = 80)
    ax.legend(loc = 0)
    ax.set_xlim(-2,3)
    plt.show()
    

    DISCLAIMER: The output root gives some warning, which you can probably ignore. The output is correct. However, the plotting shows an extra root for some reasons. This is likely due to your plotting code. The printed roots look fine though.