Search code examples
matlabmachine-learningdeep-learningloss-functioncross-entropy

How do I update pixelClassificationLayer() to a custom loss function?


I have seen in the Mathworks official website for the pixelClassificationLayer() function that I should update it to a custom loss function using the following code:

function loss = modelLoss(Y,T) 
  mask = ~isnan(T);
  targets(isnan(T)) = 0;
  loss = crossentropy(Y,T,Mask=mask,NormalizationFactor="mask-included"); 
end

netTrained = trainnet(images,net,@modelLoss,options); 

However, I can't see any mention of the inputs 'Classes' or 'ClassWeights', which I'm currently using to define the custom pixelClassificationLayer: pixelClassificationLayer('Classes',classNames,'ClassWeights',classWeights), where classNames is a vector containing the names of each class as a string and classWeights is a vector containing the weights of each class to balance classes when there are underrepresented classes in the training data.

How can I include these parameters in my custom loss function?


Solution

  • You need to explicitly account for these parameters within your custom loss function.

    Below an example, but adjust accordingly:

    function loss = modelLoss(Y, T, classNames, classWeights)
    
        % normalized to 1
        classWeights = classWeights / sum(classWeights);
    
        mask = ~isnan(T);
        T(isnan(T)) = 0;
    
        numClasses = numel(classNames);
        T_onehot = zeros([size(T, 1), size(T, 2), numClasses, size(T, 4)], 'like', Y);
        for i = 1:numClasses
            T_onehot(:, :, i, :) = (T == i);
        end
    
        % class-wise weighted cross-entropy
        weightedLoss = 0;
        for c = 1:numClasses
            classMask = mask & (T == c);
            weightedLoss = weightedLoss + classWeights(c) * crossentropy(Y(:, :, c, :), T_onehot(:, :, c, :), Mask=classMask);
        end
    
        % Normalize by # of valid pixels
        numValidPixels = sum(mask(:));
        loss = weightedLoss / max(numValidPixels, 1);
    end
    
    
    % Define weights
    classNames = [...];
    classWeights = [...]; % Example weights
    
    customLoss = @(Y, T) modelLoss(Y, T, classNames, classWeights);
    
    netTrained = trainnet(images, net, customLoss, options);