Search code examples
pythondecoratorpython-decorators

Get function name when ContextDecorator is used as a decorator


I have the following context manager and decorator to time any given function or code block:

import time
from contextlib import ContextDecorator


class timer(ContextDecorator):
    def __init__(self, label: str):
        self.label = label

    def __enter__(self):
        self.start_time = time.perf_counter()
        return self

    def __exit__(self, *exc):
        net_time = time.perf_counter() - self.start_time
        print(f"{self.label} took {net_time:.1f} seconds")
        return False

You can use it as a context manager:

with timer("my code block"):
    time.sleep(2)

# my code block took 2.0 seconds

You can also use it as a decorator:

@timer("my_func")
def my_func():
    time.sleep(3)

my_func()

# my_func took 3.0 seconds

The only thing I don't like is having to manually pass the function name as the label when it's used as a decorator. I would love for the decorator to automatically use the function name if no label is passed:

@timer()
def my_func():
    time.sleep(3)

my_func()

# my_func took 3.0 seconds

Is there any way to do this?


Solution

  • If you also override the __call__() method inherited from the ContextDecorator base class in your timer class, and add a unique default value to the initializer for the label argument, you can check for that and grab the function's __name__ when it's called:

    import time
    from contextlib import ContextDecorator
    
    
    class timer(ContextDecorator):
        def __init__(self, label: str=None):
            self.label = label
    
        def __call__(self, func):
            if self.label is None:  # Label was not provided
                self.label = func.__name__  # Use function's name.
            return super().__call__(func)
    
        def __enter__(self):
            self.start_time = time.perf_counter()
            return self
    
        def __exit__(self, *exc):
            net_time = time.perf_counter() - self.start_time
            print(f"{self.label} took {net_time:.1f} seconds")
            return False
    
    
    @timer()
    def my_func():
        time.sleep(3)
    
    my_func()  # -> my_func took 3.0 seconds