In PyTorch, the function torch.diag()
gets the value of the k-th diagonal of a tensor.
For example, a.diag(diagonal=1)
gets the value of the 1-th diagonal. Unfortunately diag_part()
doesn't appear to work in Tensorflow:
a = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]
a.diag(diagonal=1)
tensor([2, 6])
a.diag(diagonal=2)
tensor([3])
Is there an equivalent function?
TensorFlow 2 >= v2.2
You can use tf.linalg.diag_part
>>> a = tf.reshape(tf.range(1,10),(3,3))
>>> a
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=int32)>
>>> tf.linalg.diag_part(a,k=1)
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([2, 6], dtype=int32)>
>>> tf.linalg.diag_part(a,k=2)
<tf.Tensor: shape=(1,), dtype=int32, numpy=array([3], dtype=int32)>
TensorFlow 1.x and TensorFlow 2 <= v2.1
2020-11-26: As of tf 1.15 and tf2.1, The code in tf.linalg.diag_part
to produce superdiagonals and subdiagonals seems to be disabled. You can use directly matrix_diag_part_v2
to get the desired behaviour as a workaround :
import tensorflow as tf
from tensorflow.python.ops.gen_array_ops import matrix_diag_part_v2
a = tf.reshape(tf.range(1,10),(3,3))
superdiag = matrix_diag_part_v2(a,k=1,padding_value=0)
superdiag2 = matrix_diag_part_v2(a,k=2,padding_value=0)
with tf.Session() as sess:
print(f"Matrix A : {sess.run(a)}")
print(f"Superdiagonal 1 : {sess.run(superdiag)}")
print(f"Superdiagonal 2 : {sess.run(superdiag2)}")
Results in
Matrix A : [[1 2 3]
[4 5 6]
[7 8 9]]
Superdiagonal 1 : [2 6]
Superdiagonal 2 : [3]
2021-01-08: The bug in tf 1.15 is not high priority and a fix is not planned. Source :
Yes. This is clearly a bug in 1.15. But It's definitely not something significant enough that we'd make a patch release for it, we only do patch releases for major bugs or security fixes.
2021-01-08: Thanks to Krzysztof to point out that the same issue found in TF1 arises for TF versions <= 2.1. The matrix_diag_part_v2
workaround works also for TF2.1 and TF2.0.