Search code examples
pythonpython-3.xtensorflowdiagonal

Tensorflow: How to get the value of the k-th diagonal


In PyTorch, the function torch.diag() gets the value of the k-th diagonal of a tensor.

For example, a.diag(diagonal=1) gets the value of the 1-th diagonal. Unfortunately diag_part() doesn't appear to work in Tensorflow:

a = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]
a.diag(diagonal=1)
tensor([2, 6])
a.diag(diagonal=2)
tensor([3])

Is there an equivalent function?


Solution

  • TensorFlow 2 >= v2.2

    You can use tf.linalg.diag_part

    >>> a = tf.reshape(tf.range(1,10),(3,3))
    >>> a
    <tf.Tensor: shape=(3, 3), dtype=int32, numpy=
    array([[1, 2, 3],
           [4, 5, 6],
           [7, 8, 9]], dtype=int32)>
    >>> tf.linalg.diag_part(a,k=1)
    <tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 6], dtype=int32)>
    >>> tf.linalg.diag_part(a,k=2)
    <tf.Tensor: shape=(1,), dtype=int32, numpy=array([3], dtype=int32)>
    

    TensorFlow 1.x and TensorFlow 2 <= v2.1

    2020-11-26: As of tf 1.15 and tf2.1, The code in tf.linalg.diag_part to produce superdiagonals and subdiagonals seems to be disabled. You can use directly matrix_diag_part_v2 to get the desired behaviour as a workaround :

    import tensorflow as tf
    from tensorflow.python.ops.gen_array_ops import matrix_diag_part_v2
    
    a = tf.reshape(tf.range(1,10),(3,3))
    superdiag = matrix_diag_part_v2(a,k=1,padding_value=0)
    superdiag2 = matrix_diag_part_v2(a,k=2,padding_value=0)
    
    with tf.Session() as sess:
        print(f"Matrix A : {sess.run(a)}")
        print(f"Superdiagonal 1 : {sess.run(superdiag)}")
        print(f"Superdiagonal 2 : {sess.run(superdiag2)}")
    

    Results in

    Matrix A : [[1 2 3]
     [4 5 6]
     [7 8 9]]
    Superdiagonal 1 : [2 6]
    Superdiagonal 2 : [3]
    

    2021-01-08: The bug in tf 1.15 is not high priority and a fix is not planned. Source :

    Yes. This is clearly a bug in 1.15. But It's definitely not something significant enough that we'd make a patch release for it, we only do patch releases for major bugs or security fixes.

    2021-01-08: Thanks to Krzysztof to point out that the same issue found in TF1 arises for TF versions <= 2.1. The matrix_diag_part_v2 workaround works also for TF2.1 and TF2.0.