Search code examples
javaapache-beamapache-beam-ioapache-beam-internals

Apache Beam update current row values based on the values from previous row


Apache Beam update values based on the values from the previous row

I have grouped the values from a CSV file. Here in the grouped rows, we find a few missing values which need to be updated based on the values from the previous row. If the first column of the row is empty then we need to update it by 0.

I am able to group the records, But unable to figure out a logic to update the values, How do I achieve this?

Records

customerId date amount
BS:89481 1/1/2012 100
BS:89482 1/1/2012
BS:89483 1/1/2012 300
BS:89481 1/2/2012 900
BS:89482 1/2/2012 200
BS:89483 1/2/2012

Records on Grouping

customerId date amount
BS:89481 1/1/2012 100
BS:89481 1/2/2012 900
BS:89482 1/1/2012
BS:89482 1/2/2012 200
BS:89483 1/1/2012 300
BS:89483 1/2/2012

Update missing values

customerId date amount
BS:89481 1/1/2012 100
BS:89481 1/2/2012 900
BS:89482 1/1/2012 000
BS:89482 1/2/2012 200
BS:89483 1/1/2012 300
BS:89483 1/2/2012 300

Code Until Now:

public class GroupByTest {
    public static void main(String[] args) throws IOException {
        System.out.println("We are about to start!!");

        final File schemaFile = new File(
                "C:\\AI\\Workspace\\office\\lombok\\artifact\\src\\main\\resources\\schema_transform2.avsc");

        File csvFile = new File(
                "C:\\AI\\Workspace\\office\\lombok\\artifact\\src\\main\\resources\\CustomerRequest-case2.csv");
        Schema schema = new Schema.Parser().parse(schemaFile);

        Pipeline pipeline = Pipeline.create();

        // Reading schema
        org.apache.beam.sdk.schemas.Schema beamSchema = AvroUtils.toBeamSchema(schema);

        final PCollectionTuple tuples = pipeline

                // Reading csv input
                .apply("1", FileIO.match().filepattern(csvFile.getAbsolutePath()))

                // Reading files that matches conditions 
                .apply("2", FileIO.readMatches())

                // Reading schema and validating with schema and converts to row and returns
                // valid and invalid list
                .apply("3", ParDo.of(new FileReader(beamSchema)).withOutputTags(FileReader.validTag(),
                        TupleTagList.of(invalidTag())));

        // Fetching only valid rows
        final PCollection<Row> rows = tuples.get(FileReader.validTag()).setCoder(RowCoder.of(beamSchema));

        // Transformation
        //Convert row to KV
        final Group.CombineFieldsByFields<Row> combine = Group.<Row>byFieldNames("customerId", "date")
            .aggregateField("balance", Sum.ofDoubles(), "balances");

        final PCollection<Row> aggregagte = rows.apply(combine);

        PCollection<String> pOutput=aggregagte.apply(Select.flattenedSchema()).apply(ParDo.of(new RowToString()));
        
                        
        
        pipeline.run().waitUntilFinish();
        System.out.println("The end");

    }

    private static String getColumnValue(String columnName, Row row, Schema sourceSchema) {
        String type = sourceSchema.getField(columnName).schema().getType().toString().toLowerCase();
        LogicalType logicalType = sourceSchema.getField(columnName).schema().getLogicalType();
        if (logicalType != null) {
            type = logicalType.getName();
        }

        switch (type) {
        case "string":
            return row.getString(columnName);
        case "int":
            return Objects.requireNonNull(row.getInt32(columnName)).toString();
        case "bigint":
            return Objects.requireNonNull(row.getInt64(columnName)).toString();
        case "double":
            return Objects.requireNonNull(row.getDouble(columnName)).toString();
        case "timestamp-millis":
            return Instant.ofEpochMilli(Objects.requireNonNull(row.getDateTime("eventTime")).getMillis()).toString();

        default:
            return row.getString(columnName);

        }
    }



}

Modified Code: Original code

final Group.CombineFieldsByFields<Row> combine = Group.<Row>byFieldNames("customerId", "date")
        .aggregateField("amount", Sum.ofDoubles(), "balances");

Grouping by customerID

class ToKV extends DoFn<Row, KV<String, Row>> {

    private static final long serialVersionUID = -8093837716944809689L;
    String columnName1 = null;

    @ProcessElement
    public void processElement(ProcessContext context) {
        Row row = context.element();
        org.apache.beam.sdk.schemas.Schema schema = row.getSchema();
        context.output(KV.of(row.getValue(columnName1).toString(), row));
    }

    public void setColumnName1(String columnName1) {
        this.columnName1 = columnName1;
    }


}

Grouping by customerID:

ToKV toKV = new ToKV();
toKV.setColumnName1("ID");
PCollection<KV<String, Row>> kvRows = rows.apply(ParDo.of(toKV)).setCoder(KvCoder.of(StringUtf8Coder.of(), rows.getCoder()));
    
    
PCollection<KV<String,Iterable<Row>>> groupedKVRows = kvRows.apply(GroupByKey.<String,Row>create());

// Trying to grouping by date

    PCollection<Row> outputRow = 
            groupedKVRows
            .apply(ParDo.of(new GroupByDate()))
            .setCoder(RowCoder.of(AvroUtils.toBeamSchema(schema)));

How to write the logic to convert Iterable to pCollection so that the date can be sorted.

class GroupByDate extends DoFn<KV<String,Iterable<Row>>, Row> {

    private static final long serialVersionUID = -1345126662309830332L;

    @ProcessElement
    public void processElement(ProcessContext context) {
        String strKey = context.element().getKey();
        Iterable<Row> rows = context.element().getValue();
        
    
        
        
    }

Avro schema:

{
  "type" : "record",
  "name" : "Entry",
  "namespace" : "transform",
  "fields" : [  {
    "name" : "customerId",
    "type" : [ "string", "null" ]
  }, {
    "name" : "date",
    "type" : [ "string", "null" ],
    "logicalType": "date"
    
  }, {
    "name" : "amount",
    "type" : [ "double", "null" ]
  } ]
}

Update To convert PCollection to Row[]

class KVToRow extends DoFn<KV<String, Iterable<Row>>, Row[]> {

    private static final long serialVersionUID = -1345126662309830332L;

    @ProcessElement
    public void processElement(ProcessContext context) {
        String strKey = context.element().getKey();
        List<Row> rowList = new ArrayList();
        Iterable<Row> rowValue = context.element().getValue();
        rowValue.forEach(data -> {
            rowList.add(data);

        });
        Row[] rowArray = new Row[rowList.size()-1];
        rowArray=rowList.toArray(rowArray);
        context.output(rowArray);
    }
}

Suggested Code

Row[] rowArray = Iterables.toArray(rows, Row.class);

Error:

The method toArray(Iterable<? extends T>, Class) in the type Iterables is not applicable for the arguments (PCollection, Class)

Convert iterable to array

Row[] rowArray =  groupedKVRows.apply(ParDo.of(new KVToRow()));

Error:

Multiple markers at this line - Type mismatch: cannot convert from PCollection<Row[]> to Row[] - 1 changed line, 2 deleted


Solution

  • Beam does not provide any order guarantees, so you will have to group them as you did.

    But as far as I can understand from your case, you need to group by customerId. After that, you can apply a PTransform like ParDo to sort the grouped Rows by date and fill missing values however you wish.

    Example sorting by converting to Array

    static class SortAndForwardFillFn extends DoFn<KV<String, Iterable<Row>>> {
    
        @ProcessElement
        public void processElement(@Element KV<String, Iterable<Row>> element, OutputReceiver<KV<String, Iterable<Row>>> outputReceiver) {
    
            // Create a formatter for parsing dates
            DateTimeFormatter formatter = DateTimeFormat.forPattern("dd/MM/yyyy HH:mm:ss");
    
            // Convert iterable to array
            Row[] rowArray = Iterables.toArray(rows, Row.class);
    
            // Sort array using dates
            Arrays
                .sort(
                    rowArray,
                    Comparator
                    .comparingLong(row -> formatter.parseDateTime(row.getString("date")).getMillis())
            );
    
            // Store the last amount
            Double lastAmount = 0.0;
    
            // Create a List for storing sorted and filled rows
            List<Row> resultRows = new ArrayList<>(rowArray.length);
    
            // Iterate over the array and fill in the missing parts
            for (Row row : rowArray) {
    
                // Get current amount
                Double currentAmount = row.getDouble("amount");
    
                // If null, fill the previous value and add to results, 
                // otherwise add as it is
                resultRows.add(...);
            }
    
            // Output using the output receiver
            outputReceiver
                .output(
                    KV.of(element.getKey(), resultRows)
                )
            );
        }
    }