Search code examples
pythonpython-3.xunit-testingmockingpsycopg2

Mock psycopg2 database insertion in Python


My unit test does not mock the database insertion because I do see in the logs and in the database that the record id 42 had been inserted in the database, see logs:

./tests/test_requesthandler.py::TestRequestHandler::test_handle_post_coordinates Failed: [undefined]AssertionError: {'id': 42, 'latitude': 12.9, 'longitude': 77.6} != [{'latitude': 12.9, 'longitude': 77.6}]

How can I mock the database insertion in my unit test?

Here is my unit test:

import unittest
from unittest.mock import patch, Mock, MagicMock

from http.server import BaseHTTPRequestHandler
import json
from src.requesthandler import RequestHandler    # The code to test

class TestRequestHandler(unittest.TestCase):
    def setUp(self):
        self.handler = RequestHandler()

    @patch('src.requesthandler.psycopg2.connect')
    def test_handle_post_coordinates(self, mock_connect):
        print(json.dumps({"latitude": "12.9", "longitude": "77.6"}).encode('utf-8'))
        expected =  [{'latitude': 12.9, 'longitude': 77.6}]
        # This will disable the database connection
        # mock_connect.return_value.cursor.return_value.execute.return_value = None
        mock_con = mock_connect.return_value  # result of psycopg2.connect(**connection_stuff)
        mock_cur = mock_con.cursor.return_value  # result of con.cursor(cursor_factory=DictCursor)
        mock_cur.execute.return_value = expected  # return this when calling cur.fetchall()
        mock_cur.fetchone.return_value = expected  # return this when calling cur.fetchall()
        mock_con.commit.return_value = expected  # return this when calling cur.fetchall()
        environ = {
            'CONTENT_LENGTH': '23',
            'REQUEST_METHOD': 'POST',
            'PATH_INFO': '/coordinates',
            'wsgi.input': Mock(read=Mock(return_value=json.dumps({'latitude': 12.9, 'longitude': 77.6}).encode('utf-8')))
        }
        start_response = Mock()

        response = self.handler.handle_post_coordinates(environ, start_response)

        self.assertEqual(json.loads(response[0].decode().replace("'", '"')), [{'latitude': 12.9, 'longitude': 77.6}])
        start_response.assert_called_with('200 OK', [('Content-type', 'text/plain')])

    def test_handle_get(self):
        environ = {
            'REQUEST_METHOD': 'GET',
            'PATH_INFO': '/coordinates',
        }
        start_response = Mock()

        response = self.handler.handle_get(environ, start_response)

        self.assertEqual(json.loads(response[0].decode()), {'mssg': 'werkt123'})
        start_response.assert_called_with('200 OK', [('Content-type', 'application/json')])

if __name__ == '__main__':
    unittest.main()

And here is my code:

import json
from http.server import BaseHTTPRequestHandler, HTTPServer
import psycopg2
from json import dumps
from waitress import serve
import logging

class RequestHandler(BaseHTTPRequestHandler):

    # the constructor is called "__init__"for convenience
    def __init__(self):
        self.coordinates = []
        print('qwe')
        # Connect to the PostgreSQL database
        self.conn = psycopg2.connect(
            host="localhost",
            database="postgres",
            user="postgres",
            password="admin"
        )

        # Create a cursor object
        self.cursor = self.conn.cursor()


    def _send_response(self, message, status=200):
        self.send_response(status)
        self.send_header("Content-type", "application/json")
        self.end_headers()
        self.wfile.write(bytes(json.dumps(message), "utf8"))

    def handle_post_coordinates(self, environ, start_response):
        content_length = int(environ.get('CONTENT_LENGTH', 0))
        request_body = environ['wsgi.input'].read(content_length).decode()

        coordinates = json.loads(request_body)
        self.coordinates.append(coordinates)

        self.cursor.execute("INSERT INTO coordinates (latitude, longitude) VALUES (%s, %s) RETURNING id", (coordinates['latitude'], coordinates['longitude']))
        new_coordinate_id = self.cursor.fetchone()[0]
        self.conn.commit()
        new_coordinate = {'id': new_coordinate_id, 'latitude': coordinates['latitude'], 'longitude': coordinates['longitude']}

        status = '200 OK'
        headers = [('Content-type', 'text/plain')]
        start_response(status, headers)
        # return [b"Coordinates added"]
        return [bytes(str(new_coordinate), 'utf-8')]

    def handle_get(self, environ, start_response):
        if environ['PATH_INFO'] == '/coordinates':
            # self.cursor.execute("SELECT * FROM coordinates")
            # coordinates = self.cursor.fetchall()
            # coordinates = [{'id': c[0], 'latitude': c[1], 'longitude': c[2]} for c in coordinates]
            # response = dumps(coordinates)
            response = dumps({'mssg' : 'werkt123'})
            # response = {'mssg' : 'haha3'}
        elif environ['PATH_INFO'].startswith('/coordinates/'):
            coordinate_id = int(environ['PATH_INFO'].split('/')[-1])
            self.cursor.execute("SELECT * FROM coordinates WHERE id = %s", (coordinate_id,))
            coordinate = self.cursor.fetchone()
            if coordinate:
                coordinate = {'id': coordinate[0], 'latitude': coordinate[1], 'longitude': coordinate[2]}
                response = dumps(coordinate)
            else:
                response = dumps({'error': 'Coordinate not found'})
                start_response("404 Not Found", [('Content-type', 'application/json')])
                return [response.encode()]
        else:
            response = dumps({'error': 'Invalid endpoint'})
            start_response("404 Not Found", [('Content-type', 'application/json')])
            return [response.encode()]

        start_response("200 OK", [('Content-type', 'application/json')])
        return [response.encode()]

    def application(self, environ, start_response):
        try: 
            path = environ.get('PATH_INFO', '').lstrip('/')
            if path == 'coordinates':
                if environ['REQUEST_METHOD'] == 'GET':
                    return self.handle_get(environ, start_response)
                elif environ['REQUEST_METHOD'] == 'POST':
                    return self.handle_post_coordinates(environ, start_response)
                    # return self.do_POST(environ, start_response)
                # elif environ['REQUEST_METHOD'] == 'PUT':
                #     return handle_put_coordinates(environ, start_response)
                # elif environ['REQUEST_METHOD'] == 'DELETE':
                #     return handle_delete_coordinates(environ, start_response)
            else:
                start_response("404 Not Found", [('Content-type', 'text/plain')])
                return [b"Not Found"]
        except Exception as e:
            print("Error:", str(e))
            start_response("500 Internal Server Error", [])
            return [b"Error: " + str(e).encode()]

Solution

  • I have managed to fix it like this:

    from unittest import TestCase, mock
    from unittest.mock import patch, Mock, MagicMock
    
    from http.server import BaseHTTPRequestHandler
    import json
    from src.requesthandler import RequestHandler    # The code to test
    
    class TestRequestHandler(TestCase):
        @mock.patch('src.requesthandler.psycopg2.connect')
        def setUp(self, mock_connect):
            self.handler = RequestHandler()
            self.mock_connect = mock_connect
    
        # @patch('src.requesthandler.psycopg2')
    
        def test_handle_post_coordinates(self):
            expected =  [{'latitude': 12.9, 'longitude': 77.6}]
            mock_con = self.mock_connect.return_value
            mock_cur = mock_con.cursor.return_value
            mock_cur.execute.return_value = expected
            mock_cur.fetchone.return_value = expected
            mock_con.commit.return_value = expected
            environ = {
                'CONTENT_LENGTH': '23',
                'REQUEST_METHOD': 'POST',
                'PATH_INFO': '/coordinates',
                'wsgi.input': Mock(read=Mock(return_value=json.dumps({'latitude': 12.9, 'longitude': 77.6}).encode('utf-8')))
            }
            start_response = Mock()
    
            response = self.handler.handle_post_coordinates(environ, start_response)
            print(json.loads(response[0].decode().replace("'", '"')))
    
            print(response[0].decode())
            print([{'id': {'latitude': 12.9, 'longitude': 77.6}, 'latitude': 12.9, 'longitude': 77.6}])
    
            self.assertEqual(json.loads(response[0].decode().replace("'", '"')), {'id': {'latitude': 12.9, 'longitude': 77.6}, 'latitude': 12.9, 'longitude': 77.6})
            start_response.assert_called_with('200 OK', [('Content-type', 'text/plain')])
    
        def test_handle_get(self):
            environ = {
                'REQUEST_METHOD': 'GET',
                'PATH_INFO': '/coordinates',
            }
            start_response = Mock()
    
            response = self.handler.handle_get(environ, start_response)
    
            self.assertEqual(json.loads(response[0].decode()), {'mssg': 'werkt123'})
            start_response.assert_called_with('200 OK', [('Content-type', 'application/json')])
    
    if __name__ == '__main__':
        unittest.main()