Search code examples
pythonpytorch

torch Module causes model field variable to become None on reference right after assignment


How is it possible that a field variable becomes None right after the line of assignment?

In particular, the following code prints CLIPTokenizerFast NoneType CLIPTextModelWithProjection when _setup_txt_encoder is being run

self.txt_encoder should be the same variable as txt_encoder but when being accessed, it is retreieved as a NoneType

It appears that inheriting from torch.nn.Module is causing the problem because if I remove it from the class inheritance, there is no problem.

from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, CLIPTextModelWithProjection
from typing import Optional
from torch.nn import Module

class Foo:
    def _setup_txt_encoder(self, clip_txt_model_name: str):
        print('loading text tokenizer and encoder')
        tokenizer = AutoTokenizer.from_pretrained(clip_txt_model_name, clean_up_tokenization_spaces=True)
        txt_encoder = CLIPTextModelWithProjection.from_pretrained(clip_txt_model_name).requires_grad_(False)
        self.tokenizer, self.txt_encoder = tokenizer, txt_encoder
        self.txt_encoder = txt_encoder
        print(type(self.tokenizer).__name__, type(self.txt_encoder).__name__, type(txt_encoder).__name__)
        return tokenizer, txt_encoder
    tokenizer: Optional[PreTrainedTokenizer|PreTrainedTokenizerFast] = None
    txt_encoder: Optional[CLIPTextModelWithProjection] = None
class Bar(Module, Foo):
    def __init__(self, clip_txt_model_name='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'):
        super().__init__()
        if self.tokenizer is None or self.txt_encoder is None:
            self._setup_txt_encoder(clip_txt_model_name)

test = Bar()

EDIT: some more information. It's really weird the following returns different things

print(getattr(test,'txt_encoder'))
print(test.__getattr__('txt_encoder'))

outputs

None
CLIPTextModelWithProjection(
  (text_model): CLIPTextTransformer(
...

Solution

  • torch.nn.Module overrides __setattr__ and __getattr__

    When the attribute is set to a Module object, it gets added to the dict self.__dict__['_modules'] by the overridden __setattr__ but the base __setattr__ is not called to assign it to base object. But when you call getattr or reference the attribute using the syntax foo.bar, __getattr__ is not called because the attribute is present, even if it is none or deleted before. Therefore, getattr retrieves the original item assigned to the object rather than calling the __getattr__, which is overriden, to retrieve the value assigned to _modules. This is not limited to None. You can assign it to other values like an integer 1 and self.txt_encoder would still print 1 after doing self.txt_encoder = module

    A solution is to call the base __setattr__ to assign the value to the object itself as well. For the example given in the question, we have

            self.tokenizer, self.txt_encoder = tokenizer, txt_encoder
            object.__setattr__(self,'txt_encoder',txt_encoder)
            print(type(self.tokenizer).__name__, type(self.txt_encoder).__name__, type(txt_encoder).__name__)
    

    And now it should print out CLIPTokenizerFast CLIPTextModelWithProjection CLIPTextModelWithProjection

    An alternative solution is to not assign None and make sure to assign it in init. So, rather than

    txt_encoder: Optional[CLIPTextModelWithProjection] = None
    

    We instead have

    txt_encoder: CLIPTextModelWithProjection
    

    This will ensure that the overridden __getattr__ is called.