Search code examples
pythonsqlsqlalchemyfastapisqlmodel

Insert many-to-many relationship objects using SQLModel when one side of the relationship already exists in the database


I am trying to insert records in a database using SQLModel where the data looks like the following. A House object, which has a color and many locations. Locations will also be associated with many houses. The input is:

[
    {
        "color": "red",
        "locations": [
            {"type": "country", "name": "Netherlands"},
            {"type": "municipality", "name": "Amsterdam"},
        ],
    },
    {
        "color": "green",
        "locations": [
            {"type": "country", "name": "Netherlands"},
            {"type": "municipality", "name": "Amsterdam"},
        ],
    },
]

Here's a reproducible example of what I'm trying to do:

import asyncio
from typing import List

from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import Field, Relationship, SQLModel, UniqueConstraint
from sqlmodel.ext.asyncio.session import AsyncSession

DATABASE_URL = "sqlite+aiosqlite:///./database.db"


engine = create_async_engine(DATABASE_URL, echo=True, future=True)


async def init_db() -> None:
    async with engine.begin() as conn:
        await conn.run_sync(SQLModel.metadata.create_all)


SessionLocal = sessionmaker(
    autocommit=False,
    autoflush=False,
    bind=engine,
    class_=AsyncSession,
    expire_on_commit=False,
)


class HouseLocationLink(SQLModel, table=True):
    house_id: int = Field(foreign_key="house.id", nullable=False, primary_key=True)
    location_id: int = Field(
        foreign_key="location.id", nullable=False, primary_key=True
    )


class Location(SQLModel, table=True):
    id: int = Field(primary_key=True)
    type: str  # country, county, municipality, district, city, area, street, etc
    name: str  # Amsterdam, Germany, My Street, etc

    houses: List["House"] = Relationship(
        back_populates="locations",
        link_model=HouseLocationLink,
    )

    __table_args__ = (UniqueConstraint("type", "name"),)


class House(SQLModel, table=True):
    id: int = Field(primary_key=True)
    color: str = Field()
    locations: List["Location"] = Relationship(
        back_populates="houses",
        link_model=HouseLocationLink,
    )
    # other fields...


data = [
    {
        "color": "red",
        "locations": [
            {"type": "country", "name": "Netherlands"},
            {"type": "municipality", "name": "Amsterdam"},
        ],
    },
    {
        "color": "green",
        "locations": [
            {"type": "country", "name": "Netherlands"},
            {"type": "municipality", "name": "Amsterdam"},
        ],
    },
]


async def add_houses(payload) -> List[House]:
    result = []
    async with SessionLocal() as session:
        for item in payload:
            locations = []
            for location in item["locations"]:
                locations.append(Location(**location))
            house = House(color=item["color"], locations=locations)
            result.append(house)
        session.add_all(result)
        await session.commit()


asyncio.run(init_db())
asyncio.run(add_houses(data))

The problem is that when I run this code, it tries to insert duplicated location objects together with the house object. I'd love to be able to use relationship here because it makes accessing house.locations very easy.

However, I have not been able to figure out how to keep it from trying to insert duplicated locations. Ideally, I'd have a mapper function to perform a get_or_create location.

The closest I've seen to making this possible is SQLAlchemy's association proxy. But looks like SQLModel doesn't support that.

Does anybody have an idea on how to achieve this? If you know how to do it using SQLAlchemy instead of SQLModel, I'd be interested in seeing your solution. I haven't started on this project yet, so I might as well using SQLAlchemy if it will make my life easier.

I've also tried tweaking with sa_relationship_kwargs such as

sa_relationship_kwargs={
    "lazy": "selectin",
    "cascade": "none",
    "viewonly": "true",
}

But that prevents the association entries from being added to the HouseLocationLink table.

Any pointers will much appreciated. Even if it means changing my approach altogether.

Thanks!


Solution

  • I am writing this solution because you mentioned you are open to using SQLAlchemy. As you mentioned, you need association proxy but you also need "Unique Objects". I have adapted it to function with asynchronous queries (instead of synchronous), aligning with my individual preferences, all without altering the logic significantly.

    import asyncio
    from sqlalchemy import UniqueConstraint, ForeignKey, select, text, func
    from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped, relationship
    from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
    from sqlalchemy.ext.associationproxy import AssociationProxy, association_proxy
    
    class Base(DeclarativeBase):
        pass
    
    class UniqueMixin:
        cache = {}
    
        @classmethod
        async def as_unique(cls, session: AsyncSession, *args, **kwargs):
            cache = getattr(session, "_cache", None)
            if cache is None:
                session._cache = cache = {}
    
            key = cls, cls.unique_hash(*args, **kwargs)
            if key in cache:
                return cache[key]
            with session.no_autoflush:
                statement = select(cls).where(cls.unique_filter(*args, **kwargs)).limit(1)
                obj = (await session.scalars(statement)).first()
                if obj is None:
                    obj = cls(*args, **kwargs)
                    session.add(obj)
            cache[key] = obj
            return obj
    
        @classmethod
        def unique_hash(cls, *args, **kwargs):
            raise NotImplementedError("Implement this in subclass")
    
        @classmethod
        def unique_filter(cls, *args, **kwargs):
            raise NotImplementedError("Implement this in subclass")
    
    class Location(UniqueMixin, Base):
        __tablename__ = "location"
        id: Mapped[int] = mapped_column(primary_key=True)
        name: Mapped[str] = mapped_column()
        type: Mapped[str] = mapped_column()
        house_associations: Mapped[list["HouseLocationLink"]] = relationship(back_populates="location")
        __table_args = (UniqueConstraint(type, name),)
    
        @classmethod
        def unique_hash(cls, name, type):
            # this is the key for the dict
            return type, name
    
        @classmethod
        def unique_filter(cls, name, type):
            # this is how you want to establish the uniqueness
            # the result of this filter will be the value in the dict
            return (cls.type == type) & (cls.name == name)
    
    class House(Base):
        __tablename__ = "house"
        id: Mapped[int] = mapped_column(primary_key=True)
        name: Mapped[str] = mapped_column()
        location_associations: Mapped[list["HouseLocationLink"]] = relationship(back_populates="house")
        locations: AssociationProxy[list[Location]] = association_proxy(
            "location_associations",
            "location",
            # you need this so you can directly add ``Location`` objects to ``House``
            creator=lambda location: HouseLocationLink(location=location),
        )
    
    class HouseLocationLink(Base):
        __tablename__ = "houselocationlink"
        house_id: Mapped[int] = mapped_column(ForeignKey(House.id), primary_key=True)
        location_id: Mapped[int] = mapped_column(ForeignKey(Location.id), primary_key=True)
        location: Mapped[Location] = relationship(back_populates="house_associations")
        house: Mapped[House] = relationship(back_populates="location_associations")
    
    engine = create_async_engine("sqlite+aiosqlite:///test.sqlite")
    
    async def main():
        data = [
            {
                "name": "red",
                "locations": [
                    {"type": "country", "name": "Netherlands"},
                    {"type": "municipality", "name": "Amsterdam"},
                ],
            },
            {
                "name": "green",
                "locations": [
                    {"type": "country", "name": "Netherlands"},
                    {"type": "municipality", "name": "Amsterdam"},
                ],
            },
        ]
    
        async with engine.begin() as conn:
            await conn.run_sync(Base.metadata.create_all)
    
        async with AsyncSession(engine) as session, session.begin():
            for item in data:
                house = House(
                    name=item["name"],
                    locations=[await Location.as_unique(session, **location) for location in item["locations"]]
                )
                session.add(house)
    
        async with AsyncSession(engine) as session:
            statement = select(func.count(text("*")), Location)
            assert await session.scalar(statement) == 2
    
            statement = select(func.count(text("*")), House)
            assert await session.scalar(statement) == 2
    
            statement = select(func.count(text("*")), HouseLocationLink)
            assert await session.scalar(statement) == 4
    
    
    asyncio.run(main())
    

    You can notice that the asserts do pass with no violation of unique constraint and no multiple inserts. I have left some inline comments which mention the "key" aspects of this code. If you run this code multiple times, you will notice that only new House objects and corresponding HouseLocationLink are added, no new Location objects are added. There will be only one query made to cache this behavior per key - value pair.