Search code examples
javajava-8java-streamcollectors

Custom Collector for Collectors.groupingBy doesn't work as expected


Consider the simple class Foo:

public class Foo {

    public Float v1;
    public Float v2;
    public String name;

    public Foo(String name, Float v1, Float v2) {
        this.name = name;
        this.v1 = v1;
        this.v2 = v2;
    }

    public String getName() {
        return name;
    }
}

Now, I have a collection of Foos and I'd like to group them by Foo::getName. I wrote a custom Collector to do that but it doesn't seem to work as expected. More precisely, combiner() never gets called. Why?

public class Main {

    public static void main(String[] args) {

        List<Foo> foos = new ArrayList<>();
        foos.add(new Foo("blue", 2f, 2f));
        foos.add(new Foo("blue", 2f, 3f));
        foos.add(new Foo("green", 3f, 4f));

        Map<String, Float> fooGroups = foos.stream().collect(Collectors.groupingBy(Foo::getName, new FooCollector()));
        System.out.println(fooGroups);
    }

    private static class FooCollector implements Collector<Foo, Float, Float> {

        @Override
        public Supplier<Float> supplier() {
            return () -> new Float(0);
        }

        @Override
        public BiConsumer<Float, Foo> accumulator() {
            return (v, foo) -> v += foo.v1 * foo.v2;
        }

        @Override
        public BinaryOperator<Float> combiner() {
            return (v1, v2) -> v1 + v2;
        }

        @Override
        public Function<Float, Float> finisher() {
            return Function.identity();
        }

        @Override
        public Set<Characteristics> characteristics() {
            Set<Characteristics> characteristics = new TreeSet<>();
            return characteristics;
        }
    }
}

Solution

  • First, the combiner function does not need to get called if you aren't using multiple threads (parallel stream). The combiner gets called to combine the results of the operation on chunks of your stream. There is no parallelism here so the combiner doesn't need to be called.

    You are getting zero values because of your accumulator function. The expression

    v += foo.v1 * foo.v2;
    

    will replace v with a new Float object. The original accumulator object is not modified; it is still 0f. Besides, Float, like other numeric wrapper types (and String) is immutable and cannot be changed.

    You need some other kind of accumulator object that is mutable.

    class FloatAcc {
        private Float total;
        public FloatAcc(Float initial) {
            total = initial;
        }
        public void accumulate(Float item) {
            total += item;
        }
        public Float get() {
            return total;
        }
    }
    

    Then you can modify your custom Collector to use FloatAcc. Supply a new FloatAcc, call accumulate in the accumulator function, etc.

    class FooCollector implements Collector<Foo, FloatAcc, Float> {
        @Override
        public Supplier<FloatAcc> supplier() {
            return () -> new FloatAcc(0f);
        }
        @Override
        public BiConsumer<FloatAcc, Foo> accumulator() {
            return (v, foo) -> v.accumulate(foo.v1 * foo.v2);
        }
        @Override
        public BinaryOperator<FloatAcc> combiner() {
            return (v1, v2) -> {
                v1.accumulate(v2.get());
                return v1;
            };
        }
        @Override
        public Function<FloatAcc, Float> finisher() {
            return FloatAcc::get;
        }
        @Override
        public Set<Characteristics> characteristics() {
            Set<Characteristics> characteristics = new TreeSet<>();
            return characteristics;
        }
    }
    

    With these changes I get what you're expecting:

    {green=12.0, blue=10.0}