Search code examples
pythonnumpyhistogramhistogram2d

How do I find the bin with the highest count using np.hist2D()


Is there a way to find the bin with the highest count from np.hist2D(). My code so far is:

counts, xedges, yedges = np.histogram2d(x,y bins=100) # x and y are two lists of numbers
print (len(counts), len(xedges), len(yedges)) # 100 101 101

I managed to get the counts but struggling to relate this to x and y edges.

Thank you.

Update:

I worked it out - any neater solutions welcome.


Solution

  • To get the maximum, use counts.max(). To get the indices of the maximum, use argmax followed by unravel_index as in np.unravel_index(np.argmax(counts), counts.shape). The indices can be used to find the x and y edges of the bin.

    Here is an example, together with a visualization showing how everything fits together and to examine the result. Note that bins=100 generates 10000 bins; in the example just 10 bins are used in each direction to obtain a clear plot.

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle
    
    N = 200
    x = np.random.uniform(0, 80, N)
    y = np.random.uniform(0, 40, N)
    
    counts, xedges, yedges = np.histogram2d(x, y, bins=(10, 10))
    
    x_ind, y_ind = np.unravel_index(np.argmax(counts), counts.shape)
    print(f'The maximum count is {counts[x_ind][y_ind]:.0f} at index ({x_ind}, {y_ind})')
    print(f'Between x values {xedges[x_ind]} and {xedges[x_ind+1]}')
    print(f'and between y values {yedges[y_ind]} and {yedges[y_ind+1]}')
    
    fig, (ax1, ax2) = plt.subplots(ncols=2)
    
    ax1.scatter(x,y,marker='.',s=20,lw=0)
    rect = Rectangle((xedges[x_ind], yedges[y_ind]), xedges[x_ind+1] - xedges[x_ind], yedges[y_ind+1] - yedges[y_ind],
                     linewidth=1,edgecolor='crimson',facecolor='none')
    ax1.add_patch(rect)
    ax1.set_title(f'max count: {counts[x_ind][y_ind]:.0f}')
    
    ax2.imshow(counts.T, origin='lower')
    ax2.plot(x_ind, y_ind, 'or')
    ax2.set_title('heatmap')
    
    plt.show()
    

    resulting plot