How is it possible that a field variable becomes None right after the line of assignment?
In particular, the following code prints
CLIPTokenizerFast NoneType CLIPTextModelWithProjection
when _setup_txt_encoder is being run
self.txt_encoder should be the same variable as txt_encoder but when being accessed, it is retreieved as a NoneType
It appears that inheriting from torch.nn.Module is causing the problem because if I remove it from the class inheritance, there is no problem.
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, CLIPTextModelWithProjection
from typing import Optional
from torch.nn import Module
class Foo:
def _setup_txt_encoder(self, clip_txt_model_name: str):
print('loading text tokenizer and encoder')
tokenizer = AutoTokenizer.from_pretrained(clip_txt_model_name, clean_up_tokenization_spaces=True)
txt_encoder = CLIPTextModelWithProjection.from_pretrained(clip_txt_model_name).requires_grad_(False)
self.tokenizer, self.txt_encoder = tokenizer, txt_encoder
self.txt_encoder = txt_encoder
print(type(self.tokenizer).__name__, type(self.txt_encoder).__name__, type(txt_encoder).__name__)
return tokenizer, txt_encoder
tokenizer: Optional[PreTrainedTokenizer|PreTrainedTokenizerFast] = None
txt_encoder: Optional[CLIPTextModelWithProjection] = None
class Bar(Module, Foo):
def __init__(self, clip_txt_model_name='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'):
super().__init__()
if self.tokenizer is None or self.txt_encoder is None:
self._setup_txt_encoder(clip_txt_model_name)
test = Bar()
EDIT: some more information. It's really weird the following returns different things
print(getattr(test,'txt_encoder'))
print(test.__getattr__('txt_encoder'))
outputs
None
CLIPTextModelWithProjection(
(text_model): CLIPTextTransformer(
...
torch.nn.Module overrides __setattr__
and __getattr__
When the attribute is set to a Module object, it gets added to the dict self.__dict__['_modules']
by the overridden __setattr__
but the base __setattr__
is not called to assign it to base object. But when you call getattr or reference the attribute using the syntax foo.bar
, __getattr__
is not called because the attribute is present, even if it is none or deleted before. Therefore, getattr retrieves the original item assigned to the object rather than calling the __getattr__
, which is overriden, to retrieve the value assigned to _modules. This is not limited to None. You can assign it to other values like an integer 1 and self.txt_encoder
would still print 1 after doing self.txt_encoder = module
A solution is to call the base __setattr__
to assign the value to the object itself as well. For the example given in the question, we have
self.tokenizer, self.txt_encoder = tokenizer, txt_encoder
object.__setattr__(self,'txt_encoder',txt_encoder)
print(type(self.tokenizer).__name__, type(self.txt_encoder).__name__, type(txt_encoder).__name__)
And now it should print out CLIPTokenizerFast CLIPTextModelWithProjection CLIPTextModelWithProjection
An alternative solution is to not assign None and make sure to assign it in init. So, rather than
txt_encoder: Optional[CLIPTextModelWithProjection] = None
We instead have
txt_encoder: CLIPTextModelWithProjection
This will ensure that the overridden __getattr__
is called.