Search code examples
pythonscipycurve-fitting

Fittig a rice distribution using scipy


I'm trying to write a fitter for some rice distributed data that I have, but it is not working for some, probably stupid, reason.

The distribution gets created fine, and the fitting routing seems to work from what I am used to with Gaussians. However, when I fit the curve, I just get nonsense. Can't seem to see where I am going wrong.

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import rice
from scipy.optimize import curve_fit

# Custom Rice PDF function
def rice_pdf(x, nu, amplitude, b, scale):
    return (x / b) * np.exp(-(x**2 + scale**2) / (2 * b**2)) * amplitude

# Function to fit a Rice distribution to a histogram using curve_fit
def fit_rice_distribution_to_histogram(hist_data, bins):
    # Calculate bin centers
    bin_centers = (bins[:-1] + bins[1:]) / 2

    # Initial guesses for the parameters (nu, amplitude, b, scale)
    initial_guess = [8.4, 1.0, 1.0, np.mean(bin_centers)]

    # Fit the Rice distribution to the histogram data using curve_fit
    params, covariance = curve_fit(rice_pdf, bin_centers, hist_data, p0=initial_guess)
    nu, amplitude, b, scale = params

    # Create the fitted Rice distribution
    fitted_distribution = rice(nu, loc=scale, scale=np.sqrt(b**2 + scale**2))

    return fitted_distribution, nu, amplitude, b, scale

# Example usage:
if __name__ == "__main__":
    # Parameters for the Rice distribution
    nu = 8.5
    sigma = 10.5
    sample_size = 100

    # Calculate b from nu and sigma
    b = nu / sigma

    # Generate random data points from the Rice distribution
    data = rice.rvs(b=b, scale=sigma, size=sample_size)

    # Create a histogram of the generated data
    hist_data, bins, _ = plt.hist(data, bins=20, density=True, alpha=0.5, label="Generated Data")
    plt.xlabel("Value")
    plt.ylabel("Probability Density")

    # Fit a Rice distribution to the histogram using curve_fit
    fitted_distribution, fitted_nu, amplitude, fitted_b, fitted_scale = fit_rice_distribution_to_histogram(hist_data, bins)

    # Plot the original histogram and the fitted distribution
    x = np.linspace(min(bins), max(bins), 1000)
    pdf_values = fitted_distribution.pdf(x)
    plt.plot(x, pdf_values, 'r', label="Fitted Rice Distribution")
    plt.legend()
    plt.show()

    # Print fitted parameters
    print("Fitted Nu:", fitted_nu)
    print("Fitted Amplitude:", amplitude)
    print("Fitted b:", fitted_b)
    print("Fitted Scale:", fitted_scale)

Solution

  • Based on the trial dataset you have provided:

    import numpy as np
    import matplotlib.pyplot as plt
    from scipy import stats, optimize
    
    nu = 8.5
    sigma = 10.5
    n = 100
    b = nu / sigma
    
    np.random.seed(12345)
    
    # Reference Distribution:
    X = stats.rice(b=b, scale=sigma)
    
    
    # Binned dataset:
    data = X.rvs(size=n)
    density, bins = np.histogram(data, density=1.)
    centers = (bins[:-1] + bins[1:]) / 2
    

    From there we can estimate parameters from histogram data:

    def model(x, b, loc, scale):
        return stats.rice.pdf(x, b=b, loc=loc, scale=scale)
    
    popt, pcov = optimize.curve_fit(model, centers, density)
    

    The model function significantly differ from yours as it misses the Bessel function term specified in the documentation.

    xlin = np.linspace(0, 50, 200)
    fig, axe = plt.subplots()
    axe.hist(data, alpha=0.5, density=1.0, label="Data")
    axe.plot(xlin, X.pdf(xlin), label="Model")
    axe.plot(xlin, model(xlin, *popt), label="Fit")
    axe.legend()
    axe.grid()
    

    enter image description here

    Increasing n=10000 and number of bins to 100 gives accurate results:

    enter image description here

    Update

    You have two options to get your parameters back. Do the math between variable changes applied by default in scipy. Or rewrite the function as it is defined.

    def rice(x, nu, sigma):
        return (x/sigma**2)*np.exp(-0.5*(x**2 + nu**2)/sigma**2)*special.i0(x*nu/sigma**2)
    

    Then you can directly fit your parameters:

    popt2, pcov2 = optimize.curve_fit(rice, centers, density, p0=[10, 10])
    
    #  array([ 8.86411171, 10.34217091]),
    #  array([[ 0.12896075, -0.06411305],
    #         [-0.06411305,  0.03441797]])
    

    enter image description here