I'm trying to manually calculate a gradient of a matrix and I can do it by using numpy but I don't know to do the same thing in pytorch. the equation in NumPy is
def grad(A, W0, W1, X):
dim = A.shape
assert len(dim) == 2
A_rows = dim[0]
A_cols = dim[1]
gradient = (np.einsum('ik, jl', np.eye(A_cols, A_rows), (((A).dot(X)).dot(W0)).dot(W1).T) + np.einsum('ik, jl', A, ((X).dot(W0)).dot(W1).T))
return gradient
I wrote a function in pytorch but it's giving me an error saying 'RuntimeError: dimension mismatch for operand 0: equation 4 tensor 2'
The function I wrote using pytorch is
def torch_grad(A, W0, W1, X):
dim = A.shape
A_rows = dim[0]
A_cols = dim[1]
W0W1 = torch.mm(W0, W1)
AX = torch.mm(A, X)
AXW0W1 = torch.mm(AX, W0W1)
XW0W1 = torch.mm(X, W0W1)
print(torch.eye(A_cols, A_rows).shape, torch.t(AXW0W1).shape)
e1 = torch.einsum('ik jl', torch.eye(A_cols, A_rows), torch.t(AXW0W1))
e2 = torch.einsum('ik, jl', A, torch.t(XW0W1))
return e1 + e2
I would appreciate if someone can show me how to implement the numpy code in pytorch. Thanks!
You are missing a comma in the first torch.einsum
call.
e1 = torch.einsum('ik, jl', torch.eye(A_cols, A_rows), torch.t(AXW0W1))
Besides the typo, that's not how the gradients with respect to A
are calculated, and it fails when A
and AX
have different sizes, which would be otherwise valid for the forward pass. For the e1
, that should be a matrix multiplication, maybe that was your intention, in which case the torch.einsum
should be 'ik, kl'
, but that's just an overly complicated way to perform a matrix multiplication and using torch.mm
is simpler and more efficient. And the e2
is not involved in anything calculation that was performed with respect to A
, therefore it is not part of gradients.
def torch_grad(A, W0, W1, X):
# Forward
W0W1 = torch.mm(W0, W1)
AX = torch.mm(A, X)
AXW0W1 = torch.mm(AX, W0W1)
XW0W1 = torch.mm(X, W0W1)
# Backward / Gradients
rows, cols = AXW0W1.size()
grad_AX = torch.mm(torch.eye(rows, cols), W0W1.t())
grad_A = torch.mm(grad_AX, X.t())
return grad_A
# Autograd version to verify that the gradients are correct
def torch_autograd(A, W0, W1, X):
# Forward
W0W1 = torch.mm(W0, W1)
AX = torch.mm(A, X)
AXW0W1 = torch.mm(AX, W0W1)
XW0W1 = torch.mm(X, W0W1)
# Backward / Gradients
rows, cols = AXW0W1.size()
AXW0W1.backward(torch.eye(rows, cols))
return A.grad
# requires_grad=True for the autograd version to track
# gradients with respect to A
A = torch.randn(3, 4, requires_grad=True)
X = torch.randn(4, 5)
W0 = torch.randn(5, 6)
W1 = torch.randn(6, 5)
grad_result = torch_grad(A, W0, W1, X)
autograd_result = torch_autograd(A, W0, W1, X)
torch.equal(grad_result, autograd_result) # => True