Search code examples
pydantic

How to create a dynamically imported strategy pattern in pydantic?


I would like to implement a strategy pattern in pydantic, by taking the strategy string, importing it as a module, retrieving the model for that item for it, and continue model validation with the dynamically imported model.

For example:

data = {
  "shape": {
    "strategy": "mypkg.Cone",
    "radius": 5,
    "height": 10
  }
  "transformations": [
    {
      "strategy": "mypkg.Translate",
      "x": 10,
      "y": 10,
      "z": 10
    },
    {
      "strategy": "quaternions.Rotate",
      "quaternion": [0, 0, 0, 0]
    }
  ]
}

I'd like this to be contained within the model class, and arbitrarily recursive so that I can just use Model(**data) and all strategies are resolved, and so that strategy-loaded models can have their own substrategies.

From this input I'd expect the following validated instance:

<__main__.Model(
  shape=<mypkg.Cone(
    radius=5,
    height=10
  )>,
  transformations=[
    <mypkg.Translate(x=10, y=10, z=10)>,
    <quaternions.Rotate(quaternion=[0,0,0,0])>
  ]
)>

The closest I have gotten is to use try to force rebuild the model during validation and to dynamically add new members to the discriminated type unions, but it only takes effect the NEXT validation cycle of the model:

class Model:
  shape: Shape
  transformations: list[Transformation]

  @model_validator(mode="before")
  def adjust_unions(cls, data):
    cls.model_fields.update({"shape": import_and_add_to_union(cls, data, "shape")})
    cls.model_rebuild(force=True)
    return data

import_and_add_to_union takes the existing FieldInfo, imports the module, retrieves the new union member, and produces a new FieldInfo with an annotation for a discriminated type union, with both the existing union members, and the newly imported one. This works correctly, but only goes into effect the NEXT validation cycle:

try:
  Model(**data) # errors
except:
  pass
try:
  # Now works, but would error out once more for every nested
  # level of substrategies any of the strategies may have.
  Model(**data)
except:
  pass

on top of that I would like the Shape model to be a self contained strategy pattern, that when validated with strategy: cone returns an instance a validated instance of Cone instead. Now the Shape class requires its parent model to be aware that it is a strategy model and the parent needs to build the discriminated type union.

Is there any way to improve this?


Solution

  • It seems like the optimal solution can be achieved in a much simpler way by overriding __new__ as follows:

    class Strategy(BaseModel):
        strategy: str
    
        def __new__(cls, *args, **kwargs):
            qualifiers = kwargs.get("strategy").split(".")
            module = ".".join(qualifiers[:-1])
            attr = qualifiers[-1]
            module_ref = importlib.import_module(module)
            child_cls = getattr(module_ref, attr)
            return super().__new__(child_cls)
    

    any class that inherits from Strategy will then be loaded according to the value of its strategy attribute. In an example module called my.pkg:

    class Cone(Strategy):
        radius: float
        height: float
    

    and then from anywhere:

    >>> Strategy(strategy="my.pkg.Cone", radius=3, height=5)
    Cone(strategy='my.pkg.Cone', radius=3.0, height=5.0)