Search code examples
pythondecoratorwrapperpython-decorators

Function wrappers inside class using decorators


I have a class which interacts with a database and so there are repetitive actions (establish session, commit, close session) before and after each member method of the class.

As follows:

class UserDatabaseManager(object):

    DEFAULT_DB_PATH = 'test.db'

    def __init__(self, dbpath=DEFAULT_DB_PATH):
        dbpath = 'sqlite:///' + dbpath
        self.engine = create_engine(dbpath, echo=True)

    def add_user(self, username, password):
        Session = sessionmaker(bind=self.engine)
        session = Session()
        # <============================== To be wrapped
        user = User(username, password)
        session.add(user)
        # ==============================>
        session.commit()
        session.close()

    def delete_user(self, user):
        Session = sessionmaker(bind=self.engine)
        session = Session()
        # <============================== To be wrapped
        # Delete user here
        # ==============================>
        session.commit()
        session.close()

What is an idiomatic way to abstract out the repeated session calls with a function wrapper?

I would prefer to do this with decorators by declaring a private _Decorators class inside UserDatabaseManager and implementing the wrapper function inside there, but then such class won't be able to access the self.engine instance attribute of the outer class.


Solution

  • A simple (and in my opinion, the most idiomatic) way of doing this is to wrap the setup/teardown boilerplate code in a context manager using contextlib.contextmanager. You then simply use a with statement in the functions that do the work (rather than trying to wrap that function itself).

    For example:

    from contextlib import contextmanager
    
    class UserDatabaseManager(object):
    
        DEFAULT_DB_PATH = 'test.db'
    
        def __init__(self, dbpath=DEFAULT_DB_PATH):
            dbpath = 'sqlite:///' + dbpath
            self.engine = create_engine(dbpath, echo=True)
    
        @contextmanager
        def session(self):
            try:
                Session = sessionmaker(bind=self.engine)
                session = Session()
                yield session
                session.commit()
            except:
                session.rollback()
            finally:
                session.close()
    
        def add_user(self, username, password):
            with self.session() as session:
                user = User(username, password)
                session.add(user)
    
        def delete_user(self, user):
            with self.session() as session:
                session.delete(user)