Search code examples
javaakka-stream

How to prevent Akka TCP Stream of incoming connections from connecting after a configured max number of connections have connected?


[EDIT]: the problem has been solved and solution provided by Artur has been added as an edit at the end.

The idea I am trying to implement is, a TCP server is allowed n connections, if I get a n+1 connection do not allow the connection.

So, I needed to somehow get to cancelling a connection and I went ahead by connecting that particular flow to a Sink.cancelled().

What I have is a IncomingConnection connected to a custom flow that partitions the IncomingConnection based on connection count. Once max connection count is breached the partition logic directs it to an outlet connected to Sink.cancelled.

The expectation was to immediately cancel the connection, but it allows the client to connect and then disconnects after some time.

Maybe I am running into the same issue as noted in the answers for Why does Akka TCP stream server disconnect client when there is no flow for the connection.handlewith? where a flow to handle is not found, and it lingers and disconnects.

I am looking for

  1. A clean solution to not allow incoming connections when the max is breached.
  2. What is Sink.cancelled() doing (if at all it's doing something).
package com.example;

import java.util.concurrent.CompletionStage;

import akka.Done;
import akka.NotUsed;
import akka.actor.typed.ActorSystem;
import akka.actor.typed.javadsl.Behaviors;
import akka.stream.FlowShape;
import akka.stream.SinkShape;
import akka.stream.UniformFanOutShape;
import akka.stream.javadsl.Flow;
import akka.stream.javadsl.GraphDSL;
import akka.stream.javadsl.Partition;
import akka.stream.javadsl.Sink;
import akka.stream.javadsl.Source;
import akka.stream.javadsl.Tcp;
import akka.stream.javadsl.Tcp.IncomingConnection;
import akka.stream.javadsl.Tcp.ServerBinding;
import akka.util.ByteString;


public class SimpleStream03 {
    private static int connectionCount = 0;
    private static int maxConnectioCount = 2;

    public static void runServer() {

        ActorSystem actorSystem = ActorSystem.create(Behaviors.empty(), "actorSystem");

        Source<IncomingConnection, CompletionStage<ServerBinding>> source = Tcp.get(actorSystem).bind("127.0.0.1",
                8888); 

        Sink<IncomingConnection, CompletionStage<Done>> handler = Sink.foreach(conn -> {
            
            System.out.println("Handler Sink Connection Count " + connectionCount);
            System.out.println("Handler Sink Client connected from: " + conn.remoteAddress());

            conn.handleWith(Flow.of(ByteString.class), actorSystem);

        });

        Flow<IncomingConnection, IncomingConnection, NotUsed> connectioncountFlow = Flow
                .fromGraph(GraphDSL.create(builder -> {

                    SinkShape<IncomingConnection> sinkCancelled = builder.add(Sink.cancelled());
                    FlowShape<IncomingConnection, IncomingConnection> inFlowShape = builder
                            .add(Flow.of(IncomingConnection.class).map(conn -> {
                                connectionCount++;
                                return conn;
                            }));
                    UniformFanOutShape<IncomingConnection, IncomingConnection> partition = builder
                            .add(Partition.create(IncomingConnection.class, 2, param -> {
                                if (connectionCount > maxConnectioCount) {
                                    connectionCount = maxConnectioCount;
                                    System.out.println("Outlet 0 -> Sink.cancelled");
                                    return 0;
                                }
                                System.out.println("Outlet 1 -> forward to handler");
                                return 1;
                            }));

                    builder.from(inFlowShape).toFanOut(partition);
                    builder.from(partition.out(0)).to(sinkCancelled);
                    return new FlowShape<>(inFlowShape.in(), partition.out(1));

                }));

        CompletionStage<ServerBinding> bindingFuture = source.via(connectioncountFlow).to(handler).run(actorSystem);

        bindingFuture.handle((binding, throwable) -> {
            if (binding != null) {
                System.out.println("Server started, listening on: " + binding.localAddress());

            } else {
                System.err.println("Server could not bind to  : " + throwable.getMessage());
                actorSystem.terminate();
            }
            return NotUsed.getInstance();
        });

    }

    public static void main(String[] args) throws InterruptedException {
        SimpleStream03.runServer();

    }

}

The output confirms the partitioning is working and main sink handler is being reached by the 2 connections.

Server started, listening on: /127.0.0.1:8888
Outlet 1 -> forward to handler
Handler Sink Connection Count 1
Handler Sink Client connected from: /127.0.0.1:60327
Outlet 1 -> forward to handler
Handler Sink Connection Count 2
Handler Sink Client connected from: /127.0.0.1:60330
Outlet 0 -> Sink.cancelled

Edit : Implementing the accepted answer, the following change prevents the incoming connection after threshold is breached. The client sees a Connection reset by peer

        +---------------------------------------------------------+
        |                                                         |
        |                                         Fail Flow       |
        |                                        +-------------+  |
        |                                    +-->+Sink  |Source|  |
        |                                    |   |cancel|fail  |  |
        |                                    |   +-------------+  |
        |                      +----------+  |                    |
        |                      |          |  |                    |
        |  +----------+        |        O0+--+                    |
connections|FLOW      |        |          |  O0:count > threshold |
+-------+-->          +------->+ Partition|                       |
        |  |count++   |        |        O1+----------------------------->
        |  +----------+        |          |                       |
        |                      |          |  O1:count <= threshold|
        |                      +----------+                       |
        |                                                         |
        +---------------------------------------------------------+

Replace

SinkShape<IncomingConnection> sinkCancelled = builder.add(Sink.cancelled());

With

Sink<IncomingConnection, CompletionStage<Done>> connectionCancellingSink = Sink.foreach(ic -> ic
                            .handleWith(Flow.fromSinkAndSource(Sink.cancelled(), Source.failed(new Throwable("killed"))),
                                    actorSystem));// Sink.ignore and Sink.cancel give me the same expected result
SinkShape<IncomingConnection> sinkCancelledShape = builder.add(connectionCancellingSink);
                    

Solution

  • The Sink.cancelled() cancels upstream immediately (https://doc.akka.io/docs/akka/current/stream/operators/Sink/cancelled.html).

    However your Partition is created with eagerCancel set to false

    
      /**
       * Create a new `Partition` operator with the specified input type, `eagerCancel` is `false`.
       *
       * @param clazz a type hint for this method
       * @param outputCount number of output ports
       * @param partitioner function deciding which output each element will be targeted
       */
      def create[T](
          @unused clazz: Class[T],
          outputCount: Int,
          partitioner: function.Function[T, Integer]): Graph[UniformFanOutShape[T, T], NotUsed] =
        new scaladsl.Partition(outputCount, partitioner.apply, eagerCancel = false)
    
    

    which means that the Partition will only cancel when ALL of its downstream connections cancel. This is not what you want. But you don't want eagerCancel=true either, because that means that the first connection over the limit will zap the whole Partition and therefore all your connections.. basically trashing the whole server.

    Perhaps it's useful to think about the situation here in terms of nested streams. The top level Source<IncomingConnection> represent a stream of accepted TCP connections. You don't want to cancel that stream. If you do, you just killed your server. Every IncomingConnection represents an individual TCP connection. The exchange of bytes happening on such connection is also represented as a stream. It is this stream that you want to cancel for every connection above the threshold.

    To do that you can define a connection cancelling Sink like that:

    Sink<IncomingConnection, CompletionStage<Done>>
          connectionCancellingSink =
          Sink.foreach(
            ic ->
              ic.handleWith(
                Flow.fromSinkAndSource(Sink.cancelled(), Source.empty()),
                actorSystem));
    
    

    The IncomingConnection allows you to attach a handler using handleWith method. For that you need a Flow as you both consume bytes from client and potentially send bytes to client too (incoming bytes go into the Flow and whatever you want to send back to the client you need to produce on the output of the Flow). In our case we just want to cancel that stream immediately. You can use Flow.fromSinkAndSource to get a Flow out of... Sink and Source. You can leverage that to plugin Sink.cancelled and Source.empty. So Source.empty means we won't send any bytes to the connection and Sink.cancelled will immediately cancel the stream and hopefully the underlying TCP connection. Let's give it a shot.

    Last thing to do is to plugin our new cancelling Sink to the Partition

    SinkShape<IncomingConnection> sinkCancelled =
                builder.add(connectionCancellingSink);
    //...the rest stays the same
     builder.from(partition.out(0))
                  .to(sinkCancelled);
             
    

    If you do that, on the third connection you will see the following message:

    Not aborting connection from 127.0.0.1:49874 because downstream cancelled stream without failure
    

    So Sink.cancelled() is not really triggering what you want. Let's redefine our cancelling Flow:

    Sink<IncomingConnection, CompletionStage<Done>>
          connectionCancellingSink =
          Sink.foreach(
            ic ->
              ic.handleWith(
                Flow.fromSinkAndSource(Sink.ignore(),
                                       Source.failed(new Throwable("killed"))),
                actorSystem));
    

    Now, this is using Sink.ignore() to just ignore the incoming bytes but fails the stream via Source.failed(...). This will cause the connection to be terminated immediately and the stracktrace will be printed on the server output. If you want to keep it quiet, you can create exception without stacktrace:

    public static class TerminatedException extends Exception
      {
        public TerminatedException(String message)
        {
          super(message, null, false, false);
        }
      }
    

    and then use that to fail your connection stream

    Flow.fromSinkAndSource(Sink.ignore(),
                           Source.failed(
                             new TerminatedException(("killed"))))
       
    

    that way you will get cleaner logs.