Search code examples
pytestfastapisqlmodel

Unit Tests in FastAPI


I have a backend app developed with FastAPI, using SQLModel (SQLAlchemy & Pydantic) and connected to a Postgres database. I have integration tests to test if my endpoints are working fine with a stagging PG DB. But right now I have to write units tests and I don't know how to proceed to test my endpoints and the functions called in a isolated way.


Here is a really simplified version of my project:

The architecture of my project: (Consider that there is an __ init__.py file in each folder)

app/
├── api/
│   ├── core/
│   │   ├── config.py   #get the env settings and distribute it to the app
│   │   ├── .env
│   ├── crud/
│   │   ├── items.py    #the CRUD functions called by the router
│   ├── db/
│   │   ├── session.py  #the get_session function handling the db engine
│   ├── models/
│   │   ├── items.py    #the SQLModel object def as is in the db
│   ├── routers/
│   │   ├── items.py    #the routing system
│   ├── schemas/
│   │   ├── items.py    #the python object def as it is used in the app
│   ├── main.py         #the main app
├── tests/              #the pytest tests
│   ├── unit_tests/
│   ├── integration_tests/
│   │   ├── test_items.py

In the crud/items.py:

from fastapi.encoders import jsonable_encoder
from sqlmodel import Session, select
from api.models import Item
from api.schemas import ItemCreate


def get_item(db_session: Session, item_id: int) -> Item:
    query = select(Item).where(Item.id == item_id)
    return db_session.exec(query).first()


def create_new_item(db_session: Session, *, obj_input: ItemCreate) -> Item:
    obj_in_data = jsonable_encoder(obj_input)
    db_obj = Item(**obj_in_data)
    db_session.add(db_obj)
    db_session.commit()
    db_session.refresh(db_obj)
    return db_obj

In the db/session.py:

from sqlalchemy.engine import Engine
from sqlmodel import create_engine, Session
from api.core.config import settings

engine: Engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, pool_pre_ping=True)


def get_session() -> Session:
    with Session(engine) as session:
        yield session

In the models/items.py:

from sqlmodel import SQLModel, Field, MetaData

meta = MetaData(schema="pouetpouet")  # https://github.com/tiangolo/sqlmodel/issues/20


class Item(SQLModel, table=True):
    __tablename__ = "cities"
    # __table_args__ = {"schema": "pouetpouet"}
    metadata = meta

    id: int = Field(primary_key=True, default=None)
    city_name: str

In the routers/items.py:

from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session
from api.crud import get_item, create_new_item
from api.db.session import get_session
from api.models import Item
from api.schemas import ItemRead, ItemCreate

router = APIRouter(prefix="/api/items", tags=["Items"])


@router.get("/{item_id}", response_model=ItemRead)
def read_item(
    *,
    db_session: Session = Depends(get_session),
    item_id: int,
) -> Item:
    item = get_item(db_session=db_session, item_id=item_id)
    if not item:
        raise HTTPException(status_code=404, detail="Item not found")
    return item


@router.post("/", response_model=ItemRead)
def create_item(
    *,
    db_session: Session = Depends(get_session),
    item_input: ItemCreate,
) -> Item:
    item = create_new_item(db_session=db_session, obj_input=item_input)
    return item

In the schemas/items.py:

from typing import Optional
from sqlmodel import SQLModel


class ItemBase(SQLModel):
    city_name: Optional[str] = None


class ItemCreate(ItemBase):
    pass

class ItemRead(ItemBase):
    id: int
    class Config:
        orm_mode: True

In the tests/integration_tests/test_items.py:

from fastapi.testclient import TestClient
from api.main import app

client = TestClient(app)

def test_create_item() -> None:
    data = {"city_name": "Las Vegas"}
    response = client.post("/api/items/", json=data)
    assert response.status_code == 200
    content = response.json()
    assert content["city_name"] == data["city_name"]
    assert "id" in content

The point here is that I feel stuck with the db_session: Session argument used in all the functions in the crud/items.py and the routers/items.py because I think it is mandatory to get a valid session of a valid postgres connexion for the tests.


ps: not being very experienced in backend development, do not hesitate to bring constructive remarks about my code if you notice something strange. It will be very well received.


Solution

  • As I mentioned in my comments to your previous question, I think unit tests should not deal with databases of any kind, but instead mock all the functions interacting with databases.

    Suggested approach

    Here is an example for how I would write a unit test for your CRUD function create_new_item.

    First a simplified setup for demo purposes:

    from sqlmodel import Field, SQLModel, Session
    
    
    class ItemCreate(SQLModel):
        x: str
        y: float
    
    
    class Item(ItemCreate, table=True):
        id: int | None = Field(primary_key=True)
    
    
    def create_new_item(db_session: Session, *, obj_input: ItemCreate) -> Item:
        db_obj = Item.from_orm(obj_input)
        db_session.add(db_obj)
        db_session.commit()
        db_session.refresh(db_obj)
        return db_obj
    

    The test case:

    from unittest import TestCase
    from unittest.mock import call, create_autospec
    
    from sqlmodel import Session
    
    # ... import Item, ItemCreate
    
    
    class MyTestCase(TestCase):
        def test_create_new_item(self) -> None:
            test_item = ItemCreate(x="foo", y=3.14)
            mock_session = create_autospec(Session, instance=True)
    
            expected_output = Item.from_orm(test_item)
            expected_session_calls = [
                call.add(expected_output),
                call.commit(),
                call.refresh(expected_output),
            ]
    
            output = create_new_item(mock_session, obj_input=test_item)
            self.assertEqual(expected_output, output)
            self.assertListEqual(expected_session_calls, mock_session.mock_calls)
    

    Notice that the test never connects to any database. We must assume that the SQLAlchemy functions (and by extension those of SQLModel) work as advertised. Otherwise we should not be using them in the first place. Our unit tests need to ensure that we call those functions correctly.


    Even greater isolation

    Technically, the way we are testing this function now, we are still relying on other code we wrote to function a certain way, namely the models. This is a grey area in my opinion because we technically did not write any methods for those classes, but rely on existing methods of the base classes.

    But you could argue that the model definitions themselves should be irrelevant to the create_new_item function. In other words, a unit test for it succeeding or failing should not depend on whether our model definitions for Item and ItemCreate match in a way that allows one to be constructed from the other.

    So a test with the strictest isolation could look like this:

    from unittest import TestCase
    from unittest.mock import MagicMock, call, create_autospec, patch
    
    from sqlmodel import Session
    
    # ... import Item
    
    
    class MyTestCase(TestCase):
        @patch.object(Item, "from_orm")
        def test_create_new_item2(self, mock_from_orm: MagicMock) -> None:
            test_item = MagicMock()
            mock_session = create_autospec(Session, instance=True)
    
            expected_output = mock_from_orm.return_value = object()
            expected_session_calls = [
                call.add(expected_output),
                call.commit(),
                call.refresh(expected_output),
            ]
    
            output = create_new_item(mock_session, obj_input=test_item)
            self.assertEqual(expected_output, output)
            self.assertListEqual(expected_session_calls, mock_session.mock_calls)
    

    Less isolation: A step towards integration

    If you want to have a separate type of test that actually does perform database queries to check if the results are as expected, you could inject a different session object bound to a SQLite in-memory database engine.

    from unittest import TestCase
    
    from sqlmodel import SQLModel, Session, create_engine, select
    
    # ... import Item, ItemCreate
    
    
    class MyTestCase(TestCase):
        def test_create_new_item3(self) -> None:
            engine = create_engine("sqlite:///")
            SQLModel.metadata.create_all(engine)
    
            test_item = ItemCreate(x="foo", y=3.14)
            expected_output = Item(**test_item.dict(), id=1)
    
            with Session(engine) as session:
                output = create_new_item(session, obj_input=test_item)
                self.assertEqual(expected_output, output)
                result = session.exec(select(Item)).all()
                self.assertListEqual([expected_output], result)
    

    Of course this will not work, if your models have database-specific settings incompatible with SQLite (such as a Postgres schema).

    In that case you have various options as discussed in your previous question. You could for example actually create a separate testing database and write some isolation/cleanup logic around your tests. Then supply an engine connected to that testing database instead.

    Or you could try to cleverly monkey-patch your way around those specific settings.