Search code examples
pythondjangodjango-modelsdjango-ormpydantic

Pydantic from_orm to load Django model with related list field


I have the following Django models:

from django.db import models


class Foo(models.Model):
    id: int
    name = models.TextField(null=False)


class Bar(models.Model):
    id: int
    foo = models.ForeignKey(
        Foo,
        on_delete=models.CASCADE,
        null=False,
        related_name="bars",
    )

And Pydantic models (with orm_mode set to True):

from pydantic import BaseModel


class BarPy(BaseModel):
    id: int
    foo_id: int


class FooPy(BaseModel):
    id: int
    name: str
    bars: list[BarPy]

Now I want to perform a query on the model Foo and load it into FooPy, so i wrote this query:

foo_db = Foo.objects.prefetch_related("bars").all()
pydantic_model = FooPy.from_orm(foo_db)

But it gives me this error:

pydantic.error_wrappers.ValidationError: 1 validation error for FooPy
  bars
    value is not a valid list (type=type_error.list)

I am able to do it when explicitly using the FooPy constructor and assigning the values manually but i want to use from_orm.


Solution

  • The bars attribute on your Foo model is a ReverseManyToOneDescriptor that just returns a RelatedManager for the Bar model. As with any manager in Django, to get a queryset of all the instances managed by it, you need to call the all method on it. Typically you would do something like foo.bars.all().

    You can add your own custom validator to FooPy and make it pre=True to grab all the related Bar instances and pass a sequence of them along to the default validators:

    from django.db.models.manager import BaseManager
    from pydantic import BaseModel, validator
    
    ...
    
    class FooPy(BaseModel):
        id: int
        name: str
        bars: list[BarPy]
    
        @validator("bars", pre=True)
        def get_all_from_manager(cls, v: object) -> object:
            if isinstance(v, BaseManager):
                return list(v.all())
            return v
    

    Note that it is not enough to just do .all() because that will return a queryset, which will not pass the default sequence validator built into Pydantic models. You would get the same error.

    You need to give it an actual sequence (e.g. list or tuple). A QuerySet is not a sequence, but an iterable. But you can consume it and turn it into a sequence, by calling for example list on it.


    More generalized version

    You could make an attempt at generalizing that validator and add it to your own (Pydantic) base model. Something like this should work on any field you annotate as list[Model], with Model being a subclass of pydantic.BaseModel:

    from django.db.models.manager import BaseManager
    from pydantic import BaseModel, validator
    from pydantic.fields import ModelField, SHAPE_LIST
    
    ...
    
    class CustomBaseModel(BaseModel):
    
        @validator("*", pre=True)
        def get_all_from_manager(cls, v: object, field: ModelField) -> object:
            if not (isinstance(field.type_, type) and issubclass(field.type_, BaseModel)):
                return v
            if field.shape is SHAPE_LIST and isinstance(v, BaseManager):
                return list(v.all())
            return v
    

    I have not thoroughly tested this, but I think you get the idea.


    Side note

    It is worth mentioning that prefetch_related has nothing to do with the problem. The problem and its solution are the same, whether you do that or not. The difference is that without prefetch_related, you'll trigger additional database queries, when calling from_orm and thus executing the validator that consumes the queryset of .bars.all().