Search code examples
pythonoptimizationlmfit

Efficient stitching of datasets


I have multiple measurement datasets that I want to combine to a single dataset. While I have a working solution, it is terribly inefficient and I would be happy for some tips on how I can improve it.

Think of the measurements as multiple height maps of one object that I want to combine to a single height map. My measurements are not perfect and may have some tilt and height offset. Let's assume (for now) that we know the x-y position perfectly accurate. Here is an example:

import numpy as np
import matplotlib.pyplot as plt

def height_profile(x, y):
    radius = 100
    return np.sqrt(radius**2-x**2-y**2)-radius

np.random.seed(123)

datasets = {}

# DATASET 1
x = np.arange(-8, 2.01, 0.1)
y = np.arange(-3, 7.01, 0.1)

xx, yy = np.meshgrid(x, y)
# height is the actual profile + noise
zz = height_profile(xx, yy) + np.random.randn(*xx.shape)*0.001

datasets[1] = [xx, yy, zz]

plt.figure()
plt.pcolormesh(*datasets[1])
plt.colorbar()

# DATASET 2
x = np.arange(-2, 8.01, 0.1)
y = np.arange(-3, 7.01, 0.1)

xx, yy = np.meshgrid(x, y)
# height is the actual profile + noise + random offset + random tilt
zz = height_profile(xx, yy) + np.random.randn(*xx.shape)*0.001 + np.random.rand() + np.random.rand()*xx*0.1 + np.random.rand()*yy*0.1

datasets[2] = [xx, yy, zz]

plt.figure()
plt.pcolormesh(*datasets[2])
plt.colorbar()

# DATASET 3
x = np.arange(-5, 5.01, 0.1)
y = np.arange(-7, 3.01, 0.1)

xx, yy = np.meshgrid(x, y)
# height is the actual profile + noise + random offset + random tilt
zz = height_profile(xx, yy) + np.random.randn(*xx.shape)*0.001 + np.random.rand() + np.random.rand()*xx*0.1 + np.random.rand()*yy*0.1

datasets[3] = [xx, yy, zz]

plt.figure()
plt.pcolormesh(*datasets[3])
plt.colorbar()

To combine the three (or more) datasets, I have the following strategy: Find the overlap between the datasets, calculate the summed-up height difference between datasets in the overlap regions (residual_overlap) and try to minimize the height differences (residual) using lmfit. To apply the transformations on the dataset (tilt, offset, etc.) I have a dedicated function.

from lmfit import minimize, Parameters
from copy import deepcopy
from itertools import combinations
from scipy.interpolate import griddata

def data_transformation(dataset, idx, params):
    dataset = deepcopy(dataset)
    
    if 'x_offset_{}'.format(idx) in params:
        x_offset = params['x_offset_{}'.format(idx)].value
    else:
        x_offset = 0

    if 'y_offset_{}'.format(idx) in params:
        y_offset = params['y_offset_{}'.format(idx)].value
    else:
        y_offset = 0
    
    if 'tilt_x_{}'.format(idx) in params:
        x_tilt = params['tilt_x_{}'.format(idx)].value
    else:
        x_tilt = 0

    if 'tilt_y_{}'.format(idx) in params:
        y_tilt = params['tilt_y_{}'.format(idx)].value
    else:
        y_tilt = 0

    if 'piston_{}'.format(idx) in params:
        piston = params['piston_{}'.format(idx)].value
    else:
        piston = 0

    _x = dataset[0] - np.mean(dataset[0])
    _y = dataset[1] - np.mean(dataset[1])

    dataset[0] = dataset[0] + x_offset
    dataset[1] = dataset[1] + y_offset
    dataset[2] = dataset[2] + 2 * (x_tilt * _x + y_tilt * _y) + piston

    return dataset

def residual_overlap(dataset_0, dataset_1):
    xy_0 = np.stack((dataset_0[0].flatten(), dataset_0[1].flatten()), axis=1)
    xy_1 = np.stack((dataset_1[0].flatten(), dataset_1[1].flatten()), axis=1)
    difference = griddata(xy_0, dataset_0[2].flatten(), xy_1) - \
                 dataset_1[2].flatten()

    return difference

def residual(params, datasets):
    datasets = deepcopy(datasets)

    for idx in datasets:
        datasets[idx] = data_transformation(
            datasets[idx], idx, params)

    residuals = []

    for combination in combinations(list(datasets), 2):
        residuals.append(residual_overlap(
            datasets[combination[0]], datasets[combination[1]]))

    residuals = np.concatenate(residuals)
    residuals[np.isnan(residuals)] = 0

    return residuals

def minimize_datasets(params, datasets, **minimizer_kw):
    minimize_fnc = lambda *args, **kwargs: residual(*args, **kwargs)

    datasets = deepcopy(datasets)

    min_result = minimize(minimize_fnc, params,
                          args=(datasets, ), **minimizer_kw)

    return min_result

I run the "stitching" like this:

params = Parameters()
params.add('tilt_x_2', 0)
params.add('tilt_y_2', 0)
params.add('piston_2', 0)
params.add('tilt_x_3', 0)
params.add('tilt_y_3', 0)
params.add('piston_3', 0)

fit_result = minimize_datasets(params, datasets)

plt.figure()
plt.pcolormesh(*data_transformation(datasets[1], 1, fit_result.params), alpha=0.3, vmin=-0.5, vmax=0)
plt.pcolormesh(*data_transformation(datasets[2], 2, fit_result.params), alpha=0.3, vmin=-0.5, vmax=0)
plt.pcolormesh(*data_transformation(datasets[3], 3, fit_result.params), alpha=0.3, vmin=-0.5, vmax=0)
plt.colorbar()

As you can see, it does work, but the stitching takes about a minute for these small datasets on my computer. In reality I have more and bigger datasets.

Do you see a way to improve the stitching performance?

Edit: As suggested, I ran a profiler and it shows that 99.5% of the time is spent in the griddata function. That one is used to interpolate datapoints from dataset_0 to the locations of dataset_1. If I switch method to "nearest", the execution time drops to about a second, but then there is no interpolation happening. Any chance to improve the speed of the interpolation?


Solution

  • Skimming through the code, I can't really see anywhere to improve other than you are running deepcopy() over and over again.

    However, I would recommend you to do profiling. If you are using pycharm, you can do profiling using the clock/run sign. img

    I am sure other IDEs also have such capabilities. This way you can figure out which function is taking the most time.

    Whole graph:

    enter image description here

    When I zoom in to a few functions (I am showing google cloud functions):

    enter image description here

    You can see how many times they are called and how long they took etc.

    Long story short, you need a profiler!