Search code examples
pythonunit-testingpytestpython-unittest.mockpytest-mock

How do I use pytest and unittest.mock to mock interactions with a paramiko.SSHClient


I am trying to write a unit test for code that interacts with an SFTP server using the paramiko library. The code under test receives a list of remote file locations and a callback. Each file is fetched and sent into the callback. The test shall simulate a scenario, where the caller sends two files to visit and one of the files fails with an IOError. I want to make sure that the failing file is excluded from the response.

Here is the code.py:

import io
from typing import Callable, List
import typing
import paramiko


def visit_files(files: List[str], callback: Callable[[typing.BinaryIO], None]) -> List[str]:
    response = []
    with paramiko.SSHClient() as ssh:
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())        
        ssh.connect("test.rebex.net", port=22, username="demo", password="password")

        with ssh.open_sftp() as sftp:
            for file_name in files:
                try:
                    with sftp.open(file_name, "rb") as f:
                        try:
                            b = f.read()
                            callback(io.BytesIO(b))
                            response.append(file_name)
                        except ValueError:
                            print("Something went wrong")
                except IOError:
                    print("Unknown IO error")
    return response

And my test_code.py:

import typing
from unittest.mock import Mock
from pytest_mock import MockerFixture

from src.utils.code import visit_files

def test_visiting(mocker: MockerFixture):
        mock = mocker.patch('paramiko.SSHClient')

        ssh_client_mock = mock.return_value
        ssh_client_mock.connect.return_value = Mock()

        sftp_mock = ssh_client_mock.open_sftp.return_value
        sftp_mock.open.side_effect = [
            Mock(read=Mock(return_value=b'Hello, World!')),  # Mock for the first file
            IOError("Unable to open file"),  # Simulate IOError for the second file
        ]

        def print_size(b: typing.BinaryIO) -> None:
                print(b.tell())

        response = visit_files(files=["file1.txt", "file2.txt"], callback=print_size)
        assert response == ["file1.txt"]

The error I am receiving is: TypeError: a bytes-like object is required, not 'MagicMock' in line callback(io.BytesIO(b)). I can't figure out where my mocks are not set up properly.


Solution

  • I would do something like this, starting by refactoring visit_files so you can inject the client dependency

    import io
    from typing import Callable, List
    import typing
    import paramiko
    
    
    def visit_files(files: List[str], callback: Callable[[typing.BinaryIO], None], client: paramiko.SSHClient = None) -> List[str]:
        response = []
        if not client:
            client = paramiko.SSHClient
        with client() as ssh:
            ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())        
            ssh.connect("test.rebex.net", port=22, username="demo", password="password")
    
            with ssh.open_sftp() as sftp:
                for file_name in files:
                    try:
                        with sftp.open(file_name, "rb") as f:
                            try:
                                b = f.read()
                                callback(io.BytesIO(b))
                                response.append(file_name)
                            except ValueError:
                                print("Something went wrong")
                    except IOError:
                        print("Unknown IO error")
        return response
    

    and then, create SSHClientMock taking into account the context managers:

    import typing
    from unittest.mock import MagicMock
    import paramiko
    from contextlib import contextmanager
    import io
    
    from src.utils.code import visit_files
    
    
    class SSHClientMock(MagicMock):
        def __init__(self, **kwargs):
            super().__init__(spec=paramiko.SSHClient, **kwargs)
    
        def __enter__(self):
            return self
    
        def __exit__(self):
            pass
    
        @contextmanager
        def open_sftp(self):
            def _open_sftp_mock_open(filename, mode):
                if filename == "file1.txt":
                    return io.BytesIO(b"Hello World!")
                elif filename == "file2.txt":
                    raise IOError("Unable to open file")
                assert False
    
            open_sftp_mock = MagicMock()
            open_sftp_mock.open = _open_sftp_mock_open
            yield open_sftp_mock
    
    
    def test_visiting():
            def print_size(b: typing.BinaryIO) -> None:
                    print(b.tell())
    
            response = visit_files(files=["file1.txt", "file2.txt"], callback=print_size, client=SSHClientMock())
            assert response == ["file1.txt"]