Search code examples
javajava-streamspliterator

Spliterator Stream passing results to the next tream consumer


Given the following code

public class StreamSpliteratorTest {

    public static void main(String[] args) {
        var ids = Stream.generate(() -> RandomStringUtils.randomAlphanumeric(5))
            .limit(100)
            .collect(Collectors.toList());
        get(ids)
            .forEach(resultMap -> {
                System.out.printf("Got result map with size %s%n", resultMap.size());
            });
    }

    static Stream<Map<String, String>> get(Collection<String> ids) {
        var remainginIds = new HashSet<>(ids);
        var initialCount = remainginIds.size();
        return StreamSupport.stream(Spliterators.spliteratorUnknownSize(new Iterator<>() {

            @Override
            public boolean hasNext() {
                return !remainginIds.isEmpty();
            }

            @Override
            public Map<String, String> next() {
                Map<String, String> result = Map.of();

                try {
                    var chunk = remainginIds.stream().limit(10).collect(Collectors.toSet());
                    remainginIds.removeAll(chunk);
                    result = fetch(chunk).get(5, TimeUnit.SECONDS);
                    System.out.printf("%s of %s ids done%n", initialCount - remainginIds.size(), initialCount);
                } catch (Exception e) {
                    System.err.printf("Request thread pool was interrupted: %s%n", e.getMessage());
                }

                return result;
            }
        }, Spliterator.IMMUTABLE), true);
    }

    static CompletableFuture<Map<String, String>> fetch(Collection<String> ids) {
        var delay = CompletableFuture.delayedExecutor(1, TimeUnit.SECONDS);
        return CompletableFuture
            .supplyAsync(() -> ids.stream().collect(Collectors.toMap(e -> e, e -> e)), delay);
    }
}

The result of which is

10 of 100 ids done
20 of 100 ids done
30 of 100 ids done
40 of 100 ids done
50 of 100 ids done
60 of 100 ids done
70 of 100 ids done
80 of 100 ids done
90 of 100 ids done
100 of 100 ids done
Got result map with size 10
Got result map with size 10
Got result map with size 10
Got result map with size 10
Got result map with size 10
Got result map with size 10
Got result map with size 10
Got result map with size 10
Got result map with size 10
Got result map with size 10

I am confused as to why the result of one execution of next() isn't immediately passed on to the next consumer in the calling code? I would expect this to result in this output:


10 of 100 ids done
Got result map with size 10
20 of 100 ids done
Got result map with size 10
30 of 100 ids done
Got result map with size 10
40 of 100 ids done
Got result map with size 10
50 of 100 ids done
Got result map with size 10
60 of 100 ids done
Got result map with size 10
70 of 100 ids done
Got result map with size 10
80 of 100 ids done
Got result map with size 10
90 of 100 ids done
Got result map with size 10
100 of 100 ids done
Got result map with size 10

What am I doing wrong here?


Solution

  • The Spliterator returned from Spliterators.spliteratorUnknownSize tries to batch a number of elements from the iterator in order to increase parallel performance.

    You can see this here: https://github.com/openjdk/jdk/blob/51b53a821bb3cfb962f80a637f5fb8cde988975a/src/java.base/share/classes/java/util/Spliterators.java#L1828

    To avoid this you could use a different spliterator or switch the stream to sequential mode by calling the sequential() method before forEach(...).

    One way to keep using a parallel stream would be to move your batching out to before the stream and use the Arrays.stream method:

    static Stream<Map<String, String>> get(Collection<String> ids) {
        List<String>[] chunks = new List[ids.size() / 10 + 1];
        Arrays.setAll(chunks, i -> new ArrayList<>());
        int i = 0;
        for(String id : ids){
            chunks[i++ / 10].add(id);
        }
        AtomicInteger done = new AtomicInteger(0);
        int initialCount = ids.size();
        return Arrays.stream(chunks).map(c -> {
            Map<String, String> result = Map.of();
            try {
                result = fetch(c).get(5, TimeUnit.SECONDS);
                System.out.printf("%s of %s ids done%n", done.addAndGet(c.size()), initialCount);
            } catch (Exception e) {
                System.err.printf("Request thread pool was interrupted: %s%n", e.getMessage());
            }
            return result;
        }).parallel();
    }