Search code examples
pythonsqlalchemy

Adding multiple foreign keys from same model (FastAPI and SqlAlechemy)


I am trying to have two foreign keys from User table inside Ban table. Here is how I did it:

class Ban(Base):
    __tablename__ = "ban"

    ban_id = Column(Integer, primary_key=True, index=True)
    poll_owner_id = Column(Integer)
    banned_by = Column(String ,  ForeignKey('user.username', ondelete='CASCADE', ), unique=True)
    user_id = Column(Integer,  ForeignKey('user.user_id', ondelete='CASCADE', ))
    updated_at = Column(DateTime)
    create_at = Column(DateTime)


    ban_to_user = relationship("User", back_populates='user_to_ban', cascade='all, delete')

and User table:

class User(Base):
    __tablename__ = "user"

    user_id = Column(Integer, primary_key=True, index=True)
    username = Column(String, unique=True)
    email = Column(String)
    create_at = Column(DateTime)
    updated_at = Column(DateTime)

    user_to_ban = relationship("Ban", back_populates='ban_to_user', cascade='all, delete')

When I try to run a query to fetch all users like this:

@router.get('/all')
async def get_all_users(db:Session = Depends(get_db)):
    return db.query(models.User).all()

I get this error:

sqlalchemy.exc.InvalidRequestError: One or more mappers failed to initialize - can't proceed with initialization of other mappers. Triggering mapper: 'mapped class User->user'. Origina
l exception was: Could not determine join condition between parent/child tables on relationship User.user_to_ban - there are multiple foreign key paths linking the tables.  Specify the 'foreign_keys' argument, providing a list of those columns which should be counted as containing a foreign key reference to the parent table.

I did the relationship between them as you can see but it states that there is problem between them. If needed I can show you how I did migration for my db using alembic if that is possible cause or is there a cleaner and better way to do this.


Solution

  • You can have several foreign keys to a single table, like in your case for banned user and banned_by user.

    You just need to disambiguate, which ForeignKey for which relationship (docs):

    class Ban(Base):
        __tablename__ = "ban"
        id = Column(Integer, primary_key=True)
        banned_user_id = Column(Integer, ForeignKey("user.id"))  # for banned_user relationship
        banned_by_user_id = Column(Integer, ForeignKey("user.id"))  # for banned_by relationship
        banned_user = relationship("User", foreign_keys=[banned_user_id], back_populates="bans")
        banned_by = relationship("User", foreign_keys=[banned_by_user_id])
    

    Full demo:

    from sqlalchemy import (
        Column,
        ForeignKey,
        Integer,
        String,
        create_engine,
        select,
    )
    from sqlalchemy.orm import Session, declarative_base, relationship
    
    Base = declarative_base()
    
    
    class User(Base):
        __tablename__ = "user"
        id = Column(Integer, primary_key=True)
        username = Column(String, unique=True)
        bans = relationship(
            "Ban",
            back_populates="banned_user",
            foreign_keys="Ban.banned_user_id",
        )
    
    
    class Ban(Base):
        __tablename__ = "ban"
        id = Column(Integer, primary_key=True)
        banned_user_id = Column(Integer, ForeignKey("user.id"))
        banned_by_user_id = Column(Integer, ForeignKey("user.id"))
        banned_user = relationship(
            "User", foreign_keys=[banned_user_id], back_populates="bans"
        )
        banned_by = relationship("User", foreign_keys=[banned_by_user_id])
    
    
    engine = create_engine("sqlite://", echo=True, future=True)
    
    Base.metadata.create_all(engine)
    
    spongebob = User(username="spongebob")
    patrick = User(username="patrickstarr")
    
    spongebob_bans_patrick = Ban(banned_by=spongebob, banned_user=patrick)
    
    with Session(engine) as session:
        session.add_all(
            [
                spongebob,
                patrick,
                spongebob_bans_patrick,
            ]
        )
        session.commit()
    
    with Session(engine) as session:
        result = session.scalars(select(Ban)).first()
        print(
            "User:",
            result.banned_user.username,
            "was banned by User:",
            result.banned_by.username,
        )
    
    # User: patrickstarr was banned by User: spongebob