I have a scheme where I store a matrix with zeros on the diagonals as a vector. I want to later on optimize over that vector, so I require gradient tracking. My challenge is to reshape between the two.
I want - for domain specific reasons - keep the order of data in the matrix so that transposed elements of the W
matrix next to each other in the vector form.
The size of the W
matrix is subject to change, so I start with enumering items in the top-left part of the matrix, and continue outwards.
I have come up with two ways to do this. See code snippet.
import torch
import torch.sparse
w = torch.tensor([10,11,12,13,14,15],requires_grad=True,dtype=torch.float)
i = torch.LongTensor([
[0, 1,0],
[1, 0,1],
[0, 2,2],
[2, 0,3],
[1, 2,4],
[2, 1,5],
])
v = torch.FloatTensor([1, 1, 1 ,1,1,1 ])
reshaper = torch.sparse.FloatTensor(i.t(), v, torch.Size([3,3,6])).to_dense()
W_mat_with_reshaper = reshaper @ w
W_mat_directly = torch.tensor([
[0, w[0], w[2],],
[w[1], 0, w[4],],
[w[3], w[5], 0,],
])
print(W_mat_with_reshaper)
print(W_mat_directly)
and this gives output
tensor([[ 0., 10., 12.],
[11., 0., 14.],
[13., 15., 0.]], grad_fn=<UnsafeViewBackward>)
tensor([[ 0., 10., 12.],
[11., 0., 14.],
[13., 15., 0.]])
As you can see, the direct way to reshape the vector into a matrix does not have a grad function, but the multiply-with-a-reshaper-tensor does. Creating the reshaper-tensor seems like it will be a hassle, but on the other hand, manually writing the matrix for is also infeasible.
Is there a way to do arbitrary reshapes in pytorch that keeps grack of gradients?
Instead of constructing W_mat_directly
from the elements of w
, try assigning w
into W
:
W_mat_directly = torch.zeros((3, 3), dtype=w.dtype)
W_mat_directly[(0, 0, 1, 1, 2, 2), (1, 2, 0, 2, 0, 1)] = w
You'll get
tensor([[ 0., 10., 11.], [12., 0., 13.], [14., 15., 0.]], grad_fn=<IndexPutBackward>)