Search code examples
luaopenmptorch

How to disable omp in Torch nn package?


Specifically I would like nn.LogSoftMax to not use omp when the size of the input tensor is small. I have a small script to test the run time.

require 'nn'
my_lsm = function(t)
    o = torch.zeros((#t)[1])
    sum = 0.0
    for i = 1,(#t)[1] do
        o[i] = torch.exp(t[i])
        sum = sum + o[i]
    end
    o = o / sum
    return torch.log(o)
end

ii=torch.randn(arg[1])
m=nn.LogSoftMax()

timer = torch.Timer()
timer:stop()
timer:reset()
timer:resume()
my_lsm(ii)
print(timer:time().real)

timer:stop()
timer:reset()
timer:resume()
m:forward(ii)
print(timer:time().real)

If arg[1] is 10, then my basic log softmax function run much faster:

0.00021696090698242
0.033425092697144

But once arg[1] is 10,000,000, omp really helps a lot:

29.561321973801 
0.11547803878784

So I suspect that omp overhead is very high. If my code has to call log softmax several times with small inputs (says tensor size is only 3), it will cost too much time. Is there a way to manually disable omp usage in some cases (but not always)?


Solution

  • Is there a way to manually disable omp usage in some cases (but not always)?

    If you really want to do that one possibility is to use torch.setnumthreads and torch.getnumthreads like that:

    local nth = torch.getnumthreads()
    torch.setnumthreads(1)
    -- do something
    torch.setnumthreads(nth)
    

    So you can monkey-patch nn.LogSoftMax as follow:

    nn.LogSoftMax.updateOutput = function(self, input)
      local nth = torch.getnumthreads()
      torch.setnumthreads(1)
      local out = input.nn.LogSoftMax_updateOutput(self, input)
      torch.setnumthreads(nth)
      return out
    end