I am working on a piecewise linear interpolation problem in Python where I optimize the placement of support points (in x- and y-direction) to fit some data. The 1D version of my code works as expected, distributing the support points nicely to fit the data, especially placing more support points in areas with smaller bending radii.
import numpy as np
from scipy.interpolate import interpn
from scipy.optimize import minimize
import matplotlib.pyplot as plt
# Erzeuge Testdaten: parabel um den nullpunkt
x_data = np.linspace(-3.15/2, 3.15/2, 800)
y_data = np.sin(x_data)
N = 8 # Anzahl der Stützstellen
def piecewise_linear_interp_error(support_points, x_data, y_data):
x_support = support_points[:N]
y_support = support_points[N:]
idx = np.argsort(x_support)
x_support = x_support[idx]
y_support = y_support[idx]
grid = (x_support,)
points = x_data[:, None]
# Interpolation der Stützstellen
y_interp = interpn(grid, y_support, points, method='linear', bounds_error=False, fill_value=None).flatten()
# Berechne den quadratischen Fehler
error = np.sqrt(np.mean((y_data - y_interp) ** 2))
return error
# Startwerte für die Optimierung (gleichmäßig verteilte Punkte)
initial_x_support = np.linspace(np.min(x_data), np.max(x_data), N)
initial_y_support = np.interp(initial_x_support, x_data, y_data)
initial_support_points = np.concatenate([initial_x_support, initial_y_support])
result = minimize(piecewise_linear_interp_error, initial_support_points, args=(x_data, y_data), method='SLSQP')
optimized_x_support = result.x[:N]
optimized_y_support = result.x[N:]
plt.plot(x_data, y_data, label='Original Data')
plt.plot(optimized_x_support, optimized_y_support, 'ro-', label='Optimized Support Points')
plt.legend()
plt.show()
print("Optimized Support Points:")
for x, y in zip(optimized_x_support, optimized_y_support):
print(f"({x:.2f}, {y:.2f})")
I modeled the 1-D and n-D versions to be similar. Most of the extra code is for piercing together, separating the optimization vector, and handling the grid for the multidimensional case.
import numpy as np
from scipy.interpolate import interpn
from scipy.optimize import minimize
import matplotlib.pyplot as plt
def display_support_points(src_grid, target_vec):
print("Optimized Support Points:")
for i, points in enumerate(src_grid):
formatted_points = ', '.join(f'{p:.2f}' for p in points)
print(f"({formatted_points}, {target_vec[i]})")
class PiecewiseLinearInterpolation:
def __init__(self, source_values, target_values, source_resolution):
self.n_dimensions = source_values.shape[1]
self.src_vals = source_values
self.target_values = target_values
self.src_resolution = source_resolution # List of support points for each dimension
self.src_vec_shape = None
self.initial_support_vector = self.generate_inital_support_vector()
def generate_inital_support_vector(self):
initial_src_support_vec = []
for i, x in enumerate(self.src_vals.T):
initial_src_support_vec.append(np.linspace(np.min(x), np.max(x), self.src_resolution[i]))
initial_src_support_vec = np.concatenate(initial_src_support_vec)
# Create a grid for each dimension based on the resolutions
src_grids = [np.linspace(np.min(x), np.max(x), res) for x, res in
zip(self.src_vals.T, self.src_resolution)]
# Create a meshgrid for interpolation
src_grid = np.array(np.meshgrid(*src_grids, indexing='ij')).T.reshape(-1, self.n_dimensions)
orig_source_idx = []
for dim in range(self.n_dimensions):
orig_source_idx.append(np.unique(self.src_vals[:, dim], ))
# reshape original target_values to the shape of src_vals
self.src_vec_shape = [len(np.unique(x)) for x in orig_source_idx]
target_vec = self.target_values.reshape(self.src_vec_shape)
initial_target_support = interpn(
orig_source_idx, # The grid for each dimension
target_vec, # The data to interpolate
src_grid, # The points where interpolation is needed
method='linear', # Interpolation method
bounds_error=False, # Do not raise an error if points are out of bounds
fill_value=None # Use NaN for out-of-bounds points
).flatten() # Flatten the result to get the initial support points for y
return np.concatenate([initial_src_support_vec, initial_target_support])
def calc_interpolation_error(self, support_vector):
src_vec, target_vec = self.split_support_vector(support_vector)
src_vec_sorted, target_vec_sorted = self.reorder_support_vectors(src_vec, target_vec)
points = np.array(self.src_vals).reshape(-1, self.n_dimensions)
target_interp = interpn(src_vec_sorted, target_vec_sorted, points, method='linear', bounds_error=False,
fill_value=None)
interp_err = np.sqrt(np.sum((self.target_values - target_interp) ** 2))
return interp_err
def reorder_support_vectors(self, src_vec, target_vec):
for i in range(self.n_dimensions):
idx = np.argsort(src_vec[i])
src_vec[i] = src_vec[i][idx]
target_vec = np.take(target_vec, idx, axis=i)
return src_vec, target_vec
def split_support_vector(self, support_vec):
src_vec = []
start = 0
for res in self.src_resolution:
end = start + res
src_vec.append(support_vec[start:end])
start = end
target_vec = support_vec[start:]
target_vec = target_vec.reshape(tuple(self.src_resolution))
return src_vec, target_vec
def optimize_support_points(self):
result = minimize(self.calc_interpolation_error, self.initial_support_vector, method='SLSQP')
result_src_vec, result_target_vec = self.split_support_vector(result.x)
res_src_vec_sorted, res_target_vec_sorted = self.reorder_support_vectors(result_src_vec, result_target_vec)
result_grid = np.array(np.meshgrid(*res_src_vec_sorted, indexing='ij')).T.reshape(-1, self.n_dimensions)
return result_grid, res_target_vec_sorted
def display_results(self, src_grid, target_vec):
plt.figure(figsize=(10, 8))
if src_grid.shape[1] == 1:
plt.plot(src_grid[:, 0], target_vec, 'ro-', label='Optimized Support Points')
elif src_grid.shape[1] == 2:
ax = plt.axes(projection='3d')
ax.scatter(self.src_vals[:, 0], self.src_vals[:, 1], self.target_values, label='Original Data')
ax.scatter(src_grid[:, 0], src_grid[:, 1], target_vec,
color='red', label='Optimized Support Points')
else:
raise ValueError("Plotting for dimensions higher than 2 is not supported.")
plt.legend()
plt.show()
# This is how its supposed to be used:
def main():
# Create a dataset with two dimensions for src_vals
x = np.linspace(-3.15 / 2, 3.15 / 2, 20)
y = np.linspace(-3.15 / 2, 3.15 / 2, 20)
X, Y = np.meshgrid(x, y)
Z = np.sin(X) + np.sin(Y)
raw_src_vec = np.array([X.flatten(), Y.flatten()]).T
raw_target_vec = Z.flatten()
x_resolutions = [5, 10] # Individual resolutions for each dimension
interpolator = PiecewiseLinearInterpolation(raw_src_vec, raw_target_vec, x_resolutions)
src_grid, target_vec = interpolator.optimize_support_points()
interpolator.display_results(src_grid, target_vec)
#display_support_points(src_grid, target_vec)
if __name__ == "__main__":
main()
To me, it seems as if the optimizer messed it up; however, it is more likely that I messed up with the multi-dimensional stuff. I wrote tests and verified pieces that seemed suspicious, and those were correct. How can I streamline the code to be more straightforward? Where is my data corruption?
There are two problems.
The first is that you're misunderstanding the meaning of the data (or using the wrong data to plot).
Take another look at the error function which you are trying to minimize.
def calc_interpolation_error(self, support_vector):
src_vec, target_vec = self.split_support_vector(support_vector)
src_vec_sorted, target_vec_sorted = self.reorder_support_vectors(src_vec, target_vec)
points = np.array(self.src_vals).reshape(-1, self.n_dimensions)
target_interp = interpn(src_vec_sorted, target_vec_sorted, points, method='linear', bounds_error=False,
fill_value=None)
interp_err = np.sqrt(np.sum((self.target_values - target_interp) ** 2))
return interp_err
This function calculates the error between the target values and the output of the interp
function.
In other words, support_vector
is just a parameter that allows the interp
function to predict correctly, and there is no guarantee that it will be aligned with the Original Data
.
The following plots the output of the interp
function for randomly generated points.
def display_results(self, src_grid, target_vec):
plt.figure(figsize=(10, 8))
if src_grid.shape[1] == 1:
plt.plot(src_grid[:, 0], target_vec, "ro-", label="Optimized Support Points")
elif src_grid.shape[1] == 2:
ax = plt.axes(projection="3d")
ax.scatter(self.src_vals[:, 0], self.src_vals[:, 1], self.target_values, label="Original Data")
ax.scatter(src_grid[:, 0], src_grid[:, 1], target_vec.ravel(),
color="red", label="Optimized Support Points")
# Generate random points to predict.
x_min, x_max = self.src_vals[:, 0].min(), self.src_vals[:, 0].max()
y_min, y_max = self.src_vals[:, 1].min(), self.src_vals[:, 1].max()
n = 1000
points = np.random.default_rng(0).random((n, 2)) * ([x_max - x_min, y_max - y_min]) + [x_min, y_min]
# Predict (interpolate) the data for the random points.
src_vec = src_grid[:self.src_resolution[0], 0], src_grid[::self.src_resolution[0], 1]
predicted_data = interpn(src_vec, target_vec, points, method="linear", bounds_error=False,
fill_value=None)
# Plot the predicted data.
ax.scatter(points[:, 0], points[:, 1], predicted_data.ravel(),
color="orange", label="Predicted Data")
else:
raise ValueError("Plotting for dimensions higher than 2 is not supported.")
plt.legend()
plt.show()
Here is the result. The orange dots are predicted data using support points. As I will explain later, please ignore the support points for now.
As you can see, the results are fairly good.
The second problem is much simple. There is a bug in the plot.
target_vec
should be transposed.
ax.scatter(src_grid[:, 0], src_grid[:, 1], target_vec.T.ravel(),
color='red', label='Optimized Support Points')
I believe this is (roughly) what you were expecting.
Note that if you analyze the result carefully, you will find that half of the support points are not being used because they are outside the area. This may be solved by limiting the range of the support points or by allowing the error function to evaluate points outside the source grid. But this is something you will have to experiment with, so I'll leave it at that.
As a side note, using the nearest
interpolation instead of linear
made it easier to understand what is happening, and easier to debug.
This may also help your further development :)