I'm trying to update a table with info from some other rows from the same table. However, I cannot get SQLAlchemy to generate the proper SQL. It always ends up with a WHERE false
clause in the subquery, which nullifies the effect.
I have tried several approaches, and this one seems the most correct but still doesn't work. Other examples I found in here are for older versions of SQLAlchemy.
Here's my code to execute it. (Please forgive the ambiguous naming -- I'm trying to obscure the code from the original source but keep the readable for troubleshooting.)
parent_id: UUID = ...
iteration: int = ...
current_generation_number: int = ...
previous_generation_number: int = current_generation - 1
previous_generation = (
select(Run)
.where(Run.parent_id == parent_id)
.where(Run.grand_iteration == iteration)
.where(Run.generation == previous_generation_number)
.where(Run.data_partition in [DataPartition.VALIDATION, DataPartition.TEST])
.subquery(name="previous_generation")
)
update_operation = (
update(Run)
.where(Run.parent_id == parent_id)
.where(Run.grand_iteration == iteration)
.where(Run.generation == current_generation_number)
.where(Run.arguments == previous_generation.c.arguments)
.where(Run.data_partition == previous_generation.c.data_partition)
.values(
metric1=previous_generation.c.metric1,
metric2=previous_generation.c.metric2,
metric3=previous_generation.c.metric3,
)
)
self.db.execute(update_operation)
self.db.commit()
What I expect to be generated is something of the sort:
UPDATE runs
SET
metric1=previous_generation.metric1,
metric2=previous_generation.metric2,
metric3=previous_generation.metric3,
FROM (
SELECT /* ... columns ... */
FROM runs
WHERE
parent_id = %(parent_id_1)s::UUID
AND iteration = %(iteration_1)s
AND generation = %(generation_1)s
AND data_partition IN ("TEST", "VALIDATION")
) AS previous_generation
WHERE
runs.parent_id = %(parent_id_1)s::UUID
AND runs.iteration = %(iteration_1)s
AND runs.generation = %(generation_2)s
AND runs.arguments = previous_generation.arguments
AND runs.data_partition = previous_generation.data_partition
And here's the SQL that SQLAlchemy logs output. Interestingly, it is output twice (I'm not sure if that's part of the problem). Notes below.
UPDATE runs
SET
metric1=previous_generation.metric1,
metric2=previous_generation.metric2,
metric3=previous_generation.metric3,
FROM (
SELECT
runs.id AS id,
runs.parent_id AS parent_id,
runs.generation AS generation,
runs.iteration AS iteration,
runs.arguments AS arguments,
runs.data_partition AS data_partition,
runs.metric1 AS metric1,
runs.metric2 AS metric2,
runs.metric3 AS metric3
FROM runs
WHERE false
) AS previous_generation
WHERE
runs.parent_id = %(parent_id_1)s::UUID
AND runs.iteration = %(iteration_1)s
AND runs.generation = %(generation_1)s
AND runs.arguments = previous_generation.arguments
AND runs.data_partition = previous_generation.data_partition
RETURNING runs.id
And the parameters:
{
'parent_id_1': UUID('1cb259e1-9f2e-40b8-884a-5706a8275312'),
'iteration_1': 1,
'generation_1': 3
}
Note the differences:
WHERE false
, and my conditions are not even includedWhat am I doing wrong in here? Any guidance is appreciated.
Context: SQLAlchemy 2.0, Python 3.9, PostgreSQL 16.2
Seems the problem is with
where(Run.data_partition in [DataPartition.VALIDATION, DataPartition.TEST])
Python operation in
is performed here instead of the sqlachemy operation. This should fix the query:
where(Run.data_partition.in_([DataPartition.VALIDATION, DataPartition.TEST]))
More about this https://docs.sqlalchemy.org/en/20/core/operators.html