Specifically I would like nn.LogSoftMax
to not use omp when the size of the input tensor is small. I have a small script to test the run time.
require 'nn'
my_lsm = function(t)
o = torch.zeros((#t)[1])
sum = 0.0
for i = 1,(#t)[1] do
o[i] = torch.exp(t[i])
sum = sum + o[i]
end
o = o / sum
return torch.log(o)
end
ii=torch.randn(arg[1])
m=nn.LogSoftMax()
timer = torch.Timer()
timer:stop()
timer:reset()
timer:resume()
my_lsm(ii)
print(timer:time().real)
timer:stop()
timer:reset()
timer:resume()
m:forward(ii)
print(timer:time().real)
If arg[1]
is 10, then my basic log softmax function run much faster:
0.00021696090698242
0.033425092697144
But once arg[1]
is 10,000,000, omp really helps a lot:
29.561321973801
0.11547803878784
So I suspect that omp overhead is very high. If my code has to call log softmax several times with small inputs (says tensor size is only 3), it will cost too much time. Is there a way to manually disable omp usage in some cases (but not always)?
Is there a way to manually disable omp usage in some cases (but not always)?
If you really want to do that one possibility is to use torch.setnumthreads
and torch.getnumthreads
like that:
local nth = torch.getnumthreads()
torch.setnumthreads(1)
-- do something
torch.setnumthreads(nth)
So you can monkey-patch nn.LogSoftMax
as follow:
nn.LogSoftMax.updateOutput = function(self, input)
local nth = torch.getnumthreads()
torch.setnumthreads(1)
local out = input.nn.LogSoftMax_updateOutput(self, input)
torch.setnumthreads(nth)
return out
end