Search code examples
pythonflaskdecorator

Passing "self" to a decorator in a class


I am trying to imitate the behavior of flask's decorator, like so:

app = Flask(__name__)

@app.route("/")
...

However, I cannot seem to pass "self". Here's an example:


class A:
    def __init__(self):
        pass

    def very_important_function(self, text):
        # processing goes here
        
    def decorator(self, func, text):
        def inner():
            if not self.very_important_function(text):
                return "Fail!"
            func()
        return inner
 
a = A()

@a.decorator("text")
def check():
    return True

check()

Running check() seems to output the following error:

A.decorator.<locals>.inner() takes 0 positional arguments but 1 was given

Quick research shows that I should use functools.wraps() to decorate inner() like so:

    def decorator(self, func, text):
        @functools.wraps(func)
        def inner():
            if not self.very_important_function(text):
                return "Fail!"
            func()
        return inner

This time, I get the error:

A.decorator() missing 1 required positional argument: 'text'

This is better, but it shows that it is likely interpreting the "text" as the self argument!

At this point, I jumped into flask's source code to try and find the code for Flask().route(). However, it does not seem to be anywhere in the Flask class.

After doing some research, I stumbled across bound methods. Is there a way to bind a decorator function inside a class to the class, and how do I access this bound object?

P.S. I wish to exactly mimic the app.route() decorator, such that the user does not have to pass in the object via arguments. Also, the decorator must be available as a function of the class, like app.route() is called via app so decorators outside the class are not viable unless it can be "injected" into the class.

Edit:

import time
import flask
import datetime
import psycopg2
from uuid import uuid4
from errors import Auth
from hashlib import sha256
from functools import wraps


class AuthenticationManager:
    def __init__(self, conn, uauth_response: flask.Response = None):
        self.db_conn = conn
        self.uauth_response = uauth_response

    def _sha256hash(self, data: str) -> str:
        return sha256(data.encode("UTF-8")).hexdigest()

    def create_session(self, uid: str, response: flask.Response) -> flask.Response:
        cur = self.db_conn.cursor()
        token = uuid4()
        expire_date = datetime.datetime.now()
        expire_date = expire_date + datetime.timedelta(days=30)
        cur.execute(
            "INSERT INTO sessions (id, token, expiry) VALUES (%s, %s, %s)",
            (uid, token, int(expire_date)),
        )
        response.set_cookie("auth", token, expires=expire_date)
        return response

    def register(self, username: str, password: str, groups: list[str]) -> None:
        cur = self.db_conn.cursor()
        phash = self._sha256hash(password)
        cur.execute("SELECT * FROM users WHERE username=%s", (username,))
        if cur.fetchone():
            raise Auth.Registration.UserAlreadyExists
        uid = uuid4()
        cur.execute(
            "INSERT INTO users (id, username, hash, groups) VALUES (%s, %s, %s, %s)",
            (uid, username, password, groups),
        )

    def login(
        self, username: str, password: str, response: flask.Response
    ) -> flask.Response:
        cur = self.db_conn.cursor()
        phash = self._sha256hash(password)
        cur.execute(
            "SELECT * FROM users WHERE username=%s AND hash=%s", (username, phash)
        )
        l = cur.fetchone()
        if l == None:
            raise Auth.Login.AuthenticationFailure

        cur.execute("SELECT * FROM sessions WHERE id=%s", (l[0],))
        if (l2 := cur.fetchone()) != None:
            if l2[2] < int(time.time()):
                cur.execute("DELETE FROM sessions WHERE id=%s", (l[0],))

        return self.create_session(l[0], response)

    def check_auth(self) -> bool:
        cur = self.db_conn.cursor()
        if type(token := flask.request.cookies.get("auth")) != str:
            return False

        cur.execute("SELECT * FROM sessions WHERE token=%s", (token,))
        l = cur.fetchone()
        if l == None:
            return False
        if l[2] < int(time.time()):
            return False

        return True

    def get_groups(self) -> list[str] | None:
        cur = self.db_conn.cursor()
        if type(token := flask.request.cookies.get("auth")) != str:
            return None
        uid = cur.execute(
            "SELECT id FROM sessions WHERE token=%s", (token,)
        ).fetchone()[0]
        groups = cur.execute("SELECT groups FROM users WHERE id=%s", (uid,)).fetchone()[
            0
        ]
        return groups

    def login_required(self, func, uauth_response: flask.Response = None, groups: list[str] = None):
        def decorator(func):
            def inner(*args, **kwargs):
                if not self.uauth_response or not uauth_response:
                    uauth_response = flask.Response(
                        "You are not authorised to access this resource.", status=401
                    )

                if not uauth_response and self.uauth_response:
                    uauth_response = self.uauth_response

                if not self.check_auth():
                    return uauth_response
                if groups:
                    user_groups = self.get_groups()
                    for group in groups:
                        if group not in user_groups:
                            return uauth_response
                func(*args, **kwargs)

            return inner
        return decorator

This is my attempt to implement the decorator factory. However, it still seems like func is being interpreted as self, as I get the following error:

TypeError: AuthenticationManager.login_required() missing 1 required positional argument: 'func'


Solution

  • Decorators in Python are functions that accept a function and return a function. You want them created by another function, so you need another level of nesting, so that a.decorator("text") still returns a function that works as a valid decorator:

    class A:
        def __init__(self):
            pass
    
        def very_important_function(self, text):
            return text
        
        def decorator_factory(self, text):
            def decorator(func):
                def inner(*args, **kwargs):
                    if not self.very_important_function(text):
                        return "Fail!"
                    func(*args, **kwargs)
                return inner
            return decorator
    

    This way calling A().decorator_factory("foo") gives us a decorator we can use to decorate another function, already containing all the parameters we wanted to use for this decorator.

    a = A()
    
    @a.decorator_factory("text")
    def check():
        return True
    
    check()
    

    When working with decorators, it's very helpful to realise that

    @decorator
    def foo():
        return "bar"
    

    is pretty much equivalent to this, much more standard-looking syntax:

    def foo():
        return "bar"
    foo = decorator(foo)