Search code examples
pythonunit-testingsqlalchemypytestcontextmanager

How to test operations in a context manager using pytest


I have a database handler that utilizes SQLAlchemy ORM to communicate with a database. As part of SQLAlchemy's recommended practices, I interact with the session by using it as a context manager. How can I test what a function called inside the context manager using that context manager has done?

EDIT: I realized the file structure mattered due to the complexity in introduced. I re-structured the code below to more closely mirror what the end file structure will be like, and what a common production repo in my environment would look like, with code being defined in one file and tests in a completely separate file.

For example:

Code File (delete_things_from_table.py):

from db_handler import delete, SomeTable


def delete_stuff(handler):
    stmt = delete(SomeTable)
    with handler.Session.begin() as session:
        session.execute(stmt)
        session.commit()

Test File:

import pytest
import delete_things_from_table as dlt
from db_handler import Handler

def test_delete_stuff():
    handler = db_handler()
    dlt.delete_stuff(handler):

    # Test that session.execute was called
    # Test the value of 'stmt'
    # Test that session.commit was called

I am not looking for a solution specific to SQLAlchemy; I am only utilizing this to highlight what I want to test within a context manager, and any strategies for testing context managers are welcome.


Solution

  • After sleeping on it, I came up with a solution. I'd love additional/less complex solutions if there are any available, but this works:

    import pytest
    import delete_things_from_table as dlt
    from db_handler import Handler
    
    class MockSession:
        def __init__(self):
            self.execute_params = []
            self.commit_called = False
    
        def execute(self, *args, **kwargs):
            self.execute_params.append(["call", args, kwargs])
            return self
    
        def commit(self):
            self.commit_called = True
            return self
        
        def begin(self):
            return self
    
        def __enter__(self):
            return self
        def __exit__(self, type, value, traceback):
            pass
    
    
    def test_delete_stuff(monkeypatch):
        handler = db_handler()
        # Parens in 'MockSession' below are Important, pass an instance not the class
        monkeypatch.setattr(handler, Session, MockSession()) 
        dlt.delete_stuff(handler):
    
        # Test that session.execute was called
        assert len(handler.Session.execute_params)
        # Test the value of 'stmt'
        assert str(handler.Session.execute_params[0][1][0]) == "DELETE FROM some_table"
        # Test that session.commit was called
        assert handler.Session.commit_called
    

    Some key things to note:

    1. I created a static mock instead of a MagicMock as it's easier to control the methods/data flow with a custom mock class
    2. Since the SQLAlchemy session context manager requires a begin() to start the context, my mock class needed a begin. Returning self in begin allows us to test the values later.
    3. context managers rely on on the magic methods __enter__ and __exit__ with the argument signatures you see above.
    4. The mocked class contains mocked methods which alter instance variables allowing us to test later
    5. This relies on monkeypatch (there are other ways I'm sure), but what's important to note is that when you pass your mock class you want to patch in an instance of the class and not the class itself. The parentheses make a world of difference.

    I don't think it's an elegant solution, but it's working. I'll happily take any suggestions for improvement.