Search code examples
javaapache-sparkjava-8user-defined-functions

How to create a Spark UDF that returns a Tuple or updates two columns at the same time?


I'm trying to create UDF in java that performs a computation by iterating over all columns in the dataset, calculates a score for each column and sets a flag with a specific string. I tried having a UDF that updates two columns at the same time but no success. So my current approach is to have a single column composed of a tuple that I plan on splitting into two in a subsequent step.

This is my attempt (example code):

public static void main(String[] args) {
        // Spark session
        SparkSession spark = SparkSession.builder()
                .appName("UDFExample")
                .config("spark.master", "local")
                .getOrCreate();


        UDF2<Integer, Tuple2<Double, String>, Tuple2<Double, String>> calculateScore = (feature, ScoreValueAndReason) -> {
            double new_ScoreValue = 0.5;

            String new_ScoreReason = "Null Issue";
            double ScoreValue = ScoreValueAndReason._1();
            String ScoreReason = ScoreValueAndReason._2();

            double updated_ScoreValue = Math.min(ScoreValue, new_ScoreValue);
            String updated_ScoreReason = new_ScoreValue < ScoreValue ? new_ScoreReason: ScoreReason;

            return new Tuple2<>(updated_ScoreValue, updated_ScoreReason);

        };
        Dataset<Row> df = spark.createDataFrame(Arrays.asList(
                RowFactory.create(1, 2, 3, 4),
                RowFactory.create(2, 1, 4, 3),
                RowFactory.create(3, 4, 9, 2),
                RowFactory.create(4, 3, 2, 1)
        ), new StructType(new StructField[]{
                new StructField("A", DataTypes.IntegerType, true, Metadata.empty()),
                new StructField("B", DataTypes.IntegerType, true, Metadata.empty()),
                new StructField("C", DataTypes.IntegerType, true, Metadata.empty()),
                new StructField("D", DataTypes.IntegerType, true, Metadata.empty())
        }));

        df = df.withColumn("ScoreValueAndReason",
                functions.struct(functions.lit(1.0), functions.lit("No issues"))
        );
        df.show();

        // Register the UDF
        spark.udf().register("calculateScore", calculateScore, DataTypes.createStructType(new StructField[]{
                new StructField("_1", DataTypes.DoubleType, false, Metadata.empty()),
                new StructField("_2", DataTypes.StringType, false, Metadata.empty())
        }));

        String[] columnNames = df.columns();
        for (String columnName: columnNames) {
            df = df.withColumn(columnName, functions.callUDF("calculateScore", col(columnName), col("ScoreValueAndReason")));
        }
        df.show();

I tried the implementation above as well as a slightly different udf, but I keep running into this error: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema incompatible with scala.Tuple2

I expect a resulting dataframe with either two new columns scoreeReason and scoreValue or (more like in the example I provided) a single column scoreReasonAndValue that contains the values updated by my udf.


Solution

  • See here.

    UDFs aren't, out of the box, supporting complex types and you are forced to use Row directly. For Scala there was https://github.com/gaborbarna/typedudf and is https://github.com/typelevel/frameless providing typed udf's that derive and use encoders correctly to work with Row.

    That said if you can live without using the udf function in sql then you can define this helper function using the encoders which will work for any param types (tested on Spark 3.4.1, ScalaUDF may be different on other versions):

    public <A, B, R> BiFunction<Column, Column, Column> withEncoders(UDF2<A, B, R> udf2, String name, Encoder<A> encA, Encoder<B> encB, Encoder<R> encR, DataType returnType){
        DataType replaced = CharVarcharUtils.failIfHasCharVarchar(returnType);
        scala.Function2<A, B, R> func = new scala.runtime.AbstractFunction2<A, B, R>(){
    
            @Override
            public Object apply(Object v1, Object v2) {
                try {
                    return ((UDF2<Object, Object, Object>) udf2).call(v1, v2);
                } catch (Exception e) {
                    e.printStackTrace();
                    return null;
                }
            }
        };
        return (Column a, Column b) -> {
            ScalaUDF udf = new ScalaUDF(func, replaced, Seq$.MODULE$.< Expression>newBuilder().$plus$eq(a.expr()).$plus$eq(b.expr()).result(),
                    Seq$.MODULE$.< scala.Option<ExpressionEncoder<?>> >newBuilder().$plus$eq(Some$.MODULE$.apply((ExpressionEncoder<A>) encA)).
                            $plus$eq(Some$.MODULE$.apply((ExpressionEncoder<B>)encB)).result(),
                    Some$.MODULE$.apply((ExpressionEncoder<R>) encR), Some$.MODULE$.apply(name), true, true);
            return new Column(udf);
        };
    }
    

    Then:

    UDF2<Integer, Tuple2<Double, String>, Tuple2<Double, String>> calculateScoreRaw = (feature, ScoreValueAndReason) -> {
        double new_ScoreValue = 0.5;
    
        String new_ScoreReason = "Null Issue";
        double ScoreValue = ScoreValueAndReason._1();
        String ScoreReason = ScoreValueAndReason._2();
    
        double updated_ScoreValue = Math.min(ScoreValue, new_ScoreValue);
        String updated_ScoreReason = new_ScoreValue < ScoreValue ? new_ScoreReason: ScoreReason;
    
        return new Tuple2<>(updated_ScoreValue, updated_ScoreReason);
    
    };
    
    Encoder<Tuple2<Double, String>> dsTEncoder = Encoders.tuple(Encoders.DOUBLE(), Encoders.STRING());
    
    BiFunction<Column, Column, Column> calculateScore = withEncoders(calculateScoreRaw, "calculateScore",
            Encoders.INT(), dsTEncoder, dsTEncoder
            , DataTypes.createStructType(new StructField[]{
            new StructField("_1", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("_2", DataTypes.StringType, false, Metadata.empty())
    }));
    
    Dataset<Row> df = spark.createDataFrame(Arrays.asList(
            RowFactory.create(1, 2, 3, 4),
            RowFactory.create(2, 1, 4, 3),
            RowFactory.create(3, 4, 9, 2),
            RowFactory.create(4, 3, 2, 1)
    ), new StructType(new StructField[]{
            new StructField("A", DataTypes.IntegerType, true, Metadata.empty()),
            new StructField("B", DataTypes.IntegerType, true, Metadata.empty()),
            new StructField("C", DataTypes.IntegerType, true, Metadata.empty()),
            new StructField("D", DataTypes.IntegerType, true, Metadata.empty())
    }));
    
    String[] columnNames = df.columns();
    
    df = df.withColumn("ScoreValueAndReason",
            functions.struct(functions.lit(1.0), functions.lit("No issues"))
    );
    df.show();
    
    // Register the UDF
    //spark.udf().register("calculateScore", calculateScore, );
    
    for (String columnName: columnNames) {
        df = df.withColumn(columnName, calculateScore.apply( col(columnName), col("ScoreValueAndReason")));
    }
    df.show();
    

    produces:

    +-----------------+-----------------+-----------------+-----------------+-------------------+
    |                A|                B|                C|                D|ScoreValueAndReason|
    +-----------------+-----------------+-----------------+-----------------+-------------------+
    |{0.5, Null Issue}|{0.5, Null Issue}|{0.5, Null Issue}|{0.5, Null Issue}|   {1.0, No issues}|
    |{0.5, Null Issue}|{0.5, Null Issue}|{0.5, Null Issue}|{0.5, Null Issue}|   {1.0, No issues}|
    |{0.5, Null Issue}|{0.5, Null Issue}|{0.5, Null Issue}|{0.5, Null Issue}|   {1.0, No issues}|
    |{0.5, Null Issue}|{0.5, Null Issue}|{0.5, Null Issue}|{0.5, Null Issue}|   {1.0, No issues}|
    +-----------------+-----------------+-----------------+-----------------+-------------------+
    

    so a few bits have moved around from your original source (crucially the columnNames also has to move up otherwise it'll incorrectly include ScoreValueAndReason).

    The creation of the ScalaUDF maker:

    BiFunction<Column, Column, Column> calculateScore = withEncoders(calculateScoreRaw, "calculateScore",
            Encoders.INT(), dsTEncoder, dsTEncoder
            , DataTypes.createStructType(new StructField[]{
            new StructField("_1", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("_2", DataTypes.StringType, false, Metadata.empty())
    }));
    

    passes in the original UDF you defined, with the correct encoders and the return type. You can also make this sql callable but that's another question's answer.