I am trying to introduce a mod/mixin for a problem. In particular I am focusing here on a SpeechRecognitionProblem
. I intend to modify this problem and therefore I seek to do the following:
class SpeechRecognitionProblemMod(speech_recognition.SpeechRecognitionProblem):
def hparams(self, defaults, model_hparams):
SpeechRecognitionProblem.hparams(self, defaults, model_hparams)
vocab_size = self.feature_encoders(model_hparams.data_dir)['targets'].vocab_size
p = defaults
p.vocab_size['targets'] = vocab_size
def feature_encoders(self, data_dir):
# ...
So this one does not do much. It calls the hparams()
function from the base class and then changes some values.
Now, there are already some ready-to-go problems e.g. Libri Speech:
class Librispeech(speech_recognition.SpeechRecognitionProblem):
# ..
However, in order to apply my modifications I am doing this:
class LibrispeechMod(SpeechRecognitionProblemMod, Librispeech):
# ..
This should, if I am not mistaken, overwrite everything (with identical signatures) in Librispeech
and instead call functions of SpeechRecognitionProblemMod
Since I was able to train a model with this code I am assuming that it's working as intended so far.
Now here comes the my problem:
After training I want to serialize the model. This usually works. However, it does not with my mod and I actually know why:
At a certain point hparams()
gets called. Debugging to that point will show me the following:
self # {LibrispeechMod}
self.hparams # <bound method SpeechRecognitionProblem.hparams of ..>
self.feature_encoders # <bound method SpeechRecognitionProblemMod.feature_encoders of ..>
should be <bound method SpeechRecognitionProblemMod.hparams of ..>
! It would seem that for some reason hparams()
of SpeechRecognitionProblem
gets called directly instead of SpeechRecognitionProblemMod
. But please note that it's the correct type for feature_encoders()
The thing is that I know this is working during training. I can see that the hyper-paramaters (hparams) are applied accordingly simply because the model's graph node names change through my modifications.
There is one specialty I need to point out. tensor2tensor
allows to dynamically load a t2t_usr_dir
, which are additional python modules which get loaded by import_usr_dir
. I make use of that function in my serialization script as well:
if usr_dir:
logging.info('Loading user dir %s' % usr_dir)
This could be the only culprit I can see at the moment although I would not be able to tell why this may cause the problem.
If anybody sees something I do not I'd be glad to get a hint what I'm doing wrong here.
So what is the error you're getting?
For the sake of completeness, this is the result of the wrong hparams()
method being called:
NotFoundError (see above for traceback): Restoring from checkpoint failed.
Key transformer/symbol_modality_256_256/softmax/weights_0 not found in checkpoint
is wrong. It should be symbol_modality_<vocab-size>_256
where <vocab-size>
is a vocabulary size which gets set in SpeechRecognitionProblemMod.hparams
So, this weird behavior came from the fact that I was remote debugging and that the source files of the usr_dir
were not correctly synchronized. Everything works as intended but the source files where not matching.
Case closed.