Search code examples
scalaunit-testingapache-sparkmockingscalamock

Mocking SparkSession for unit testing


I have a method in my spark application that loads the data from a MySQL database. the method looks something like this.

trait DataManager {

val session: SparkSession

def loadFromDatabase(input: Input): DataFrame = {
            session.read.jdbc(input.jdbcUrl, s"(${input.selectQuery}) T0",
              input.columnName, 0L, input.maxId, input.parallelism, input.connectionProperties)
    }
}

The method does nothing else other than executing jdbc method and loads data from the database. How can I test this method? The standard approach is to create a mock of the object session which is an instance of SparkSession. But since SparkSession has a private constructor I was not able to mock it using ScalaMock.

The main ask here is that my function is a pure side-effecting function (the side-effect being pull data from relational database) and how can i unit test this function given that I have issues mocking SparkSession.

So is there any way I can mock SparkSession or any other better way than mocking to test this method?


Solution

  • In your case I would recommend not to mock the SparkSession. This would more or less mock the entire function (which you could do anyways). If you want to test this function my suggestion would be to run an embeded database (like H2) and use a real SparkSession. To do this you need to provide the SparkSession to your DataManager.

    Untested sketch:

    Your code:

    class DataManager (session: SparkSession) {
             def loadFromDatabase(input: Input): DataFrame = {
                session.read.jdbc(input.jdbcUrl, s"(${input.selectQuery}) T0",
                input.columnName, 0L, input.maxId, input.parallelism, input.connectionProperties)
             }
        }
    

    Your test-case:

    class DataManagerTest extends FunSuite with BeforeAndAfter {
      override def beforeAll() {
        Connection conn = DriverManager.getConnection("jdbc:h2:~/test", "sa", "");
        // your insert statements goes here
        conn.close()
      }
    
      test ("should load data from database") {
        val dm = DataManager(SparkSession.builder().getOrCreate())
        val input = Input(jdbcUrl = "jdbc:h2:~/test", selectQuery="SELECT whateveryounedd FROM whereeveryouputit ")
        val expectedData = dm.loadFromDatabase(input)
        assert(//expectedData)
      }
    }