Search code examples
torch

Add my custom loss function to torch


I want to add a loss function to torch that calculates the edit distance between predicted and target values. Is there an easy way to implement this idea? Or do I have to write my own class with backward and forward functions?


Solution

  • If your criterion can be represented as a composition of existing modules and criteria, it's a good idea to simply construct such composition using containers. The only problem is that standard containers are designed to work with modules only, not criteria. The difference is in :forward method signature:

    module:forward(input)
    criterion:forward(input, target)
    

    Luckily, we are free to define our own container which is able work with criteria too. For example, sequential:

    local GeneralizedSequential, _ = torch.class('nn.GeneralizedSequential', 'nn.Sequential')
    
    function GeneralizedSequential:forward(input, target)
        return self:updateOutput(input, target)
    end
    
    function GeneralizedSequential:updateOutput(input, target)
        local currentOutput = input
        for i=1,#self.modules do
            currentOutput = self.modules[i]:updateOutput(currentOutput, target)
        end
        self.output = currentOutput
        return currentOutput
    end
    

    Below is an illustration of how to implement nn.CrossEntropyCriterion having this generalized sequential container:

    function MyCrossEntropyCriterion(weights)
        criterion = nn.GeneralizedSequential()
        criterion:add(nn.LogSoftMax())
        criterion:add(nn.ClassNLLCriterion(weights))
        return criterion
    end
    

    Check whether everything is correct:

    output = torch.rand(3,3)
    target = torch.Tensor({1, 2, 3})
    
    mycrit = MyCrossEntropyCriterion()
    -- print(mycrit)
    print(mycrit:forward(output, target))
    print(mycrit:backward(output, target))
    
    crit = nn.CrossEntropyCriterion()
    -- print(crit)
    print(crit:forward(output, target))
    print(crit:backward(output, target))