I have a code to plot a heatmap for a set of data, represented as (x, y, f(x, y))
, and I want to find the local minimum points.
import numpy as np
import math
import matplotlib.pyplot as plt
import matplotlib as mpl
from scipy.interpolate import griddata
data = np.genfromtxt('data.dat',
skip_header=1,
delimiter=' ')
x, y, z = data[:, 1], data[::, 0], data[:, 2]
x, y, z = x*180/math.pi, y*180/math.pi, z - min(z)
xi, yi = np.linspace(max(x), min(x), 1000), np.linspace(max(y), min(y), 1000)
xi, yi = np.meshgrid(xi, yi)
zi = griddata((x, y), z, (xi, yi), method='linear')
plt.figure(figsize=(10,5))
plt.pcolormesh(xi, yi, zi, shading='auto', cmap='jet')
plt.colorbar(label='Heatmap')
plt.gca().invert_yaxis()
plt.show()
Here's code to generate some fake data:
import math
with open('data.dat', 'w') as arquivo:
for x in range(20):
for y in range(20):
z = -math.exp(math.sin(x*y)- math.cos(y))
arquivo.write(f"{x}\t\t{y}\t\t{z}\n")
Heatmap example with minimum points circled:
I tried to use np.gradient
, thinking that maybe by taking two derivatives I would be able to determine the local minimum points (zero in the first derivative and negative in the second), but I was not able to make any of it work.
To find local minima, we usually used some gradient based optimizations like gradient descent. However, it is not easy to find all local minima unless doing a lot of "restart" (generally people are happy with one local minimum). One straightforward method to your problem is using grid search: if the current point is less than the neighbor around it, it is one local minimum. The code snippet is below
# Function to get the neighbors of a given point (i,j)
def get_neighbors(i, j, shape):
neighbors = []
for x in [-1, 0, 1]:
for y in [-1, 0, 1]:
ni, nj = i + x, j + y
if (0 <= ni < shape[0]) and (0 <= nj < shape[1]) and (x, y) != (0, 0):
neighbors.append((ni, nj))
return neighbors
local_minima = []
# Iterate over the 2D grid
for i in range(zi.shape[0]):
for j in range(zi.shape[1]):
current_value = zi[i, j]
neighbors = get_neighbors(i, j, zi.shape)
# Check if the current point is less than all its neighbors
if all(current_value < zi[n[0], n[1]] for n in neighbors):
local_minima.append((xi[i, j], yi[i, j], current_value))
# Print the local minima
for loc in local_minima:
print(f"Local minimum value {loc[2]} at location ({loc[0]}, {loc[1]}).")
And then plot the local minima
# Marking all the local minima on the plot
for loc in local_minima:
plt.scatter(loc[0], loc[1], color='red', s=100, marker='x')