Search code examples
pythonpython-3.xpytorchpython-attrs

PyTorch Module with attrs cannot get parameter list


The attr's package somehow ruins pytorch's parameter() method for a module. I am wondering if anyone has any work-arounds or solutions, so that the two packages can seamlessly integrate?

If not, any advice on which github to post the issue to? My instinct would be to post this onto attr's github, but the stack trace is almost entirely relevant to pytorch's codebase.

Python 3.7.3
attrs== 19.1.0
torch==1.1.0.post2
torchvision==0.3.0
import attr
import torch


class RegularModule(torch.nn.Module):
    pass

@attr.s
class AttrsModule(torch.nn.Module):
    pass


module = RegularModule()
print(list(module.parameters()))

module = AttrsModule()
print(list(module.parameters()))

The actual output is:

$python attrs_pytorch.py
[]
Traceback (most recent call last):
  File "attrs_pytorch.py", line 18, in <module>
    print(list(module.parameters()))
  File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 814, in parameters
    for name, param in self.named_parameters(recurse=recurse):
  File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 840, in named_parameters
    for elem in gen:
  File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 784, in _named_members
    for module_prefix, module in modules:
  File "/usr/local/anaconda3/envs/bgg/lib/python3.7/site-packages/torch/nn/modules/module.py", line 975, in named_modules
    if self not in memo:
TypeError: unhashable type: 'AttrsModule'

The expected output is:

$python attrs_pytorch.py
[]
[]

Solution

  • You may get it to work with one workaround and using dataclasses (which you should, as it's in standard Python library since 3.7 which you are apparently using). Though I think simple __init__ is more readable. One could do something similar using attrs library (disabling hashing), I just prefer the solution using standard libraries if possible.

    The reason (if you manage to handle hashing related errors) is that you are calling torch.nn.Module.__init__() which generates _parameters attribute and other framework-specific data.

    First solving hashing with dataclasses:

    @dataclasses.dataclass(eq=False)
    class AttrsModule(torch.nn.Module):
        pass
    

    This solves hashing issues as, as stated by the documentation, section about hash and eq:

    By default, dataclass() will not implicitly add a hash() method unless it is safe to do so.

    which is needed by PyTorch so the model can be used in C++ backed (correct me if I'm wrong), furthermore:

    If eq is false, hash() will be left untouched meaning the hash() method of the superclass will be used (if the superclass is object, this means it will fall back to id-based hashing).

    So you are fine using torch.nn.Module __hash__ function (refer to documentation of dataclasses if any further errors arise).

    This leaves you with the error:

    AttributeError: 'AttrsModule' object has no attribute '_parameters'
    

    Because torch.nn.Module constructor is not called. Quick and dirty fix:

    @dataclasses.dataclass(eq=False)
    class AttrsModule(torch.nn.Module):
        def __post_init__(self):
            super().__init__()
    

    __post_init__ is a function called after __init__ (who would of guessed), where you can initialize torch-specific parameters.

    Still, I would advise against using those two modules together. For example, you are destroying PyTorch's __repr__ using your code, so repr=False should be passed to the dataclasses.dataclass constructor, which gives this final code (obvious collisions between libraries eliminated I hope):

    import dataclasses
    
    import torch
    
    
    class RegularModule(torch.nn.Module):
        pass
    
    
    @dataclasses.dataclass(eq=False, repr=False)
    class AttrsModule(torch.nn.Module):
        def __post_init__(self):
            super().__init__()
    
    
    module = RegularModule()
    print(list(module.parameters()))
    
    module = AttrsModule()
    print(list(module.parameters()))
    

    For more on attrs please see hynek answer and his blog post.