Search code examples
pythonfastapipydantic

Flatten nested Pydantic model


from typing import Union
from pydantic import BaseModel, Field


class Category(BaseModel):
    name: str = Field(alias="name")


class OrderItems(BaseModel):
    name: str = Field(alias="name")
    category: Category = Field(alias="category")
    unit: Union[str, None] = Field(alias="unit")
    quantity: int = Field(alias="quantity")

When instantiated like this:

OrderItems(**{'name': 'Test','category':{'name': 'Test Cat'}, 'unit': 'kg', 'quantity': 10})

It returns data like this:

OrderItems(name='Test', category=Category(name='Test Cat'), unit='kg', quantity=10)

But I want the output like this:

OrderItems(name='Test', category='Test Cat', unit='kg', quantity=10)

How can I achieve this?


Solution

  • You should try as much as possible to define your schema the way you actually want the data to look in the end, not the way you might receive it from somewhere else.


    UPDATE: Generalized solution (one nested field or more)

    To generalize this problem, let's assume you have the following models:

    from pydantic import BaseModel
    
    
    class Foo(BaseModel):
        x: bool
        y: str
        z: int
    
    
    class _BarBase(BaseModel):
        a: str
        b: float
    
        class Config:
            orm_mode = True
    
    
    class BarNested(_BarBase):
        foo: Foo
    
    
    class BarFlat(_BarBase):
        foo_x: bool
        foo_y: str
    

    Problem: You want to be able to initialize BarFlat with a foo argument just like BarNested, but the data to end up in the flat schema, wherein the fields foo_x and foo_y correspond to x and y on the Foo model (and you are not interested in z).

    Solution: Define a custom root_validator with pre=True that checks if a foo key/attribute is present in the data. If it is, it validates the corresponding object against the Foo model, grabs its x and y values and then uses them to extend the given data with foo_x and foo_y keys:

    from pydantic import BaseModel, root_validator
    from pydantic.utils import GetterDict
    
    ...
    
    class BarFlat(_BarBase):
        foo_x: bool
        foo_y: str
    
        @root_validator(pre=True)
        def flatten_foo(cls, values: GetterDict) -> GetterDict | dict[str, object]:
            foo = values.get("foo")
            if foo is None:
                return values
            # Assume `foo` must ba valid `Foo` data:
            foo = Foo.validate(foo)
            return {
                "foo_x": foo.x,
                "foo_y": foo.y,
            } | dict(values)
    

    Note that we need to be a bit more careful inside a root validator with pre=True because the values are always passed in the form of a GetterDict, which is an immutable mapping-like object. So we cannot simply assign new values foo_x/foo_y to it like we would to a dictionary. But nothing is stopping us from returning the cleaned up data in the form of a regular old dict.

    To demonstrate, we can throw some test data at it:

    test_dict = {"a": "spam", "b": 3.14, "foo": {"x": True, "y": ".", "z": 0}}
    test_orm = BarNested(a="eggs", b=-1, foo=Foo(x=False, y="..", z=1))
    test_flat = '{"a": "beans", "b": 0, "foo_x": true, "foo_y": ""}'
    bar1 = BarFlat.parse_obj(test_dict)
    bar2 = BarFlat.from_orm(test_orm)
    bar3 = BarFlat.parse_raw(test_flat)
    print(bar1.json(indent=4))
    print(bar2.json(indent=4))
    print(bar3.json(indent=4))
    

    The output:

    {
        "a": "spam",
        "b": 3.14,
        "foo_x": true,
        "foo_y": "."
    }
    
    {
        "a": "eggs",
        "b": -1.0,
        "foo_x": false,
        "foo_y": ".."
    }
    
    {
        "a": "beans",
        "b": 0.0,
        "foo_x": true,
        "foo_y": ""
    }
    

    The first example simulates a common situation, where the data is passed to us in the form of a nested dictionary. The second example is the typical database ORM object situation, where BarNested represents the schema we find in a database. The third is just to show that we can still correctly initialize BarFlat without a foo argument.

    One caveat to note is that the validator does not get rid of the foo key, if it finds it in the values. If your model is configured with Extra.forbid that will lead to an error. In that case, you'll just need to have an extra line, where you coerce the original GetterDict to a dict first, then pop the "foo" key instead of getting it.


    Original post (flatten single field)

    If you need the nested Category model for database insertion, but you want a "flat" order model with category being just a string in the response, you should split that up into two separate models.

    Then in the response model you can define a custom validator with pre=True to handle the case when you attempt to initialize it providing an instance of Category or a dict for category.

    Here is what I suggest:

    from pydantic import BaseModel, validator
    
    
    class Category(BaseModel):
        name: str
    
    
    class OrderItemBase(BaseModel):
        name: str
        unit: str | None
        quantity: int
    
    
    class OrderItemCreate(OrderItemBase):
        category: Category
    
    
    class OrderItemResponse(OrderItemBase):
        category: str
    
        @validator("category", pre=True)
        def handle_category_model(cls, v: object) -> object:
            if isinstance(v, Category):
                return v.name
            if isinstance(v, dict) and "name" in v:
                return v["name"]
            return v
    

    Here is a demo:

    if __name__ == "__main__":
        insert_data = '{"name": "foo", "category": {"name": "bar"}, "quantity": 1}'
        insert_obj = OrderItemCreate.parse_raw(insert_data)
        print(insert_obj.json(indent=2))
        ...  # insert into DB
        response_obj = OrderItemResponse.parse_obj(insert_obj.dict())
        print(response_obj.json(indent=2))
    

    Here is the output:

    {
      "name": "foo",
      "unit": null,
      "quantity": 1,
      "category": {
        "name": "bar"
      }
    }
    
    {
      "name": "foo",
      "unit": null,
      "quantity": 1,
      "category": "bar"
    }
    

    One of the benefits of this approach is that the JSON Schema stays consistent with what you have on the model. If you use this in FastAPI that means the swagger documentation will actually reflect what the consumer of that endpoint receives. You could of course override and customize schema creation, but... why? Just define the model correctly in the first place and avoid headache in the future.