Search code examples
javagoogle-cloud-dataflowgcloud

How to do a cartesian product of two PCollections in Dataflow?


I would like to do a cartesian product of two PCollections. Neither PCollection can fit into memory, so doing side input is not feasible.

My goal is this: I have two datasets. One is many elements of small size. The other is few (~10) of very large size. I would like to take the product of these two elements and then produce key-value objects.


Solution

  • I think CoGroupByKey might work in your situation:

    https://cloud.google.com/dataflow/model/group-by-key#join

    That's what I did for a similar use-case. Though mine had probably not been constrained by the memory (have you tried a larger cluster with bigger machines?):

    PCollection<KV<String, TableRow>> inputClassifiedKeyed = inputClassified
            .apply(ParDo.named("Actuals : Keys").of(new ActualsRowToKeyedRow()));
    
    PCollection<KV<String, Iterable<Map<String, String>>>> groupedCategories = p
    [...]
    .apply(GroupByKey.create());
    

    So the collections are keyed by the same key.

    Then I declared the Tags:

    final TupleTag<Iterable<Map<String, String>>> categoryTag = new TupleTag<>();
    final TupleTag<TableRow> actualsTag = new TupleTag<>();
    

    Combined them:

    PCollection<KV<String, CoGbkResult>> actualCategoriesCombined =
            KeyedPCollectionTuple.of(actualsTag, inputClassifiedKeyed)
                    .and(categoryTag, groupedCategories)
                    .apply(CoGroupByKey.create());
    

    And in my case the final step - reformatting the results (from the tagged groups in the continuous flow:

    actualCategoriesCombined.apply(ParDo.named("Actuals : Formatting").of(
        new DoFn<KV<String, CoGbkResult>, TableRow>() {
            @Override
            public void processElement(ProcessContext c) throws Exception {
                KV<String, CoGbkResult> e = c.element();
    
                Iterable<TableRow> actualTableRows =
                        e.getValue().getAll(actualsTag);
                Iterable<Iterable<Map<String, String>>> categoriesAll =
                        e.getValue().getAll(categoryTag);
    
                for (TableRow row : actualTableRows) {
                    // Some of the actuals do not have categories
                    if (categoriesAll.iterator().hasNext()) {
                        row.put("advertiser", categoriesAll.iterator().next());
                    }
                    c.output(row);
                }
            }
        }))
    

    Hope this helps. Again - not sure about the in memory constraints. Please do tell the results if you try this approach.