Search code examples
if-statementdecoratorpython-decoratorsflask-jwt-extended

How to choose which decorator to apply based on condition?


Can decorators be applied based on condition ? Below is a trivial example:

import os


def bold_decorator(fn):
    def wrapper(*args, **kwargs):
        return '**' + fn(*args, **kwargs) + '**'
    return wrapper

def italic_decorator(fn):
    def wrapper(*args, **kwargs):
        return '_' + fn(*args, **kwargs) + '_'
    return wrapper


if os.environ.get('STYLE'):  
    @bold_decorator
else:
    @italic_decorator
def create_text(text=''):
    return text

if __name__ == '__main__':
    print(create_text('Decorator decisions'))

What I want is, when environment variable is set, to apply bold_decorator. And when it's not set, use italic_decorator. I'm using Flask framework for JWTs which has decorators jwt_required and jwt_optional in which I can't modify source of these decorators. I'm trying to find solution to this problem. Any help would be appreciated


Solution

  • For the specific case of flask-jwt-extended, you could do the if/else logic in your own custom decorator which is then applied to all your view functions. You can view the full docs for that here (https://flask-jwt-extended.readthedocs.io/en/stable/custom_decorators/) but it might look something like this:

    from flask_jwt_extended import (
        verify_jwt_in_request,
        verify_jwt_in_request_optional    
    )
    
    def custom_jwt_required(fn):
        @wraps(fn)
        def wrapper(*args, **kwargs):
            if os.environ.get('ALLOW_OPTIONAL_JWT'):  
                verify_jwt_in_request_optional()
            else:
                verify_jwt_in_request()
            return fn(*args, **kwargs)
        return wrapper
    

    For a more general case, you could do this at the import level. Something like:

    if os.environ.get('STYLE'):
        from decorators import bold_decorator as decorator
    else: 
        from decorators import italic_decorator as decorator
    
    @decorator
    def create_text(test=''):
        return text