Search code examples
pythongensimword2vecloss

Loss does not decrease during training (Word2Vec, Gensim)


What can cause loss from model.get_latest_training_loss() increase on each epoch?

Code, used for training:

class EpochSaver(CallbackAny2Vec):
    '''Callback to save model after each epoch and show training parameters '''

    def __init__(self, savedir):
        self.savedir = savedir
        self.epoch = 0

        os.makedirs(self.savedir, exist_ok=True)

    def on_epoch_end(self, model):
        savepath = os.path.join(self.savedir, "model_neg{}_epoch.gz".format(self.epoch))
        model.save(savepath)
        print(
            "Epoch saved: {}".format(self.epoch + 1),
            "Start next epoch ... ", sep="\n"
            )
        if os.path.isfile(os.path.join(self.savedir, "model_neg{}_epoch.gz".format(self.epoch - 1))):
            print("Previous model deleted ")
            os.remove(os.path.join(self.savedir, "model_neg{}_epoch.gz".format(self.epoch - 1))) 
        self.epoch += 1
        print("Model loss:", model.get_latest_training_loss())

    def train():

        ### Initialize model ###
        print("Start training Word2Vec model")

        workers = multiprocessing.cpu_count()/2

        model = Word2Vec(
            DocIter(),
            size=300, alpha=0.03, min_alpha=0.00025, iter=20,
            min_count=10, hs=0, negative=10, workers=workers,
            window=10, callbacks=[EpochSaver("./checkpoints")], 
            compute_loss=True
    )     

Output:

Losses from epochs (1 to 20):

Model loss: 745896.8125
Model loss: 1403872.0
Model loss: 2022238.875
Model loss: 2552509.0
Model loss: 3065454.0
Model loss: 3549122.0
Model loss: 4096209.75
Model loss: 4615430.0
Model loss: 5103492.5
Model loss: 5570137.5
Model loss: 5955891.0
Model loss: 6395258.0
Model loss: 6845765.0
Model loss: 7260698.5
Model loss: 7712688.0
Model loss: 8144109.5
Model loss: 8542560.0
Model loss: 8903244.0
Model loss: 9280568.0
Model loss: 9676936.0

What am I doing wrong?

Language arabian. As input from DocIter - list with tokens.


Solution

  • Up through gensim 3.6.0, the loss value reported may not be very sensible, only resetting the tally each call to train(), rather than each internal epoch. There are some fixes forthcoming in this issue:

    https://github.com/RaRe-Technologies/gensim/pull/2135

    In the meantime, the difference between the previous value, and the latest, may be more meaningful. In that case, your data suggest the 1st epoch had a total loss of 745896, while the last had (9676936-9280568=) 396,368 – which may indicate the kind of progress hoped-for.