Search code examples
javaconcurrencyparallel-processingjava-streamcompletable-future

How to flatten a list inside a stream of completable futures?


I have this:

Stream<CompletableFuture<List<Item>>>

how can I convert it to

Stream<CompletableFuture<Item>>

Where: the second stream is comprised of each and all the Items inside each of the lists in the first stream.

I looked into thenCompose but that solves a completely different problem which is also referred to as "flattening".

How can this be done efficiently, in a streaming fashion, without blocking or prematurely consuming more stream items than necessary?

Here is my best attempt so far:

    ExecutorService pool = Executors.newFixedThreadPool(PARALLELISM);
    Stream<CompletableFuture<List<IncomingItem>>> reload = ... ;

    @SuppressWarnings("unchecked")
    CompletableFuture<List<IncomingItem>> allFutures[] = reload.toArray(CompletableFuture[]::new);
    CompletionService<List<IncomingItem>> queue = new ExecutorCompletionService<>(pool);
    for(CompletableFuture<List<IncomingItem>> item: allFutures) {
        queue.submit(item::get);
    }
    List<IncomingItem> THE_END = new ArrayList<IncomingItem>();
    CompletableFuture<List<IncomingItem>> ender = CompletableFuture.allOf(allFutures).thenApply(whatever -> {
        queue.submit(() -> THE_END);
        return THE_END;
    });
    queue.submit(() -> ender.get());
    Iterable<List<IncomingItem>> iter = () -> new Iterator<List<IncomingItem>>() {
        boolean checkNext = true;
        List<IncomingItem> next = null;
        @Override
        public boolean hasNext() {
            if(checkNext) {
                try {
                    next = queue.take().get();
                } catch (InterruptedException | ExecutionException e) {
                    throw new RuntimeException(e);
                }
                checkNext = false;
            }
            if(next == THE_END || next == null) {
                return false;
            }
            else {
                return true;
            }
        }
        @Override
        public List<IncomingItem> next() {
            if(checkNext) {
                hasNext();
            }
            if(!hasNext()) {
                throw new IllegalStateException();
            }
            checkNext = true;
            return next;
        }
    };
    Stream<IncomingItem> flat = StreamSupport.stream(iter.spliterator(), false).flatMap(List::stream);

This works at first, unfortunately, it has a fatal bug: the resulting stream seems to terminate prematurely, before retrieving all the items.

2023 UPDATE

Four years later, I still have not chosen an answer, because both answers at the moment are saying this is "impossible". That doesn't sound right. I'm simply asking for an efficient way to take items that are completed in batches and be able to monitor their completion individually. Perhaps this is easier with Virtual Threads in Java 21? I no longer maintain this codebase but the question remains a great unsolved problem.

Update#2: Clarification of assumptions

CompletableFuture represents a future result of an asynchronous computation according to https://www.geeksforgeeks.org/completablefuture-in-java/ ... interestingly the JavaDocs are not quite as clear.

Update#3:

Performance is not the focus of the question, only efficiency, in a specific sense: if task 1 takes T1 to complete and task 2 takes T2 to complete, then waiting for a total of max(T1,T2) time for completion is more efficient than waiting for T1+T2 time. Then extend that to N-sized lists of tasks T, U, V, ..., and it probably looks something like max(T1..Tn)+max(U1..Un)+max(V1..Vn) or hopefully max(T1..Vn). Assuming the same N for all lists here is only an explanatory simplification (really should be I, J, K...). In other words, please assume that the asynchronous tasks which are represented in the Stream of Lists are I/O bound, not CPU-bound. This can be simulated by inserting random Thread.sleep() if you wish to demonstrate something in code. Apologies for the sloppy notation - I have a background in Computer Science but it's been a while since I've tried to formally describe a problem like this.

Update#4

Based on the answer from @Slaw, I can see that the core problem I was facing is the mismatch between stream arrival order (the order in which futures are provided through the Stream) and future completion order (the order in which each CompletableFuture completes execution and unlocks the List within it). So a new revised TL;DR for this question is: How can you take a Stream that's produced in arbitrary order, and re-order it into futures execution completion order, so that a flatMap() operation will not block unnecessarily?


Solution

  • Here is a solution which not only avoids blocking, by preferring to process already completed futures, it even retains lazy processing of the original stream, as far as possible. As long as there are completed futures among the already encountered futures it will not advance in the source traversal.

    This is best demonstrated by an example using an infinite source which still can complete in finite time (and quite fast due to the preference of completed futures).

    Stream<CompletableFuture<List<Integer>>> streamOfFutures = Stream.generate(
        () -> CompletableFuture.supplyAsync(
            () -> ThreadLocalRandom.current().ints(10, 0, 200).boxed().toList(),
            CompletableFuture.delayedExecutor(
                ThreadLocalRandom.current().nextLong(5), TimeUnit.SECONDS))
    );
    System.out.println(flattenResults(streamOfFutures)
        .peek(System.out::println)
        .anyMatch(i -> i == 123)
    );
    

    The implementation will process already completed futures immediately on encounter. Only if the future has not been completed yet, a queuing action will be chained and the pending counter increased. Care must be taken to decrease the counter even on exceptional completion and to queue an item (an empty list), to unblock the consumer thread in case it’s taking an element right at this point. The exception will be propagated to the caller when encountered. Like with short-circuiting parallel streams, it’s possible to miss errors if the result is found before processing all elements.

    If the terminal operation is short-circuiting and finishes without processing the entire stream, the counter is irrelevant and the operation will not wait for the completion of pending futures. Only when the source stream has been traversed completely, the counter becomes relevant for detecting when all futures have been completed.

    static <T> Stream<T> flattenResults(Stream<CompletableFuture<List<T>>> stream) {
        Spliterator<CompletableFuture<List<T>>> srcSp = stream.spliterator();
        BlockingQueue<List<T>> queue = new LinkedBlockingQueue<>();
    
        return StreamSupport.stream(new Spliterators.AbstractSpliterator<T>(
                                                     srcSp.estimateSize(), 0) {
            final AtomicLong pending = new AtomicLong();
            Spliterator<T> fromCurrentList;
            Throwable failure;
    
            @Override
            public boolean tryAdvance(Consumer<? super T> action) {
                if(checkExisting(action)) return true;
    
                while(srcSp.tryAdvance(this::checkNew)) {
                    if(checkExisting(action)) return true;
                }
    
                return checkAfterSourceExhausted(action);
            }
            private boolean checkExisting(Consumer<? super T> action) {
                for(;;) {
                    var sp = fromCurrentList;
                    if(sp == null) {
                        List<T> newList = queue.poll();
                        if(newList == null) return false;
                        fromCurrentList = sp = newList.spliterator();
                    }
                    if(sp.tryAdvance(action)) return true;
                    fromCurrentList = null;
                }
            }
    
            private void checkNew(CompletableFuture<List<T>> f) {
                if(f.isDone()) fromCurrentList = f.join().spliterator();
                else {
                    pending.incrementAndGet();
                    f.whenComplete((r, t) -> {
                        if(t != null) {
                            failure = t;
                            r = List.of();
                        }
                        queue.offer(r);
                        pending.decrementAndGet();
                    });
                }
            }
    
            private boolean checkAfterSourceExhausted(Consumer<? super T> action) {
                while(pending.get() != 0 || !queue.isEmpty()) {
                    checkFailure();
                    try {
                        List<T> newList = queue.take();
                        fromCurrentList = newList.spliterator();
                        if(checkExisting(action)) return true;
                    } catch(InterruptedException ex) {
                        throw new CompletionException(ex);
                    }
                }
                return false;
            }
    
            private void checkFailure() {
                Throwable t = failure;
                if(t != null) {
                    if(t instanceof RuntimeException rt) throw rt;
                    if(t instanceof Error e) throw e;
                    throw new CompletionException(t);
                }
            }
        }, false);
    }
    

    You may use something like

    Stream<CompletableFuture<List<Integer>>> streamOfFutures = IntStream.range(0, 10)
        .mapToObj(i -> 
          CompletableFuture.supplyAsync(
            () -> IntStream.range(i * 10, (i + 1) * 10).boxed().toList(),
            CompletableFuture.delayedExecutor(10 - i, TimeUnit.SECONDS)));
    
    System.out.println(flattenResults(streamOfFutures)
        .peek(System.out::println)
        .anyMatch(i -> i == 34)
    );
    

    to visualize the “first completed, first processed”. Or change the terminal operation to

    flattenResults(streamOfFutures).forEach(System.out::println);
    

    to demonstrate the completion of all futures is correctly recognized or

    if(flattenResults(streamOfFutures).count() != 100)
        throw new AssertionError();
    else
        System.out.println("Success");
    

    to have something which can be integrated into automated test.