Search code examples
swiftpromisefutureswift-nio

Rewriting looping blocking code to SwiftNIO style non-blocking code


I'm working on a driver that will read data from the network. It doesn't know how much is in a response, other than that when it tries to read and gets 0 bytes back, it is done. So my blocking Swift code looks naively like this:

func readAllBlocking() -> [Byte] {

  var buffer: [Byte] = []
  var fullBuffer: [Byte] = []

  repeat {
    buffer = read() // synchronous, blocking
    fullBuffer.append(buffer)
  } while buffer.count > 0

  return fullBuffer
}

How can I rewrite this as a promise that will keep on running until the entire result is read? After trying to wrap my brain around it, I'm still stuck here:

func readAllNonBlocking() -> EventLoopFuture<[Byte]> {

  ///...?
}

I should add that I can rewrite read() to instead of returning a [Byte] return an EventLoopFuture<[Byte]>


Solution

  • Generally, loops in synchronous programming are turned into recursion to get the same effect with asynchronous programming that uses futures (and also in functional programming).

    So your function could look like this:

    func readAllNonBlocking(on eventLoop: EventLoop) -> EventLoopFuture<[Byte]> {
        // The accumulated chunks
        var accumulatedChunks: [Byte] = []
    
        // The promise that will hold the overall result
        let promise = eventLoop.makePromise(of: [Byte].self)
    
        // We turn the loop into recursion:
        func loop() {
            // First, we call `read` to read in the next chunk and hop
            // over to `eventLoop` so we can safely write to `accumulatedChunks`
            // without a lock.
            read().hop(to: eventLoop).map { nextChunk in
                // Next, we just append the chunk to the accumulation
                accumulatedChunks.append(contentsOf: nextChunk)
                guard nextChunk.count > 0 else {
                    promise.succeed(accumulatedChunks)
                    return
                }
                // and if it wasn't empty, we loop again.
                loop()
            }.cascadeFailure(to: promise) // if anything goes wrong, we fail the whole thing.
        }
    
        loop() // Let's kick everything off.
    
        return promise.futureResult
    }
    

    I would like to add two things however:

    First, what you're implementing above is to simply read in everything until you see EOF, if that piece of software is exposed to the internet, you should definitely add a limit on how many bytes to hold in memory maximally.

    Secondly, SwiftNIO is an event driven system so if you were to read these bytes with SwiftNIO, the program would actually look slightly differently. If you're interested what it looks like to simply accumulate all bytes until EOF in SwiftNIO, it's this:

    struct AccumulateUntilEOF: ByteToMessageDecoder {
        typealias InboundOut = ByteBuffer
    
        func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
            // `decode` will be called if new data is coming in.
            // We simply return `.needMoreData` because always need more data because our message end is EOF.
            // ByteToMessageHandler will automatically accumulate everything for us because we tell it that we need more
            // data to decode a message.
            return .needMoreData
        }
    
        func decodeLast(context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool) throws -> DecodingState {
            // `decodeLast` will be called if NIO knows that this is the _last_ time a decode function is called. Usually,
            // this is because of EOF or an error.
            if seenEOF {
                // This is what we've been waiting for, `buffer` should contain all bytes, let's fire them through
                // the pipeline.
                context.fireChannelRead(self.wrapInboundOut(buffer))
            } else {
                // Odd, something else happened, probably an error or we were just removed from the pipeline. `buffer`
                // will now contain what we received so far but maybe we should just drop it on the floor.
            }
            buffer.clear()
            return .needMoreData
        }
    }
    

    If you wanted to make a whole program out of this with SwiftNIO, here's an example that is a server which accepts all data until it sees EOF and then literally just writes back the number of received bytes :). Of course, in the real world you would never hold on to all the received bytes to count them (you could just add each individual piece) but I guess it serves as an example.

    import NIO
    
    let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
    defer {
        try! group.syncShutdownGracefully()
    }
    
    struct AccumulateUntilEOF: ByteToMessageDecoder {
        typealias InboundOut = ByteBuffer
    
        func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
            // `decode` will be called if new data is coming in.
            // We simply return `.needMoreData` because always need more data because our message end is EOF.
            // ByteToMessageHandler will automatically accumulate everything for us because we tell it that we need more
            // data to decode a message.
            return .needMoreData
        }
    
        func decodeLast(context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool) throws -> DecodingState {
            // `decodeLast` will be called if NIO knows that this is the _last_ time a decode function is called. Usually,
            // this is because of EOF or an error.
            if seenEOF {
                // This is what we've been waiting for, `buffer` should contain all bytes, let's fire them through
                // the pipeline.
                context.fireChannelRead(self.wrapInboundOut(buffer))
            } else {
                // Odd, something else happened, probably an error or we were just removed from the pipeline. `buffer`
                // will now contain what we received so far but maybe we should just drop it on the floor.
            }
            buffer.clear()
            return .needMoreData
        }
    }
    
    // Just an example "business logic" handler. It will wait for one message
    // and just write back the length.
    final class SendBackLengthOfFirstInput: ChannelInboundHandler {
        typealias InboundIn = ByteBuffer
        typealias OutboundOut = ByteBuffer
    
        func channelRead(context: ChannelHandlerContext, data: NIOAny) {
            // Once we receive the message, we allocate a response buffer and just write the length of the received
            // message in there. We then also close the channel.
            let allData = self.unwrapInboundIn(data)
            var response = context.channel.allocator.buffer(capacity: 10)
            response.writeString("\(allData.readableBytes)\n")
            context.writeAndFlush(self.wrapOutboundOut(response)).flatMap {
                context.close(mode: .output)
            }.whenSuccess {
                context.close(promise: nil)
            }
        }
    
        func errorCaught(context: ChannelHandlerContext, error: Error) {
            print("ERROR: \(error)")
            context.channel.close(promise: nil)
        }
    }
    
    let server = try ServerBootstrap(group: group)
        // Allow us to reuse the port after the process quits.
        .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1)
        // We should allow half-closure because we want to write back after having received an EOF on the input
        .childChannelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
        // Our program consists of two parts:
        .childChannelInitializer { channel in
            channel.pipeline.addHandlers([
                // 1: The accumulate everything until EOF handler
                ByteToMessageHandler(AccumulateUntilEOF(),
                                     // We want 1 MB of buffering max. If you remove this parameter, it'll also
                                     // buffer indefinitely.
                                     maximumBufferSize: 1024 * 1024),
                // 2: Our "business logic"
                SendBackLengthOfFirstInput()
            ])
        }
        // Let's bind port 9999
        .bind(to: SocketAddress(ipAddress: "127.0.0.1", port: 9999))
        .wait()
    
    // This will never return.
    try server.closeFuture.wait()
    

    Demo:

    $ echo -n "hello world" | nc localhost 9999
    11