Search code examples
javajava-streambigdecimal

Java streams adding multiple values conditionally


I have a List of objects like this, where amount can be negative or positive:

class Sale {
   String country;
   BigDecimal amount;
}

And I would like to end up with a pair of sums of all negative values, and all positive values, by country.

With these values:

country | amount
nl      | 9
nl      | -3
be      | 7.9
be      | -7

Is there a way to end up with Map<String, Pair<BigDecimal, BigDecimal>> using a single stream?

It's easy to do this with two separate streams, but I can't figure it out with just one.


Solution

  • It should be using Collectors.toMap with a merge function to sum pairs.

    Assuming that a Pair is immutable and has only getters for the first and second elements, the code may look like this:

    static Map<String, Pair<BigDecimal, BigDecimal>> sumUp(List<Sale> list) {
        return list.stream()
                   .collect(Collectors.toMap(
                       Sale::getCountry,
                       sale -> sale.getAmount().signum() >= 0 
                           ? new Pair<>(sale.getAmount(), BigDecimal.ZERO)
                           : new Pair<>(BigDecimal.ZERO, sale.getAmount()),
                       (pair1, pair2) -> new Pair<>(
                           pair1.getFirst().add(pair2.getFirst()),
                           pair1.getSecond().add(pair2.getSecond())
                       )
                       // , LinkedHashMap::new // optional parameter to keep insertion order
                   ));
    }
    

    Test

    List<Sale> list = Arrays.asList(
        new Sale("us", new BigDecimal(100)),
        new Sale("uk", new BigDecimal(-10)),
        new Sale("us", new BigDecimal(-50)),
        new Sale("us", new BigDecimal(200)),
        new Sale("uk", new BigDecimal(333)),
        new Sale("uk", new BigDecimal(-70))
    );
    
    Map<String, Pair<BigDecimal, BigDecimal>> map = sumUp(list);
    
    map.forEach((country, pair) -> 
        System.out.printf("%-4s|%s%n%-4s|%s%n", 
            country, pair.getFirst(), country, pair.getSecond()
    ));
    

    Output

    uk  |333
    uk  |-80
    us  |300
    us  |-50