Search code examples
luatorch

Can't find where addmm function is defined in Torch Lua code


I am understanding a Neural Network implemented in Torch Lua. During the backward pass through the Linear layer, it calls a function called Linear:updateGradInput(https://github.com/torch/nn/blob/master/Linear.lua#L75 )

function Linear:updateGradInput(input, gradOutput)
  if self.gradInput then

     local nElement = self.gradInput:nElement()
     self.gradInput:resizeAs(input)
     if self.gradInput:nElement() ~= nElement then
        self.gradInput:zero()
     end
     if input:dim() == 1 then
        self.gradInput:addmv(0, 1, self.weight:t(), gradOutput)
     elseif input:dim() == 2 then
        self.gradInput:addmm(0, 1, gradOutput, self.weight)
     end
     return self.gradInput
  end
end

in that function is does a basic matrix multiplication operation by calling a function named addmm(https://github.com/torch/nn/blob/master/Linear.lua#L86 ). I am not able to find where this addmm function is defined.

There is an addmm function defined in the TH Library (https://github.com/torch/torch7/blob/master/lib/TH/generic/THTensorMath.c#L1282) but I am not sure how the Lua code is connected to this code in C.


Solution

  • Just figured out the connection between the Lua code and the C code. The call to addmm in Lua code directs to this function(https://github.com/torch/torch7/blob/master/TensorMath.lua#L487-L510) and this in turn calls the addmm function defined in the C Torch Library defined here(https://github.com/torch/torch7/blob/master/lib/TH/generic/THTensorMath.c#L1282).

    It's tricky because Lua constructs the call to the C function via strings.