Search code examples
javaapache-sparkrddapache-spark-dataset

Using the Apache Spark RDD map method (Java API) to produce a non-columnar result


Please note: I believe I'm correct in trying to use the RDD map method here, but if there is another way to accomplish what I'm looking for, I'm all ears!


Brand new to Spark 2.4.x here, and using the Java (not Scala) API.

I'm trying to wrap my brain around the RDD map(...) method, specifically on Datasets and not restricted only to RDDs. The canonical example of its use from the official docs is:

public class GetLength implements Function<String, Integer> {
  public Integer call(String s) { return s.length(); }
}

JavaRDD<String> lines = sc.textFile("data.txt");
JavaRDD<Integer> lineLengths = lines.map(new GetLength());

So it seems that, in this case, that after the lines RDD is created, it has a single column (whose name I'm unsure of) where each column value is a different line of the file, and that each row in the RDD represents a different line of the file as well. Meaning, lines is an nx1 matrix where n is the number of rows/lines in the file.

It also seems that when the GetLength function is executed, it is fed each row's one-and-only column as the input string and returns an integer representing the line length of that string as a new column value in a different dataset, which is also nx1 (just holding line length info instead of the actual lines/strings).

OK, so I get that trivial example. But what if we have nxm datasets, meaning, lots of rows and lots of columns, and we want to write functions that transform them into other nxm datasets?

For example, let's say I have the following "input" dataset:

+-------------------------+
| price | color | fizz    |
+-------------------------+
| 2.99  | red   | hallo   |
| 13.94 | blue  | yahtzee |
| 7.19  | red   | trueth  |
...
| 4.15  | green | whatevs |
+-------------------------+

Where price is a numeric/floating-point type and both color and fizz are strings. So here we have an nx3 shaped dataset; n rows and always 3 columns in each row.

How would I write a map function that also returned an nx3 dataset, with the same columns/column names/schema, but different values (based on the function)?

For instance, say I wanted a new nx3 dataset with the same schema, but that added 2.0 to the price column if the row's color value equals the string "red"?

Hence, using the arbitrary dataset above, the new dataset coming out of this map function would look like:

+-------------------------+
| price | color | fizz    |
+-------------------------+
| 4.99  | red   | hallo   |  <== added 2.0 to price since color == "red"
| 13.94 | blue  | yahtzee |
| 9.19  | red   | trueth  |  <== added 2.0 to price since color == "red"
...
| 4.15  | green | whatevs |
+-------------------------+

I'm tempted to do something like:

public class UpdatedPriceForRedColors implements Function2<String, Double, Double> {
  public Double call(String color, Double currentPrice) {

    if ("red".equals(color) {
        return currentPrice + 2.0;
    } else {
        return currentPrice;
    }
  }
}

JavaRDD<Double> updatedPrices = myDS.map(new UpdatedPriceForRedColors());

However, several issues here:

  1. updatedPrices is now only an nx1 dataset consisting of the correctly-computed prices for each row in myDS, whereas I want something that has the same price/color/fizz that looks like the 2nd arbitrary dataset up above
  2. How does the UpdatedPriceForRedColors know that its first string argument is the color column, and not the fizz column?
  3. The Function API seems to only go up to either Function5 or Function6 (it's hard to discern what is available to the Java API and what is exclusive to the Scala API). This means I can only write functions that take in 5 or 6 arguments, whereas I might have datasets with 10+ columns in them, and I might very well need most of those column values "injected" into the function so I can compute the return value of the new dataset. What options do I have available in this case?

Solution

  • First of all, RDDs kind of always have one column, because RDDs have no schema information and thus you are tied to the T type in RDD<T>.

    Option 1 is to use a Function<String,String> which parses the String in RDD<String>, does the logic to manipulate the inner elements in the String, and returns an updated String.

    If you want your RDD to have some schema information, you can use an RDD<Row> which lets you access separate elements inside a Row (option 2).

    import org.apache.spark.sql.Row
    JavaRDD<Row> rddRow = rddString.map(new Function<String, Row>() {
        @Override
        public Row call(String line) throws Exception {
          String[] parts = line.split("\\t");//tab separated values
          Row row = RowFactory.create(parts[0], parts[1], parts[2]);
          return row;
        }
      });
    

    Now you can map the rows:

    RDD<Row> updatedRdd = rddRow.map(new Function<Row, Row>() {
        @Override
        public Row call(Row row) throws Exception {
          Float price = row.get(0);
          String color = row.get(1);
          //Your logic here          
          Row row = RowFactory.create(/* three elements here, or whatever*/);
          return row;
        }
      });
    

    If you go one step further, you can use a true Dataset (as explained here) and leverage the Dataframe/Dataset API (option 3).

    import org.apache.spark.sql.Row
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    
    StructType schema = DataTypes.createStructType(
        new StructField[]{
                DataTypes.createStructField("price", FloatType, false),
                DataTypes.createStructField("color", StringType, false),
                DataTypes.createStructField("fizz", StringType, false)
        });
    
    
    JavaRDD<Row> rddRow = rddString.map(new Function<String, Row>() {
        @Override
        public Row call(String line) throws Exception {
          String[] parts = line.split("\\t");//tab separated values
          Row row = RowFactory.create(parts[0], parts[1], parts[2]);
          return row;
        }
      });
    
    DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
    

    Having a dataframe lets you now use something like this:

    DataFrame df2 = df.withColumn("price", 
        when(col("color").equals("red"), col("price").add(2f))
            .otherwise(col("price")));
    

    Disclaimer: I haven't checked the java syntax and API as I'm used to scala.