Search code examples
pythonpython-3.xunit-testingpython-unittest

How to mock global variable in a imported module in python unit test?


i've a peculiar scenario like the below post!

Post

But the difference is my database module also has a global variable that needs to be mocked out!

    """Initialzation"""
from database.db_init import DB
from common.common_util import CommonUtil as util

db_username = util.get_value_from_ssm_parameter_store(
    '/_username')
db_password = util.get_value_from_ssm_parameter_store(
    '/rds/_password')
db_host = DB.get_rds_host()


def create_session():
    """this will create a db session"""
    db = DB(user=db_username, password=db_password,
            host=db_host, database='test')
    Session = db.getSession()
    session = Session()
    return session

I've mocked the get_value_from_ssm_parameter_store() function and even mocked the variables as @patch('database.db_username', 'test'), also tried database.db_username = MagicMock(return_value='test') inside my unittest class!

but still the call is happening to aws. Can anybody help me on mocking the global variables in a imported module?


Solution

  • Since the util.get_value_from_ssm_parameter_store() and DB.get_rds_host() methods executed in module scope of session.py.

    You should patch these methods before import the create_session function from session module.

    E.g.

    common/common_util.py:

    class CommonUtil:
        @staticmethod
        def get_value_from_ssm_parameter_store(key):
            print('call real aws')
    

    database/db_init.py:

    class Session:
        pass
    
    
    class DB:
        @staticmethod
        def get_rds_host():
            return '127.0.0.1'
    
        def __init__(self, user, password, host, database) -> None:
            pass
    
        def getSession(self):
            return Session
    

    session.py:

    from database.db_init import DB
    from common.common_util import CommonUtil as util
    
    
    db_username = util.get_value_from_ssm_parameter_store(
        '/_username')
    db_password = util.get_value_from_ssm_parameter_store(
        '/rds/_password')
    db_host = DB.get_rds_host()
    
    
    def create_session():
        """this will create a db session"""
        db = DB(user=db_username, password=db_password,
                host=db_host, database='test')
        Session = db.getSession()
        session = Session()
        return session
    

    test_session.py:

    import unittest
    from unittest.mock import patch, Mock, call
    from common.common_util import CommonUtil as util
    from database.db_init import DB
    
    
    def get_value_from_ssm_parameter_store_side_effect(key):
        if key == '/_username':
            return 'teresa teng'
        if key == '/rds/_password':
            return '123456'
    
    
    original_get_value_from_ssm_parameter_store = util.get_value_from_ssm_parameter_store
    original_get_rds_host = DB.get_rds_host
    
    
    util.get_value_from_ssm_parameter_store = Mock(side_effect=get_value_from_ssm_parameter_store_side_effect)
    DB.get_rds_host = Mock(return_value='192.168.1.1')
    
    
    class TestSession(unittest.TestCase):
        @patch('session.DB', autospec=True)
        def test_create_session(self, mock_DB):
            from session import create_session
            db_instance = mock_DB.return_value
            mock_session = Mock()
            db_instance.getSession.return_value = mock_session
            create_session()
            mock_DB.assert_called_once_with('teresa teng', '123456', '192.168.1.1', 'test')
            util.get_value_from_ssm_parameter_store.assert_has_calls([call('/_username'), call('/rds/_password')])
            db_instance.getSession.assert_called_once()
    
            # restore mock
            util.get_value_from_ssm_parameter_store = original_get_value_from_ssm_parameter_store
            DB.get_rds_host = original_get_rds_host
    
    
    if __name__ == '__main__':
        unittest.main()
    

    unit test result:

     ⚡  coverage run /Users/dulin/workspace/github.com/mrdulin/python-codelab/src/stackoverflow/64329623/test_session.py && coverage report -m --include='./src/**'
    .
    ----------------------------------------------------------------------
    Ran 1 test in 0.011s
    
    OK
    Name                                               Stmts   Miss  Cover   Missing
    --------------------------------------------------------------------------------
    src/stackoverflow/64329623/common/common_util.py       4      1    75%   4
    src/stackoverflow/64329623/database/db_init.py        10      3    70%   8, 11, 14
    src/stackoverflow/64329623/session.py                 10      0   100%
    src/stackoverflow/64329623/test_session.py            28      0   100%
    --------------------------------------------------------------------------------
    TOTAL                                                 52      4    92%