Search code examples
pythonnumpymatplotlibheatmap

How do I locate minima in an array


I have a code to plot a heatmap for a set of data, represented as (x, y, f(x, y)), and I want to find the local minimum points.

import numpy as np
import math
import matplotlib.pyplot as plt
import matplotlib as mpl
from scipy.interpolate import griddata

data = np.genfromtxt('data.dat',
                     skip_header=1,
                     delimiter='        ')


x, y, z = data[:, 1], data[::, 0], data[:, 2]
x, y, z = x*180/math.pi, y*180/math.pi, z - min(z)

xi, yi = np.linspace(max(x), min(x), 1000), np.linspace(max(y), min(y), 1000)

xi, yi = np.meshgrid(xi, yi)


zi = griddata((x, y), z, (xi, yi), method='linear')


plt.figure(figsize=(10,5))
plt.pcolormesh(xi, yi, zi, shading='auto', cmap='jet')

plt.colorbar(label='Heatmap')

plt.gca().invert_yaxis()


plt.show()

Here's code to generate some fake data:

import math

with open('data.dat', 'w') as arquivo:

    for x in range(20):
        for y in range(20):
            z = -math.exp(math.sin(x*y)- math.cos(y))
            arquivo.write(f"{x}\t\t{y}\t\t{z}\n")

Heatmap example with minimum points circled:

I tried to use np.gradient, thinking that maybe by taking two derivatives I would be able to determine the local minimum points (zero in the first derivative and negative in the second), but I was not able to make any of it work.


Solution

  • To find local minima, we usually used some gradient based optimizations like gradient descent. However, it is not easy to find all local minima unless doing a lot of "restart" (generally people are happy with one local minimum). One straightforward method to your problem is using grid search: if the current point is less than the neighbor around it, it is one local minimum. The code snippet is below

    # Function to get the neighbors of a given point (i,j)
    def get_neighbors(i, j, shape):
        neighbors = []
    
        for x in [-1, 0, 1]:
            for y in [-1, 0, 1]:
                ni, nj = i + x, j + y
    
                if (0 <= ni < shape[0]) and (0 <= nj < shape[1]) and (x, y) != (0, 0):
                    neighbors.append((ni, nj))
    
        return neighbors
    
    local_minima = []
    
    # Iterate over the 2D grid
    for i in range(zi.shape[0]):
        for j in range(zi.shape[1]):
            current_value = zi[i, j]
            neighbors = get_neighbors(i, j, zi.shape)
    
            # Check if the current point is less than all its neighbors
            if all(current_value < zi[n[0], n[1]] for n in neighbors):
                local_minima.append((xi[i, j], yi[i, j], current_value))
    # Print the local minima
    for loc in local_minima:
        print(f"Local minimum value {loc[2]} at location ({loc[0]}, {loc[1]}).")
    

    And then plot the local minima

    # Marking all the local minima on the plot
    for loc in local_minima:
        plt.scatter(loc[0], loc[1], color='red', s=100, marker='x')
    

    Sample output