Search code examples
algorithmmathequation

Solving a simultaneous equation through code


This seems like an incredibly simple and silly question to ask, but everything I've found about it has been too complex for me to understand.

I have these two very basic simultaneous equations:

X = 2x + 2z
Y = z - x

Given that I know both X and Y, how would I go about finding x and z? It's very easy to do it by hand, but I have no idea how this would be done in code.


Solution

  • This seems like an incredibly simple and silly question to ask

    Not at all. This is a very good question, and it has unfortunately a complex answer. Let's solve

    a * x + b * y = u
    c * x + d * y = v
    

    I stick to the 2x2 case here. More complex cases will require you to use a library.

    The first thing to note is that Cramer formulas are not good to use. When you compute the determinant

    a * d - b * c
    

    as soon as you have a * d ~ b * c, then you have catastrophic cancellation. This case is typical, and you must guard against it.

    The best tradeoff between simplicity / stability is partial pivoting. Suppose that |a| > |c|. Then the system is equivalent to

    a * c/a * x + bc/a * y = uc/a
          c * x +    d * y = v 
    

    which is

    cx + bc/a * y = uc/a
    cx +       dy = v  
    

    and now, substracting the first to the second yields

    cx +       bc/a * y = uc/a
         (d - bc/a) * y = v - uc/a
    

    which is now straightforward to solve: y = (v - uc/a) / (d - bc/a) and x = (uc/a - bc/a * y) / c. Computing d - bc/a is stabler than ad - bc, because we divide by the biggest number (it is not very obvious, but it holds -- do the computation with very close coefficients, you'll see why it works).

    Now, if |c| > |a|, you just swap the rows and proceed similarly.

    In code (please check the Python syntax):

    def solve(a, b, c, d, u, v):
        if abs(a) > abs(c):
             f = u * c / a
             g = b * c / a
             y = (v - f) / (d - g)
             return ((f - g * y) / c, y)
        else:
             f = v * a / c
             g = d * a / c
             x = (u - f) / (b - g)
             return (x, (f - g * x) / a)
    

    You can use full pivoting (requires you to swap x and y so that the first division is always by the largest coefficient), but this is more cumbersome to write, and almost never required for the 2x2 case.

    For the n x n case, all the pivoting stuff is encapsulated into the LU decomposition, and you should use a library for this.