Search code examples
numpytensorflowgradientlinear-algebradeterminants

Calculate log of determinant in TensorFlow when determinant overflows/underflows


My cost function involves the calculation of log(det(A)) (assuming the det(A) is positive so the log makes sense, but A is not Hermitian so that the Cholesky decomposition is not applicable here). When det(A) is very large/small, a direct call to det(A) will overflow/underflow. To circumvent this, one use the mathematical fact that

log(det(A)) = Tr(log(A)),

where the later can be evaluated using LU decomposition (which is more efficient than eigenvalue/SVD). This algorithm has been implemented in numpy as numpy.linalg.slogdet, so the problem is how to call numpy from TensorFlow.


Here is what I tried

import numpy as np
import tensorflow as tf
from tensorflow.python.framework import function

def logdet_np(a):
    _, l = np.linalg.slogdet(a)
    return l

def logdet1(a):
    return tf.py_func(logdet_np, [a], tf.float64)

@function.Defun(tf.float64, func_name='LogDet')
def logdet2(a):
    return tf.py_func(logdet_np, [a], tf.float64)

with tf.Session() as sess:
    a = tf.constant(np.eye(500)*10.)
    #print(sess.run(logdet1(a)))
    print(sess.run(logdet2(a)))

I first define a python function to pass out the numpy result. Then I defined two logdet functions using tf.py_func. The second function is decorated by function.Defun which is used to define TensorFlow functions and their gradients later. As I test them, I found that the first function logdet1 works and gives the correct result. But the second function logdet2 returns a KeyError.

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-
packages/tensorflow/python/ops/script_ops.py in __call__(self, token, args)
     77   def __call__(self, token, args):
     78     """Calls the registered function for `token` with args."""
---> 79     func = self._funcs[token]
     80     if func is None:
     81       raise ValueError("callback %s is not found" % token)

KeyError: 'pyfunc_0'

My question is what is wrong with the Defun decorator? Why is it conflicting with py_func? How can I wrap numpy functions in TensorFlor correctly?


The remaining part of defining the gradient for logdet is related to the question matrix determinant differentiation in tensorflow. According to the solution in that question, one attempts to write

@function.Defun(tf.float64, tf.float64, func_name='LogDet_Gradient')
def logdet_grad(a, grad):
    a_adj_inv = tf.matrix_inverse(a, adjoint=True)
    out_shape = tf.concat([tf.shape(a)[:-2], [1, 1]], axis=0)
    return tf.reshape(grad, out_shape) * a_adj_inv
@function.Defun(tf.float64, func_name='LogDet', grad_func=logdet_grad)
def logdet(a):
    return tf.py_func(logdet_np, [a], tf.float64, stateful=False, name='LogDet')

The above code would work if one can solve the conflict between Defun and py_func, which is the key question that I raised above.


Solution

  • With the help of @MaxB, here I post the code to define the function logdet for log(abs(det(A))) and its gradient.

    • logdet calls the numpy function numpy.linalg.slogdet to compute the log of the determinant using the idea of log(det(A))=Tr(log(A)), which is robust against the overflow/underflow of the determinant. It is based on the LU decomposition, which is more efficient compared to the eigenvalue/SVD based method.

    • The numpy function slogdet returns a tuple containing both the sign of the determinant and the log(abs(det(A))). The sign will be neglected, since it will not contribute to the gradient signal in the optimization.

    • The gradient of logdet is computed by matrix inversion, according to grad log(det(A)) = inv(A)^T. It is based on TensorFlow's code on _MatrixDeterminantGrad with slight modifications.


    import numpy as np
    import tensorflow as tf
    # from https://gist.github.com/harpone/3453185b41d8d985356cbe5e57d67342
    # Define custom py_func which takes also a grad op as argument:
    def py_func(func, inp, Tout, stateful=True, name=None, grad=None):
        # Need to generate a unique name to avoid duplicates:
        rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))
        tf.RegisterGradient(rnd_name)(grad)  # see _MySquareGrad for grad example
        g = tf.get_default_graph()
        with g.gradient_override_map({"PyFunc": rnd_name}):
            return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
    # from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py
    # Gradient for logdet
    def logdet_grad(op, grad):
        a = op.inputs[0]
        a_adj_inv = tf.matrix_inverse(a, adjoint=True)
        out_shape = tf.concat([tf.shape(a)[:-2], [1, 1]], axis=0)
        return tf.reshape(grad, out_shape) * a_adj_inv
    # define logdet by calling numpy.linalg.slogdet
    def logdet(a, name = None):
        with tf.name_scope(name, 'LogDet', [a]) as name:
            res = py_func(lambda a: np.linalg.slogdet(a)[1], 
                          [a], 
                          tf.float64, 
                          name=name, 
                          grad=logdet_grad) # set the gradient
            return res
    

    One can test that logdet works for very large/small determinant and its gradient is also correct.

    i = tf.constant(np.eye(500))
    x = tf.Variable(np.array([10.]))
    y = logdet(x*i)
    dy = tf.gradients(y, [x])
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run([y, dy]))
    

    Result: [1151.2925464970251, [array([ 50.])]]