Search code examples
javafunctional-programmingrx-javareactivex

How to correctly understand behavior of RxJava's groupBy?


I'm pretty new to RxJava and FP in general. I want to write a code to join two Observables. Let's say we have two sets of integers:

  • [0..4] with key selector as modulo of 2, giving (key, value) = {(0,0), (1,1), (0,2),...}
  • [0..9] with key selector as modulo of 3, giving (key, value) = {(0,0), (1,1), (2,2), (0,3), (1,4),...}

My steps to join them are as follows:

  1. Group each set by its keys. The 1st set creates two groups with keys 0 and 1. The 2nd creates three groups with keys 0, 1 and 2.
  2. Make a Cartesian product of two sets of groups, giving 6 pairs of groups in total with keys: 0-0, 0-1, 0-2, 1-0, 1-1, 1-2.
  3. Filter only those pairs that have same keys on both sides, leaving only 0-0 and 1-1.
  4. Within each pair, make a Cartesian product of left and right groups.

Below is the helper class to calculate Cartesian product:

public class Cross<TLeft, TRight, R> implements Observable.Transformer<TLeft, R> {
    private Observable<TRight>      _right;
    private Func2<TLeft, TRight, R> _resultSelector;

    public Cross(Observable<TRight> right, Func2<TLeft, TRight, R> resultSelector) {
        _right = right;
        _resultSelector = resultSelector;
    }

    @Override
    public Observable<R> call(Observable<TLeft> left) {
        return left.flatMap(l -> _right.map(r -> _resultSelector.call(l, r)));
    }
}

And here's the code to join:

Observable.range(0, 5).groupBy(i -> i % 2)
        .compose(new Cross<>(Observable.range(0, 10).groupBy(i -> i % 3), ImmutablePair::new))
        .filter(pair -> pair.left.getKey().equals(pair.right.getKey()))
        .flatMap(pair -> pair.left.compose(new Cross<>(pair.right, ImmutablePair::new)))
        .subscribe(System.out::println);

However, the output is not correct:

(0,0)
(0,3)
(0,6)
(0,9)
(1,1)
(1,4)
(1,7)

If I remove the line containing filter, there'll be no result at all. The correct output should be just like running this:

Observable.range(0, 5)
        .compose(new Cross<>(Observable.range(0, 10), ImmutablePair::new))
        .filter(pair -> pair.left % 2 == pair.right % 3)
        .subscribe(System.out::println);

which gives:

(0,0)
(0,3)
(0,6)
(0,9)
(1,1)
(1,4)
(1,7)
(2,0)
(2,3)
(2,6)
(2,9)
(3,1)
(3,4)
(3,7)
(4,0)
(4,3)
(4,6)
(4,9)

Could someone explain the behavior? Many thanks.

Note: I use org.apache.commons.lang3.tuple.ImmutablePair in case you wonder.


Solution

  • The problem is that this setting tries to subscribe to a group multiple times which is not allowed. You'd see the exception with subscribe(System.out::println, Throwable::printStackTrace); overload, which is always advised to use over the other. Here is the fixed example that allows reuse at the expense of another layer of ImmutablePair:

    Func1<Integer, Integer> m2 = i -> i % 2;
    Func1<Integer, Integer> m3 = i -> i % 3;
    
    Observable<ImmutablePair<Integer, Observable<Integer>>> g2 = 
            Observable.range(0, 5).groupBy(m2).map(g -> new ImmutablePair<>(g.getKey(), g.cache()));
    Observable<ImmutablePair<Integer, Observable<Integer>>> g3 = 
            Observable.range(0, 10).groupBy(m3).map(g -> new ImmutablePair<>(g.getKey(), g.cache()));
    
    Observable<ImmutablePair<ImmutablePair<Integer, Observable<Integer>>, ImmutablePair<Integer, Observable<Integer>>>> x1 
    = g2.compose(new Cross<>(g3, ImmutablePair::new));
    
    Observable<ImmutablePair<ImmutablePair<Integer, Observable<Integer>>, ImmutablePair<Integer, Observable<Integer>>>> x2 
    = x1.filter(pair -> pair.left.getKey().equals(pair.right.getKey()));
    
    
    Observable<ImmutablePair<Integer, Integer>> o = x2.flatMap(pair -> 
    pair.left.right.compose(new Cross<>(pair.right.right, ImmutablePair::new)));
    
    o.subscribe(System.out::println, Throwable::printStackTrace);
    

    (I'm sorry about the long types, Eclipse has all sorts of inference problems if I try to inline them instead of having a local variable)