Search code examples
python-3.xpytorchtime-seriespytorch-forecasting

How to give TemporalFusionTransformer model a name?


Im working on creating .py scripts that preprocess some data and then train a TemporalFusionTransformer model. After the training, I have a function that logs the evaluation metrics in a .txt file, whose name should be [email protected]. Everywhere that I have looked, searched, in the docs, on forums, articles, I cannot find a way to give my models a custom name? Any idea how to do this?

Edit: Can someone with >1500 reputation please add the tag temporalfusiontransformer in he tags section. Users below 1500 reputation (like me) cannot add new tags to the site.


Solution

  • You could create a custom class inheriting the original that requires and stores a name property on top of what other functionality the model provides, e.g.

    class NamedTFT(TemporalFusionTransformer):
        def __init__(self, name: str, *args, **kwargs):
            super(NamedTFT, self).__init__(*args, **kwargs)
            self.name = name
    

    then you could grab the model's name afterwards.