Search code examples
pythonoverloadingdecoratortyping

Can I use a python decorator to preprocess input and postprocess output based on an input type?


I have functions that can take either a List or a str. The return type should match the input type. E.g., if given a str the function should return a str.

This is a toy example to illustrate the point:

def swap_first_and_last(a: Union[List, str]) -> Union[List, str]:
    # Not in-place.
    STR = isinstance(a, str)
    a = list(a)

    a[0], a[-1] = a[-1], a[0]

    return "".join(a) if STR else a

This is an actual example:

def next_permutation(a: Union[List, str]) -> Union[List, str]:
    """
    Not in-place.
    Returns `None` if there is no next permutation
    (i.e. if `a` is the last permutation).
    The type of the output is the same as the type of the input
    (i.e. str input -> str output).
    """
    STR = isinstance(a, str)
    a = list(a)

    N = len(a)
    i = next((i for i in reversed(range(N-1)) if a[i] < a[i + 1]), None)
    if i is None:
        return None
    j = next(j for j in reversed(range(i+1, N)) if a[j] >= a[i])

    a[i], a[j] = a[j], a[i]
    a[i + 1:] = reversed(a[i + 1:])

    return "".join(a) if STR else a

As you can see, only a few lines are dedicated to handling str input vs List input, namely:

    # preprocess
    STR = isinstance(a, str)
    a = list(a)
    
    # main logic
    ...
    
    # postprocess
    return "".join(a) if STR else a

Can I use a decorator to do this slight preprocessing and postprocessing?


Solution

  • Yes, you can use a decorator like this:

    from typing import Union, List
    
    def pre_and_post_processing(func):
        def inner(a: Union[List, str]) -> Union[List, str]:
            STR = isinstance(a, str)
            a = list(a)
            b = func(a)
            return "".join(b) if STR else b
        return inner   
    
    
    @pre_and_post_processing
    def swap_first_and_last(a: List) -> List:
        a[0], a[-1] = a[-1], a[0]
        return a
    
    
    print(swap_first_and_last("asd"))  # -> dsa
    print(swap_first_and_last(["asd", "ds", "sds"]))  # -> ['sds', 'ds', 'asd']
    

    Note that the swap_first_and_last function now get and return a List.