Search code examples
javaiteratorjava-stream

Java stream's iterator forces flatmap to traverse substream before getting the first item


I have the need to create an iterator out of a stream of streams. Both the parent and the child streams are composed by non-interfering stateless operations and the obvious strategy is to use flatMap.

Turns out that iterator, at the first "hasNext" invocation, traverse the entire first substream and I don't understand why. Despite iterator() is a terminal operation is clearly stated that it shouldn't consume the stream. I need that the objects generated from the substream are generated one by one.

To replicate the behaviour I've mocked my real code with a sample which shows the same:

import java.util.Iterator;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Stream;

public class FreeRunner {

    public static void main(String[] args) {
        AtomicInteger x = new AtomicInteger();
        Iterator<C> iterator = Stream.generate(() -> null)
                .takeWhile(y -> x.incrementAndGet() < 5)
                .filter(y -> x.get() % 2 == 0)
                .map(n -> new A("A" + x.get()))
                .flatMap(A::getBStream)
                .filter(Objects::nonNull)
                .map(B::toC)
                .iterator();

        while(iterator.hasNext()) {
            System.out.println("after hasNext()");
            C next = iterator.next();
            System.out.println(next);
        }

    }

    private static class A {
        private final String name;

        public A(String name) {
            this.name = name;
            System.out.println(" > created " + name);
        }

        public Stream<B> getBStream() {
            AtomicInteger c = new AtomicInteger();
            return Stream.generate(() -> null)
                    .takeWhile(x -> c.incrementAndGet() < 5)
                    .map(n -> c.get() % 2 == 0 ? null : new B(this.name + "->B" + c.get()));
        }

        public String toString() {
            return name;
        }
    }

    private static class B {

        private final String name;

        public B(String name) {
            this.name = name;
            System.out.println(" >> created " + name);
        }

        public String toString() {
            return name;
        }

        public C toC() {
            return new C(this.name + "+C");
        }

    }

    private static class C {

        private final String name;

        public C(String name) {
            this.name = name;
            System.out.println(" >>> created " + name);
        }

        public String toString() {
            return name;
        }
    }
}

When it is executed it shows:

 > created A2
 >> created A2->B1
 >>> created A2->B1+C
 >> created A2->B3
 >>> created A2->B3+C
after hasNext()
A2->B1+C
after hasNext()
A2->B3+C
 > created A4
 >> created A4->B1
 >>> created A4->B1+C
 >> created A4->B3
 >>> created A4->B3+C
after hasNext()
A4->B1+C
after hasNext()
A4->B3+C

Process finished with exit code 0

In debug it's clear that iterator.hasNext() triggers the generation of objects B and C.

The desired behaviour, instead, is:

 > created A2
 >> created A2->B1
 >>> created A2->B1+C
after hasNext()
A2->B1+C
 >> created A2->B3
 >>> created A2->B3+C
after hasNext()
A2->B3+C
 > created A4
 >> created A4->B1
 >>> created A4->B1+C
after hasNext()
A4->B1+C
 >> created A4->B3
 >>> created A4->B3+C
after hasNext()
A4->B3+C

What am I missing here?


Solution

  • I found a way out, but I had to sacrifice the laziness of the primary stream. As I posted in the comment above the problem which I tried to simplify mocking the code is about to read an excel file sheet by sheet (filtered by sheet's name) and traversing all rows to create objects accordingly to the data in the spreadsheet.

    The original idea is still good to me but, apparently, Stream.iterator() implementation consumes each nested stream at the first hasNext() invocation operated at the creation of the first A object.

    So I abandoned flatMap() and used reduce(Stream::concat) to concatenate all streams generated by A.getBStream():

        public static void main(String[] args) {
            AtomicInteger x = new AtomicInteger();
            Iterator<C> it = Stream.generate(() -> null)
                    .takeWhile(y -> x.incrementAndGet() < 5)
                    .filter(y -> x.get() % 2 == 0)
                    .map(a -> new A("A" + x.get()))
                    .map(A::getBStream)
                    .filter(Objects::nonNull)
                    .reduce(Stream::concat)
                    .orElseGet(Stream::empty)
                    .filter(Objects::nonNull)
                    .map(B::toC)
                    .iterator();
    
            while(it.hasNext()) {
                System.out.println("after hasNext()");
                C next = it.next();
                System.out.println(next);
            }
        }
    

    This produces the following output:

     > created A2
     > created A4
     >> created A2->B0
     >>> created A2->B0+C
    after hasNext()
    A2->B0+C
     >> created A2->B1
     >>> created A2->B1+C
    after hasNext()
    A2->B1+C
     >> created A2->B2
     >>> created A2->B2+C
    after hasNext()
    A2->B2+C
     >> created A2->B3
     >>> created A2->B3+C
    after hasNext()
    A2->B3+C
     >> created A2->B4
     >> created A4->B0
     >>> created A4->B0+C
    after hasNext()
    A4->B0+C
     >> created A4->B1
     >>> created A4->B1+C
    after hasNext()
    A4->B1+C
     >> created A4->B2
     >>> created A4->B2+C
    after hasNext()
    A4->B2+C
     >> created A4->B3
     >>> created A4->B3+C
    after hasNext()
    A4->B3+C
     >> created A4->B4
    

    The price to pay is to have A2 and A4 generated up-front, but all the B objects are generated lazily