Search code examples
pythonunit-testingmockingpython-unittestpython-mock

Mocking out two redis hgets with different return values in the same python function


I have some code like this:

import redis

redis_db = redis.Redis(host=redis_host_ip, port=redis_port, password=redis_auth_password)

def mygroovyfunction():
    var_a = redis_db.hget(user, 'something_a')
    var_b = redis_db.hget(user, 'something_b')
    if var_a == something_a:
        return Response(json.dumps({}), status=200, mimetype='application/json')
    if var_b == something_b:
        return Response(json.dumps({}), status=400, mimetype='application/json')

And then in a tests.py file for unit testing this, I have some code like this:

import unittest
from mock import MagicMock, Mock, patch

@patch('redis.StrictRedis.hget', return_value='some value')
class MyGroovyFunctionTests(unittest.TestCase):
    def test_success(self, mock_redis_hget):
        response = self.app.get('mygroovyfunction')
        self.assertEqual(response.status_code, 200)

So there's some other flask stuff which I left out because it is not relevant for this question.

What I wanted to know was if it is possible to mock a return value for each individual redis hget. With my current code, the mocked return value replaces var_b, and so when the test case runs, it makes var_a also the same value, causing the code to go down the path ending in a return status_code of 400.

What is the proper way to do this kind of thing?


Solution

  • Ok I found the answer here: Python mock multiple return values

    The accepted answer for that question is what I was looking for, which is to use side_effect and make it a list of values and so each patched redis hget will be given each value in the list. So for my example the solution is to do this:

    import unittest
    from mock import MagicMock, Mock, patch
    
    @patch('redis.StrictRedis.hget', side_effect=['something_a','something_b'])
    class MyGroovyFunctionTests(unittest.TestCase):
        def test_success(self, mock_redis_hget):
            response = self.app.get('mygroovyfunction')
            self.assertEqual(response.status_code, 200)
    

    Thanks https://stackoverflow.com/users/100297/martijn-pieters for the answer I was looking for :)