I am adding a field to my table using alembic.
I am adding the field last_name
, and filling it with data using do_some_processing
function which loads data for the field from some other source.
This is the table model, I added the field last_name
to the model
class MyTable(db.Model):
__tablename__ = "my_table"
index = db.Column(db.Integer, primary_key=True, nullable=False)
age = db.Column(db.Integer(), default=0)
first_name = db.Column(db.String(100), nullable=False)
last_name = db.Column(db.String(100), nullable=False)
Here is my migration which works well
# migration_add_last_name_field
op.add_column('my_table', sa.Column('last_name', sa.String(length=100), nullable=True))
values = session.query(MyTable).filter(MyTable.age == 5).all()
for value in values:
first_name = value.first_name
value.last_name = do_some_processing(first_name)
session.commit()
The issue is, that using session.query(MyTable)
causes issues in future migrations.
For example, if I add in the future a migration which adds a field foo
to the table, and add the field to class MyTable
,
If I have unupdated environment, it will run migration_add_last_name_field
and it fails
sqlalchemy.exc.OperationalError: (MySQLdb._exceptions.OperationalError)
(1054, "Unknown column 'my_table.foo' in 'field list'")
[SQL: SELECT my_table.`index` AS my_table_index, my_table.first_name AS my_table_first_name,
my_table.last_name AS my_table_last_name, my_table.foo AS my_table_foo
FROM my_table
WHERE my_table.age = %s]
[parameters: (0,)]
(Background on this error at: http://sqlalche.me/e/13/e3q8)
since the migration that adds foo
runs only after, but session.query(MyTable)
takes all the fields in MyTable
model including foo
.
I am trying to do the update without selecting all fields to avoid selecting fields that were not created yet, like this:
op.add_column('my_table', sa.Column('last_name', sa.String(length=100), nullable=True))
values = session.query(MyTable.last_name, MyTable.first_name).filter(MyTable.age == 0).all()
for value in values:
first_name = value.first_name
value.last_name = do_some_processing(first_name)
session.commit()
But this results an error: can't set attribute
I also tried different variations of select *
also with no success.
What is the correct solution?
Adding here a solution written by Oren S
here is the usage:
from alembic_custom_ops import visit
def upgrade():
for row in op.visit_rows('my_table', ['id', 'customer_id']):
row['id'] = str(_make_unique_id(row["customer_id"]))
And here is the util class you should have in your code
from collections import ChainMap
from sqlalchemy import MetaData, update
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import sessionmaker
from typing import List
from alembic.operations import Operations, MigrateOperation
from sqlalchemy import MetaData
from sqlalchemy.ext.automap import automap_base
class VisitException(Exception):
pass
@Operations.register_operation("visit_rows")
class VisitOp(MigrateOperation):
def __init__(
self,
table_name: str,
field_names: List[str],
index_field: str,
commit_every: int,
):
self.table_name = table_name
self.field_names = field_names
self.index_field = index_field
self.commit_every = commit_every
@classmethod
def visit_rows(
cls,
operations,
table_name: str,
field_names: List[str],
index_field: str = "index",
commit_every: int = 0, # 0 means at end only
):
op = VisitOp(table_name, field_names, index_field, commit_every)
return operations.invoke(op)
@Operations.implementation_for(VisitOp)
def visit(operations, operation: VisitOp):
engine = operations.get_bind()
session_type = sessionmaker(bind=engine)
meta = MetaData(bind=engine)
meta.reflect(only=(operation.table_name,))
base = automap_base(metadata=meta)
base.prepare()
table = getattr(base.classes, operation.table_name)
session = session_type()
field_names_set = frozenset(operation.field_names)
all_fields_names = [operation.index_field] + operation.field_names
for running_count, row in enumerate(
session.query(*[getattr(table, field_name) for field_name in all_fields_names]),
start=1,
):
if len(row) != len(all_fields_names):
raise VisitException("Internal error: lists' lengths should be equal")
index_value = row[0]
db_values = dict(zip(operation.field_names, row[1:]))
changes = {}
yield ChainMap(changes, db_values)
if changes:
if changes.keys() - field_names_set:
raise VisitException("Only requested fields may be updated")
if operation.index_field in changes:
raise VisitException("Can't rewrite the selected index field")
session.execute(
update(table)
.where(getattr(table, operation.index_field) == index_value)
.values(**changes)
)
if operation.commit_every and ((running_count % operation.commit_every) == 0):
session.commit()
session.commit()