Search code examples
pythontwistedrate-limitingdeferred

How to rate limit a deferred HTTP client in Twisted?


I have an HTTP client written in Twisted that sends requests to an API of some site from a deferred. It goes something like this (somewhat simplified):

from json import loads

from core import output

from twisted.python.log import msg
from twisted.internet import reactor
from twisted.web.client import Agent, HTTPConnectionPool, _HTTP11ClientFactory, readBody
from twisted.web.http_headers import Headers
from twisted.internet.ssl import ClientContextFactory


class WebClientContextFactory(ClientContextFactory):

    def getContext(self, hostname, port):
        return ClientContextFactory.getContext(self)


class QuietHTTP11ClientFactory(_HTTP11ClientFactory):
    # To shut up the garbage in the log
    noisy = False


class Output(output.Output):
    def start(self):
        myQuietPool = HTTPConnectionPool(reactor)
        myQuietPool._factory = QuietHTTP11ClientFactory
        self.agent = Agent(
            reactor,
            contextFactory=WebClientContextFactory(),
            pool=myQuietPool
        )

    def stop(self):
        pass

    def write(self, event):
        messg = 'Whatever'
        self.send_message(messg)

    def send_message(self, message):
        headers = Headers({
            b'User-Agent': [b'MyApp']
        })
        url = 'https://api.somesite.com/{}'.format(message)
        d = self.agent.request(b'GET', url.encode('utf-8'), headers, None)

        def cbBody(body):
            return processResult(body)

        def cbPartial(failure):
            failure.printTraceback()
            return processResult(failure.value)

        def cbResponse(response):
            if response.code in [200, 201]:
                return
            else:
                msg('Site response: {} {}'.format(response.code, response.phrase))
                d = readBody(response)
                d.addCallback(cbBody)
                d.addErrback(cbPartial)
                return d

        def cbError(failure):
            failure.printTraceback()

        def processResult(result):
            j = loads(result)
            msg('Site response: {}'.format(j))

        d.addCallback(cbResponse)
        d.addErrback(cbError)
        return d

This works fine but the site is rate limiting the requests and starts dropping them if they are arriving too fast. So, I need to rate-limit the client too and make sure that it isn't sending the requests too fast - yet they aren't lost, so some kind of buffering/queuing is needed. I don't need a precise rate limiting, like "no more than X requests per second"; just some reasonable delay (lik 1 second) after each request is fine.

Unfortunately, I can't use sleep() from a deferred, so some other approach is necessesary.

From googoing around, it seems that the basic idea is to do something like

self.transport.pauseProducing()
delay = 1 # seconds
self.reactor.callLater(delay, self.transport.resumeProducing)

at least according to this answer. But the code there doesn't work "as is" - SlowDownloader is expected to take a parameter (a reactor), so SlowDownloader() causes an error.

I also found this answer, which uses the interesting idea of using the factory as a storage, so you don't need to implement your own queues and stuff - but it deals with rate-limiting on the server side, while I need to rate-limit the client.

I feel that I'm pretty close to the solution but I still can't figure out how exactly to combine the information from these two answers, in order to produce working code, so any help would be appreciated.


Solution

  • OK, I managed to solve the problem. The code below does what I need - it makes sure that there is at least a 2-second delay between messages, and that the messages are sent in the right order. It works - the server no longer returns 429 errors and the messages are correctly ordered. It's not a generic rate-limiting solution (i.e., one making sure that exactly X messages are sent every Y seconds) but it does what I need. It's also not very elegant - I'm manually keeping track of when was a request last sent, maintain a queue of requests, in order to make sure that they will be sent in the right order, and so on. If someone knows how to achieve all this with Twisted's built-in methods, please create a separate answer and I'll accept it.

    from json import loads
    from time import time
    
    from core import output
    
    from twisted.python.log import msg
    from twisted.internet import reactor
    from twisted.internet.task import deferLater
    from twisted.web.client import Agent, HTTPConnectionPool, _HTTP11ClientFactory, readBody
    from twisted.web.http_headers import Headers
    from twisted.internet.ssl import ClientContextFactory
    
    
    class WebClientContextFactory(ClientContextFactory):
    
        def getContext(self, hostname, port):
            return ClientContextFactory.getContext(self)
    
    
    class QuietHTTP11ClientFactory(_HTTP11ClientFactory):
        # To shut up the garbage in the log
        noisy = False
    
    
    class Output(output.Output):
        def start(self):
            myQuietPool = HTTPConnectionPool(reactor)
            myQuietPool._factory = QuietHTTP11ClientFactory
            self.agent = Agent(
                reactor,
                contextFactory=WebClientContextFactory(),
                pool=myQuietPool
            )
            self.last_sent = 0
            self.requests_queue = []
    
        def stop(self):
            pass
    
        def write(self, event):
            messg = 'Whatever'
            self.send_message(messg)
    
        def send_message(self, message):
            self.requests_queue.append(message)
            now = time()
            if now - self.last_sent < 2.0:
                # Make sure there's at least a 2-second delay between requests
                deferLater(reactor, 2.0, self.send_message, message)
                return
            message = self.requests_queue.pop(0)
            headers = Headers({
                b'User-Agent': [b'MyApp']
            })
            url = 'https://api.somesite.com/{}'.format(message)
            d = self.agent.request(b'GET', url.encode('utf-8'), headers, None)
    
            def cbBody(body):
                return processResult(body)
    
            def cbPartial(failure):
                failure.printTraceback()
                return processResult(failure.value)
    
            def cbResponse(response):
                if response.code in [200, 201]:
                    return
                else:
                    msg('Site response: {} {}'.format(response.code, response.phrase))
                    d = readBody(response)
                    d.addCallback(cbBody)
                    d.addErrback(cbPartial)
                    return d
    
            def cbError(failure):
                failure.printTraceback()
    
            def processResult(result):
                j = loads(result)
                msg('Site response: {}'.format(j))
    
            d.addCallback(cbResponse)
            d.addErrback(cbError)
            return d