Search code examples
pythonnumpykernel-density

why does the output of numpy KDE not map easily to the input?


I want to use KDE to estimate the cluster density across a list of XY points I have detected in my microscopy images (that's a completely different process). I'm trying to adapt the code in this answer: https://stackoverflow.com/a/64499779/2009558

Why doesn't the output of the KDE map to the input dimensions? I don't get what the need is to map the KDE output to a grid. Nor why the dimensions of the grid don't match input data. What is the value of "128j" in this line?

gx, gy = np.mgrid[x.min():x.max():128j, y.min():y.max():128j]

What sort of python object is that? It's got both numbers and letters, but it's not a string? I tried googling this but couldn't find an answer. Numpy is so unpythonic sometimes, it drives me nuts.

Here's where I'm at so far. The data's just a pandas df with X and Y coordinates as floats.

import numpy as np
import plotly.express as px
import plotly.offline as offline
import pandas as pd
from scipy.stats import gaussian_kde

xx = df['X']
yy = df['Y']
xy = np.vstack((xx, yy))
kde = gaussian_kde(xy)

gx, gy = np.mgrid[xx.min():xx.max():128j, yy.min():yy.max():128j]
gxy = np.dstack((gx, gy))
# print(gxy[0])
z = np.apply_along_axis(kde, 2, gxy)
z = z.reshape(128, 128)

fig = px.imshow(z)
fig.add_trace(go.Scatter(x = xx, y = yy, mode='markers', marker = dict(color='green', size=1)))
fig.show()

This produces most of the plot I want: The density plot with the points overlaid on it, but the dimensions of the density data are 128 x 128, instead of the dimensions of the limits of the input. KDE fail

When I try substituting the real dimensions in the reshaping like this

z = z.reshape(ceil(xx.max()-xx.min()), ceil(yy.max()-yy.min()))

I just get errors.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_19840/2556669395.py in <module>
     12 z = np.apply_along_axis(kde, 2, gxy)
     13 # z = z.reshape(128, 128)
---> 14 z = z.reshape(ceil(xx.max()-xx.min()), ceil(yy.max()-yy.min()))
     15 
     16 fig = px.imshow(z)

ValueError: cannot reshape array of size 16384 into shape (393,464)

Solution

  • Answering your questions first:

    I don't get what the need is to map the KDE output to a grid.

    The need comes from the desire to plot the result as an image, and thus as an array of pixels. The grid maps your data space with regular steps in both dimensions (with a different step for each dimension though), which you can then use to compute the color of each of your pixels.

    Why doesn't the output of the KDE map to the input dimensions?

    It does, but as the only thing px.imshow(z) knows is the matrix z, both axis refer to the matrix coordinates which is here confusing.

    Nor why the dimensions of the grid don't match input data.

    It is an arbitrary choice that will define the resolution of your image. It is 128 x 128 here because your data space was divided that way:

    gx, gy = np.mgrid[xx.min():xx.max():128j, yy.min():yy.max():128j]

    but you could choose anything else.

    What is the value of "128j" in this line?

    As slothrop and Daraan pointed out in the comments, this is just a numpy convention to precise whether or not you want the endpoint of your interval to be included.

    Concerning your error, it comes from the fact that you're trying to reorganize your vector z (that has 128 x 128 elements because of the way your space was divided earlier) into a matrix of 393 x 464 elements.

    Suggested solution:

    Your problem arises from the fact that px.imshow(z) has no clue about x and y. To solve that, we will use the fact that plotly supports xarray, which will let us link the data in your matrix z to their (x, y) coordinates.

    da = xr.DataArray(
        data=z,
        dims=["x", "y"],
        coords=dict(
            x=np.linspace(xx.min(), xx.max(), 128),
            y=np.linspace(yy.min(), yy.max(), 128),
        ),
    )
    

    Then you use da instead of z in your imshow call and that's it!