Search code examples
pythonnumpy

Fast(est) exponentiation of numpy 3D matrix


Q is a 3D matrix and could for example have the following shape:

(4000, 25, 25)

I want raise Q to the power n for {0, 1, ..., k} and sum it all. Basically, I want to calculate

\sum_{i=0}^{k-1}Q^n

I have the following function that works as expected:

def sum_of_powers(Q: np.ndarray, k: int) -> np.ndarray:
    Qs = np.sum([
        np.linalg.matrix_power(Q, n) for n in range(k)
    ], axis=0)

    return Qs

Is it possible to speed up my function or is there a faster method to obtain the same output?


Solution

  • We can perform this calculation in O(log k) matrix operations.

    Let M(k) represent the k'th power of the input, and S(k) represent the sum of those powers from 0 to k. Let I represent an appropriate identity matrix.

    Approach 1

    If you expand the product, you'll find that (M(1) - I) * S(k) = M(k+1) - I. That means we can compute M(k+1) using a standard matrix power (which takes O(log k) matrix multiplications), and compute S(k) by using numpy.linalg.solve to solve the equation (M(1) - I) * S(k) = M(k+1) - I:

    import numpy.linalg
    
    def option1(Q, k):
        identity = numpy.eye(Q.shape[-1])
        A = Q - identity
        B = numpy.linalg.matrix_power(Q, k+1) - identity
        return numpy.linalg.solve(A, B)
    

    Approach 2

    The standard exponentation by squaring algorithm computes M(2*k) as M(k)*M(k) and M(2*k+1) as M(2*k)*M(1).

    We can alter the algorithm to track both S(k-1) and M(k), by computing S(2*k-1) as S(k-1)*M(k) + S(k-1) and S(2*k) as S(2*k-1) + M(2*k):

    import numpy
    
    def option2(Q, k):
        identity = numpy.eye(Q.shape[-1])
    
        if k == 0:
            res = numpy.empty_like(Q)
            res[:] = identity
            return res
    
        power = Q
        sum_of_powers = identity
    
        # Looping over a string might look dumb, but it's actually the most efficient option,
        # as well as the simplest. (It wouldn't be the bottleneck even if it wasn't efficient.)
        for bit in bin(k+1)[3:]:
            sum_of_powers = (sum_of_powers @ power) + sum_of_powers
            power = power @ power
            if bit == "1":
                sum_of_powers += power
                power = power @ Q
        return sum_of_powers