Search code examples
scalaunit-testingapache-sparkmockitoscalatest

How to mock a function to return a dummy value in scala?


    object ReadUtils {
    def readData(sqlContext: SQLContext, fileType: FileType.Value): List[DataFrame] = {
    //some logic

}

I am writing test for the execute function

    import com.utils.ReadUtils.readData
    
    class Logs extends Interface with NativeImplicits{
    
    override def execute(sqlContext: SQLContext){
        val inputDFs: List[DataFrame] = readData(sqlContext, FileType.PARQUET)
        //some logic
    }

how to mock the readData function to return a dummy value when writing test for execute function? Currently its calling the actual function.

test("Log Test") {
val df1 = //some dummy df
val sparkSession = SparkSession
      .builder()
      .master("local[*]")
      .appName("test")
      .getOrCreate()
    sparkSession.sparkContext.setLogLevel("ERROR")
   
val log = new Logs()
    val mockedReadUtils = mock[ReadUtils.type]
    when(mockedReadUtils.readData(sparkSession.sqlContext,FileType.PARQUET)).thenReturn(df1)
    log.execute(sparkSession.sqlContext)

Solution

  • The simple answer is - you can't do it. Objects are basically singletons in scala and you can't mock singletons - that's one of the reasons why they say that you should avoid singletons as much as possible.

    You could mock sqlContext instead, and all its functions which are called in readData function.

    As another approach, you could try to add Dependency Injection with some sort of Cake Pattern - https://medium.com/rahasak/scala-cake-pattern-e0cd894dae4e

    trait DataReader {
      def readData(sqlContext: SQLContext, fileType: FileType.Value): List[DataFrame]
    }
    
    trait RealDataReader {
      def readData(sqlContext: SQLContext, fileType: FileType.Value): List[DataFrame] = {
        // some code
      }
    }
    
    trait MockedDataReader {
      def readData(sqlContext: SQLContext, fileType: FileType.Value): List[DataFrame] = {
         // some moking code
      }
    }
    
    class Logs extends Interface with NativeImplicits with DataReader {
    
    override def execute(sqlContext: SQLContext){
      val inputDFs: List[DataFrame] = readData(sqlContext, FileType.PARQUET)
        //some logic
      }
    }
    
    class RealLogs extends Logs with RealDataReader // that would be the real class
    
    class MockedLogs extends Logs with MockedDataReader // that would be the class for tests