Search code examples

How do I mock a function called by an instance of a class

I am writing a unit test, and in my file, I have a class TestPlayerAnomalyDetectionModel that creates an instance of PlayerAnomalyDetectionModel, and then calls the .fit method using a fixture that returns sample data. The .fit method calls a function get_player_account_status which sends an API request, so I want to mock this function. I also want the mock_get_player_account_status to return a different value every time it is called, so I using a side_effect but I am not sure I am doing this correctly.

If I use the decorator @mock.patch('get_player_labels.get_player_account_status'), the function isn't mocked when it's called from inside PlayerAnomalyDetectionModel, and actual API requests are sent.

Based on this answer, I believe I want to use the decorator @mock.patch('model.get_player_account_status') because get_player_account_status is called from inside, but when I run pytest, I can see that get_player_account_status is getting called instead of mock_get_player_account_status.

I have replaced some of the actual code with comments for the sake of clarity, but if any more information is needed to help answer my question, please let me know!

Here is the structure of the file:

import pandas as pd
import pytest
import unittest
import mock
from unittest.mock import Mock, MagicMock, patch
from model import PlayerAnomalyDetectionModel

def get_sample_train_data():
    ## creates sample_train_data DataFrame

    return sample_train_data

@pytest.mark.usefixtures("get_sample_train_data", "build_training_data")
class TestPlayerAnomalyDetectionModel(unittest.TestCase):
    def setUp(self):
        self.model = PlayerAnomalyDetectionModel()

    def build_training_data(self, get_sample_train_data):
        self.sample_train_data = get_sample_train_data
        print("building training data...")

    def test_fit(self, mock_get_player_account_status):
        mock_get_player_account_status.side_effect = [
            'open', 'open', 'tosViolation', 'tosViolation', 'tosViolation', 'closed', 
            'open', 'open', 'tosViolation', 'tosViolation', 'tosViolation', 'closed', 
        ], generate_plots=False)

    def test_predict(self):

    def test_save_model(self):

    def test_load_model(self):

The file looks like the following:

from get_player_labels import get_player_account_status

class PlayerAnomalyDetectionModel:
    The PlayerAnomalyDetectionModel class returns a model with methods:
    .fit to tune the model's internal thresholds on training data
    .predict to make predictions on test data
    .load_model to load a predefined model from a pkl file
    .save_model to save the model to a file
    def __init__(self):
        self.is_fitted = False
        self._thresholds = {
            (time_control,'perf_delta_thresholds'): {
                f"{rating_bin}-{rating_bin+100}": 0.15
                for rating_bin in np.arange(0,4000,100)
            for time_control in TimeControl.ALL.value
        self._account_statuses = {} # store account statuses for each model instance
        self._ACCOUNT_STATUS_SCORE_MAP = {
            "open": 0,
            "tosViolation": 1,
            "closed": 0.75, # weight closed account as closer to a tosViolation
    def load_model(self, model_file_name: str):
        Loads a model from a file

    def fit(self, train_data: pd.DataFrame, generate_plots=True):
        if self.is_fitted:
            # issue a warning that the user is retraining the model! 
            # give the user the option to combine multiple training data sets
            self._set_thresholds(train_data, generate_plots)
            self.is_fitted = True

    def _set_thresholds(self, train_data, generate_plots):

                for player in all_flagged_players:
                    if self._account_statuses.get(player) is None:
                        get_player_account_status(player, self._account_statuses)
                ## set thresholds

    def predict(self, test_data: pd.DataFrame):
        """Returns pd.DataFrame of size (m+2, k)
        where k = number of flagged games, and m = number of features 
        if not self.is_fitted:
            print("Warning: model is not fitted and will use default thresholds")
        # make predictions
        return predictions

    def save_model(
        saved_models_folder = SAVED_MODELS_FOLDER, 
        model_name: str = "player_anomaly_detection_model"
        if not os.path.exists(SAVED_MODELS_FOLDER):
        with open(f'{SAVED_MODELS_FOLDER}/{BASE_FILE_NAME}_{model_name}.pkl', 'wb') as f:
            pickle.dump(self._thresholds, f)

Below is the structure of the function get_player_account_status inside

def get_player_account_status(player, account_statuses):
        user = makeAPIcall()
        if user.get('tosViolation'):
            account_statuses[player] = "tosViolation"
        elif user.get('disabled'):
            account_statuses[player] = "closed"
            account_statuses[player] = "open"
    except ApiHttpError: 
        account_statuses[player] = "not found"


  • This is not ideal and doesn't resolve the question I asked, but I came up with a workaround because the get_player_account_status function is used to populate a private variable called _account_statuses in each instance of the model, and this variable is a dictionary containing the status of each player.

    Since the get_player_account_status account is only called when the status of a player is not present in _account_statuses, I explicitly set this inside the test_fit function:

    def test_fit(self):
        ## this is a workaround to avoid calling get_player_account_status
        self.model._account_statuses = {
            'test_player1': 'open',
            'test_player2': 'open',
            'test_player3': 'tosViolation',
            'test_player4': 'tosViolation',
            'test_player5': 'tosViolation',
            'test_player6': 'closed'
        expected_thresholds = self.model._thresholds.copy()
        expected_thresholds[('blitz','perf_delta_thresholds')]['1500-1600'] = 0.16
        expected_thresholds[('bullet','perf_delta_thresholds')]['1600-1700'] = 0.17, generate_plots=False)
        assert expected_thresholds == self.model._thresholds