Search code examples
pythonsqlalchemyalembic

sqlalchemy - alembic run update query without specifying model to avoid later migrations clash


I am adding a field to my table using alembic.
I am adding the field last_name, and filling it with data using do_some_processing function which loads data for the field from some other source.

This is the table model, I added the field last_name to the model

class MyTable(db.Model):
    __tablename__ = "my_table"

    index = db.Column(db.Integer, primary_key=True, nullable=False)
    age = db.Column(db.Integer(), default=0)
    first_name = db.Column(db.String(100), nullable=False)
    last_name = db.Column(db.String(100), nullable=False)

Here is my migration which works well

# migration_add_last_name_field
op.add_column('my_table', sa.Column('last_name', sa.String(length=100), nullable=True)) 
values = session.query(MyTable).filter(MyTable.age == 5).all()

for value in values:
    first_name = value.first_name
    value.last_name = do_some_processing(first_name)
session.commit()

The issue is, that using session.query(MyTable) causes issues in future migrations.

For example, if I add in the future a migration which adds a field foo to the table, and add the field to class MyTable, If I have unupdated environment, it will run migration_add_last_name_field and it fails

sqlalchemy.exc.OperationalError: (MySQLdb._exceptions.OperationalError) 
(1054, "Unknown column 'my_table.foo' in 'field list'")

[SQL: SELECT my_table.`index` AS my_table_index, my_table.first_name AS my_table_first_name, 
  my_table.last_name AS my_table_last_name, my_table.foo AS my_table_foo
FROM my_table 
WHERE my_table.age = %s]

[parameters: (0,)]
(Background on this error at: http://sqlalche.me/e/13/e3q8)

since the migration that adds foo runs only after, but session.query(MyTable) takes all the fields in MyTable model including foo.

I am trying to do the update without selecting all fields to avoid selecting fields that were not created yet, like this:

op.add_column('my_table', sa.Column('last_name', sa.String(length=100), nullable=True)) 


values = session.query(MyTable.last_name, MyTable.first_name).filter(MyTable.age == 0).all()


for value in values:
    first_name = value.first_name
    value.last_name = do_some_processing(first_name)
session.commit()

But this results an error: can't set attribute

I also tried different variations of select * also with no success.
What is the correct solution?


Solution

  • Adding here a solution written by Oren S
    here is the usage:

    from alembic_custom_ops import visit
    
    def upgrade():
        for row in op.visit_rows('my_table', ['id', 'customer_id']):
            row['id'] = str(_make_unique_id(row["customer_id"]))
    

    And here is the util class you should have in your code

    from collections import ChainMap
    from sqlalchemy import MetaData, update
    from sqlalchemy.ext.automap import automap_base
    from sqlalchemy.orm import sessionmaker
    from typing import List
    from alembic.operations import Operations, MigrateOperation
    from sqlalchemy import MetaData
    from sqlalchemy.ext.automap import automap_base
    
    class VisitException(Exception):
        pass
    
    @Operations.register_operation("visit_rows")
    class VisitOp(MigrateOperation):
        def __init__(
            self,
            table_name: str,
            field_names: List[str],
            index_field: str,
            commit_every: int,
        ):
            self.table_name = table_name
            self.field_names = field_names
            self.index_field = index_field
            self.commit_every = commit_every
    
        @classmethod
        def visit_rows(
            cls,
            operations,
            table_name: str,
            field_names: List[str],
            index_field: str = "index",
            commit_every: int = 0,  # 0 means at end only
        ):
            op = VisitOp(table_name, field_names, index_field, commit_every)
            return operations.invoke(op)
    
    @Operations.implementation_for(VisitOp)
    def visit(operations, operation: VisitOp):
        engine = operations.get_bind()
        session_type = sessionmaker(bind=engine)
        meta = MetaData(bind=engine)
        meta.reflect(only=(operation.table_name,))
        base = automap_base(metadata=meta)
        base.prepare()
        table = getattr(base.classes, operation.table_name)
        session = session_type()
        field_names_set = frozenset(operation.field_names)
        all_fields_names = [operation.index_field] + operation.field_names
        for running_count, row in enumerate(
            session.query(*[getattr(table, field_name) for field_name in all_fields_names]),
            start=1,
        ):
            if len(row) != len(all_fields_names):
                raise VisitException("Internal error: lists' lengths should be equal")
            index_value = row[0]
            db_values = dict(zip(operation.field_names, row[1:]))
            changes = {}
            yield ChainMap(changes, db_values)
            if changes:
                if changes.keys() - field_names_set:
                    raise VisitException("Only requested fields may be updated")
                if operation.index_field in changes:
                    raise VisitException("Can't rewrite the selected index field")
                session.execute(
                    update(table)
                    .where(getattr(table, operation.index_field) == index_value)
                    .values(**changes)
                )
            if operation.commit_every and ((running_count % operation.commit_every) == 0):
                session.commit()
        session.commit()