I have a FastAPI + SQLAlchemy project and I'm using Pytest for writing unit tests for the APIs.
In each test function, I create some data in some tables (user table, post table, comment table, etc) using SQLAlchemy. These created data in each test function will remain in the tables after test function finished and will affect on other test functions.
For example, in the first test function I create 3 posts, and 2 users, then in the second test functions, these 3 posts and 2 users remained on the tables and makes my test expectations wrong.
Following is my fixture for pytest:
@pytest.fixture
def session(engine):
Session = sessionmaker(bind=engine)
session = Session()
yield session
session.rollback() # Removes data created in each test method
session.close() # Close the session after each test
I used session.rollback()
to remove all created data during session, but it doesn't remove data.
And the following is my test functions:
class TestAllPosts(PostBaseTestCase):
def create_logged_in_user(self, db):
user = self.create_user(db)
return user.generate_tokens()["access"]
def test_can_api_return_all_posts_without_query_parameters(self, client, session):
posts_count = 5
user_token = self.create_logged_in_user(session)
for i in range(posts_count):
self.create_post(session)
response = client.get(url, headers={"Authorization": f"Bearer {user_token}"})
assert response.status_code == 200
json_response = response.json()
assert len(json_response) == posts_count
def test_can_api_detect_there_is_no_post(self, client, session):
user_token = self.create_logged_in_user(session)
response = client.get(url, headers={"Authorization": f"Bearer {user_token}"})
assert response.status_code == 404
In the latest test function, instead of getting 404, I get 200 with 5 posts (from the last test function)
How can I remove the created data in each test function after test function finished?
The problem is that there are multiple sessions.
One is used by your tests. The other one(s) is/are used by the server.
Because you are using client.get
, you are sending a request to the server, which will use its own database session.
@pytest.fixture
def session(engine):
Session = sessionmaker(bind=engine)
session = Session()
yield session
# Remove any data from database (even data not created by this session)
with contextlib.closing(engine.connect()) as connection:
transaction = connection.begin()
connection.execute(f'TRUNCATE TABLE {",".join(table.name for table in reversed(Base.metadata.sorted_tables)} RESTART IDENTITY CASCADE;'))
transaction.commit()
session.rollback() # Removes data created in each test method
session.close() # Close the session after each test
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db