Search code examples
pythonpostgispydanticgeoalchemy2

Geoalchemy2 Geometry schema for pydantic (FastAPI)


I want to use PostGIS with FastAPI and therefore use geoalchemy2 with alembic to create the table in the DB. But I'm not able to declare the schema in pydantic v2 correctly.

My Code looks as follows:

# auto-generated from env.py
from alembic import op
import sqlalchemy as sa
from geoalchemy2 import Geometry
from sqlalchemy.dialects import postgresql

...

def upgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    op.create_geospatial_table('solarpark',
    sa.Column('id', sa.Integer(), nullable=False),
    sa.Column('name_of_model', sa.String(), nullable=True),
    sa.Column('comment', sa.String(), nullable=True),
    sa.Column('lat', sa.ARRAY(sa.Float()), nullable=True),
    sa.Column('lon', sa.ARRAY(sa.Float()), nullable=True),
    sa.Column('geom', Geometry(geometry_type='POLYGON', srid=4326, spatial_index=False, from_text='ST_GeomFromEWKT', name='geometry'), nullable=True),
    sa.PrimaryKeyConstraint('id')
    )
    op.create_geospatial_index('idx_solarpark_geom', 'solarpark', ['geom'], unique=False, postgresql_using='gist', postgresql_ops={})
    op.create_index(op.f('ix_solarpark_id'), 'solarpark', ['id'], unique=False)
    # ### end Alembic commands ###
# models.py

from geoalchemy2 import Geometry
from sqlalchemy import ARRAY, Column, Date, Float, Integer, String

from app.db.base_class import Base


class SolarPark(Base):
    id = Column(Integer, primary_key=True, index=True)
    name_of_model = Column(String)
    comment = Column(String, default="None")
    lat = Column(ARRAY(item_type=Float))
    lon = Column(ARRAY(item_type=Float))
    geom = Column(Geometry("POLYGON", srid=4326))

# schemas.py

from typing import List

from pydantic import ConfigDict, BaseModel, Field

class SolarParkBase(BaseModel):
    model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)

    name_of_model: str = Field("test-model")
    comment: str = "None"
    lat: List[float] = Field([599968.55, 599970.90, 599973.65, 599971.31, 599968.55])
    lon: List[float] = Field([5570202.63, 5570205.59, 5570203.42, 5570200.46, 5570202.63])
    geom: [WHAT TO INSERT HERE?] = Field('POLYGON ((599968.55 5570202.63, 599970.90 5570205.59, 599973.65 5570203.42, 599971.31 5570200.46, 599968.55 5570202.63))')

I want the column geom to be a type of geometry to perform spatial operations on it. But how can I declare that in pydantic v2?

Thanks a lot in advance!


Solution

  • So I found the answer:

    # schemas.py
    from typing import List
    
    from pydantic import ConfigDict, BaseModel, Field
    
    from geoalchemy2.types import WKBElement
    from typing_extensions import Annotated
    
    class SolarParkBase(BaseModel):
        model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)
    
        name_of_model: str = Field("test-model")
        comment: str = "None"
        lat: List[float] = Field([599968.55, 599970.90, 599973.65, 599971.31, 599968.55])
        lon: List[float] = Field([5570202.63, 5570205.59, 5570203.42, 5570200.46, 5570202.63])
        geom: Annotated[str, WKBElement] = Field('POLYGON ((599968.55 5570202.63, 599970.90 5570205.59, 599973.65 5570203.42, 599971.31 5570200.46, 599968.55 5570202.63))')
    

    But you also need to change your CRUD function:

    # crud_solarpark.py
    class CRUDSolarPark(CRUDBase[SolarPark, SolarParkCreate, SolarParkUpdate]):
        def get(self, db: Session, *, id: int) -> SolarPark:
            db_obj = db.query(SolarPark).filter(SolarPark.id == id).first()
            if db_obj is None:
                return None
    
            if isinstance(db_obj.geom, str):
                db_obj.geom = WKTElement(db_obj.geom)
            db_obj.geom = to_shape(db_obj.geom).wkt
            return db_obj
    
        def get_multi(self, db: Session, *, skip: int = 0, limit: int = 100) -> SolarPark:
            db_obj = db.query(SolarPark).offset(skip).limit(limit).all()
            if db_obj is None:
                return None
    
            for obj in db_obj:
                if isinstance(obj.geom, str):
                    obj.geom = WKTElement(obj.geom)
                obj.geom = to_shape(obj.geom).wkt
            return db_obj
    
        def create(self, db: Session, *, obj_in: SolarParkCreate) -> SolarPark:
            obj_in_data = jsonable_encoder(obj_in)
            db_obj = SolarPark(**obj_in_data)  # type: ignore
            db.add(db_obj)
            db.commit()
            db.refresh(db_obj)
            db_obj.geom = to_shape(db_obj.geom).wkt
            return db_obj
    
        def update(
            self,
            db: Session,
            *,
            db_obj: SolarPark,
            obj_in: Union[SolarParkUpdate, Dict[str, Any]],
        ) -> SolarPark:
            obj_data = jsonable_encoder(db_obj)
            if isinstance(obj_in, dict):
                update_data = obj_in
            else:
                update_data = obj_in.dict(exclude_unset=True)
            for field in obj_data:
                if field in update_data:
                    setattr(db_obj, field, update_data[field])
            db.add(db_obj)
            db.commit()
            db.refresh(db_obj)
            db_obj.geom = to_shape(db_obj.geom).wkt
            return db_obj
    

    Hope that helps anybody who uses FastAPI and PostGIS :)