Search code examples
pythoncelerypython-multiprocessing

Celery: Spawn "sidecar" webserver process


I'm trying to collect metrics from my Celery workers, which seemed simply enough, but turns out to be utterly, ridiculously hard. After lots of approaches, I'm now trying to spawn an additional process next to the Celery worker/supervisor that hosts a simple HTTP server to expose Prometheus metrics.
To make this work, I need to spawn a process using the multiprocessing module, so the Celery task workers and the metrics server can use the same, in-memory Prometheus registry. In theory, this would be as simple as:

# app/celery_worker.py

from prometheus_client import start_http_server, REGISTRY

def start_server():
    start_http_server(port=9010, registry=REGISTRY)

if __name__ == "__main__":
    metric_server = Process(target=start_server, daemon=True)
    metric_server.start()

Alas, the worker is started using the Celery module:

python -m celery --app "app.celery_worker" worker

So my worker is never the main module. How can I spawn a process in the Celery worker?


Solution

  • I finally found a solution, making use of Celery Signals, basically hooks into the Celery lifecycle.

    Hooking up Celery

    As we wanted to spawn a separate thread once the worker is ready, we can use the worker_init signal:

    worker_init
    Dispatched before the worker is started.

    To do so, add the signal decorator to a hook function in your application's main Celery module:

    # For posterity
    from multiprocessing import Process
    from celery import Celery
    from .metrics_server import start_wsgi_server
    
    app = Celery("appName")
    
    # ...
    
    _metric_server_process: Process
    
    
    @worker_init.connect
    def start_metrics_server(**kwargs):  # noqa: ARG001
        # We need to keep the process in global state, so we can stop it later on
        global _metric_server_process  # noqa: PLW0603
    
        _metric_server_process = Process(target=start_wsgi_server)
        _metric_server_process.daemon = True
        _metric_server_process.start()
    

    What we do here is spawn a new daemon process with the server function, and bind its handle to a global variable (so we can access it later, see below).

    Shutting the server down with Celery

    To be able to kill the server if the main process stops, we can also attach to the worker_shutdown signal. This makes use of the global variable defined previously:

    @worker_shutdown.connect
    def stop_metrics_server(**kwargs):  # noqa: ARG001
        from prometheus_client import multiprocess
    
        multiprocess.mark_process_dead(_metric_server_process.pid)
        _metric_server_process.join(3)
    

    The metrics web server

    The server itself looks like the following. This code is copied verbatim from the Prometheus client library; the only change being that we don't want the server thread in daemon mode:

    from socket import AddressFamily, getaddrinfo
    from threading import Thread
    from wsgiref.simple_server import WSGIRequestHandler, make_server
    
    from prometheus_client import CollectorRegistry
    from prometheus_client.exposition import ThreadingWSGIServer, make_wsgi_app
    from prometheus_client.multiprocess import MultiProcessCollector
    
    
    def start_wsgi_server(port: int = 9010, addr: str = "0.0.0.0") -> None:
        class TmpServer(ThreadingWSGIServer):
            """
            Copy of ThreadingWSGIServer to update address_family locally.
            """
    
        registry = CollectorRegistry()
        MultiProcessCollector(registry)
        TmpServer.address_family, addr = _get_best_family(addr, port)
        app = make_wsgi_app(registry)
        httpd = make_server(addr, port, app, TmpServer, handler_class=_SilentHandler)
        thread = Thread(target=httpd.serve_forever)
        thread.start()
    
    
    def _get_best_family(address: str, port: int) -> tuple[AddressFamily, str]:
        infos = getaddrinfo(address, port)
        family, _, _, _, socket_address = next(iter(infos))
    
        return family, socket_address[0]
    
    
    class _SilentHandler(WSGIRequestHandler):
        def log_message(self, format, *args):  # noqa: A002
            """Log nothing."""
    

    Having the metrics server defined this way, you should be able to access http://localhost:9010/metrics when starting up a Celery worker, albeit no metrics are written yet. Hooray!

    Configuring Prometheus

    To use Prometheus metrics, you'll need to prepare it to run in multiprocess mode, that is, Prometheus will share its metrics in-memory between multiple processes, so: Exactly what we want. In our case, the Celery worker processes (or threads, depending on your configuration) will stash their recorded metrics in memory, and the web server process (running on the same node) will read and expose them to the Prometheus crawler.

    Running in multiprocess mode comes with some caveats, but nothing too severe. Follow the client documentation to set this up.

    Collecting metrics

    This is the neat part. Now we got a separate process next to Celery that exposes a web server that will be started with Celery and killed upon termination. It has access all metrics collected in all Celery workers (on that machine or container). This means that you can simply use Prometheus metrics as usual!