Search code examples
pythonflaskpytest

How can I mock the return value for a method in a class that uses a context manager


I have a database module database.py with a Database class in it like the following:

from someDatabase import theClient

class DatabaseClient:

    def __init__(self):
        self.connection = None

    def __enter__(self):
        self.connection = self.database_connection()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self.connection:
            self.connection.close()

    def database_connection(self):
        client = theClient(<connection params>)
        return client

    def database_query(self, table: str, query: str):
        response = self.connection.search(
            body = query,
            table = table
        )
        return response

Then in my flask app I have a route:

from app.utils.Database import Database
from app.queries import thequery
from flask import Blueprint, jsonify

api = Blueprint('the_api', __name__)


@api.route("/api/this/route", methods=["GET"])
def get_some_stuff(**kwargs):
    try:

        # get input params from request

        with Database() as db:
            response = db.database_query("some_table", thequery)

        # transform the response object

        return jsonify(response), 200
    except Exception as e:
        raise e

What I'm trying to figure out is how can I patch the DatabaseClient, or more specifically, how do i patch the database_query() method to return some sample_data I have.

I'd like to set the method return value as something and then call the route in the test like the following. I'm trying to just use pytest only and not mix and match with unittest if that's possible?

import pytest

from app import app

from tests.sample_data.api_responses.get_some_stuff import (
    get_some_stuff_api_response
)

from tests.sample_data.query_responses.get_some_stuff import (
    get_some_stuff_query_response
)


@pytest.fixture
def client():
    with app.test_client() as client:
        yield client


def test_get_some_stuff(client):


    # What is the correct patch here to set the return value of Database.database_query?
    # patch so it's equal to +get_some_stuff_query_response+

    # Call the route
    response = client.get(
        f"/api/this/route",
    )

    assert response.status_code == 200
    data = response.json
    assert data == get_some_stuff_api_response

Solution

  • You need to create a mock context manager that returns what you want.

    Here's a simple example that uses monkeypatch to do the same for file open context manager:

    import builtins
    
    
    def unit_under_test():
        with open("file.txt", "r") as f:
            return f.read()
    
    
    class MockOpen:
        def __init__(self, *args, **kwargs):
            pass
    
        def __enter__(self):
            return self
    
        def __exit__(self, *args):
            pass
    
        def read(self):
            return "Hello, World!"
    
    
    def test_unit_under_test(monkeypatch):
        monkeypatch.setattr(builtins, "open", MockOpen)
    
        assert unit_under_test() == "Hello, World!"
    

    MockOpen implements the stuff needed for a context manager (__enter__ and __exit__).

    So, this translated to your example would be something like:

    import app.utils.Database
    ...
    
    class MockDatabase:
        def __init__(self, *args, **kwargs):
            pass
    
        def __enter__(self):
            return self
    
        def __exit__(self, *args):
            pass
    
        def database_query(self, table, query):
            return "Whatever the response needs to be"
    
    
    def test_unit_under_test(monkeypatch):
        monkeypatch.setattr(app.utils.Database, "Database", MockDatabase)
    
        assert response.status_code == 200
        data = response.json
        assert data == get_some_stuff_api_response