Search code examples
pythonsqlpostgresqlsqlalchemy

UPDATE + SubQuery with conditions in SQLAlchemy 2.0 not being rendered


I'm trying to update a table with info from some other rows from the same table. However, I cannot get SQLAlchemy to generate the proper SQL. It always ends up with a WHERE false clause in the subquery, which nullifies the effect.

I have tried several approaches, and this one seems the most correct but still doesn't work. Other examples I found in here are for older versions of SQLAlchemy.

Here's my code to execute it. (Please forgive the ambiguous naming -- I'm trying to obscure the code from the original source but keep the readable for troubleshooting.)

parent_id: UUID = ...
iteration: int = ...
current_generation_number: int = ...
previous_generation_number: int = current_generation - 1

previous_generation = (
    select(Run)
    .where(Run.parent_id == parent_id)
    .where(Run.grand_iteration == iteration)
    .where(Run.generation == previous_generation_number)
    .where(Run.data_partition in [DataPartition.VALIDATION, DataPartition.TEST])
    .subquery(name="previous_generation")
)

update_operation = (
    update(Run)
    .where(Run.parent_id == parent_id)
    .where(Run.grand_iteration == iteration)
    .where(Run.generation == current_generation_number)
    .where(Run.arguments == previous_generation.c.arguments)
    .where(Run.data_partition == previous_generation.c.data_partition)
    .values(
        metric1=previous_generation.c.metric1,
        metric2=previous_generation.c.metric2,
        metric3=previous_generation.c.metric3,
    )
)

self.db.execute(update_operation)
self.db.commit()

What I expect to be generated is something of the sort:

UPDATE runs
SET 
    metric1=previous_generation.metric1, 
    metric2=previous_generation.metric2, 
    metric3=previous_generation.metric3, 
FROM (
    SELECT /* ... columns ... */
    FROM runs
    WHERE
        parent_id = %(parent_id_1)s::UUID 
        AND iteration = %(iteration_1)s 
        AND generation = %(generation_1)s 
        AND data_partition IN ("TEST", "VALIDATION")
) AS previous_generation
WHERE
    runs.parent_id = %(parent_id_1)s::UUID 
    AND runs.iteration = %(iteration_1)s 
    AND runs.generation = %(generation_2)s 
    AND runs.arguments = previous_generation.arguments 
    AND runs.data_partition = previous_generation.data_partition 

And here's the SQL that SQLAlchemy logs output. Interestingly, it is output twice (I'm not sure if that's part of the problem). Notes below.

UPDATE runs
SET 
    metric1=previous_generation.metric1, 
    metric2=previous_generation.metric2, 
    metric3=previous_generation.metric3, 
FROM (
    SELECT 
        runs.id AS id, 
        runs.parent_id AS parent_id, 
        runs.generation AS generation, 
        runs.iteration AS iteration, 
        runs.arguments AS arguments, 
        runs.data_partition AS data_partition, 
        runs.metric1 AS metric1, 
        runs.metric2 AS metric2, 
        runs.metric3 AS metric3
    FROM runs
    WHERE false
) AS previous_generation 
WHERE 
    runs.parent_id = %(parent_id_1)s::UUID 
    AND runs.iteration = %(iteration_1)s 
    AND runs.generation = %(generation_1)s 
    AND runs.arguments = previous_generation.arguments 
    AND runs.data_partition = previous_generation.data_partition 
RETURNING runs.id

And the parameters:

{
    'parent_id_1': UUID('1cb259e1-9f2e-40b8-884a-5706a8275312'),
    'iteration_1': 1, 
    'generation_1': 3
}

Note the differences:

  1. My different variables are not captured and rendered in the subquery
  2. As such, the subquery ends up with WHERE false, and my conditions are not even included

What am I doing wrong in here? Any guidance is appreciated.

Context: SQLAlchemy 2.0, Python 3.9, PostgreSQL 16.2


Solution

  • Seems the problem is with

    where(Run.data_partition in [DataPartition.VALIDATION, DataPartition.TEST])
    

    Python operation in is performed here instead of the sqlachemy operation. This should fix the query:

    where(Run.data_partition.in_([DataPartition.VALIDATION, DataPartition.TEST]))
    

    More about this https://docs.sqlalchemy.org/en/20/core/operators.html