Search code examples
pythonunit-testingpython-unittestdatabase-testing

Unit testing a function that depends on database


I am running tests on some functions. I have a function that uses database queries. So, I have gone through the blogs and docs that say we have to make an in memory or test database to use such functions. Below is my function,

def already_exists(story_data,c):
    # TODO(salmanhaseeb): Implement de-dupe functionality by checking if it already
    # exists in the DB.
    c.execute("""SELECT COUNT(*) from posts where post_id = ?""", (story_data.post_id,))
    (number_of_rows,)=c.fetchone()
    if number_of_rows > 0:
        return True
    return False

This function hits the production database. My question is that, when in testing, I create an in memory database and populate my values there, I will be querying that database (test DB). But I want to test my already_exists() function, after calling my already_exists function from test, my production db will be hit. How do I make my test DB hit while testing this function?


Solution

  • The issue is ensuring that your code consistently uses the same database connection. Then you can set it once to whatever is appropriate for the current environment.

    Rather than passing the database connection around from method to method, it might make more sense to make it a singleton.

    def already_exists(story_data):
        # Here `connection` is a singleton which returns the database connection.
        connection.execute("""SELECT COUNT(*) from posts where post_id = ?""", (story_data.post_id,))
        (number_of_rows,) = connection.fetchone()
        if number_of_rows > 0:
            return True
        return False
    

    Or make connection a method on each class and turn already_exists into a method. It should probably be a method regardless.

    def already_exists(self):
        # Here the connection is associated with the object.
        self.connection.execute("""SELECT COUNT(*) from posts where post_id = ?""", (self.post_id,))
        (number_of_rows,) = self.connection.fetchone()
        if number_of_rows > 0:
            return True
        return False
    

    But really you shouldn't be rolling this code yourself. Instead you should use an ORM such as SQLAlchemy which takes care of basic queries and connection management like this for you. It has a single connection, the "session".

    from sqlalchemy import create_engine
    from sqlalchemy.orm import sessionmaker
    
    from sqlalchemy_declarative import Address, Base, Person
    
    engine = create_engine('sqlite:///sqlalchemy_example.db')
    Base.metadata.bind = engine
    
    DBSession = sessionmaker(bind=engine)
    session = DBSession()
    

    Then you use that to make queries. For example, it has an exists method.

    session.query(Post.id).filter(q.exists()).scalar()
    

    Using an ORM will greatly simplify your code. Here's a short tutorial for the basics, and a longer and more complete tutorial.