Search code examples
pythonpicklepython-dataclasses

How to persist and load all attributes of a dataclass


I want to persist all attributes of an object which is an instance of a dataclass. Then I want to load back that object from the files that I persisted.

Here it is an example that fullfills the task:

from dataclasses import dataclass
import pickle


@dataclass
class Circle:
    radius: float
    centre: tuple

    def save(self, path: str):
        name = ".".join(("radius", "pkl"))
        with open("/".join((path, name)), "wb") as f:
            pickle.dump(self.radius, f)
        name = ".".join(("centre", "pkl"))
        with open("/".join((path, name)), "wb") as f:
            pickle.dump(self.centre, f)

    @classmethod
    def load(cls, path):
        my_model = {}
        name = "radius"
        file_name = ".".join((name, "pkl"))
        with open("\\".join((path, file_name)), "rb") as f:
            my_model[name] = pickle.load(f)
        name = "centre"
        file_name = ".".join((name, "pkl"))
        with open("\\".join((path, file_name)), "rb") as f:
            my_model[name] = pickle.load(f)
        return cls(**my_model)
>>> c = Circle(2, (0, 0))
>>> c.save(r".\Circle")
>>> c_loaded = Circle.load(r".\Circle")
>>> c_loaded == c
True

As you can see I need to repeat the same code for every attribute, what is a better way to do it?


Solution

  • In the save method it use self.__dict__. That contains all attribute names and values as a dictionary. Load is a classmethod so there is no __dict__ at that stage. However, cls.__annotations__ contains attribute names and types, still stored in a dictionary.

    Here it is the end result:

    from dataclasses import dataclass
    import pickle
    
    @dataclass
    class Circle:
        radius: float
        centre: tuple
    
        def save(self, path):
            for name, attribute in self.__dict__.items():
                name = ".".join((name, "pkl"))
                with open("/".join((path, name)), "wb") as f:
                    pickle.dump(attribute, f)
    
        @classmethod
        def load(cls, path):
            my_model = {}
            for name in cls.__annotations__:
                file_name = ".".join((name, "pkl"))
                with open("/".join((path, file_name)), "rb") as f:
                    my_model[name] = pickle.load(f)
            return cls(**my_model)