Search code examples
scalaapache-spark

How to apply udf on a dataframe and on a column in scala?


I am beginner to scala. I tried scala REPL window in intellij. I have a sample df and trying to test udf function not builtin for understanding.

df:

scala> import org.apache.spark.sql.SparkSession
 val spark: SparkSession = SparkSession.builder.appName("elephant").config("spark.master", "local[*]").getOrCreate()
 val df = spark.createDataFrame(Seq(("A",1),("B",2),("C",3))).toDF("Letter", "Number")
 df.show()

output:

|Letter|Number|
+------+------+
|     A|     1|
|     B|     2|
|     C|     3|
+------+------+

udf for dataframe filter:

scala> def kill_4(n: String) : Boolean = {
     | if (n =="A"){ true} else {false}} // please validate if its correct ???

I tried

df.withColumn("new_col", kill_4(col("Letter"))).show() // please tell correct way???

error error: type mismatch

Second: I tried direct filter:

df.filter(kill_4(col("Letter"))).show()

output desired

+------+------+
|Letter|Number|
+------+------+
|     B|     2| 
|     C|     3| 
+------+------+-

Solution

  • You can register udf and use it in code as follows:

    import org.apache.spark.sql.functions.col
    
    def kill_4(n: String) : Boolean = {
         if (n =="A"){ true } else {false}
    }
     
    val kill_udf = udf((x: String) => kill_4(x))
    
    df.select(col("Letter"),col("Number")
        kill_udf(col("Letter")).as("Kill_4") ).show(false)