Search code examples
torch

Torch7 ClassNLLCriterion()


I've been trying for a whole day to get my code to work but it fails despite the inputs and outputs being consistent.

Someone mentioned somewhere that classnllcliterion does not accept values less than or equal to zero.

How am I supposed to go about training this network. here is part of my code, I suppose it fails when in backward here the models output may contain -ve values. However when I switch to meansquarederror criterion, the code works just fine.

ninputs = 22; noutputs = 3
hidden =22


model = nn.Sequential() 
model:add(nn.Linear(ninputs, hidden)) -- define the only module
model:add(nn.Tanh())
model:add(nn.Linear(hidden, noutputs))
model:add(nn.LogSoftMax())
----------------------------------------------------------------------
-- 3. Define a loss function, to be minimized.

-- In that example, we minimize the Mean Square Error (MSE) between
-- the predictions of our linear model and the groundtruth available
-- in the dataset.

-- Torch provides many common criterions to train neural networks.

criterion = nn.ClassNLLCriterion()


----------------------------------------------------------------------
-- 4. Train the model
i=1
mean = {}
std = {}





-- To minimize the loss defined above, using the linear model defined
-- in 'model', we follow a stochastic gradient descent procedure (SGD).

-- SGD is a good optimization algorithm when the amount of training data
-- is large, and estimating the gradient of the loss function over the 
-- entire training set is too costly.

-- Given an arbitrarily complex model, we can retrieve its trainable
-- parameters, and the gradients of our loss function wrt these 
-- parameters by doing so:

x, dl_dx = model:getParameters()

-- In the following code, we define a closure, feval, which computes
-- the value of the loss function at a given point x, and the gradient of
-- that function with respect to x. x is the vector of trainable weights,
-- which, in this example, are all the weights of the linear matrix of
-- our model, plus one bias.

feval = function(x_new)
   -- set x to x_new, if differnt
   -- (in this simple example, x_new will typically always point to x,
   -- so the copy is really useless)
   if x ~= x_new then
      x:copy(x_new)
   end

   -- select a new training sample
   _nidx_ = (_nidx_ or 0) + 1
   if _nidx_ > (#csv_tensor)[1] then _nidx_ = 1 end

   local sample = csv_tensor[_nidx_]
   local target = sample[{ {23,25} }]
   local inputs = sample[{ {1,22} }]    -- slicing of arrays.

   -- reset gradients (gradients are always accumulated, to accommodate 
   -- batch methods)
   dl_dx:zero()

   -- evaluate the loss function and its derivative wrt x, for that sample
   local loss_x = criterion:forward(model:forward(inputs), target)
   model:backward(inputs, criterion:backward(model.output, target))

   -- return loss(x) and dloss/dx
   return loss_x, dl_dx
end

The error received is

/home/stormy/torch/install/bin/luajit: /home/stormy/torch/install/share/lua/5.1/nn/THNN.lua:110: Assertion `cur_target >= 0 && cur_target < n_classes' failed. at /home/stormy/torch/extra/nn/lib/THNN/generic/ClassNLLCriterion.c:45 stack traceback: [C]: in function 'v' /home/stormy/torch/install/share/lua/5.1/nn/THNN.lua:110: in function 'ClassNLLCriterion_updateOutput' ...rmy/torch/install/share/lua/5.1/nn/ClassNLLCriterion.lua:43: in function 'forward' nn.lua:178: in function 'opfunc' /home/stormy/torch/install/share/lua/5.1/optim/sgd.lua:44: in function 'sgd' nn.lua:222: in main chunk [C]: in function 'dofile' ...ormy/torch/install/lib/luarocks/rocks/trepl/scm-1/bin/th:145: in main chunk [C]: at 0x00405d50


Solution

  • The error message results from passing in targets that are out of bounds. For example:

    m = nn.ClassNLLCriterion()
    nClasses = 3
    nBatch = 10
    net_output = torch.randn(nBatch, nClasses)
    targets = torch.Tensor(10):random(1,3) -- targets are between 1 and 3
    m:forward(net_output, targets)
    m:backward(net_output, targets)
    
    Now, see the bad example (that you suffer from)
    targets[5] = 13 -- an out of bounds set of classes
    targets[4] = 0 -- an out of bounds set of classes
    -- these lines below will error
    m:forward(net_output, targets)
    m:backward(net_output, targets)