Search code examples
pythonserializationdeserializationmarshallingmarshmallow

Register a post_load hook after Schema defintion?


I'm trying to implement a Schema that deserializes into an object, where the object class is not known at Schema-definition time. I would have assumed that I can register a post_load function at runtime, but it appears that post_load works only for class methods.

It seems like I could make it work by either:

  • Updating Schema._hooks manually or by
  • Somehow creating a bound method at runtime and registering that.

Since both of these options are somewhat hack-ish, is there an official way to achieve the same result?


Solution

  • I don't think you need a metaclass.

    Define a base schema with a post-load method that just needs the class.

    class CustomSchema(Schema):
    
        @post_load
        def make_obj(self, data):
            return self.OBJ_CLS(**data)
    

    If the class is known at import time (not your use case), this allows you to factorize the instantiation by just providing the class. Nice already.

    class PetSchema(CustomSchema):
    
        OBJ_CLS = Pet
    

    If the class is not known at import time, then it can be provided afterwards.

    class PetSchema(CustomSchema):
        pass
    
    
    PetSchema.OBJ_CLS = Pet
    

    If you need some more processing before instantiating, then you can override make_obj in any class, as you show in your answer.

    class PetSchema(CustomSchema):
        def make_obj(self, data):
            data = my_func(data)
            return Pet(**data)
    

    More generally, this mechanism allows you to define hooks in a base schema. This is a good way to overcome a current limitation in marshmallow: the fact that multiple post_load methods can be executed in any order. Define a single post_load method in a base class with a hook for each processing step. (This contrived example doesn't really illustrate the point.)

    class CustomSchema(Schema):
    
        @post_load
        def post_load_steps(self, data):
            data = self.post_load_step_1(data)
            data = self.post_load_step_2(data)
            data = self.post_load_step_3(data)
            return data
    
        def post_load_step_1(self, data):
            return data
    
        def post_load_step_2(self, data):
            return data
    
        def post_load_step_3(self, data):
            return data