Search code examples
pythonnumpypytorchgradientnumpy-einsum

torch.einsum 'RuntimeError: dimension mismatch for operand 0: equation 4 tensor 2'


I'm trying to manually calculate a gradient of a matrix and I can do it by using numpy but I don't know to do the same thing in pytorch. the equation in NumPy is

def grad(A, W0, W1, X):
    dim = A.shape
    assert len(dim) == 2
    A_rows = dim[0]
    A_cols = dim[1]    
    gradient = (np.einsum('ik, jl', np.eye(A_cols, A_rows), (((A).dot(X)).dot(W0)).dot(W1).T) + np.einsum('ik, jl', A, ((X).dot(W0)).dot(W1).T))
    return gradient

I wrote a function in pytorch but it's giving me an error saying 'RuntimeError: dimension mismatch for operand 0: equation 4 tensor 2'

The function I wrote using pytorch is

def torch_grad(A, W0, W1, X):
    dim = A.shape
    A_rows = dim[0]
    A_cols = dim[1]
    W0W1 = torch.mm(W0, W1)
    AX = torch.mm(A, X)
    AXW0W1 = torch.mm(AX, W0W1)
    XW0W1 = torch.mm(X, W0W1)
    print(torch.eye(A_cols, A_rows).shape, torch.t(AXW0W1).shape)
    e1 = torch.einsum('ik jl', torch.eye(A_cols, A_rows), torch.t(AXW0W1))
    e2 = torch.einsum('ik, jl', A, torch.t(XW0W1))
    return e1 + e2

I would appreciate if someone can show me how to implement the numpy code in pytorch. Thanks!


Solution

  • You are missing a comma in the first torch.einsum call.

    e1 = torch.einsum('ik, jl', torch.eye(A_cols, A_rows), torch.t(AXW0W1))
    

    Besides the typo, that's not how the gradients with respect to A are calculated, and it fails when A and AX have different sizes, which would be otherwise valid for the forward pass. For the e1, that should be a matrix multiplication, maybe that was your intention, in which case the torch.einsum should be 'ik, kl', but that's just an overly complicated way to perform a matrix multiplication and using torch.mm is simpler and more efficient. And the e2 is not involved in anything calculation that was performed with respect to A, therefore it is not part of gradients.

    def torch_grad(A, W0, W1, X):
        # Forward
        W0W1 = torch.mm(W0, W1)
        AX = torch.mm(A, X)
        AXW0W1 = torch.mm(AX, W0W1)
        XW0W1 = torch.mm(X, W0W1)
    
        # Backward / Gradients
        rows, cols = AXW0W1.size()
        grad_AX = torch.mm(torch.eye(rows, cols), W0W1.t())
        grad_A = torch.mm(grad_AX, X.t())
        return grad_A
    
    # Autograd version to verify that the gradients are correct
    def torch_autograd(A, W0, W1, X):
        # Forward
        W0W1 = torch.mm(W0, W1)
        AX = torch.mm(A, X)
        AXW0W1 = torch.mm(AX, W0W1)
        XW0W1 = torch.mm(X, W0W1)
        
        # Backward / Gradients
        rows, cols = AXW0W1.size()
        AXW0W1.backward(torch.eye(rows, cols))
        return A.grad
    
    # requires_grad=True for the autograd version to track
    # gradients with respect to A
    A = torch.randn(3, 4, requires_grad=True)
    X = torch.randn(4, 5)
    W0 = torch.randn(5, 6)
    W1 = torch.randn(6, 5)
    
    grad_result = torch_grad(A, W0, W1, X)
    autograd_result = torch_autograd(A, W0, W1, X)
    
    torch.equal(grad_result, autograd_result) # => True