Search code examples
deep-learningdeeplearning4j

Vanishing/Exploding gradients in deeplearning4j


How to check if we're having a vanishing/exploding gradient in deeplearning4j, more specifically for recurrent neural networks? I mean, what parameters to look for and what methods should we call to get the value of such parameters?


Solution

  • As suggested above you should take a look at the GUI, introduction here.

    DL4J GUI: Overview Tab -> Update:Parameter Ratios

    The ratio of updates to parameters is specifically the ratio of mean magnitudes of these values (i.e., log10(mean(abs(updates))/mean(abs(parameters)))

    So, significantly high or low values may suggest exploding/vanishing gradients.

    Programmatically

    At the end of each iteration the gradients are stored in the gradient field of both ComputationalGraph and MultiLayerNetwork. It can be accessed via public gradient() method (this method does not change state, it is a simple getter), so you can analyze the gradients in your code.

    Here's a small code code snippet that outputs gradients' min, average, max per variable, as well as the log10 (magnitude) of the min value:

        StringBuilder gradSummary = new StringBuilder("--- Gradients ---\n");
        net.gradient().gradientForVariable().forEach((var, grad) -> {
            Number min = grad.aminNumber();
            Number max = grad.amaxNumber();
            Number mean = grad.ameanNumber();
            int order = (int) Math.log10(min.doubleValue());
            gradSummary.append(var).append(": ")
                .append(min).append(",")
                .append(mean).append(",")
                .append(max).append(",")
                .append("magnitude: ").append(order).append('\n');
        });
        gradSummary.append("-----------------");
        log.info(gradSummary.toString());
    

    It produces an output like the following (notice variables are named based on the layer names):

    2019-01-05 15:26:12 INFO  --- Gradients ---
    lstm-1_W: 4.1305625586574024E-11,2.102349571941886E-5,5.235217977315187E-4, magnitude: -10
    lstm-1_RW: 6.30961949354969E-11,1.7203132301801816E-5,1.335109118372202E-4, magnitude: -10
    lstm-1_b: 2.9782620813989524E-10,3.226526814614772E-6,3.882131932186894E-5, magnitude: -9
    lstm-2_W: 2.340811988688074E-10,2.496814886399079E-5,7.095998153090477E-4, magnitude: -9
    lstm-2_RW: 8.640199666842818E-11,4.6048542571952567E-5,0.0015051497612148523, magnitude: -10
    lstm-2_b: 6.85293555235944E-9,3.012867455254309E-5,4.262796137481928E-4, magnitude: -8
    lstm-3_W: 1.141415850725025E-10,5.7301283959532157E-5,0.0024848710745573044, magnitude: -9
    lstm-3_RW: 2.446540747769177E-10,3.4060700272675604E-5,0.002297096885740757, magnitude: -9
    lstm-3_b: 1.5003001507807312E-8,2.131067230948247E-5,2.356997865717858E-4, magnitude: -7
    norm-1_gamma: 4.6524661456714966E-8,2.8755117455148138E-5,1.543344114907086E-4, magnitude: -7
    norm-1_beta: 5.754080234510184E-7,1.0409040987724438E-4,3.460813604760915E-4, magnitude: -6
    norm-1_mean: 8.82148754044465E-7,0.0033756729681044817,0.048742543905973434, magnitude: -6
    norm-1_var: 3.0532873451782905E-10,2.6078732844325714E-6,1.6723810404073447E-4, magnitude: -9
    dense-1_W: 3.8744474295526743E-10,5.491946285474114E-5,6.59565266687423E-4, magnitude: -9
    dense-1_b: 4.4111070565122645E-6,1.4454024494625628E-4,4.0868428186513484E-4, magnitude: -5
    norm-2_gamma: 2.477656607879908E-6,9.73446512944065E-5,2.708708052523434E-4, magnitude: -5
    norm-2_beta: 3.106115855189273E-6,4.934889730066061E-4,0.0012065295595675707, magnitude: -5
    norm-2_mean: 2.7818930902867578E-5,0.004300051834434271,0.01411475520581007, magnitude: -4
    norm-2_var: 1.806318869057577E-5,0.007471780758351088,0.020012110471725464, magnitude: -4
    output_W: 7.830021786503494E-8,1.4970696065574884E-4,4.896917380392551E-4, magnitude: -7
    output_b: 3.1583107193000615E-4,6.765704602003098E-4,0.0011031415779143572, magnitude: -3
    -----------------
    

    You can even wrap this code around iteration listener, and output this once per N iterations to help babysit your training process.