Search code examples
pythonnumpyscipydeterminants

Numpy/Scipy: Efficient Determinant of Gram Matrix


I need to compute the (log of the) determinant of the Gram matrix of a matrix A and I was wondering if there is a way to compute this efficiently and in a stable way in Numpy/Scipy.

import numpy as np
m, n = 100, 150
J = np.random.randn(m, n)
np.log(np.det(J.dot(J.T)))

is there some LAPACK routine or some math trick I could use to speed things up and make it more stable?


Solution

  • For better numerical stability, I would suggest to use slogdet, which is your main aim in any case. There may also be a very minimal gain if you use np.inner(J, J) instead of J.dot(J.T). For really speeding things up, I would recommend using jax.numpy.

    import numpy as np
    import jax
    import jax.numpy as jnp
    
    m, n = 100, 150
    J = np.random.randn(m, n)
    
    def a(J):
      return np.log(np.linalg.det(J.dot(J.T)))
    
    def b(J):
       return np.linalg.slogdet(np.inner(J, J))[1]
    
    def c(J):
       return jnp.linalg.slogdet(jnp.inner(J, J))[1]
    
    # jit + compile
    d = jax.jit(c)
    d(J)
    
    # check correctness
    print(np.allclose(a(J), b(J))) # True
    print(np.allclose(a(J), c(J))) # True
    print(np.allclose(a(J), d(J))) # True
    

    Checking run times, on Google Colab:

    %timeit -n 1000 -r 10 a(J)
    # 240 µs ± 16.2 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)
    
    %timeit -n 1000 -r 10 b(J)
    # 227 µs ± 10.2 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)
    
    J_dev = jax.device_put(J)
    
    %timeit -n 1000 -r 10 c(J_dev).block_until_ready()
    # 112 µs ± 4.46 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)
    
    %timeit -n 1000 -r 10 d(J_dev).block_until_ready()
    # 96.2 µs ± 4.23 µs per loop (mean ± std. dev. of 10 runs, 1000 loops each)
    

    So rougly about ~2x speedup is possible this way.