Search code examples
javaapache-sparkapache-spark-mllibapache-spark-ml

Spark MLlib classification input format using Java


How can I transform List of DTOs to Spark ML input dataset format?

I have DTO:

public class MachineLearningDTO implements Serializable {
    private double label;
    private double[] features;

    public MachineLearningDTO() {
    }

    public MachineLearningDTO(double label, double[] features) {
        this.label = label;
        this.features = features;
    }

    public double getLabel() {
        return label;
    }

    public void setLabel(double label) {
        this.label = label;
    }

    public double[] getFeatures() {
        return features;
    }

    public void setFeatures(double[] features) {
        this.features = features;
    }
}

And code:

Dataset<MachineLearningDTO> mlInputDataSet = spark.createDataset(mlInputData, Encoders.bean(MachineLearningDTO.class));
LogisticRegression logisticRegression = new LogisticRegression();
LogisticRegressionModel model = logisticRegression.fit(MLUtils.convertMatrixColumnsToML(mlInputDataSet));

After execution of code I am getting:

java.lang.IllegalArgumentException: requirement failed: Column features must be of type org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7 but was actually ArrayType(DoubleType,false).

If change it to org.apache.spark.ml.linalg.VectorUDT with code:

VectorUDT vectorUDT = new VectorUDT();
vectorUDT.serialize(Vectors.dense(......));

Then I am getting:

java.lang.UnsupportedOperationException: Cannot infer type for class org.apache.spark.ml.linalg.VectorUDT because it is not bean-compliant

at org.apache.spark.sql.catalyst.JavaTypeInference$.org$apache$spark$sql$catalyst$JavaTypeInference$$serializerFor(JavaTypeInference.scala:437)


Solution

  • I have figured out, just in case someone also will stuck with it, I wrote simple converter and it works:

    private Dataset<Row> convertToMlInputFormat(List< MachineLearningDTO> data) {
        List<Row> rowData = data.stream()
                .map(dto ->
                        RowFactory.create(dto.getLabel() ? 1.0d : 0.0d, Vectors.dense(dto.getFeatures())))
                .collect(Collectors.toList());
        StructType schema = new StructType(new StructField[]{
                new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
                new StructField("features", new VectorUDT(), false, Metadata.empty()),
        });
    
        return spark.createDataFrame(rowData, schema);
    }