Search code examples
pythontensor2tensor

Overwriting methods via mixin pattern does not work as intended


I am trying to introduce a mod/mixin for a problem. In particular I am focusing here on a SpeechRecognitionProblem. I intend to modify this problem and therefore I seek to do the following:

class SpeechRecognitionProblemMod(speech_recognition.SpeechRecognitionProblem):

    def hparams(self, defaults, model_hparams):
        SpeechRecognitionProblem.hparams(self, defaults, model_hparams)
        vocab_size = self.feature_encoders(model_hparams.data_dir)['targets'].vocab_size
        p = defaults
        p.vocab_size['targets'] = vocab_size

    def feature_encoders(self, data_dir): 
        # ...

So this one does not do much. It calls the hparams() function from the base class and then changes some values.

Now, there are already some ready-to-go problems e.g. Libri Speech:

@registry.register_problem()
class Librispeech(speech_recognition.SpeechRecognitionProblem):
    # ..

However, in order to apply my modifications I am doing this:

@registry.register_problem()
class LibrispeechMod(SpeechRecognitionProblemMod, Librispeech):
    # ..

This should, if I am not mistaken, overwrite everything (with identical signatures) in Librispeech and instead call functions of SpeechRecognitionProblemMod.

Since I was able to train a model with this code I am assuming that it's working as intended so far.

Now here comes the my problem:

After training I want to serialize the model. This usually works. However, it does not with my mod and I actually know why:

At a certain point hparams() gets called. Debugging to that point will show me the following:

self                  # {LibrispeechMod}
self.hparams          # <bound method SpeechRecognitionProblem.hparams of ..>
self.feature_encoders # <bound method SpeechRecognitionProblemMod.feature_encoders of ..>

self.hparams should be <bound method SpeechRecognitionProblemMod.hparams of ..>! It would seem that for some reason hparams() of SpeechRecognitionProblem gets called directly instead of SpeechRecognitionProblemMod. But please note that it's the correct type for feature_encoders()!

The thing is that I know this is working during training. I can see that the hyper-paramaters (hparams) are applied accordingly simply because the model's graph node names change through my modifications.

There is one specialty I need to point out. tensor2tensor allows to dynamically load a t2t_usr_dir, which are additional python modules which get loaded by import_usr_dir. I make use of that function in my serialization script as well:

if usr_dir:
    logging.info('Loading user dir %s' % usr_dir)
    import_usr_dir(usr_dir)

This could be the only culprit I can see at the moment although I would not be able to tell why this may cause the problem.

If anybody sees something I do not I'd be glad to get a hint what I'm doing wrong here.


So what is the error you're getting?

For the sake of completeness, this is the result of the wrong hparams() method being called:

NotFoundError (see above for traceback): Restoring from checkpoint failed.
Key transformer/symbol_modality_256_256/softmax/weights_0 not found in checkpoint

symbol_modality_256_256 is wrong. It should be symbol_modality_<vocab-size>_256 where <vocab-size> is a vocabulary size which gets set in SpeechRecognitionProblemMod.hparams.


Solution

  • So, this weird behavior came from the fact that I was remote debugging and that the source files of the usr_dir were not correctly synchronized. Everything works as intended but the source files where not matching.

    Case closed.