Search code examples
pythonneural-networkspacyloss-function

Losses keep increasing within iteration


I am just a little confused on the following: I am training a neural network and have it print out the losses. I am training it over 4 iterations just to try it out, and use batches. I normally see loss functions as parabolas, where the losses would decrease to a minimum point before increasing again. But my losses keep increasing as the iteration progresses.

For example, let's say there are 100 batches in each iteration. In iteration 0, losses started at 26.3 (batch 0) and went up to 1500.7 (batch 100). In iteration 1, it started at 2.4e-14 and went up to 80.8.

I am following an example from spacy (https://spacy.io/usage/examples#training-ner). Should I be comparing the losses across batches instead (i.e. if I take the points from all of the batch 0s it should resemble a parabola)?


Solution

  • If you are using the exact same code as linked, this behaviour is to be expected.

    for itn in range(n_iter):
            random.shuffle(TRAIN_DATA)
            losses = {}
            # batch up the examples using spaCy's minibatch
            batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
            for batch in batches:
                texts, annotations = zip(*batch)
                nlp.update(
                    texts,  # batch of texts
                    annotations,  # batch of annotations
                    drop=0.5,  # dropout - make it harder to memorise data
                    losses=losses,
                )
            print("Losses", losses)
    

    An "iteration" is the outer loop: for itn in range(n_iter). And from the sample code you can also infer that losses is being reset every iteration. The nlp.update call will actually increment the appropriate loss in each call, i.e. with each batch that it processes.

    So yes: the loss increases WITHIN an iteration, for each batch that you process. To check whether your model is actually learning anything, you need to check the loss across iterations, similar as how the print statement in the original snippet only prints after looping through the batches, not during.

    Hope that helps!