Search code examples
pythonunit-testingpytestsqlobject

Mocking sqlobject function call for test db


I am trying to mock sqlbuilder.func for test cases with pytest

I successfully mocked sqlbuilder.func.TO_BASE64 with correct output but when I tried mocking sqlbuilder.func.FROM_UNIXTIME I didn't get any error but the resulted output is incorrect with the generated query. Below is the minimal working example of the problem.

models.py

from sqlobject import (
    sqlbuilder,
    sqlhub,
    SQLObject,
    StringCol,
    BLOBCol,
    TimestampCol,
)

class Store(SQLObject):
    name = StringCol()
    sample = BLOBCol()
    createdAt = TimestampCol()

DATE_FORMAT = "%Y-%m-%d"
def retrieve(name):
    query = sqlbuilder.Select([
            sqlbuilder.func.TO_BASE64(Store.q.sample),
        ],
        sqlbuilder.AND(
            Store.q.name == name,
            sqlbuilder.func.FROM_UNIXTIME(Store.q.createdAt, DATE_FORMAT) >= sqlbuilder.func.FROM_UNIXTIME("2018-10-12", DATE_FORMAT)
        )
    )

    connection = sqlhub.getConnection()
    query = connection.sqlrepr(query)
    print(query)
    queryResult = connection.queryAll(query)
    return queryResult

conftest.py

import pytest

from models import Store
from sqlobject import sqlhub
from sqlobject.sqlite import sqliteconnection

@pytest.fixture(autouse=True, scope="session")
def sqlite_db_session(tmpdir_factory):
    file = tmpdir_factory.mktemp("db").join("sqlite.db")
    conn = sqliteconnection.SQLiteConnection(str(file))
    sqlhub.processConnection = conn
    init_tables()
    yield conn
    conn.close()

def init_tables():
    Store.createTable(ifNotExists=True)

test_ex1.py

import pytest

from sqlobject import sqlbuilder
from models import retrieve

try:
    import mock
    from mock import MagicMock
except ImportError:
    from unittest import mock
    from unittest.mock import MagicMock

def TO_BASE64(x):
    return x

def FROM_UNIXTIME(x, y):
    return 'strftime("%Y%m%d", datetime({},"unixepoch", "localtime"))'.format(x)

# @mock.patch("sqlobject.sqlbuilder.func.TO_BASE64")
# @mock.patch("sqlobject.sqlbuilder.func.TO_BASE64", MagicMock(side_effect=lambda x: x))
# @mock.patch("sqlobject.sqlbuilder.func.TO_BASE64", new_callable=MagicMock(side_effect=lambda x: x))
@mock.patch("sqlobject.sqlbuilder.func.TO_BASE64", TO_BASE64)
@mock.patch("sqlobject.sqlbuilder.func.FROM_UNIXTIME", FROM_UNIXTIME)
def test_retrieve():
    result = retrieve('Some')
    assert result == []

Current SQL:

SELECT store.sample FROM store WHERE (((store.name) = ('Some')) AND (1))

Expected SQL:

SELECT
  store.sample
FROM 
  store
WHERE
  store.name = 'Some'
AND
  strftime(
    '%Y%m%d',
    datetime(store.created_at, 'unixepoch', 'localtime')
  ) >= strftime(
    '%Y%m%d',
    datetime('2018-10-12', 'unixepoch', 'localtime')
  )

Edit Example

#! /usr/bin/env python

from sqlobject import *

__connection__ = "sqlite:/:memory:?debug=1&debugOutput=1"

try:
    import mock
    from mock import MagicMock
except ImportError:
    from unittest import mock
    from unittest.mock import MagicMock

class Store(SQLObject):
    name = StringCol()
    sample = BLOBCol()
    createdAt = TimestampCol()

Store.createTable()

DATE_FORMAT = "%Y-%m-%d"
def retrieve(name):
    query = sqlbuilder.Select([
            sqlbuilder.func.TO_BASE64(Store.q.sample),
        ],
        sqlbuilder.AND(
            Store.q.name == name,
            sqlbuilder.func.FROM_UNIXTIME(Store.q.createdAt, DATE_FORMAT) >= sqlbuilder.func.FROM_UNIXTIME("2018-10-12", DATE_FORMAT)
        )
    )

    connection = Store._connection
    query = connection.sqlrepr(query)
    queryResult = connection.queryAll(query)
    return queryResult


def TO_BASE64(x):
    return x

def FROM_UNIXTIME(x, y):
    return 'strftime("%Y%m%d", datetime({},"unixepoch", "localtime"))'.format(x)

for p in [
    mock.patch("sqlobject.sqlbuilder.func.TO_BASE64",TO_BASE64),
    mock.patch("sqlobject.sqlbuilder.func.FROM_UNIXTIME",FROM_UNIXTIME),
]:
    p.start()

retrieve('Some')

mock.patch.stopall()

Solution

  • By default, sqlbuilder.func is an SQLExpression that passes its attribute (sqlbuilder.func.datetime, e.g.) to the SQL backend as a constant (sqlbuilder.func actually is an alias for sqlbuilder.ConstantSpace). See the docs about SQLExpression, the FAQ and the code for func.

    When you mock an attribute in func namespace it's evaluated by SQLObject and passed to the backend in reduced form. If you want to return a string literal from the mocking function you need to tell SQLObject it's a value that has to be passed to the backend as is, unevaluated. The way to do it is to wrap the literal in SQLConstant like this:

    def FROM_UNIXTIME(x, y):
        return sqlbuilder.SQLConstant('strftime("%Y%m%d", datetime({},"unixepoch", "localtime"))'.format(x))
    

    See SQLConstant.

    The entire test script now looks this

    #! /usr/bin/env python3.7
    
    from sqlobject import *
    
    __connection__ = "sqlite:/:memory:?debug=1&debugOutput=1"
    
    try:
        import mock
        from mock import MagicMock
    except ImportError:
        from unittest import mock
        from unittest.mock import MagicMock
    
    class Store(SQLObject):
        name = StringCol()
        sample = BLOBCol()
        createdAt = TimestampCol()
    
    Store.createTable()
    
    DATE_FORMAT = "%Y-%m-%d"
    def retrieve(name):
        query = sqlbuilder.Select([
                sqlbuilder.func.TO_BASE64(Store.q.sample),
            ],
            sqlbuilder.AND(
                Store.q.name == name,
                sqlbuilder.func.FROM_UNIXTIME(Store.q.createdAt, DATE_FORMAT) >= sqlbuilder.func.FROM_UNIXTIME("2018-10-12", DATE_FORMAT)
            )
        )
    
        connection = Store._connection
        query = connection.sqlrepr(query)
        queryResult = connection.queryAll(query)
        return queryResult
    
    
    def TO_BASE64(x):
        return x
    
    def FROM_UNIXTIME(x, y):
        return sqlbuilder.SQLConstant('strftime("%Y%m%d", datetime({},"unixepoch", "localtime"))'.format(x))
    
    for p in [
        mock.patch("sqlobject.sqlbuilder.func.TO_BASE64",TO_BASE64),
        mock.patch("sqlobject.sqlbuilder.func.FROM_UNIXTIME",FROM_UNIXTIME),
    ]:
        p.start()
    
    retrieve('Some')
    
    mock.patch.stopall()
    

    The output is:

     1/Query   :  CREATE TABLE store (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        name TEXT,
        sample TEXT,
        created_at TIMESTAMP
    )
     1/QueryR  :  CREATE TABLE store (
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        name TEXT,
        sample TEXT,
        created_at TIMESTAMP
    )
     2/QueryAll:  SELECT store.sample FROM store WHERE (((store.name) = ('Some')) AND ((strftime("%Y%m%d", datetime(store.created_at,"unixepoch", "localtime"))) >= (strftime("%Y%m%d", datetime(2018-10-12,"unixepoch", "localtime")))))
     2/QueryR  :  SELECT store.sample FROM store WHERE (((store.name) = ('Some')) AND ((strftime("%Y%m%d", datetime(store.created_at,"unixepoch", "localtime"))) >= (strftime("%Y%m%d", datetime(2018-10-12,"unixepoch", "localtime")))))
     2/QueryAll-> []
    

    PS. Full disclosure: I'm the current maintainer of SQLObject.