Search code examples
pythonpytorch

How to batch matrix-vector multiplication (one matrix, many vectors) in pytorch without duplicating the matrix in memory


I have n vectors of size d and a single d x d matrix J. I'd like to compute the n matrix-vector multiplications of J with each of the n vectors.

For this, I'm using pytorch's expand() to get a broadcast of J, but it seems that when computing the matrix vector product, pytorch instantiates a full n x d x d tensor in the memory. e.g. the following code

device = torch.device("cuda:0")
n = 100_000_000
d = 10

x = torch.randn(n, d, dtype=torch.float32, device=device)
J = torch.randn(d, d, dtype=torch.float32, device=device).expand(n, d, d)
y = torch.sign(torch.matmul(J, x[..., None])[..., 0])

raises

RuntimeError: CUDA out of memory. Tried to allocate 37.25 GiB (GPU 0; 11.00 GiB total capacity; 3.73 GiB already allocated; 5.69 GiB free; 3.73 GiB reserved in total by PyTorch)

which means that pytorch, unnecessarily, tries to allocate space for n copies of the matrix J

How can I perform this task in a vectorized way (the matrices are small, so I don't want to loop over each matrix-vector multiplication) without exhausting my GPU's memory?


Solution

  • I think this will solve it:

    import torch
    x = torch.randn(n, d)
    J = torch.randn(d, d) # no need to expand
    
    y = torch.matmul(J, x.T).T
    

    Verifying using your expression:

    Jex = J.expand(n, d, d)
    y1 = torch.matmul(Jex, x[..., None])[..., 0]
    y = torch.matmul(J, x.T).T
    
    torch.allclose(y1, y) # using allclose for float values
    # tensor(True)