Search code examples
pythonsqlalchemypytestfastapi

Pytest- How to remove created data after each test function


I have a FastAPI + SQLAlchemy project and I'm using Pytest for writing unit tests for the APIs.

In each test function, I create some data in some tables (user table, post table, comment table, etc) using SQLAlchemy. These created data in each test function will remain in the tables after test function finished and will affect on other test functions.

For example, in the first test function I create 3 posts, and 2 users, then in the second test functions, these 3 posts and 2 users remained on the tables and makes my test expectations wrong.

Following is my fixture for pytest:

@pytest.fixture
def session(engine):
    Session = sessionmaker(bind=engine)
    session = Session()
    yield session
    session.rollback()  # Removes data created in each test method
    session.close()  # Close the session after each test

I used session.rollback() to remove all created data during session, but it doesn't remove data.

And the following is my test functions:

class TestAllPosts(PostBaseTestCase):

    def create_logged_in_user(self, db):
        user = self.create_user(db)
        return user.generate_tokens()["access"]

    def test_can_api_return_all_posts_without_query_parameters(self, client, session):
        posts_count = 5
        user_token = self.create_logged_in_user(session)
        for i in range(posts_count):
            self.create_post(session)

        response = client.get(url, headers={"Authorization": f"Bearer {user_token}"})
        assert response.status_code == 200
        json_response = response.json()
        assert len(json_response) == posts_count

    def test_can_api_detect_there_is_no_post(self, client, session):
        user_token = self.create_logged_in_user(session)
        response = client.get(url, headers={"Authorization": f"Bearer {user_token}"})
        assert response.status_code == 404

In the latest test function, instead of getting 404, I get 200 with 5 posts (from the last test function)

How can I remove the created data in each test function after test function finished?


Solution

  • The problem is that there are multiple sessions.

    One is used by your tests. The other one(s) is/are used by the server.

    Because you are using client.get, you are sending a request to the server, which will use its own database session.

    1. To solve your problem you can just truncate all tables at the end of each test: https://stackoverflow.com/a/25220958/5521670
    @pytest.fixture
    def session(engine):
        Session = sessionmaker(bind=engine)
        session = Session()
        yield session
    
        # Remove any data from database (even data not created by this session)
        with contextlib.closing(engine.connect()) as connection:
            transaction = connection.begin()
            connection.execute(f'TRUNCATE TABLE {",".join(table.name for table in reversed(Base.metadata.sorted_tables)} RESTART IDENTITY CASCADE;'))
            transaction.commit()
    
        session.rollback()  # Removes data created in each test method
        session.close()  # Close the session after each test
    
    1. Another alternative would be to make the server use your test session (just like the FastAPI documentation suggests): https://fastapi.tiangolo.com/advanced/testing-database/
    def override_get_db():
        try:
            db = TestingSessionLocal()
            yield db
        finally:
            db.close()
    
    
    app.dependency_overrides[get_db] = override_get_db