Some operations in torch are executed in-place. Shorthand operators like += for example.
Is it possible to get in-place execution for other operations, such as softmax
?
I'm currently working with language processing. The model produces a long sequence of probability distributions over a large vocabulary. This final output tensor is responsible for ca 60% of allocated memory. Which is a huge problem, since I need to calculate a softmax over it and that doubles the required memory.
Here is an example of the problem. I am not interested in the tensor t, only in its softmax:
import numpy as np
import torch
import torch.nn.functional as F
t = torch.tensor(np.zeros((30000,30000))).cuda() #allocates 6.71 GB of GPU
softmax = F.softmax(t, 1) #out of memory error
del t #too late, program crashed
Even the following doesn't work:
F.softmax(torch.tensor(np.zeros((30000,30000))).cuda(), 1)
I have created an in-place version of softmax:
import numpy as np
import torch
import torch.nn.functional as F
# in-place version
t = torch.tensor(np.ones((100,200)))
torch.exp(t, out=t)
summed = torch.sum(t, dim=1, keepdim=True)
t /= summed
# original version
t2 = torch.tensor(np.ones((100,200)))
softmax = F.softmax(t2, 1)
assert torch.allclose(t, softmax)
To answer my question: If you want in-place functions, you have to create them yourself by plugging together low-level operations:
torch.exp
can be given an optional out
parameter.t[idx] = something
are in-place/=
, *=
, +=
, -=
are in-placeThis requires careful debugging and can be non-intuitive:
t = t / summed #not in-place
t /= summed #in-place
I've read that in-place operations can produce problems with gradients. I'll do some more testing with this code.