Search code examples
pythoncurve-fittingscipy-optimize

Fit 3d coordinates into a parabola


I would like to predict a ball trajectory by fitting its 3d coordinates into a parabola. Below is my code. But instead of a parabola, I got a straight line. If you have any clue about it, please let me know. Thanks!

# draw scatter coordiante 
fig = plt.figure()
ax = plt.axes(projection = '3d')
x_list = []
y_list = []
z_list = []
for x in rm_list:
    x_list.append(x[0][0])
    y_list.append(x[0][1])
    z_list.append(x[0][2])
x = np.array(x_list)
y = np.array(y_list)
z = np.array(z_list)
ax.scatter(x, y, z)

# curve fit
def func(x, a, b, c, d):
    return a * x[0]**2 + b * x[1]**2 + c * x[0] * x[1] + d

data = np.column_stack([x_list, y_list, z_list])
popt, _ = curve_fit(func, data[:,:2].T, ydata=data[:,2])
a, b, c, d = popt
print('y= %.5f * x ^ 2 + %.5f * y ^ 2 + %.5f * x * y + %.5f' %(a, b, c, d))
x1 = np.linspace(0.3, 0.4, 100)
y1 = np.linspace(0.02, 0.06, 100)
z1 = a * x1 ** 2 + b * y1 ** 2 + c * x1 * y1 + d
ax.plot(x1, y1, z1, color='green')

plt.show()

enter image description here

Update 1 After changing the func to ax^2 + by^2 + cxy + dx + ey + f, I got a parabola but not fitting to the coordinate. enter image description here


Solution

  • That you have your underlying timestamp data makes the fitting procedure easier:

    import matplotlib.pyplot as plt
    import numpy as np
    from scipy.optimize import curve_fit
    from numpy.polynomial import Polynomial
    
    # test data generation with some noise
    # here read in your data
    np.random.seed(123)
    n = 40
    x_param = [ 1, 21, -1]
    y_param = [12, -3,  0]
    z_param = [-3,  0, -2]
    px = Polynomial(x_param)
    py = Polynomial(y_param)
    pz = Polynomial(z_param)
    t = np.random.choice(np.linspace (-3000, 2000, 1000)/500, n)
    x = px(t) + np.random.random(n)
    y = py(t) + np.random.random(n)
    z = pz(t) + np.random.random(n)
    
    
    # here start the real calculations 
    # draw scatter coordinates of raw data
    fig = plt.figure()
    ax = plt.axes(projection = '3d')
    ax.scatter(x, y, z, label="raw data")
    
    # curve fit function
    def func(t, x2, x1, x0, y2, y1, y0, z2, z1, z0):
        Px=Polynomial([x2, x1, x0])
        Py=Polynomial([y2, y1, y0])
        Pz=Polynomial([z2, z1, z0])
        return np.concatenate([Px(t), Py(t), Pz(t)])
    
    
    # curve fit
    # start values are not necessary for this example 
    # but make it your rule to always provide start values for curve_fit 
    start_vals = [ 1, 10,  1, 
                  10,  1,  1, 
                  -1, -1, -1]
    xyz = np.concatenate([x, y, z])
    popt, _ = curve_fit(func, t, xyz, p0=start_vals)
    print(popt)
    #[ 1.58003630e+00  2.10059868e+01 -1.00401965e+00  
    #  1.25895591e+01 -2.97374035e+00 -3.23358241e-03 
    # -2.44293562e+00  3.96407428e-02 -1.99671092e+00]
    
    # regularly spaced fit data
    t_fit = np.linspace(min(t), max(t), 100)
    xyz_fit = func(t_fit, *popt).reshape(3, -1)    
    ax.plot(xyz_fit[0, :], xyz_fit[1, :], xyz_fit[2, :], color="green", label="fitted data")
    
    ax.legend()
    plt.show()
    

    Sample output:

    enter image description here