Search code examples
pythonsqlalchemyflask-sqlalchemy

Implementing IS-A cardinality in Flask SQLAlchemy and Marshamllow schemas


Say I have a table called User which holds general information for all the users registered in the system:

class User(db.Model):
    __tablename__ = 'User'

    id = db.Column(db.Integer, primary_key=True)
    role = db.Column(db.String(20), nullable=False)
    email = db.Column(db.String(256), unique=True, nullable=False)
    password = db.Column(db.String(256), nullable=False)
    sign_up_date = db.Column(db.DateTime, nullable=False)
    archived = db.Column(db.Integer, nullable=False, default=0)

Then I have two sub-types of user, each containing all of the above columns, plus some type specific columns. For example:

Class Customer:

class Customer(User):
    __tablename__ = "Customer"
    # USER is a CUSTOMER
    address = db.Column(db.String(256), nullable=False)
    name = db.Column(db.String(64), nullable=False)

Class Driver:

class Driver(User):
    __tablename__ = "Driver"
    # USER is a DRIVER
    first_name = db.Column(db.String(64), nullable=False)
    last_name = db.Column(db.String(64), nullable=False)
    profile_picture_path = db.Column(db.String(256), nullable=False)

Class WarehouseManager:

class WarehouseManager(User):
    __tablename__ = "WarehouseManager"
    # USER is a WAREHOUSE MANAGER
    first_name = db.Column(db.String(64), nullable=False)
    last_name = db.Column(db.String(64), nullable=False)
    profile_picture_path = db.Column(db.String(256), nullable=False)
    warehouse_name = db.Column(db.String(64), nullable=False)

and so on.

The problem arises whenever I try to create this with db.create_all(), I get the following error:

sqlalchemy.exc.ArgumentError: Column 'first_name' on class WarehouseManager conflicts with existing column 'User.first_name'.  If using Declarative, consider using the use_existing_column parameter of mapped_column() to resolve conflicts.

Also I don't exactly have an idea about how I'd create schemas in Marshmallow for this.

To be clear, I want the database for Driver to contain the user_id and all the additional fields (first_name, last_name, profile_picture_path), but when I am interacting with the database via SQLAlchemy I want to be able to enter the User columns in the Driver table I.E.:

driver = Driver(email="[email protected]", first_name="Driver Name"...)

instead of:

user = Users(email="[email protected]"...)
driver = Driver(user_id = user.id, first_name="Driver Name"...)

I see the ladder as repetative and bad practice in terms of using a quality ORM. A quality ORM should handle these common cases with ease. This is not the first time I have been dissapointed by ORM's limited capabilities, especially when handling temporary tables, materialized views, triggers and so on.

Edit: I have tried adding use_existing_column in the db.Column for first_name and last_name. I still got the same error.

Edit 2: Here is a sandbox which reproduces the problem.


Solution

  • Try this

    from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
    from sqlalchemy.orm import relationship
    from sqlalchemy.ext.declarative import declarative_base
    
    Base = declarative_base()
    
    class User(Base):
        __tablename__ = 'User'
    
        id = Column(Integer, primary_key=True)
        role = Column(String(20), nullable=False)
        email = Column(String(256), unique=True, nullable=False)
        password = Column(String(256), nullable=False)
        sign_up_date = Column(DateTime, nullable=False)
        archived = Column(Integer, nullable=False, default=0)
    
    class Customer(User):
        __tablename__ = "Customer"
        __mapper_args__ = {'polymorphic_identity': 'customer'}
        
        id = Column(Integer, ForeignKey('User.id'), primary_key=True)
        address = Column(String(256), nullable=False)
        name = Column(String(64), nullable=False)
    
    class Driver(User):
        __tablename__ = "Driver"
        __mapper_args__ = {'polymorphic_identity': 'driver'}
        
        id = Column(Integer, ForeignKey('User.id'), primary_key=True)
        first_name = Column(String(64), nullable=False)
        last_name = Column(String(64), nullable=False)
        profile_picture_path = Column(String(256), nullable=False)
    
    class WarehouseManager(User):
        __tablename__ = "WarehouseManager"
        __mapper_args__ = {'polymorphic_identity': 'warehouse_manager'}
        
        id = Column(Integer, ForeignKey('User.id'), primary_key=True)
        first_name = Column(String(64), nullable=False)
        last_name = Column(String(64), nullable=False)
        profile_picture_path = Column(String(256), nullable=False)
        warehouse_name = Column(String(64), nullable=False)
    

    This way, you have a common User table, and each subtype (Customer, Driver, WarehouseManager) has its own table with additional columns. The __mapper_args__ parameter helps SQLAlchemy to understand the inheritance structure.

    Regarding Marshmallow, you can create separate schemas for each type and use fields.Nested for the common fields:

    from marshmallow import Schema, fields
    
    class UserSchema(Schema):
        id = fields.Integer()
        role = fields.String()
        email = fields.String()
        password = fields.String()
        sign_up_date = fields.DateTime()
        archived = fields.Integer()
    
    class CustomerSchema(UserSchema):
        address = fields.String()
        name = fields.String()
    
    class DriverSchema(UserSchema):
        first_name = fields.String()
        last_name = fields.String()
        profile_picture_path = fields.String()
    
    class WarehouseManagerSchema(UserSchema):
        first_name = fields.String()
        last_name = fields.String()
        profile_picture_path = fields.String()
        warehouse_name = fields.String()
    

    This way, you can use these schemas to serialize and deserialize your objects based on their types.