Search code examples
pythonmockingmonkeypatching

Trouble patching a function of a function


I'm pretty new to patching and I've run into a something I don't know how to patch. Basically, in the file I want to test, there is the method difficult_method(). It looks a little like this:

from import_location import User

def difficult_method():
  ids = list_of_ids
  for id in list_of_ids:
    try:
      user = User.query.filter(User.id == user_id).all()[0]
    except:
      continue
    #do lots of stuff

The code I want to mock is User.query.filter(User.id == user_id).all() and as far as I am concerned it can return a static list. How would I replace that line in code that looks something like this:

from mock import patch

@patch(#what would go here?)
def test_difficult_method(): 
  from file_to_test import difficult_method
  assert difficult_method() returns ...

Solution

  • I figured it out! The key was to create a MockUser class, like so:

    user = #creating a user
    
    class MockFilter(object):
      def all(self):
        return [user]
    
    
    class MockQuery(object):
    
      def filter(self, match):
        return MockFilter()
    
    
    class MockUser(object):
      query = MockQuery()
      id = '2'
    

    Then I patched it in like so:

    from mock import patch
    
    @patch('import_location.User', MockUser)
    def test_difficult_method(): 
      from file_to_test import difficult_method
      assert difficult_method() returns ...