Search code examples
matlabplotneural-networkcustomizationmatlab-figure

How to make neural network training charts have a logarithmic vertical axis?


When using MATLAB's NN training tool (trainNetwork), we get charts that have a linear vertical axis, as shown below:

Current situation

This chart should provide some graphical feedback regarding the training progress, and it perhaps does for classification problems (where the y-axis represents "Accuracy (%)"), but in regression problems, the RMSE values might have vastly different orders of magnitude as training progresses - making everything after the initial drop indistinguishable and quite useless.

What I'd like to do is to convert the vertical axis to logarithmic, yielding the following result:

Desired situation

(I don't mind that some graphical elements move around or get lost in the process, as the curve is what's important to me.)

The way I'm doing it now is by pausing the training process, and manually running

set(findall(findall(0,'type','figure'),'type','Axes',...
  'Tag','NNET_CNN_TRAININGPLOT_AXESVIEW_AXES_REGRESSION_RMSE'),'YScale','log');

(or some variation thereof, depending on the open figures etc.).

I'm looking for a way to change the scale without user intervention, and do so as close as possible to the start of the training. Also, it would be great if I could choose which chart to rescale (RMSE and/or Loss).

I'm using R2018a.


Minimal code required to generate such a figure (based on the MATLAB docs titled "Train Network for Image Regression"):

[XTrain,~,YTrain] = digitTrain4DArrayData;

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(12,25)
    reluLayer
    fullyConnectedLayer(1)
    regressionLayer];

options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.001, ...
    'Verbose',false, ...
    'MaxEpochs',5, ...
    'Plots','training-progress');

net = trainNetwork(XTrain,YTrain,layers,options);

Solution

  • We can piggyback on the custom 'OutputFcn' mechanism available in the training options, and specify there a function which does the rescaling. User control over which axes get rescaled is exerted via the variable whichAx.

    Full code:

    function net = q51762507()
    [XTrain,~,YTrain] = digitTrain4DArrayData;
    
    layers = [ ...
        imageInputLayer([28 28 1])
        convolution2dLayer(12,25)
        reluLayer
        fullyConnectedLayer(1)
        regressionLayer];
    
    whichAx = [false, true]; % [bottom, top]
    
    options = trainingOptions('sgdm', ...
        'InitialLearnRate',0.001, ...
        'Verbose',false, ...
        'MaxEpochs',5, ...
        'Plots','training-progress',...
        'OutputFcn', @(x)makeLogVertAx(x,whichAx) );
    
    net = trainNetwork(XTrain,YTrain,layers,options);
    
    function stop = makeLogVertAx(state, whichAx)
    stop = false; % The function has to return a value.
    % Only do this once, following the 1st iteration
    if state.Iteration == 1
      % Get handles to "Training Progress" figures:
      hF  = findall(0,'type','figure','Tag','NNET_CNN_TRAININGPLOT_FIGURE');
      % Assume the latest figure (first result) is the one we want, and get its axes:
      hAx = findall(hF(1),'type','Axes');
      % Remove all irrelevant entries (identified by having an empty "Tag", R2018a)
      hAx = hAx(~cellfun(@isempty,{hAx.Tag}));
      set(hAx(whichAx),'YScale','log');
    end