Search code examples
javajava-8java-streambigdecimal

How can I use groupBy to create a map whose values are averages of BigDecimal fields?


I have the following class:

final public class Person {

 private final String name;
 private final String state;
 private final BigDecimal salary;

 public Person(String name, String state, BigDecimal salary) {
    this.name = name;
    this.state = state;
    this.salary = salary;
 }

 //getters omitted for brevity...
}

I want to create a map which lists the average of salaries by state. How can I do so using Java8 streams? I tried to use downstream collectors on the groupBy but wasn't able to do so in an elegant way.

I did the following which works but is pretty hideous looking:

Stream.of(p1,p2,p3,p4).collect(groupingBy(Person::getState, mapping(d -> d.getSalary(), toList())))
.forEach((state,wageList) -> {
        System.out.print(state+"-> ");
        final BigDecimal[] wagesArray = wageList.stream()
                .map(bd -> new BigDecimal[]{bd, BigDecimal.ONE})
                .reduce((a, b) -> new BigDecimal[]{a[0].add(b[0]), a[1].add(BigDecimal.ONE)})
                .get();
        System.out.println(wagesArray[0].divide(wagesArray[1])
                                        .setScale(2, RoundingMode.CEILING));
    });

Is there a better way?


Solution

  • Here's a complete example using only BigDecimal arithmetics, and showing how to implement a custom collector

    import java.math.BigDecimal;
    import java.util.Collections;
    import java.util.Map;
    import java.util.Set;
    import java.util.function.BiConsumer;
    import java.util.function.BinaryOperator;
    import java.util.function.Function;
    import java.util.function.Supplier;
    import java.util.stream.Collector;
    import java.util.stream.Collectors;
    import java.util.stream.Stream;
    
    public final class Person {
    
        private final String name;
        private final String state;
        private final BigDecimal salary;
    
        public Person(String name, String state, BigDecimal salary) {
            this.name = name;
            this.state = state;
            this.salary = salary;
        }
    
        public String getName() {
            return name;
        }
    
        public String getState() {
            return state;
        }
    
        public BigDecimal getSalary() {
            return salary;
        }
    
        public static void main(String[] args) {
            Person p1 = new Person("John", "NY", new BigDecimal("2000"));
            Person p2 = new Person("Jack", "NY", new BigDecimal("3000"));
            Person p3 = new Person("Jane", "GA", new BigDecimal("1500"));
            Person p4 = new Person("Jackie", "GA", new BigDecimal("2500"));
    
            Map<String, BigDecimal> result =
                Stream.of(p1, p2, p3, p4).collect(
                    Collectors.groupingBy(Person::getState,
                                          Collectors.mapping(Person::getSalary,
                                                             new AveragingCollector())));
            System.out.println("result = " + result);
    
        }
    
        private static class AveragingCollector implements Collector<BigDecimal, IntermediateResult, BigDecimal> {
            @Override
            public Supplier<IntermediateResult> supplier() {
                return IntermediateResult::new;
            }
    
            @Override
            public BiConsumer<IntermediateResult, BigDecimal> accumulator() {
                return IntermediateResult::add;
            }
    
            @Override
            public BinaryOperator<IntermediateResult> combiner() {
                return IntermediateResult::combine;
            }
    
            @Override
            public Function<IntermediateResult, BigDecimal> finisher() {
                return IntermediateResult::finish
            }
    
            @Override
            public Set<Characteristics> characteristics() {
                return Collections.emptySet();
            }
        }
    
        private static class IntermediateResult {
            private int count = 0;
            private BigDecimal sum = BigDecimal.ZERO;
    
            IntermediateResult() {
            }
    
            void add(BigDecimal value) {
                this.sum = this.sum.add(value);
                this.count++;
            }
    
            IntermediateResult combine(IntermediateResult r) {
                this.sum = this.sum.add(r.sum);
                this.count += r.count;
                return this;
            }
    
            BigDecimal finish() {
                return sum.divide(BigDecimal.valueOf(count), 2, BigDecimal.ROUND_HALF_UP);
            }
        }
    }
    

    If you accept to transform your BigDecimal values to double (which, for an average of salaries, is perfectly acceptable, IMHO), you can just use

    Map<String, Double> result2 =
                Stream.of(p1, p2, p3, p4).collect(
                    Collectors.groupingBy(Person::getState,
                                          Collectors.mapping(Person::getSalary,
                                                             Collectors.averagingDouble(BigDecimal::doubleValue))));