Search code examples
javaapache-sparkapache-spark-sqlapache-spark-mlapache-spark-dataset

How to transform a csv string into a Spark-ML compatible Dataset<Row> format?


I have a Dataset<Row> df, that contains two columns ("key" and "value") of type string. df.printSchema(); is giving me the following output:

root
 |-- key: string (nullable = true)
 |-- value: string (nullable = true)

The content of the value column is actually a csv formated line (coming from a kafka topic), with the last entry of that line representing the class label and all the previous entries beeing the features (first row not included in the dataset):

feature0,feature1,label
0.6720004294237854,-0.4033586564886893,0
0.6659082469383558,0.07688976580256132,0
0.8086502311695247,0.564354801275521,1

Since I would like to train a classifier on this data, I need to transform this representation into a row of type dense vector, containing all the feature values and a column of type double, containing the label value:

root
 |-- indexedFeatures: vector (nullable = false)
 |-- indexedLabel: double (nullable = false)

How can I do this, using java 1.8 and Spark 2.2.0?

Edit: I got further, but while attempting to make it work with a flexible amount feature dimensions, I got stuck again. I created a follow-up question.


Solution

  • A VectorAssembler (javadocs) can transform the dataset into the required format.

    First, the input is split into three columns:

    Dataset<FeaturesAndLabelData> featuresAndLabelData = inputDf.select("value").as(Encoders.STRING())
      .flatMap(s -> {
        String[] splitted = s.split(",");
        if (splitted.length == 3) {
          return Collections.singleton(new FeaturesAndLabelData(
            Double.parseDouble(splitted[0]),
            Double.parseDouble(splitted[1]), 
            Integer.parseInt(splitted[2]))).iterator();
        } else {
          // apply some error handling...
          return Collections.emptyIterator();
        }
      }, Encoders.bean(FeaturesAndLabelData.class));
    

    The result is then transformed by a VectorAssembler:

    VectorAssembler assembler = new VectorAssembler()
      .setInputCols(new String[] { "feature1", "feature2" })
      .setOutputCol("indexedFeatures");
    Dataset<Row> result = assembler.transform(featuresAndLabelData)
      .withColumn("indexedLabel", functions.col("label").cast("double"))
      .select("indexedFeatures", "indexedLabel");
    

    The result dataframe has the required format:

    +----------------------------------------+------------+
    |indexedFeatures                         |indexedLabel|
    +----------------------------------------+------------+
    |[0.6720004294237854,-0.4033586564886893]|0.0         |
    |[0.6659082469383558,0.07688976580256132]|0.0         |
    |[0.8086502311695247,0.564354801275521]  |1.0         |
    +----------------------------------------+------------+
    
    root
     |-- indexedFeatures: vector (nullable = true)
     |-- indexedLabel: double (nullable = true)
    

    FeaturesAndLabelData is a simple Java bean to make sure that the column names are correct:

    public class FeaturesAndLabelData {
      private double feature1;
      private double feature2;
      private int label;
    
      //getters and setters...
    }