Search code examples
apache-sparkapache-spark-sqludf

Merge multiple columns in a Spark DataFrame [Java]


How to combine multiple columns (say 3) from a DataFrame in a single column (in a new DataFrame) where each row becomes a Spark DenseVector? Similar to this thread but in Java and with a few tweaks mentioned below.

I tried using a UDF like this:

private UDF3<Double, Double, Double, Row> toColumn = new UDF3<Double, Double, Double, Row>() {

    private static final long serialVersionUID = 1L;

    public Row call(Double first, Double second, Double third) throws Exception {           
        Row row = RowFactory.create(Vectors.dense(first, second, third));

        return row; 
    }
};

And then register the UDF:

sqlContext.udf().register("toColumn", toColumn, dataType);

Where the dataType is:

StructType dataType = DataTypes.createStructType(new StructField[]{
    new StructField("bla", new VectorUDT(), false, Metadata.empty()),
    });

When I call this UDF on a DataFrame with 3 columns and print out the schema of the new DataFrame, I get this:

root |-- features: struct (nullable = true) | |-- bla: vector (nullable = false)

The problem here is that I need a vector to be outside, not within a struct. Something like this:

root
 |-- features: vector (nullable = true)

I don't know how to get this since the register function requires the return type of UDF to be DataType (which, in turn, doesn't provide a VectorType)


Solution

  • You actually nested the vector type into a struct manually by using this data type:

    new StructField("bla", new VectorUDT(), false, Metadata.empty()),
    

    If you remove the outer StructField, you will get what you want. Of course, in this case, you need to modify a bit the signature of your function definition. That is, you need to return with the type Vector.

    Please see below my concrete example of what I mean in the form of a simple JUnit test.

    package sample.spark.test;
    
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.mllib.linalg.Vector;
    import org.apache.spark.mllib.linalg.VectorUDT;
    import org.apache.spark.mllib.linalg.Vectors;
    import org.apache.spark.sql.DataFrame;
    import org.apache.spark.sql.RowFactory;
    import org.apache.spark.sql.SQLContext;
    import org.apache.spark.sql.api.java.UDF3;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.Metadata;
    import org.apache.spark.sql.types.StructField;
    import org.junit.Test;
    
    import java.io.Serializable;
    import java.util.Arrays;
    import java.util.HashSet;
    import java.util.Set;
    
    import static org.junit.Assert.assertEquals;
    import static org.junit.Assert.assertTrue;
    
    public class ToVectorTest implements Serializable {
      private static final long serialVersionUID = 2L;
    
      private UDF3<Double, Double, Double, Vector> toColumn = new UDF3<Double, Double, Double, Vector>() {
    
        private static final long serialVersionUID = 1L;
    
        public Vector call(Double first, Double second, Double third) throws Exception {
          return Vectors.dense(first, second, third);
        }
      };
    
      @Test
      public void testUDF() {
        // context
        final JavaSparkContext sc = new JavaSparkContext("local", "ToVectorTest");
        final SQLContext sqlContext = new SQLContext(sc);
    
        // test input
        final DataFrame input = sqlContext.createDataFrame(
            sc.parallelize(
                Arrays.asList(
                    RowFactory.create(1.0, 2.0, 3.0),
                    RowFactory.create(4.0, 5.0, 6.0),
                    RowFactory.create(7.0, 8.0, 9.0),
                    RowFactory.create(10.0, 11.0, 12.0)
                )),
            DataTypes.createStructType(
                Arrays.asList(
                    new StructField("feature1", DataTypes.DoubleType, false, Metadata.empty()),
                    new StructField("feature2", DataTypes.DoubleType, false, Metadata.empty()),
                    new StructField("feature3", DataTypes.DoubleType, false, Metadata.empty())
                )
            )
        );
        input.registerTempTable("input");
    
        // expected output
        final Set<Vector> expectedOutput = new HashSet<>(Arrays.asList(
            Vectors.dense(1.0, 2.0, 3.0),
            Vectors.dense(4.0, 5.0, 6.0),
            Vectors.dense(7.0, 8.0, 9.0),
            Vectors.dense(10.0, 11.0, 12.0)
        ));
    
        // processing
        sqlContext.udf().register("toColumn", toColumn, new VectorUDT());
        final DataFrame outputDF = sqlContext.sql("SELECT toColumn(feature1, feature2, feature3) AS x FROM input");
        final Set<Vector> output = new HashSet<>(outputDF.toJavaRDD().map(r -> r.<Vector>getAs("x")).collect());
    
        // evaluation
        assertEquals(expectedOutput.size(), output.size());
        for (Vector x : output) {
          assertTrue(expectedOutput.contains(x));
        }
    
        // show the schema and the content
        System.out.println(outputDF.schema());
        outputDF.show();
    
        sc.stop();
      }
    }