Search code examples
pythonsqlalchemyoptimistic-locking

How to test SQLAlchemy versioning in unit tests - Python


Note: Using flask_sqlalchemy here

I'm working on adding versioning to multiple services on the same DB. To make sure it works, I'm adding unit tests that confirm I get an error (for this case my error should be StaleDataError). For other services in other languages, I pulled the same object twice from the DB, updated one instance, saved it, updated the other instance, then tried to save that as well.

However, because SQLAlchemy adds a fake-cache layer between the DB and the service, when I update the first object it automatically updates the other object I hold in memory. Does anyone have a way around this? I created a second session (that solution had worked in other languages) but SQLAlchemy knows not to hold the same object in two sessions.

I was able to manually test it by putting time.sleep() halfway through the test and manually changing data in the DB, but I'd like a way to test this using just the unit code.

Example code:

def test_optimistic_locking(self):
        c = Customer(formal_name='John', id=1)
        db.session.add(c)
        db.session.flush()
        cust = Customer.query.filter_by(id=1).first()
        db.session.expire(cust)
        same_cust = Customer.query.filter_by(id=1).first()
        db.session.expire(same_cust)
        same_cust.formal_name = 'Tim'
        db.session.add(same_cust)
        db.session.flush()
        db.session.expire(same_cust)
        cust.formal_name = 'Jon'
        db.session.add(cust)
        with self.assertRaises(StaleDataError): db.session.flush()
        db.session.rollback()

Solution

  • It actually is possible, you need to create two separate sessions. See the unit test of SQLAlchemy itself for inspiration. Here's a code snippet of one of our unit tests written with pytest:

    def test_article__versioning(connection, db_session: Session):
        article = ProductSheetFactory(title="Old Title", version=1)
        db_session.refresh(article)
        assert article.version == 1
    
        db_session2 = Session(bind=connection)
        article2 = db_session2.query(ProductSheet).get(article.id)
        assert article2.version == 1
    
        article.title = "New Title"
        article.version += 1
        db_session.commit()
        assert article.version == 2
    
        with pytest.raises(sqlalchemy.orm.exc.StaleDataError):
            article2.title = "Yet another title"
            assert article2.version == 1
            article2.version += 1
            db_session2.commit()
    

    Hope that helps. Note that we use "version_id_generator": False in the model, that's why we increment the version ourselves. See the docs for details.