i've a peculiar scenario like the below post!
But the difference is my database module also has a global variable that needs to be mocked out!
"""Initialzation"""
from database.db_init import DB
from common.common_util import CommonUtil as util
db_username = util.get_value_from_ssm_parameter_store(
'/_username')
db_password = util.get_value_from_ssm_parameter_store(
'/rds/_password')
db_host = DB.get_rds_host()
def create_session():
"""this will create a db session"""
db = DB(user=db_username, password=db_password,
host=db_host, database='test')
Session = db.getSession()
session = Session()
return session
I've mocked the get_value_from_ssm_parameter_store() function and even mocked the variables as @patch('database.db_username', 'test'), also tried database.db_username = MagicMock(return_value='test') inside my unittest class!
but still the call is happening to aws. Can anybody help me on mocking the global variables in a imported module?
Since the util.get_value_from_ssm_parameter_store()
and DB.get_rds_host()
methods executed in module scope of session.py
.
You should patch these methods before import the create_session
function from session
module.
E.g.
common/common_util.py
:
class CommonUtil:
@staticmethod
def get_value_from_ssm_parameter_store(key):
print('call real aws')
database/db_init.py
:
class Session:
pass
class DB:
@staticmethod
def get_rds_host():
return '127.0.0.1'
def __init__(self, user, password, host, database) -> None:
pass
def getSession(self):
return Session
session.py
:
from database.db_init import DB
from common.common_util import CommonUtil as util
db_username = util.get_value_from_ssm_parameter_store(
'/_username')
db_password = util.get_value_from_ssm_parameter_store(
'/rds/_password')
db_host = DB.get_rds_host()
def create_session():
"""this will create a db session"""
db = DB(user=db_username, password=db_password,
host=db_host, database='test')
Session = db.getSession()
session = Session()
return session
test_session.py
:
import unittest
from unittest.mock import patch, Mock, call
from common.common_util import CommonUtil as util
from database.db_init import DB
def get_value_from_ssm_parameter_store_side_effect(key):
if key == '/_username':
return 'teresa teng'
if key == '/rds/_password':
return '123456'
original_get_value_from_ssm_parameter_store = util.get_value_from_ssm_parameter_store
original_get_rds_host = DB.get_rds_host
util.get_value_from_ssm_parameter_store = Mock(side_effect=get_value_from_ssm_parameter_store_side_effect)
DB.get_rds_host = Mock(return_value='192.168.1.1')
class TestSession(unittest.TestCase):
@patch('session.DB', autospec=True)
def test_create_session(self, mock_DB):
from session import create_session
db_instance = mock_DB.return_value
mock_session = Mock()
db_instance.getSession.return_value = mock_session
create_session()
mock_DB.assert_called_once_with('teresa teng', '123456', '192.168.1.1', 'test')
util.get_value_from_ssm_parameter_store.assert_has_calls([call('/_username'), call('/rds/_password')])
db_instance.getSession.assert_called_once()
# restore mock
util.get_value_from_ssm_parameter_store = original_get_value_from_ssm_parameter_store
DB.get_rds_host = original_get_rds_host
if __name__ == '__main__':
unittest.main()
unit test result:
⚡ coverage run /Users/dulin/workspace/github.com/mrdulin/python-codelab/src/stackoverflow/64329623/test_session.py && coverage report -m --include='./src/**'
.
----------------------------------------------------------------------
Ran 1 test in 0.011s
OK
Name Stmts Miss Cover Missing
--------------------------------------------------------------------------------
src/stackoverflow/64329623/common/common_util.py 4 1 75% 4
src/stackoverflow/64329623/database/db_init.py 10 3 70% 8, 11, 14
src/stackoverflow/64329623/session.py 10 0 100%
src/stackoverflow/64329623/test_session.py 28 0 100%
--------------------------------------------------------------------------------
TOTAL 52 4 92%