Search code examples
pytestraypytest-mock

Using pytest patch decorator to test ray actors remote function


I'm trying to run unit test for a ray remote function. I am using a @patch decorator to patch the remote function. The

foo.py

class Foo(object):
    def __init__(self):
        self.value = 0

    def bar(self):
        self.value = 100
        print("In original method")
        assert False

test_foo.py

from unittest.mock import patch

import pytest
import unittest
import ray

from tests.foo import Foo


@pytest.fixture
def ray_fixture():
    print("Initializing ray")
    if not ray.is_initialized():
        ray.init()
    yield None
    print("Terminating ray")
    ray.shutdown()


def fake_bar(self):
    print("In fake method")
    assert True


@pytest.mark.usefixtures("ray_fixture")
class FooTestCase(unittest.TestCase):
    """Test cases for Foo module"""

    @patch("foo.Foo.bar", new=fake_bar)
    def test_bar(self):
        Foo().bar()

    @patch("foo.Foo.bar", new=fake_bar)
    def test_bar_remote(self):
        foo_actor = ray.remote(Foo).remote()
        obj_ref = foo_actor.bar.remote()
        ray.get(obj_ref)

The test test_bar passes and test_bar_remote fails. If I use ray.init(local_mode=True) then both tests pass. I can not use local_mode=True due to other limitations.

How can we patch ray actor's remote method using @patch?


Solution

  • Here's an alternative. Subclass Foo with a stubbed/mocked implementation and use it in ray. That way, the Foo class would be intact, you would only update those that needs to be mocked e.g. the method bar().

    test_foo.py

    ...
    class FooStub(Foo):
        def bar(self, *args, **kwargs):
            print("In another fake method")
            assert True
    
            # Optionally, you can also call the real method if you want. You may update the arguments as needed.
            # super().bar(*args, **kwargs)
    
    @pytest.mark.usefixtures("ray_fixture")
    class FooTestCase(unittest.TestCase):
        ...
        def test_bar_remote(self):
            foo_actor = ray.remote(FooStub).remote()
            obj_ref = foo_actor.bar.remote()
            ray.get(obj_ref)
    ...
    

    Output

    $ pytest -q -rP
    ..
    ================================================================================================= PASSES ==================================================================================================
    __________________________________________________________________________________________ FooTestCase.test_bar ___________________________________________________________________________________________
    ------------------------------------------------------------------------------------------ Captured stdout setup ------------------------------------------------------------------------------------------
    Initializing ray
    ------------------------------------------------------------------------------------------ Captured stdout call -------------------------------------------------------------------------------------------
    In fake method
    ---------------------------------------------------------------------------------------- Captured stdout teardown -----------------------------------------------------------------------------------------
    Terminating ray
    _______________________________________________________________________________________ FooTestCase.test_bar_remote _______________________________________________________________________________________
    ------------------------------------------------------------------------------------------ Captured stdout setup ------------------------------------------------------------------------------------------
    Initializing ray
    ---------------------------------------------------------------------------------------- Captured stdout teardown -----------------------------------------------------------------------------------------
    Terminating ray
    2 passed, 1 warning in 5.03s