Search code examples
pythonflasksqlalchemypytestunit-of-work

Pytest Flask rollback transactions after tests when using the Unit of Work pattern


I am studying the "Cosmic Python" book and chapter 6 explains how to use the Unit of Work pattern to change the interaction with the database/repository.

Chapter 6 of the book can be accessed here: https://www.cosmicpython.com/book/chapter_06_uow.html

The code provided by the author is the following:

from __future__ import annotations
import abc
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session

from allocation import config
from allocation.adapters import repository


class AbstractUnitOfWork(abc.ABC):
    products: repository.AbstractRepository

    def __enter__(self) -> AbstractUnitOfWork:
        return self

    def __exit__(self, *args):
        self.rollback()

    @abc.abstractmethod
    def commit(self):
        raise NotImplementedError

    @abc.abstractmethod
    def rollback(self):
        raise NotImplementedError



DEFAULT_SESSION_FACTORY = sessionmaker(bind=create_engine(
    config.get_postgres_uri(),
    isolation_level="REPEATABLE READ",
))

class SqlAlchemyUnitOfWork(AbstractUnitOfWork):

    def __init__(self, session_factory=DEFAULT_SESSION_FACTORY):
        self.session_factory = session_factory

    def __enter__(self):
        self.session = self.session_factory()  # type: Session
        self.products = repository.SqlAlchemyRepository(self.session)
        return super().__enter__()

    def __exit__(self, *args):
        super().__exit__(*args)
        self.session.close()

    def commit(self):
        self.session.commit()

    def rollback(self):
        self.session.rollback()

I am trying to test my endpoints on Flask but I could not make it rollback the data inserted after each test.

To solve that I tried to install the package pytest-flask-sqlalchemy but with the following error:

'SqlAlchemyUnitOfWork' object has no attribute 'engine'

I do not quite understand how pytest-flask-sqlalchemy works and I have no clue on how to make the Unit of Work rollback transactions after a test.

Is it possible to make it work the way the author implemented it?

Edited

It is possible to replicate my situation through the following repository:

https://github.com/Santana94/CosmicPythonRollbackTest

You should get that the test is not rolling back previous actions by cloning it and running make all.


Solution

  • Finally, I got to make the rollback functionality happen after every test.

    I got that working when I saw a package called pytest-postgresql implementing it on itself. I just made my adjustments to make tests rollback the database data that I was working with. For that, I just had to implement this function on conftest.py:

    @pytest.fixture(scope='function')
    def db_session():
        engine = create_engine(config.get_postgres_uri(), echo=False, poolclass=NullPool)
        metadata.create_all(engine)
        pyramid_basemodel.Session = scoped_session(sessionmaker(extension=ZopeTransactionExtension()))
        pyramid_basemodel.bind_engine(
            engine, pyramid_basemodel.Session, should_create=True, should_drop=True)
    
        yield pyramid_basemodel.Session
    
        transaction.commit()
        metadata.drop_all(engine)
    

    After that, I had to place the db_session as a parameter of a test if I wanted to rollback transactions:

    @pytest.mark.usefixtures('postgres_db')
    @pytest.mark.usefixtures('restart_api')
    def test_happy_path_returns_202_and_batch_is_allocated(db_session):
        orderid = random_orderid()
        sku, othersku = random_sku(), random_sku('other')
        earlybatch = random_batchref(1)
        laterbatch = random_batchref(2)
        otherbatch = random_batchref(3)
        api_client.post_to_add_batch(laterbatch, sku, 100, '2011-01-02')
        api_client.post_to_add_batch(earlybatch, sku, 100, '2011-01-01')
        api_client.post_to_add_batch(otherbatch, othersku, 100, None)
    
        r = api_client.post_to_allocate(orderid, sku, qty=3)
        assert r.status_code == 202
    
        r = api_client.get_allocation(orderid)
        assert r.ok
        assert r.json() == [
            {'sku': sku, 'batchref': earlybatch},
        ]
    

    It is possible to check out the requirements for that and other aspects of that implementation on my GitHub repository.

    https://github.com/Santana94/CosmicPythonRollbackTest