Search code examples
pythonnumpyphysics

Numerically finding a curve that satisfies an underdetermined set of equations


[partial cross-post from Mathematica Stack Exchange. Opening this here so I can in principal solicit non-mathematica solutions ideally in the numpy/scipy ecoystem here.]

I am trying to determine a phase diagram along the lines of Fig 3. or Fig 4. of Deviri and Safran (2021):

figure 3

figure 4

The curves plotted are binodals and are determined by the conditions of equality of chemical potentials of each component in all phases and equality of osmotic pressure in each phase (Section 4 of the SI is a good primer).

As an example, I use free energy function similar to the one defined in the paper (although the one I am actually interested in is a bit more complex. Ideally I can come up with a solution that is robust and generally applicable). This is still a good starting point because my system does reduce to this under certain limits.

We narrow our attention to the case of two solutes (1 solvent) and 2 phases. Then we wish to find the locus of points in ($\phi_a,\phi_b$) for which the following conditions hold

\mu_a^{(1)} (\phi^{(1)}_a,\phi^{(1)}_b) = \mu_a^{(2)} (\phi^{(2)}_a,\phi^{(2)}_b)

\mu_b^{(1)} (\phi^{(1)}_a,\phi^{(1)}_b) = \mu_b^{(2)} (\phi^{(2)}_a,\phi^{(2)}_b)

\Pi^{(1)} (\phi^{(1)}_a,\phi^{(1)}_b) = \Pi^{(2)} (\phi^{(2)}_a,\phi^{(2)}_b)

Essentially, we have 4 unknowns $\phi^{(1)}_a,\phi^{(1)}_b, \phi^{(0)}_a,\phi^{(0)}_b$ and 3 equations. The solutions should look like the curves present in Fig 3 or Fig 4 of the linked publication (depending on parameter values). Subscripts here in my notation label the components and superscripts the phase. The conditions described above should be equivalent to equations (39)-(41) in this SI

The quantities in the above equations can be obtained using derivatives of the free energy, an example in Wolfram Mathematica.

     F[a_, b_, uaa_, ubb_, uab_, na_, nb_] := 
     a/na*Log[a] + b/nb*Log[b] + (1 - a - b)*Log[1 - a - b + 10^-14] - 
      1/2*uaa*a^2 - 1/2*ubb*b^2 - uab*b*a 

    det[a_, b_, uaa_, ubb_, uab_, na_, nb_] := 
     Det[D[F[a, b, uaa, ubb, uab, na, nb], {{a, b}, 2}]] // Evaluate
    \[Mu]a[a_, b_, uaa_, ubb_, uab_, na_, nb_] := 
     D[F[a, b, uaa, ubb, uab, na, nb], a] // Evaluate
    \[Mu]b[a_, b_, uaa_, ubb_, uab_, na_, nb_] := 
     D[F[a, b, uaa, ubb, uab, na, nb], b] // Evaluate
   \[CapitalPi]F[a_, b_, uaa_, ubb_, uab_, na_, 
  nb_] := \[Mu]a[a, b, uaa, ubb, uab, na, nb]*
    a + \[Mu]b[a, b, uaa, ubb, uab, na, nb]*b - 
   F[a, b, uaa, ubb, uab, na, nb] // Evaluate

My janky brute-force solution attempt in numpy/python, since I was running low on clever ideas.

import numpy as np
from numpy import log as log
import matplotlib.pyplot as plt

Some helper functions.

def approx_first_derivative(f, x0, y0, h=1e-7):
    '''
    chemical potentials are first derivs of  F. 
    Quick and dirty finite-diff based first derivative.
    '''

    fx = (f(x0 + h, y0) - f(x0 - h, y0)) / (2 * h)
    fy = (f(x0, y0 + h) - f(x0, y0 - h)) / (2 * h)
    
    return fx, fy


def approx_second_derivative(f, x0, y0, h=1e-7):
    '''
    spinodal line is the boundary of stab. for F. 
    can be obtained from det(Hessian(F)) = 0. Finite diff. 
    Quick and dirty finite-diff based second derivative.
    '''
    # Approximate second order partial derivatives using finite differences
    fxx = (f(x0 + h, y0) - 2 * f(x0, y0) + f(x0 - h, y0)) / h**2
    fyy = (f(x0, y0 + h) - 2 * f(x0, y0) + f(x0, y0 - h)) / h**2
    fxy = (f(x0 + h, y0 + h) - f(x0 + h, y0 - h) - f(x0 - h, y0 + h) + f(x0 - h, y0 - h)) / (4 * h**2)
    
    return np.array([[fxx, fxy], [fxy, fyy]])

Define physical quantities

def F(a,b,uaa,ubb,uab,na,nb):
    entropy = (a/na)*log(a+1e-14)+(b/nb)*log(b+1e-14)+(1-a-b)*log(1-a-b+1e-14)
    energy = -0.5*uaa*a**2 - 0.5*ubb*b**2 -uab*a*b
    return entropy+energy

def mu_a(a, b, uaa, ubb, uab, na, nb):
    fx, _ = approx_first_derivative(lambda x, y: F(x, y, uaa, ubb, uab, na, nb), a, b)
    return fx

def mu_b(a, b, uaa, ubb, uab, na, nb):
    _, fy = approx_first_derivative(lambda x, y: F(x, y, uaa, ubb, uab, na, nb), a, b)
    return fy

def Pi(a, b, uaa, ubb, uab, na=10, nb=6):
    return mu_a(a, b, uaa, ubb, uab, na, nb) * a + mu_b(a, b, uaa, ubb, uab, na, nb) * b - F(a, b, uaa, ubb, uab, na, nb)

def det_approx(a, b, uaa, ubb, uab, na, nb):
    H = approx_second_derivative(lambda x, y: F(x, y, uaa, ubb, uab, na, nb), a, b)
    return np.linalg.det(H)

Now essentially we seek pairs of points in the x,y plane such that mu_a, mu_b and Pi as defined above are all equal. I.e for 2 points on this curve.

\mu_a^{(1)} (\phi^{(1)}_a,\phi^{(1)}_b) = \mu_a^{(2)} (\phi^{(2)}_a,\phi^{(2)}_b)

\mu_b^{(1)} (\phi^{(1)}_a,\phi^{(1)}_b) = \mu_b^{(2)} (\phi^{(2)}_a,\phi^{(2)}_b)

\Pi^{(1)} (\phi^{(1)}_a,\phi^{(1)}_b) = \Pi^{(2)} (\phi^{(2)}_a,\phi^{(2)}_b)

I just try to brute force and search for this. I have shared my attempts below.

"""
Very janky brute force grid search.
Most likely not the right way to do this.
"""
epsilon = 1e-1  
uaa,ubb,uab=0,0,4.36


satisfying_points= []
mypoints = set()
# Check each pair of points
for i in np.arange(0,1,epsilon):
    for j in np.arange(0,1,epsilon):
        for k in np.arange(i+epsilon,1,epsilon):
            for l in np.arange(j+epsilon,1,epsilon):
                #print((i,j),(k,l))
                if abs(mu_a(i,j,uaa,ubb,uab,10,6) - mu_a(k,l,uaa,ubb,uab,10,6)) <= epsilon and \
                   abs(mu_b(i,j,uaa,ubb,uab,10,6) - mu_b(k,l,uaa,ubb,uab,10,6)) <= epsilon and \
                   abs(Pi(i,j,uaa,ubb,uab,10,6) - Pi(k,l,uaa,ubb,uab,10,6)) <= epsilon:
                    mypoints.add((i, j))
                    mypoints.add((k, l))
                    #break  

satisfying_points = np.array([point for point in mypoints])

plt.scatter(satisfying_points[:, 0], satisfying_points[:, 1], s=10, color='blue')
plt.xlabel('x')
plt.ylabel('y')
plt.xlim(0,1)
plt.ylim(0,1)
plt.title('Points satisfying the conditions')
plt.grid(True)
plt.show()

And

"""
Attempting to vectorise the janky brute force solution.
Most likely not the right way to do this.
"""

eps = 1e-2
thresh = 1e-2

x_vals = np.arange(0, 0.4, eps)
y_vals = np.arange(0, 0.4, eps)

satisfying_points_x = []
satisfying_points_y = []
uaa,ubb,uab=0,0,4.36

# Compute the mu_a, mu_b, and Pi values for all points in the grid
mu_a_values = mu_a(x_vals[:, np.newaxis], y_vals, uaa, ubb, uab, 10, 6)
mu_b_values = mu_b(x_vals[:, np.newaxis], y_vals, uaa, ubb, uab, 10, 6)
Pi_values = Pi(x_vals[:, np.newaxis], y_vals, uaa, ubb, uab, 10, 6)

# Loop over the grid
for idx_i, i in enumerate(x_vals):
    for idx_j, j in enumerate(y_vals):
        
        # Compare current point's values with all subsequent points using vectorization
        diff_mu_a = np.abs(mu_a_values[idx_i, idx_j] - mu_a_values[idx_i + 1:, idx_j + 1:])
        diff_mu_b = np.abs(mu_b_values[idx_i, idx_j] - mu_b_values[idx_i + 1:, idx_j + 1:])
        diff_Pi = np.abs(Pi_values[idx_i, idx_j] - Pi_values[idx_i + 1:, idx_j + 1:])
        
        # Find indices where all conditions are satisfied
        match_indices = np.where((diff_mu_a <= thresh) & (diff_mu_b <= thresh) & (diff_Pi <= thresh))
        
        if match_indices[0].size > 0:
            # If we found matches, add the points to our list
            satisfying_points_x.extend([i, x_vals[match_indices[0][0] + idx_i + 1]])
            satisfying_points_y.extend([j, y_vals[match_indices[1][0] + idx_j + 1]])
            #break

# Convert the lists to arrays for plotting
satisfying_points_x = np.array(satisfying_points_x)
satisfying_points_y = np.array(satisfying_points_y)

x = np.linspace(0, 1, 100)
y = np.linspace(0, 1, 100)
X, Y = np.meshgrid(x, y)

Z_approx = np.array([[det_approx(x_ij, y_ij, uaa,ubb,uab, 10, 6) for x_ij, y_ij in zip(x_row, y_row)] for x_row, y_row in zip(X, Y)])


plt.contour(X, Y, Z_approx, levels=[0], colors='red')
plt.scatter(satisfying_points_x, satisfying_points_y, s=10, color='blue')
plt.xlabel('x')
plt.ylabel('y')
plt.xlim(0,1)
plt.ylim(0,1)
plt.grid(True)
plt.show()

For some parameter values e.g. this gets me a large number of points in the neighborhood. uaa,ubb,uab=0,0,4.36 of the correct curve (should be similar to Fig. 4 of the paper). Fails to work at all for uaa,ubb,uab=1.8,0,0(should be similar to Fig. 3 of the paper).

I attach a quick visualisation I get from my code. In red I plot the spinodal. the binodal curve that I am interested in should be around this touching at critical points. In blue are the points I have been able to identify using this grid search approach.

enter image description here

EDIT: Second Approach

I was quite inspired by @Reinderien tremendous effort. I think his algorithm and approach seems pretty good, yet, the the solution curve admitting a critical point in the interior of the spinodal (det(Hessian(F)=0) line is decidedly unphysical in my understanding of the problem. It is also close but not exactly the same as the published result.

My second attempt phrases the problem as an optimisation problem.

import numpy as np
from scipy.optimize import minimize
from scipy.optimize import NonlinearConstraint
uaa, ubb, uab = 0, 0, 4.36
na, nb = 10, 6

def objective_fun(x):
    x1, y1, x2, y2 = x
    term1 = mu_a(x1, y1, uaa, ubb, uab, na, nb) - mu_a(x2, y2, uaa, ubb, uab, na, nb)
    term2 = mu_b(x1, y1, uaa, ubb, uab, na, nb) - mu_b(x2, y2, uaa, ubb, uab, na, nb)
    term3 = Pi(x1, y1, uaa, ubb, uab, na, nb) - Pi(x2, y2, uaa, ubb, uab, na, nb)
    return term1**2 + term2**2 + term3**2

# Define a nonlinear constraint to ensure x1 != x2 and y1 != y2
#def constraint(x):
#    return (x[0] - x[2])**2 + (x[1] - x[3])**2 - 1e-4  # Should be greater than 0

#nonlinear_constraint = NonlinearConstraint(constraint, 1e-4, np.inf)

# Create a grid of initial guesses
x1_values = np.linspace(0, 0.4, 20)  # Adjust the range and number of points as needed
y1_values = np.linspace(0, 0.4, 20)
x2_values = np.linspace(0, 0.4, 20)
y2_values = np.linspace(0, 0.4, 20)

solutions = []

for x1 in x1_values:
    for y1 in y1_values:
        for x2 in x2_values:
            for y2 in y2_values:
                initial_guess = [x1, y1, x2, y2]
                
                # Perform the optimization
                #sol = minimize(objective_fun, initial_guess, constraints=[nonlinear_constraint], bounds=[(0, 1), (0, 1), (0, 1), (0, 1)])
                sol = minimize(objective_fun, initial_guess, bounds=[(0, 1), (0, 1), (0, 1), (0, 1)])
    
                # Extract the solution
                x1_sol, y1_sol, x2_sol, y2_sol = sol.x
                
                # Check if the solver was successful
                if sol.success:
                    # Check if x1 != x2 and y1 != y2
                    if np.abs(x1_sol - x2_sol) > 1e-4 and np.abs(y1_sol - y2_sol) > 1e-4:
                        # Check if the chemical potentials are close to each other
                        if np.abs(mu_a(x1_sol, y1_sol, uaa, ubb, uab, na, nb) - mu_a(x2_sol, y2_sol, uaa, ubb, uab, na, nb)) <= 1e-4 and \
                           np.abs(mu_b(x1_sol, y1_sol, uaa, ubb, uab, na, nb) - mu_b(x2_sol, y2_sol, uaa, ubb, uab, na, nb)) <= 1e-4 and \
                           np.abs(Pi(x1_sol, y1_sol, uaa, ubb, uab, na, nb) - Pi(x2_sol, y2_sol, uaa, ubb, uab, na, nb)) <= 1e-4:
                            print(f"Found a solution at {sol.x} with initial guess {initial_guess}")
                            solutions.append(sol)
                else:
                    continue
                    #print(f"No with initial guess {initial_guess}")

And produce a plot

import matplotlib.pyplot as plt

# Initialize lists to store x1, y1, x2, y2 values
x1_vals = []
y1_vals = []
x2_vals = []
y2_vals = []

# Extract the values from the solutions
for sol in solutions:
    x1, y1, x2, y2 = sol.x
    x1_vals.append(x1)
    y1_vals.append(y1)
    x2_vals.append(x2)
    y2_vals.append(y2)

Z_approx = np.array([[det_approx(x_ij, y_ij, 0,0, 4.36, 10, 6) for x_ij, y_ij in zip(x_row, y_row)] for x_row, y_row in zip(X, Y)])

# Create the scatter plot
plt.figure(figsize=(10, 10))
plt.scatter(x1_vals, y1_vals, c='blue', label='(x1, y1)', alpha=0.6)
plt.scatter(x2_vals, y2_vals, c='red', label='(x2, y2)', alpha=0.6)
plt.contour(X, Y, Z_approx, levels=[0], colors='blue')
# Add labels and title
plt.xlabel('x')
plt.ylabel('y')
plt.title('Scatter Plot of Solutions')
plt.legend()

# Show the plot
plt.show()

Which is again close, but not identical to Fig 4. of the linked publication.

enter image description here

The spinodal is plotted in the blue solid line, the 2 critical points (marked X in Fig 4 of the publication) occur close to where the binodal is supposed to touch the spinodal so that is encouraging. However, the solution to the curve along the direction (x=y) is not very good.

Q1 : What would the correct way to solve this problem essentially, finding a curve that satisfies an underdetermined set of non-linear equations. Ideally the numerical method is efficient and robust, and applicable to F similar but not identical to the toy-version described here. The ideal solution may use any standard packages/algorithms/canned approaches. I just am not aware of them.


Solution

  • This was difficult, and I'm not confident that my solution is the most efficient; but it does basically work. Roughly,

    1. Within the centre of the radial surface pi, find the location of its minimum.
    2. Start a broad search referenced off of that minimum location for any pair with low error, using one call to minimize.
    3. From the origin of that pair, branch left and right with a bounded-follow routine. Terminate if any of these becomes true:
    • Curve runs out of bounds.
    • Error between components exceeds tolerance.
    • Pair converges to a critical point.
    from typing import NamedTuple
    
    import pandas as pd
    import scipy.optimize
    import numpy as np
    import matplotlib.pyplot as plt
    
    
    class Params(NamedTuple):
        uaa: float
        ubb: float
        uab: float
        na: int
        nb: int
    
    
    def F(
        phi: np.ndarray,
        uaa: float, ubb: float, uab: float, na: int, nb: int,
    ) -> np.ndarray:
        """
        Helmholtz-free energy for a Flory-Huggins-like model
    
        :param phi: (2xAxB) array of phi independent coordinates
        :return: AxB array of function values
        """
        coef_shape = (2,) + (1,)*(len(phi.shape) - 1)
        nanb = np.array((na, nb)).reshape(coef_shape)
        uaub = np.array((uaa, ubb)).reshape(coef_shape)
        tot = phi.sum(axis=0)
    
        entropy = (
            (phi/nanb * np.log(phi)).sum(axis=0)
            + (1 - tot)*np.log(1 - tot)
        )
        energy = (
            phi**2 * (-0.5 * uaub)
        ).sum(axis=0) - uab * phi.prod(axis=0)
        return entropy + energy
    
    
    def gradF(
        phi: np.ndarray,
        uaa: float, ubb: float, uab: float, na: int, nb: int,
    ) -> np.ndarray:
        """
        First-order Jacobian of F in phi.
        :param phi: (2xAxB) array of phi independent coordinates
        :return: (2xAxB) array of gradient vectors
        """
        coef_shape = (2,) + (1,)*(len(phi.shape) - 1)
        nanb = np.array((na, nb)).reshape(coef_shape)
        uaub = np.array((uaa, ubb)).reshape(coef_shape)
        return (
            (np.log(phi) + 1) / nanb
            - np.log(1 - phi.sum(axis=0))
            - phi * uaub
            - phi[::-1] * uab
            - 1
        )
    
    
    def hessianF(
        phi: np.ndarray,
        uaa: float, ubb: float, uab: float, na: int, nb: int,
    ) -> np.ndarray:
        """
        Hessian of F in phi.
        :param phi: (2xAxB) array of phi independent coordinates
        :return: (2x2xAxB) array of Hessian matrices
        """
        coef_shape = (2,) + (1,)*(len(phi.shape) - 1)
        nanb = np.array((na, nb)).reshape(coef_shape)
        uaub = np.array((uaa, ubb)).reshape(coef_shape)
    
        diag = 1/phi/nanb - uaub
        antidiag = -uab
        hessian = np.empty(shape=(2, 2, *phi.shape[1:]), dtype=phi.dtype)
        hessian[[0, 1], [0, 1], ...] = diag
        hessian[[0, 1], [1, 0], ...] = antidiag
        hessian += 1/(1 - phi.sum(axis=0))
    
        return hessian
    
    
    def test_grad(params: Params) -> None:
        """
        Sanity test to ensure that analytic gradients are correct.
        """
        xk = np.array((0.2, 0.3))
        assert scipy.optimize.check_grad(F, gradF, xk, *params) < 1e-6
    
        estimated = scipy.optimize.approx_fprime(xk, gradF, 1e-9, *params)
        analytic = hessianF(xk, *params)
        assert np.allclose(estimated, analytic, rtol=0, atol=1e-6)
    
    
    def mu(
        phi: np.ndarray,
        uaa: float, ubb: float, uab: float, na: int, nb: int,
    ) -> np.ndarray:
        """
        Chemical potential (first-order Jacobian of F)
        """
        return gradF(phi, uaa, ubb, uab, na, nb)
    
    
    def Pi(phi: np.ndarray, params: Params) -> np.ndarray:
        """Osmotic pressure"""
        mu_ab = mu(phi, *params)
        return (mu_ab * phi).sum(axis=0) - F(phi, *params)
    
    
    def plot_isodims(
        phi: np.ndarray,
        mua: np.ndarray,
        mub: np.ndarray,
        pi: np.ndarray,
    ) -> plt.Figure:
        """
        Contour plot of individual mua/b/pi series, all over phi_a and phi_b
        """
        fig: plt.Figure
        axes: tuple[plt.Axes, ...]
        fig, axes = plt.subplots(nrows=1, ncols=3)
        fig.suptitle('Individual series')
    
        (ax_mua, ax_mub, ax_pi) = axes
        for ax in axes:
            ax.set_xlabel('phi a')
            ax.grid(True)
        ax_mua.set_ylabel('phi b')
        ax_mub.set_yticklabels(())
        ax_pi.set_yticklabels(())
    
        ax_mua.contour(*phi, mua)
        ax_mua.set_title('mu a')
        ax_mub.contour(*phi, mub)
        ax_mub.set_title('mu b')
        ax_pi.contour(*phi, pi)
        ax_pi.set_title('pi')
    
        return fig
    
    
    def plot_example_proximity(
        phi: np.ndarray,  # 2x500x500
        mua: np.ndarray,  # 500x500
        mub: np.ndarray,  # 500x500
        pi: np.ndarray,   # 500x500
        params: Params,
        phi_pimin: np.ndarray,      # 2x
        phi_endpoints: np.ndarray,  # 2x2
        zapprox: np.ndarray,        # 500x500
        followed_phi: np.ndarray,   # 2x2x343-ish
    ) -> plt.Figure:
        """
        Plot, for the initial-fit origin point (1) and its pair (2), all three isocurves
        Include the Hessian determinant estimation (zapprox)
        """
        phi0 = phi_endpoints[:, 0]  # origin point (1)
        phi1 = phi_endpoints[:, 1]  # next point (2)
        mua0, mub0 = mu(phi0, *params)   # mu values at origin
        pi0 = Pi(phi0, params)  # pi value at origin
    
        # Array indices where mu closely matches that at the origin
        mua_close_y, mua_close_x = (np.abs(mua - mua0) < 1e-3).nonzero()
        mub_close_y, mub_close_x = (np.abs(mub - mub0) < 1e-3).nonzero()
    
        # Error from the pi value seen at the origin
        pi_close = pi - pi0
    
        fig: plt.Figure
        ax: plt.Axes
        fig, ax = plt.subplots()
        ax.set_title('Isocurves from initial fit origin to next,'
                     '\nwith Hessian determinant')
    
        ax.scatter(*phi0, c='black', marker='+', s=100, label='origin (1)')
        ax.scatter(*phi1, c='pink', marker='+', s=100, label='next (2)')
        ax.scatter(*phi_pimin, c='blue', marker='+', s=100, label='pi_min')
        ax.plot(phi[0, 0, mua_close_x], phi[0, 0, mua_close_y], label='mua')
        ax.plot(phi[1, mub_close_x, 0], phi[1, mub_close_y, 0], label='mub')
        ax.contour(*phi, pi_close, levels=[0], colors='purple')  # Pi isocurve
        ax.contour(*phi, zapprox, levels=[0], colors='brown')  # Hessian determinant
        ax.plot(*followed_phi[:, 0, :], label='follow1')
        ax.plot(*followed_phi[:, 1, :], label='follow2')
    
        ax.legend()
    
        return fig
    
    
    def pairwise_error(phi: np.ndarray, params: Params) -> float:
        """
        For a phi pair, calculate the error between pair components.
        :param phi: 2x2, a1 a2 b1 b2
        :return: Least-squares error between mu_a, mu_b and pi
        """
        mua, mub = mu(phi, *params)
        pi = Pi(phi, params)
        return (
            (mua[0] - mua[1])**2 +
            (mub[0] - mub[1])**2 +
            10 * (pi[0] - pi[1])**2
        )
    
    
    def least_sq_difference(phi: np.ndarray, params: Params) -> float:
        """
        From given phi endpoints, calculate exact (not estimated) values for mu and pi, and return
        their least-squared distance as the objective cost
        """
        phi = phi.reshape((2, 2))
        return pairwise_error(phi, params)
    
    
    def initial_pair_estimate(
        phi_min: float, phi_max: float,
        phi_pimin: np.ndarray,  # 2x array
        params: Params,
    ) -> np.ndarray:
        """
        Perform an initial fit to find any pair on the target curve
    
        :param phi_min: Minimum search space bound of phi in both axes
        :param phi_max: Maximum search space bound of phi in both axes
        :param phi_pimin: Phi coordinates of minimum of pi; the search space is centred here
        :param params: uaa, ubb, uab, na, nb
        :return: endpoint coordinates in phi (2x2)
        """
        phia_pimin, phib_pimin = phi_pimin
    
        # As in original problem construction, the first point must be below-left of the second point;
        # and the distance must be greater than 0 to avoid degeneracy (superimposed pair).
        nondegenerate = scipy.optimize.LinearConstraint(
            A=[
                [-1, 1,  0, 0],
                [ 0, 0, -1, 1],
            ],
            lb=0.05,
        )
    
        result = scipy.optimize.minimize(
            fun=least_sq_difference, args=(params,),
            x0=(
                phia_pimin,   phia_pimin*3/2,
                phib_pimin/2, phib_pimin,
            ),
            # The origin is below-left of the minimum location of pi
            bounds=scipy.optimize.Bounds(
                lb=[phi_min,        phi_min,    phi_min,    phi_min],
                ub=[phia_pimin*0.5, phi_max,    phib_pimin, phi_max],
            ),
            constraints=nondegenerate,
            tol=1e-12,
        )
        if not result.success:
            raise ValueError(result.message)
    
        endpoints = result.x.reshape((2, 2))
        print(f'Initial pair found with error={result.fun:.2e}:\n{endpoints}')
        return endpoints
    
    
    def characterise_initial(
        params: Params,
        phi_min: float = 0,
        phi_max: float = 1,
    ) -> tuple[
        np.ndarray,  # mesh-like phi coordinate space, 2xAxB
        np.ndarray,  # mua, AxB
        np.ndarray,  # mub, AxB
        np.ndarray,  # pi, AxB
        np.ndarray,  # Hessian determinant estimate, AxB
    ]:
        """
        Fill out a bunch of initial dimension information
        """
        phi_a = phi_b = np.linspace(phi_min, phi_max, 500)
        phi = np.stack(np.meshgrid(phi_a, phi_b))
        mua, mub = mu(phi, *params)
        pi = Pi(phi, params)
    
        hess = hessianF(phi, *params)
        zapprox = np.linalg.det(hess.transpose(2, 3, 0, 1))
    
        return phi, mua, mub, pi, zapprox
    
    
    def estimate_pimin(phi: np.ndarray, pi: np.ndarray) -> np.ndarray:
        """
        Start referenced from the (estimated) minimum of pi, in the middle of the region of interest
        The Hessian estimate runs through this point.
        """
        ijmin = np.unravel_index(pi.argmin(), pi.shape)
        phi_pimin = phi[:, ijmin[0], ijmin[1]]
        print(f'Pi minimum point of {pi[ijmin]:.5f} at {phi_pimin}')
        return phi_pimin
    
    
    def update_bounds(
        phi_old: np.ndarray,
        phi_new: np.ndarray,
        step: float,
    ) -> tuple[
        np.ndarray,  # new solution estimate x0
        np.ndarray,  # new lower bounds
        np.ndarray,  # new upper bounds
    ]:
        """
        After a follow-solve step, move the bounds. These bounds constitute a small moving rectangle
        with some wiggle room centered on the next place we guess a solution point will appear (using
        linear extrapolation). This makes it so the inner minimizer doesn't need to work as hard.
        :param phi_old: Previous phi, 2x2
        :param phi_new: New phi, 2x2
        :param step: Approximate inter-point distance
        :return: Updated guess for next solution, and bounds for next minimize call
        """
        delta = phi_new - phi_old
    
        # Determine which is the major axis of motion; assign the other to have more freedom
        if abs(delta[0, 0]) > abs(delta[0, 1]):
            wide_axis = 1
        else:
            wide_axis = 0
    
        # Put the next guess roughly at distance 'step' in direction 'delta'
        scale = step / np.sqrt(delta[0, 0]**2 + delta[1, 0]**2)
        x0 = phi_new + delta * scale
    
        # One new bound is halfway between points 'new' and 'x0'. The other is past x0.
        boundp0 = (x0[:, 0] + phi_new[:, 0])/2
        boundq0 = boundp0 + delta[:, 0]*scale
    
        lobound = np.array((
            (min(boundp0[0], boundq0[0]), x0[0, 1] - 10*step),
            (min(boundp0[1], boundq0[1]), x0[1, 1] - 10*step),
        ))
        hibound = np.array((
            (max(boundp0[0], boundq0[0]), x0[0, 1] + 10*step),
            (max(boundp0[1], boundq0[1]), x0[1, 1] + 10*step),
        ))
    
        # For the off-axis (not the major axis of motion), bound width is wider
        lobound[wide_axis, 0] -= 10*step
        hibound[wide_axis, 0] += 10*step
    
        # Keep step bounds within overall bounds
        lobound = lobound.clip(min=0)
        hibound = hibound.clip(max=1)
    
        return x0, lobound, hibound
    
    
    def check_critical(
        direction: int,
        phi: np.ndarray,
        step: float,
        params: Params,
    ) -> bool:
        """
        If the two solution points have converged on one location, this is a critical point; report it
        and tell the outer optimizer to terminate this branch.
        :param direction: 1 or -1 (right or left)
        :param phi: New solution point, 2x2
        :param step: Approximate inter-point distance
        :param params: Passed to the chemical routines
        :return: True if this is a critical point.
        """
        if not np.all(
            np.abs(phi[:, 0] - phi[:, 1]) < step
        ):
            return False
    
        print(f'Convergence in direction {direction}: critical point at')
        df = pd.DataFrame(
            {
                'phia': phi[0],
                'phib': phi[1],
                'mua': mu(phi, *params)[0],
                'mub': mu(phi, *params)[1],
                'pi': Pi(phi, params),
            },
        )
        print(df.T)
        return True
    
    
    def follow(
        phi_endpoints: np.ndarray,
        params: Params,
        step: float = 1e-3,
        tol: float = 1e-9,
        maxiter: int = 1000,
    ) -> np.ndarray:
        """
        Multidimensional, stepped following algorithm.
        :param phi_endpoints: Initial phi search point; will branch in two directions from here
        :param params: Passed to chemical routines
        :param step: Roughly, distance between points. Actual distance will vary
        :param tol: If we follow the curve to a point where the error between components exceeds this
                    tolerance, we bail.
        :param maxiter: Upper limit on the number of points per branch
        :return: Array of phi values, separated by branch.
        """
    
        phi_halves = []
        error_min = np.inf
        error_max = -np.inf
    
        # Search left and right from initial-fit origin
        for direction in (-1, +1):
            phi_old = phi_endpoints.copy()
            phi_built = [phi_old]
            phi_halves.append(phi_built)
            off0 = direction*step
            off1 = off0/2
            lobound = phi_endpoints + (
                (min(off0, off1), -step),
                (-step,           -step),
            )
            hibound = phi_endpoints + (
                (max(off0, off1), step),
                (step,            step),
            )
            x0 = (lobound + hibound)/2
    
            for _ in range(maxiter):
                result = scipy.optimize.minimize(
                    fun=least_sq_difference, args=(params,),
                    x0=x0.ravel(), tol=1e-12,
                    bounds=scipy.optimize.Bounds(lb=lobound.ravel(), ub=hibound.ravel()),
                )
                phi_new = result.x.reshape(2,2)
                if not result.success:
                    # This is not fatal in many cases.
                    if result.message != 'ABNORMAL_TERMINATION_IN_LNSRCH':
                        print(result.message)
                        break
                if result.fun > tol:
                    print(f'Search direction {direction}: out of tol; terminating')
                    break
    
                error_max = max(result.fun, error_max)
                error_min = min(result.fun, error_min)
                phi_built.append(phi_new)
    
                if check_critical(direction, phi_new, step, params):
                    break
    
                x0, lobound, hibound = update_bounds(phi_old, phi_new, step)
                if np.any(hibound - lobound <= 0):
                    print(f'Search direction {direction}: out of bounds; terminating')
                    break
    
                phi_old = phi_new
    
        print(f'Follow complete; error within {error_min:.2e} and {error_max:.2e}')
    
        phi_series = np.stack(
            phi_halves[0][::-1] + phi_halves[1],
            axis=-1,
        )
        return phi_series
    
    
    def main() -> None:
        params = Params(uaa=0, ubb=0, uab=4.36, na=10, nb=6)
        phi_min, phi_max = 1e-6, 0.4
    
        test_grad(params)
    
        phi, mua, mub, pi, zapprox = characterise_initial(params, phi_min, phi_max)
        phi_pimin = estimate_pimin(phi, pi)
        phi_endpoints = initial_pair_estimate(phi_min, phi_max, phi_pimin, params)
    
        followed_phi = follow(phi_endpoints, params)
    
        # plot_isodims(phi, mua, mub, pi)
        plot_example_proximity(
            phi, mua, mub, pi, params, phi_pimin,
            phi_endpoints, zapprox, followed_phi,
        )
        plt.show()
    
    
    if __name__ == '__main__':
        main()
    
    Pi minimum point of -0.01350 at [0.24448937 0.22925894]
    Initial pair found with error=7.35e-13:
    [[0.12118504 0.35659973]
     [0.0191144  0.19124418]]
    Convergence in direction -1: critical point at
                 0         1
    phia  0.058309  0.058303
    phib  0.340558  0.340551
    mua  -2.160096 -2.160096
    mub  -0.758148 -0.758148
    pi    0.086084  0.086084
    Convergence in direction 1: critical point at
                 0         1
    phia  0.302626  0.302102
    phib  0.069781  0.069417
    mua  -0.857908 -0.857908
    mub  -2.130651 -2.130651
    pi    0.043277  0.043276
    Follow complete; error within 3.12e-17 and 4.33e-10
    

    Simple plot of all three components to eyeball how they work:

    component plot

    This one is more complicated:

    solution

    • Black cross: initial fit, origin half of pair
    • Pink cross: initial fit, second half of pair
    • Blue cross: location of minimum pi
    • Orange and blue lines: isocurves for the initial-fit pair, mu_a and mu_b
    • Purple curve: Isocurve for the initial-fit pair, pi
    • Brown loop: Hessian determinant
    • Green: origin-half of followed curve, running from cross branching left and right to critical points converging with red curve
    • Red: second half of followed curve, running from cross branching left and right to critical points converging with green

    With the most recent edit, remaining bugs fixed:

    • Step-bound for follow solution needs to be sliver-shaped, with a wide tolerance on the axis orthogonal to the major axis of motion
    • Hessian approximation had a transposition error, making things appear non-physical