Search code examples
scalascala-3cats-effecthttp4s

How do you obey x-rate-limit headers using cats-effect and the http4s client?


As this should be purely functional I put together this limiter class (as of now it might contain bugs since I couldn't use it yet but you get the idea). But how do I actually use it? I went ahead and tried to write a client middleware as per http4s' documentation but the types just don't work out, IIUC for middleware I should use Client.run but I can't suspend that code in IO.

class Limiter(using clock: Clock[IO])(
  private val requestsLeft: IO[AtomicCell[IO, Int]],
  private val resetAt: IO[AtomicCell[IO, FiniteDuration]]
):
  def update(currentRequestsLeft: Int, currentResetAt: Int) =
    for 
      requestsLeft <- requestsLeft
      resetAt      <- resetAt
      _ <- requestsLeft.set(currentRequestsLeft)
      _ <- resetAt.set(FiniteDuration(currentResetAt, SECONDS))
    yield () 

  def delay[T](effect: IO[T]) =
    val delay = for 
      requestsLeft <- requestsLeft.flatMap(_.get)
      resetAt      <- resetAt.flatMap(_.get)
      currentTime <- clock.realTime
      delay <- if requestsLeft <= 0 then IO.sleep(resetAt - currentTime) else IO.unit
    yield delay

    delay >> effect
  end delay
end Limiter

Solution

  • final edit: working version, if anyone comes accross this later looking for help

    class Throttler(using clock: Clock[IO])(
      private val requestsLeft: Ref[IO, Long],
      private val resetAt: Ref[IO, FiniteDuration]
    ):
      def update(currentRequestsLeft: Int, currentResetAt: Long) =
        for 
          _ <- requestsLeft.set(currentRequestsLeft)
          _ <- resetAt.set(currentResetAt.seconds)
        yield ()
    
      def throttle() =
        for 
          requestsLeft <- requestsLeft.get
          resetAt <- resetAt.get
          currentTime <- clock.realTime
          _ <- IO.println(s"remaining as seen by throttle: ${requestsLeft}")
          delay <- if requestsLeft <= 0 
                   then IO.println(s"throttling until ${resetAt}") >> IO.sleep(resetAt - currentTime) 
                   else IO.println("not throttling")
        yield delay
    end Throttler
    

    --

      def throttledClient(underlying: Client[IO], throttler: Throttler) = Client[IO] { request =>
        underlying.run(request).evalTap { response => 
          (
            for 
              remaining <- response.headers.get(ci"x-rate-limit-remaining").flatMap(_.head.value.toIntOption)
              resetAt   <- response.headers.get(ci"x-rate-limit-reset").flatMap(_.head.value.toLongOption)
            yield (remaining, resetAt)
          ).traverse(throttler.update)
        }.preAllocate(throttler.throttle())
      }