Search code examples
pythonpython-3.xsqlalchemypydantic

SQLAlchemy relationships field to Pydantic : Validation Error


I have some models declared with SQLAlchemy declarative base. Their fields represent some IP addresses. When I try to convert instances of these models to pydantic model via orm_mode, it fails with the following error

E   pydantic.error_wrappers.ValidationError: 4 validation errors for IpSchema
E   ip_address -> 0
E     value is not a valid IPv4 address (type=value_error.ipv4address)
E   ip_address -> 0
E     value is not a valid IPv6 address (type=value_error.ipv6address)
E   ip_address -> 0
E     value is not a valid IPv4 or IPv6 address (type=value_error.ipvanyaddress)
E   ip_address -> 0
E     str type expected (type=type_error.str)

The following is the code. I have tried to check it with pytest, but it fails.

Can the orm_mode code be overwritten?

from typing import List, Union

from pydantic import BaseModel, Field, IPvAnyAddress
from sqlalchemy import INTEGER, Column, ForeignKey, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship

Base = declarative_base()


class IpModel(Base):
    __tablename__ = "ip_model"

    id = Column(INTEGER, primary_key=True, autoincrement=True, index=True)
    ip_address = relationship("IpAddress", back_populates="ip_model")


class IpAddress(Base):
    __tablename__ = "ip"

    id = Column(INTEGER, primary_key=True, autoincrement=True, index=True)
    address = Column(String(64), nullable=False)

    ip_model_id = Column(INTEGER, ForeignKey("ip_model.id"), nullable=False)
    ip_model = relationship("IpModel", back_populates="ip_address")


class IpSchema(BaseModel):
    ip_address: List[Union[IPv4Address, IPv6Address, IPvAnyAddress]] = Field()

    class Config:
        orm_mode = True


def test_ipv4():
    ipv4: str = "192.168.1.1"

    ip = IpAddress(address=ipv4)
    m = IpModel(ip_address=[ip])
    s = IpSchema.from_orm(m)

    assert str(s.ip_address[0]) == ipv4

How can I solve this problem?


Solution

  • Pydantic does not know how to map each relationship ORM instances to its address field. For that you will need to add a pydantic validator with the pre=True argument in order to map each ORM instance to the address field before pydantic validation.

    Here is how it should look like

    class IpSchema(BaseModel):
        ip_address: List[Union[IPv4Address, IPv6Address, IPvAnyAddress]] = Field()
    
        class Config:
            orm_mode = True
    
        @validator('ip_address', pre=True)
        def validate(cls, ip_adress_relationship, **kwargs):
            return [ip.address for ip in ip_adress_relationship]
    

    Please note that validators with pre=True run before and after setting values to Pydantic model. In your example it changes nothing, but, for example, if you want to transform list of IPs to str, you need to check type of value first:

    class IpSchema(BaseModel):
        ip_address: str
    
        class Config:
            orm_mode = True
    
        @validator('ip_address', pre=True)
        def validate(cls, ip_adress_relationship, **kwargs):
            if isinstance(ip_adress_relationship, str):
                return ip_adress_relationship
            return ','.join([ip.address for ip in ip_adress_relationship])
    

    And here is the whole reproducible (and working) example :

    from typing import List, Union
    
    from pydantic import BaseModel, Field, IPvAnyAddress
    from pydantic import validator
    from pydantic.schema import IPv4Address
    from pydantic.schema import IPv6Address
    from sqlalchemy import INTEGER, Column, ForeignKey, String
    from sqlalchemy.ext.declarative import declarative_base
    from sqlalchemy.orm import relationship
    
    Base = declarative_base()
    
    
    class IpModel(Base):
        __tablename__ = "ip_model"
    
        id = Column(INTEGER, primary_key=True, autoincrement=True, index=True)
        ip_address = relationship("IpAddress", back_populates="ip_model")
    
    
    class IpAddress(Base):
        __tablename__ = "ip"
    
        id = Column(INTEGER, primary_key=True, autoincrement=True, index=True)
        address = Column(String(64), nullable=False)
    
        ip_model_id = Column(INTEGER, ForeignKey("ip_model.id"), nullable=False)
        ip_model = relationship("IpModel", back_populates="ip_address")
    
    
    class IpSchema(BaseModel):
        ip_address: List[Union[IPv4Address, IPv6Address, IPvAnyAddress]] = Field()
    
        class Config:
            orm_mode = True
    
        @validator('ip_address', pre=True)
        def validate(cls, ip_adress_relationship, **kwargs):
            return [ip.address for ip in ip_adress_relationship]
    
    
    def test_ipv4():
        ipv4: str = "192.168.1.1"
    
        ip = IpAddress(address=ipv4)
        m = IpModel(ip_address=[ip])
        s = IpSchema.from_orm(m)
    
        assert str(s.ip_address[0]) == ipv4
    if __name__ == '__main__':
        test_ipv4()