Search code examples
pythonmatrixscipyequation

How to solve a set of matrix equations where more than 1 vectors are unknown


I am trying to solve a set of equations like follows with python:

X2 - C1*X3 = X1
X3 = k*(X4-C2)*sqrt(X2(1-X2))
AX4 = C3*b

where X2,X3 and X4 are unknown N-dimension(N is constant) vectors.C1,k,C2 are known constant scalars.b and X1 is known constant N-D vector.A is known N * N matrix.vector multiply vector means simply ' * ' in Python. sqrting a vector means every element in vector be square rooted. vector minus a scalar means every element minus it.

I tried fsolve in scipy but it doesn't accept 2-D inputs.what method can i use to solve this problem?


Solution

  • Don't leave all of those variables as degrees of freedom. All but one of them is a simple function. Choosing x2 as your sole degree of freedom leads to a simpler optimisation problem:

    import numpy as np
    import numpy.typing
    from typing import Annotated, TypeVar, Literal
    
    from scipy.optimize import fsolve
    
    DType = TypeVar('DType', bound=np.generic)
    Vector = Annotated[numpy.typing.NDArray[DType], Literal['N']]
    Matrix = Annotated[numpy.typing.NDArray[DType], Literal['N', 'N']]
    
    
    def implied_params(
        c1: float, c2: float, k: float,
        x1: Vector, x2: Vector,
    ) -> tuple[
        Vector,  # x3
        Vector,  # x4
    ]:
        # X2 - C1 * X3 = X1
        x3 = (x2 - x1)/c1
    
        # X3 = k * (X4 - C2) * sqrt(X2(1 - X2))
        x4 = (x3 / k / np.sqrt(x2*(1 - x2))) + c2
    
        return x3, x4
    
    
    def equations(
        x2: Vector,
        x1: Vector,
        A: Matrix, b: Vector,
        c1: float, c2: float, c3: float, k: float,
    ) -> float:
        x3, x4 = implied_params(c1=c1, c2=c2, k=k, x1=x1, x2=x2)
    
        # AX4 = C3 * b
        return A@x4 - c3*b
    
    
    def solve(
        A: Matrix, b: Vector,
        c1: float, c2: float, c3: float, k: float,
        x1: Vector,
    ) -> tuple[
        Vector,  # x2
        Vector,  # x3
        Vector,  # x4
    ]:
        x2 = fsolve(
            func=equations,
            args=(x1, A, b, c1, c2, c3, k),
            x0=np.full_like(x1, fill_value=0.5),
        )
        x3, x4 = implied_params(c1=c1, c2=c2, k=k, x1=x1, x2=x2)
        return x2, x3, x4
    
    
    def demo() -> None:
        n = 5
        rand = np.random.default_rng(seed=0)
        A = rand.random(size=(n, n))
        b, x1 = rand.random(size=(2, n))
        c1, c2, c3, k = rand.random(size=4)
        x2, x3, x4 = solve(A=A, b=b, c1=c1, c2=c2, c3=c3, k=k, x1=x1)
    
        print('A =')
        print(A)
        print('b =', b)
        print('  = A@x4 / c3')
        print('  =', A@x4/c3)
        print(f'c = {c1:.6}, {c2:.6}, {c3:.6}')
        print(f'k = {k:.6}')
        print('x1 =', x1)
        print('   = x2 - c1*x3')
        print('   =', x2 - c1*x3)
        print('x2 =', x2)
        print('x3 =', x3)
        print('   = k*(x4 - c2)*sqrt(x2(1 - x2))')
        print('   =', k*(x4 - c2)*np.sqrt(x2*(1 - x2)))
        print('x4 =', x4)
    
    
    if __name__ == '__main__':
        demo()
    
    A =
    [[0.63696169 0.26978671 0.04097352 0.01652764 0.81327024]
     [0.91275558 0.60663578 0.72949656 0.54362499 0.93507242]
     [0.81585355 0.0027385  0.85740428 0.03358558 0.72965545]
     [0.17565562 0.86317892 0.54146122 0.29971189 0.42268722]
     [0.02831967 0.12428328 0.67062441 0.64718951 0.61538511]]
    b = [0.38367755 0.99720994 0.98083534 0.68554198 0.65045928]
      = A@x4 / c3
      = [0.38367755 0.99720994 0.98083534 0.68554198 0.65045928]
    c = 0.310242, 0.485835, 0.889488
    k = 0.934044
    x1 = [0.68844673 0.38892142 0.13509651 0.72148834 0.52535432]
       = x2 - c1*x3
       = [0.68844673 0.38892142 0.13509651 0.72148834 0.52535432]
    x2 = [0.63693755 0.34375446 0.1581322  0.63977285 0.49016752]
    x3 = [-0.16602911 -0.14558628  0.07425077 -0.26339285 -0.11341731]
       = k*(x4 - c2)*sqrt(x2(1 - x2))
       = [-0.16602911 -0.14558628  0.07425077 -0.26339285 -0.11341731]
    x4 = [ 0.11619616  0.15766755  0.70370751 -0.10156709  0.24293609]