Search code examples
pythonmachine-learningpytorchtorch

torch in-place operations to save memory (softmax)


Some operations in torch are executed in-place. Shorthand operators like += for example.

Is it possible to get in-place execution for other operations, such as softmax?

I'm currently working with language processing. The model produces a long sequence of probability distributions over a large vocabulary. This final output tensor is responsible for ca 60% of allocated memory. Which is a huge problem, since I need to calculate a softmax over it and that doubles the required memory.

Here is an example of the problem. I am not interested in the tensor t, only in its softmax:

import numpy as np
import torch
import torch.nn.functional as F

t = torch.tensor(np.zeros((30000,30000))).cuda()  #allocates 6.71 GB of GPU
softmax = F.softmax(t, 1)  #out of memory error
del t  #too late, program crashed

Even the following doesn't work:

F.softmax(torch.tensor(np.zeros((30000,30000))).cuda(), 1)

Solution

  • I have created an in-place version of softmax:

    import numpy as np
    import torch
    import torch.nn.functional as F
    
    # in-place version
    t = torch.tensor(np.ones((100,200)))
    torch.exp(t, out=t)
    summed = torch.sum(t, dim=1, keepdim=True)
    t /= summed
    
    # original version
    t2 = torch.tensor(np.ones((100,200)))
    softmax = F.softmax(t2, 1)
    
    assert torch.allclose(t, softmax)
    

    To answer my question: If you want in-place functions, you have to create them yourself by plugging together low-level operations:

    • many functions such as torch.exp can be given an optional out parameter.
    • assignments t[idx] = something are in-place
    • shorthand operators /=, *=, +=, -= are in-place

    This requires careful debugging and can be non-intuitive:

    t = t / summed  #not in-place
    t /= summed  #in-place
    

    I've read that in-place operations can produce problems with gradients. I'll do some more testing with this code.