Search code examples
testingjupyter-notebookmockingpytestpython-unittest

Testing a Jupyter Notebook


I am trying to come up with a method to test a number of Jupyter notebooks. A test should run when a new notebook is implemented in a Github branch and submitted for a pull request. The tests are not that complicated, they are mostly just testing if the notebook runs end-to-end and without any errors, and maybe a few asserts. However:

  • There are certain calls in some cells that need to be mocked, e.g. a call to download the data from a database.
  • There may be some magic cells in the notebooks which run a pip command or something else.

I am open to use any testing library, such as 'pytest' or unittest, although pytest is preferred.

I looked at a few libraries for testing notebooks such as nbmake, treon, and testbook, but I was unable to make them work. I also tried to convert the notebook to a python file, but the magic cells were converted to a get_ipython().run_cell_magic(...) call which became an issue, since pytest uses python and not ipython, and get_ipython() is only available in ipython.

So, I am wondering what is a good way to test jupyter notebooks with all of that in mind. Any help is appreciated.


Solution

  • Here is my own solution using testbook. Let's say I have a notebook called my_notebook.ipynb with the following content:

    enter image description here

    The trick is to inject a cell before my call to bigquery.Client and mock it:

    from testbook import testbook
    
    @testbook('./my_notebook.ipynb')
    def test_get_details(tb):
        tb.inject(
            """
            import mock
            mock_client = mock.MagicMock()
            mock_df = pd.DataFrame()
            mock_df['week'] = range(10)
            mock_df['count'] = 5
            p1 = mock.patch.object(bigquery, 'Client', return_value=mock_client)
            mock_client.query().result().to_dataframe.return_value = mock_df
            p1.start()
            """,
            before=2,
            run=False
        )
        tb.execute()
        dataframe = tb.get('dataframe')
        assert dataframe.shape == (10, 2)
    
        x = tb.get('x')
        assert x == 7