Search code examples
pythondjangodecoratordry

How to apply the same decorator chain to multiple functions


@extend_schema(
    methods=['GET'],
    responses={(200, STYLES_MIME_TYPE): OpenApiTypes.BINARY})
@extend_schema(
    methods=['PUT'],
    request={STYLES_MIME_TYPE: OpenApiTypes.BINARY},
    responses={(204, 'application/json'): OpenApiResponse(
        response={'type': 'array', 'items': {'type': 'integer', 'format': 'int32'}},
        examples=[OpenApiExample(
            'Returned style IDs example',
            status_codes=['204'],
            value=[101, 102, 103])])})

@api_view(['GET', 'PUT'])
@permission_classes([IsAuthenticated|ReadOnly])
@renderer_classes([StylesRenderer, StylesJSONRenderer])
@parser_classes([StylesParser])
def styles(request: Request, pid: int) -> Response:
    """
    Get or save styles for a project.
    GET - protobuf binary response
    POST - returnd IDs for saved styles
    """
    try:
        project = Project.objects.get(pk=pid)
        return _handle_styles(request, project)
    except Project.DoesNotExist:
        raise Http404()


@extend_schema(
    methods=['GET'],
    responses={(200, STYLES_MIME_TYPE): OpenApiTypes.BINARY})
@extend_schema(
    methods=['PUT'],
    request={STYLES_MIME_TYPE: OpenApiTypes.BINARY},
    responses={(204, 'application/json'): OpenApiResponse(
        response={'type': 'array', 'items': {'type': 'integer', 'format': 'int32'}},
        examples=[OpenApiExample(
            'Returned style IDs example',
            status_codes=['204'],
            value=[101, 102, 103])])})

@api_view(['GET', 'PUT'])
@permission_classes([IsAuthenticated|ReadOnly])
@renderer_classes([StylesRenderer, StylesJSONRenderer])
@parser_classes([StylesParser])
def styles_xref(request: Request, xref: uuid.UUID) -> Response:
    """
    Get or save styles for a project.
    GET - protobuf binary response
    POST - returnd IDs for saved styles
    """
    try:
        project = Project.objects.get(xref=xref)
        return _handle_styles(request, project)
    except Project.DoesNotExist:
        raise Http404()

This is Django, and obviously I want to use the same decorators for those 2 views. The only difference is that one looks up object by int ID, and the other by UUID xref field. How can I keep this DRY?


Solution

  • You could define a new decorator which returns a pre-decorated function with the chain you want. For example, we can first define three custom decorators:

    import functools
    
    
    # A decorator factory which returns a new decorator.
    def decorator_factory(message):
    
        def decorator(function):
    
            # Wraps the decorated function.
            @functools.wraps(function)
            def wrapper(*args, **kwargs):
    
                # Example behavior:
                #  - Prints a message before calling the decorated function.
                print(message)
    
                # Calls the decorated function.
                return function(*args, **kwargs)
    
            return wrapper
    
        return decorator
    
    
    # Defines three new decorators.
    decorator_1 = decorator_factory("Ham")
    decorator_2 = decorator_factory("Spam")
    decorator_3 = decorator_factory("Eggs")
    

    The way these decorators are presently invoked resembles the following, which quickly becomes repetitive for multiple functions:

    @decorator_1
    @decorator_2
    @decorator_3
    def f():
    
        pass  # Do something.
    
    
    @decorator_1
    @decorator_2
    @decorator_3
    def g():
    
        pass  # Do something.
    
    
    @decorator_1
    @decorator_2
    @decorator_3
    def h():
    
        pass  # Do something.
    

    However, you can decorate a wrapper function within the body of a decorator:

    def decorator_chain(function):
    
        @functools.wraps(function)
        @decorator_1
        @decorator_2
        @decorator_3
        def wrapper(*args, **kwargs):
    
            return function(*args, **kwargs)
    
        return wrapper
    

    Which simplifies the function definitions to:

    @decorator_chain
    def f():
    
        pass  # Do something.
    
    
    @decorator_chain
    def g():
    
        pass  # Do something.
    
    
    @decorator_chain
    def h():
    
        pass  # Do something.
    

    In your provided example, this might look something like the following:

    import functools
    
    
    def decorator_chain(function):
    
        @functools.wraps(function)
        @extend_schema(
            methods   = ['GET'],
            responses = {(200, STYLES_MIME_TYPE): OpenApiTypes.BINARY}
        )
        @extend_schema(
            methods   = ['PUT'],
            request   = {STYLES_MIME_TYPE: OpenApiTypes.BINARY},
            responses = {
                (204, 'application/json'): OpenApiResponse(
                    response = {'type': 'array', 'items': {'type': 'integer', 'format': 'int32'}},
                    examples = [
                        OpenApiExample(
                            'Returned style IDs example',
                            status_codes = ['204'],
                            value        = [101, 102, 103]
                        )
                    ]
                )
            }
        )
        @api_view(['GET', 'PUT'])
        @permission_classes([IsAuthenticated | ReadOnly])
        @renderer_classes([StylesRenderer, StylesJSONRenderer])
        @parser_classes([StylesParser])
        def wrapper(*args, **kwargs):
    
            return function(*args, **kwargs)
    
        return wrapper
    
    
    @decorator_chain
    def styles(request: Request, pid: int) -> Response:
        """
        Get or save styles for a project.
        GET - protobuf binary response
        POST - returnd IDs for saved styles
        """
        try:
            project = Project.objects.get(pk=pid)
            return _handle_styles(request, project)
        except Project.DoesNotExist:
            raise Http404()
    
    
    @decorator_chain
    def styles_xref(request: Request, xref: uuid.UUID) -> Response:
        """
        Get or save styles for a project.
        GET - protobuf binary response
        POST - returnd IDs for saved styles
        """
        try:
            project = Project.objects.get(xref=xref)
            return _handle_styles(request, project)
        except Project.DoesNotExist:
            raise Http404()
    

    Using a decorator factory could even allow you to quickly create different variants of a given chain of decorators.