Search code examples
scalaunit-testingapache-sparkmockitoscalatest

How to mock sqlContext.read.parquet()?


class Test{
override def execute(sqlContext: SQLContext) {
  val df: DataFrame = sqlContext.read.parquet(path)
}

How to mock sqlContext.read.parquet ? Need to read from a json and return that dummy dataframe when this is called

class XTest extends FunSuite with MockitoSugar {

test("Test") {

val sparkSession = SparkSession
  .builder()
  .master("local[*]")
  .appName("View_Persistence_Spark_Job")
  .getOrCreate()
sparkSession.sparkContext.setLogLevel("ERROR")

val test = new Test()
val df_from_json = sparkSession.read.option("multiline", "true").json("src/test/resources/test.json")
    val mockContext = mock[SQLContext]
        when(mockContext.read.parquet("src/test/resources/test.json")).thenReturn(df_from_json)
    test.execute(sparkSession.sqlContext)

Solution

  • There are 2 issues that needs to be understood from your example.

    First, you cannot concatenate calls when using mocks. Trying:

    when(mockContext.read.parquet("src/test/resources/test.json"))
    

    will always fail because the result mockContext does not implement the method read, therefore it will cause null reference exception. In order to solve this, we need to add another mock, which will be the result of mockContext.read. So this part will be:

    val mockContext = mock[SQLContext]
    val dataFrameReader = mock[DataFrameReader]
    
    when(mockContext.read).thenReturn(dataFrameReader)
    when(dataFrameReader.parquet("src/test/resources/test.json")).thenReturn(df_from_json)
    

    The second thing, in order to make test use that mock, you need to pass it to him, and not sparkSession.sqlContext, which is not a mock, therefore you cannot override its behaviour.

    To sum up, a complete test will be:

    test("Test") {
    
      val sparkSession = SparkSession
        .builder()
        .master("local[*]")
        .appName("View_Persistence_Spark_Job")
        .getOrCreate()
      sparkSession.sparkContext.setLogLevel("ERROR")
    
      val test = new Test()
      val df_from_json: sql.DataFrame = sparkSession.read.option("multiline", "true").json("src/test/resources/test.json")
      val mockContext = mock[SQLContext]
      val dataFrameReader = mock[DataFrameReader]
    
      when(mockContext.read).thenReturn(dataFrameReader)
      when(dataFrameReader.parquet("src/test/resources/test.json")).thenReturn(df_from_json)
      test.execute(mockContext)
    }