Search code examples
pythonnumpyfloating-pointk-meansquantization

Quantizing normally distributed floats in Python and NumPy


Let the values in the array A be sampled from a Gaussian distribution. I want to replace every value in A with one of n_R "representatives" in R so that the total quantization error is minimized.

Here is NumPy code that does linear quantization:

n_A, n_R = 1_000_000, 256
mu, sig = 500, 250
A = np.random.normal(mu, sig, size = n_A)
lo, hi = np.min(A), np.max(A)
R = np.linspace(lo, hi, n_R)
I = np.round((A - lo) * (n_R - 1) / (hi - lo)).astype(np.uint32)

L = np.mean(np.abs(A - R[I]))
print('Linear loss:', L)
-> Linspace loss: 2.3303939600700603

While this works, the quantization error is large. Is there a smarter way to do it? I'm thinking that one could take advantage of A being normally distributed or perhaps use an iterative process that minimizes the "loss" function.

Update While researching this question, I found a related question about "weighting" the quantization. Adapting their method sometimes gives better quantization results:

from scipy.stats import norm

dist = norm(loc = mu, scale = sig)
bounds = dist.cdf([mu - 3*sig, mu + 3*sig])
pp = np.linspace(*bounds, n_R)
R = dist.ppf(pp)

# Find closest matches
lhits = np.clip(np.searchsorted(R, A, 'left'), 0, n_R - 1)
rhits = np.clip(np.searchsorted(R, A, 'right') - 1, 0, n_R - 1)

ldiff = R[lhits] - A
rdiff = A - R[rhits]
I = lhits
idx = np.where(rdiff < ldiff)[0]
I[idx] = rhits[idx]

L = np.mean(np.abs(A - R[I]))
print('Gaussian loss:', L)
-> Gaussian loss: 1.6521974945326285

K-means clustering might be better but seem to be too slow to be practical on large arrays.


Solution

  • K-means

    K-means clustering might be better but seem to be too slow to be practical on large arrays.

    For the 1D clustering case, there are algorithms faster than K-means. See https://stats.stackexchange.com/questions/40454/determine-different-clusters-of-1d-data-from-database

    I picked one of those algorithms, Jenks Natural Breaks, and ran it on a random sub-sample of your dataset:

    A_samp = np.random.choice(A, size=10000)
    breaks = np.array(jenkspy.jenks_breaks(A_samp, n_classes=n_R))
    R = (breaks[:-1] + breaks[1:]) / 2
    

    This is pretty fast, and gets a quantization loss for the full dataset of about 1.28.

    To visualize what each of these methods are doing, I plotted the cdf of the breaks that each of them come up with against the index within R of the break.

    QQ plot

    Gaussian is a straight line, by definition. This means that it has an equal number of breaks at every percentile of the distribution. The linear method spends very little of its breaks in the middle of the distribution, and uses most of them at the tails. Jenks finds a compromise between the two of them.

    Automatically searching for lower loss

    Looking at the chart above, I had an idea: all of these methods of choosing breaks are sigmoid-shaped curves of various sorts when plotted in the quantile domain. (Gaussian sort of fits if you think of it as a really stretched out sigmoid.)

    I wrote a function which parameterized each of those curves using a single variable, strength, which is how fast the sigmoid should curve. Once I had that, I used scipy.optimize.minimize to automatically search for a curve which minimized the loss.

    It turns out that if you let Scipy optimize this, it picks a curve strength really close to Jenks, and the curve it finds is slightly worse than the Jenks one, with a loss of about 1.33.

    You can see the notebook with this failed approach here.

    Quantizing with 2^16 floats

    In the case where you need to create 2^16 different representatives, it's computationally infeasible to use Jenks. However, you can do something that's pretty close: Jenks with a small number of classes plus linear interpolation.

    Here's the code for this:

    import itertools
    
    
    def pairwise(iterable):
        "s -> (s0, s1), (s1, s2), (s2, s3), ..."
        a, b = itertools.tee(iterable)
        next(b, None)
        return zip(a, b)
    
    
    def linspace_jenks(A, n_R, jenks_classes, dist_lo, dist_hi):
        assert n_R % jenks_classes == 0, "jenks_classes must be divisor of n_R"
        simplify_factor = n_R // jenks_classes
        assert jenks_classes ** 2 <= len(A), "Need more data to estimate"
        breaks = jenkspy.jenks_breaks(A, n_classes=jenks_classes)
        # Adjust lowest and highest break to match highest/lowest observed value
        breaks[0] = dist_lo
        breaks[-1] = dist_hi
        linspace_classes = []
        for lo, hi in pairwise(breaks):
            linspace_classes.append(np.linspace(lo, hi, simplify_factor, endpoint=False))
        linspace_classes = np.hstack(linspace_classes)
        assert len(linspace_classes) == n_R
        return linspace_classes
    

    Example call:

    A_samp = np.random.choice(A, size = 2**16)
    jenks_R = linspace_jenks(A_samp, n_R, 128, np.min(A), np.max(A))
    

    How does the performance compare to the linear method? On my system, I get a loss of 0.009421 for linear with n_R=2^16. The following graph shows the losses that the linspace_jenks method gets for each value of jenks_classes.

    linspace jenks loss

    With just 32 Jenks classes, and filling the rest in with linear interpolation, the loss goes down to 0.005031.