I am writing a unit test, and in my test_model.py
file, I have a class TestPlayerAnomalyDetectionModel
that creates an instance of PlayerAnomalyDetectionModel
, and then calls the .fit
method using a fixture that returns sample data. The .fit
method calls a function get_player_account_status
which sends an API request, so I want to mock this function. I also want the mock_get_player_account_status
to return a different value every time it is called, so I using a side_effect
but I am not sure I am doing this correctly.
If I use the decorator @mock.patch('get_player_labels.get_player_account_status')
, the function isn't mocked when it's called from inside PlayerAnomalyDetectionModel
, and actual API requests are sent.
Based on this answer, I believe I want to use the decorator @mock.patch('model.get_player_account_status')
because get_player_account_status
is called from inside model.py
, but when I run pytest test_model.py
, I can see that get_player_account_status
is getting called instead of mock_get_player_account_status
.
I have replaced some of the actual code with comments for the sake of clarity, but if any more information is needed to help answer my question, please let me know!
Here is the structure of the test_model.py
file:
import pandas as pd
import pytest
import unittest
import mock
from unittest.mock import Mock, MagicMock, patch
from model import PlayerAnomalyDetectionModel
@pytest.fixture(scope="class")
def get_sample_train_data():
## creates sample_train_data DataFrame
return sample_train_data
@pytest.mark.usefixtures("get_sample_train_data", "build_training_data")
class TestPlayerAnomalyDetectionModel(unittest.TestCase):
def setUp(self):
self.model = PlayerAnomalyDetectionModel()
@pytest.fixture(autouse=True)
def build_training_data(self, get_sample_train_data):
self.sample_train_data = get_sample_train_data
print("building training data...")
print(self.sample_train_data)
@mock.patch('model.get_player_account_status')
def test_fit(self, mock_get_player_account_status):
mock_get_player_account_status.side_effect = [
'open', 'open', 'tosViolation', 'tosViolation', 'tosViolation', 'closed',
'open', 'open', 'tosViolation', 'tosViolation', 'tosViolation', 'closed',
]
self.model.fit(self.sample_train_data, generate_plots=False)
def test_predict(self):
pass
def test_save_model(self):
pass
def test_load_model(self):
pass
The model.py
file looks like the following:
from get_player_labels import get_player_account_status
class PlayerAnomalyDetectionModel:
"""
The PlayerAnomalyDetectionModel class returns a model with methods:
.fit to tune the model's internal thresholds on training data
.predict to make predictions on test data
.load_model to load a predefined model from a pkl file
.save_model to save the model to a file
"""
def __init__(self):
self.is_fitted = False
self._thresholds = {
(time_control,'perf_delta_thresholds'): {
f"{rating_bin}-{rating_bin+100}": 0.15
for rating_bin in np.arange(0,4000,100)
}
for time_control in TimeControl.ALL.value
}
self._account_statuses = {} # store account statuses for each model instance
self._ACCOUNT_STATUS_SCORE_MAP = {
"open": 0,
"tosViolation": 1,
"closed": 0.75, # weight closed account as closer to a tosViolation
}
def load_model(self, model_file_name: str):
"""
Loads a model from a file
"""
pass
def fit(self, train_data: pd.DataFrame, generate_plots=True):
if self.is_fitted:
pass
# issue a warning that the user is retraining the model!
# give the user the option to combine multiple training data sets
else:
self._set_thresholds(train_data, generate_plots)
self.is_fitted = True
def _set_thresholds(self, train_data, generate_plots):
while(True):
for player in all_flagged_players:
if self._account_statuses.get(player) is None:
get_player_account_status(player, self._account_statuses)
else:
pass
## set thresholds
def predict(self, test_data: pd.DataFrame):
"""Returns pd.DataFrame of size (m+2, k)
where k = number of flagged games, and m = number of features
"""
if not self.is_fitted:
print("Warning: model is not fitted and will use default thresholds")
# make predictions
return predictions
def save_model(
self,
saved_models_folder = SAVED_MODELS_FOLDER,
model_name: str = "player_anomaly_detection_model"
):
if not os.path.exists(SAVED_MODELS_FOLDER):
os.mkdir(SAVED_MODELS_FOLDER)
with open(f'{SAVED_MODELS_FOLDER}/{BASE_FILE_NAME}_{model_name}.pkl', 'wb') as f:
pickle.dump(self._thresholds, f)
Below is the structure of the function get_player_account_status
inside get_player_labels.py
:
def get_player_account_status(player, account_statuses):
try:
user = makeAPIcall()
if user.get('tosViolation'):
account_statuses[player] = "tosViolation"
elif user.get('disabled'):
account_statuses[player] = "closed"
else:
account_statuses[player] = "open"
except ApiHttpError:
account_statuses[player] = "not found"
This is not ideal and doesn't resolve the question I asked, but I came up with a workaround because the get_player_account_status
function is used to populate a private variable called _account_statuses
in each instance of the model, and this variable is a dictionary containing the status of each player.
Since the get_player_account_status
account is only called when the status of a player is not present in _account_statuses
, I explicitly set this inside the test_fit
function:
def test_fit(self):
## this is a workaround to avoid calling get_player_account_status
self.model._account_statuses = {
'test_player1': 'open',
'test_player2': 'open',
'test_player3': 'tosViolation',
'test_player4': 'tosViolation',
'test_player5': 'tosViolation',
'test_player6': 'closed'
}
expected_thresholds = self.model._thresholds.copy()
expected_thresholds[('blitz','perf_delta_thresholds')]['1500-1600'] = 0.16
expected_thresholds[('bullet','perf_delta_thresholds')]['1600-1700'] = 0.17
self.model.fit(self.sample_train_data, generate_plots=False)
assert expected_thresholds == self.model._thresholds