2

I am adding a field to my table using alembic.
I am adding the field last_name, and filling it with data using do_some_processing function which loads data for the field from some other source.

This is the table model, I added the field last_name to the model

class MyTable(db.Model):
    __tablename__ = "my_table"

    index = db.Column(db.Integer, primary_key=True, nullable=False)
    age = db.Column(db.Integer(), default=0)
    first_name = db.Column(db.String(100), nullable=False)
    last_name = db.Column(db.String(100), nullable=False)

Here is my migration which works well

# migration_add_last_name_field
op.add_column('my_table', sa.Column('last_name', sa.String(length=100), nullable=True)) 
values = session.query(MyTable).filter(MyTable.age == 5).all()

for value in values:
    first_name = value.first_name
    value.last_name = do_some_processing(first_name)
session.commit()

The issue is, that using session.query(MyTable) causes issues in future migrations.

For example, if I add in the future a migration which adds a field foo to the table, and add the field to class MyTable, If I have unupdated environment, it will run migration_add_last_name_field and it fails

sqlalchemy.exc.OperationalError: (MySQLdb._exceptions.OperationalError) 
(1054, "Unknown column 'my_table.foo' in 'field list'")

[SQL: SELECT my_table.`index` AS my_table_index, my_table.first_name AS my_table_first_name, 
  my_table.last_name AS my_table_last_name, my_table.foo AS my_table_foo
FROM my_table 
WHERE my_table.age = %s]

[parameters: (0,)]
(Background on this error at: http://sqlalche.me/e/13/e3q8)

since the migration that adds foo runs only after, but session.query(MyTable) takes all the fields in MyTable model including foo.

I am trying to do the update without selecting all fields to avoid selecting fields that were not created yet, like this:

op.add_column('my_table', sa.Column('last_name', sa.String(length=100), nullable=True)) 


values = session.query(MyTable.last_name, MyTable.first_name).filter(MyTable.age == 0).all()


for value in values:
    first_name = value.first_name
    value.last_name = do_some_processing(first_name)
session.commit()

But this results an error: can't set attribute

I also tried different variations of select * also with no success.
What is the correct solution?

dina
  • 4,039
  • 6
  • 39
  • 67

2 Answers2

1

The cookbook describes this issue here: data-migrations-general-techniques

Some options here might be:

  1. use a separate metadata and reflection to load the table before and after
  2. create the table manually afterwards and only reference the columns you need
  3. best case run an update that is SQL only and doesn't depend on python level processing (in my example below I could have done this by casting the int to a string with something like op.execute("UPDATE users SET name = CAST(user_id AS text)"). I know this isn't always possible though.

Here is an example of option 2:

In this case the users table only had a user_id column in it that I then converted into a string to set as the new name column.

def upgrade() -> None:
    metadata = sa.MetaData()
    op.add_column("users", sa.Column("name", sa.String))
    # New table with the added col and the other col that should exist already.
    users_t = sa.Table(
        "users",
        metadata,
        sa.Column("user_id", sa.Integer, primary_key=True),
        sa.Column("name", sa.String))
    user_ids = op.get_bind().execute(sa.select(users_t.c.user_id)).scalars().all()
    # Slow but would work.
    for user_id in user_ids:
        op.execute(
            users_t.update().where(
                users_t.c.user_id==op.inline_literal(user_id)).values({
                    "name": op.inline_literal(str(user_id))}))
Ian Wilson
  • 6,223
  • 1
  • 16
  • 24
1

Adding here a solution written by Oren S
here is the usage:

from alembic_custom_ops import visit

def upgrade():
    for row in op.visit_rows('my_table', ['id', 'customer_id']):
        row['id'] = str(_make_unique_id(row["customer_id"]))

And here is the util class you should have in your code

from collections import ChainMap
from sqlalchemy import MetaData, update
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import sessionmaker
from typing import List
from alembic.operations import Operations, MigrateOperation
from sqlalchemy import MetaData
from sqlalchemy.ext.automap import automap_base

class VisitException(Exception):
    pass

@Operations.register_operation("visit_rows")
class VisitOp(MigrateOperation):
    def __init__(
        self,
        table_name: str,
        field_names: List[str],
        index_field: str,
        commit_every: int,
    ):
        self.table_name = table_name
        self.field_names = field_names
        self.index_field = index_field
        self.commit_every = commit_every

    @classmethod
    def visit_rows(
        cls,
        operations,
        table_name: str,
        field_names: List[str],
        index_field: str = "index",
        commit_every: int = 0,  # 0 means at end only
    ):
        op = VisitOp(table_name, field_names, index_field, commit_every)
        return operations.invoke(op)

@Operations.implementation_for(VisitOp)
def visit(operations, operation: VisitOp):
    engine = operations.get_bind()
    session_type = sessionmaker(bind=engine)
    meta = MetaData(bind=engine)
    meta.reflect(only=(operation.table_name,))
    base = automap_base(metadata=meta)
    base.prepare()
    table = getattr(base.classes, operation.table_name)
    session = session_type()
    field_names_set = frozenset(operation.field_names)
    all_fields_names = [operation.index_field] + operation.field_names
    for running_count, row in enumerate(
        session.query(*[getattr(table, field_name) for field_name in all_fields_names]),
        start=1,
    ):
        if len(row) != len(all_fields_names):
            raise VisitException("Internal error: lists' lengths should be equal")
        index_value = row[0]
        db_values = dict(zip(operation.field_names, row[1:]))
        changes = {}
        yield ChainMap(changes, db_values)
        if changes:
            if changes.keys() - field_names_set:
                raise VisitException("Only requested fields may be updated")
            if operation.index_field in changes:
                raise VisitException("Can't rewrite the selected index field")
            session.execute(
                update(table)
                .where(getattr(table, operation.index_field) == index_value)
                .values(**changes)
            )
        if operation.commit_every and ((running_count % operation.commit_every) == 0):
            session.commit()
    session.commit()
dina
  • 4,039
  • 6
  • 39
  • 67