Search code examples
scalaapache-sparkmockitohbaserdd

How to create a Spark RDD of mocked elements?


I'd like to create an RDD (an actual one, not mocked) that contains mocked elements (with Mockito) in a unit test.

My attempt is:

lazy val sc = SparkContext.getOrCreate()
val myRDD = sc.parallelize(Seq( (Mockito.mock(classOf[ImmutableBytesWritable]), Mockito.mock(classOf[Result])) ))

where ImmutableBytesWritable and Result come from HBase API. I got org.apache.spark.SparkException: Task not serializable

Is there any way possible to achieve my goal? Thank you!


Solution

  • By default, Mockito mocks are not serializable, that's why you get the error.

    To create serializable mocks, you have to define it explicitly :

    mock = Mockito.mock(
        classOf[ImmutableBytesWritable],
        Mockito.withSettings().serializable()
    )
    

    The same thing should be applied to your Result mock.

    In case you got a java.lang.ClassNotFoundException: org.apache.hadoop.hbase.io.ImmutableBytesWritable$MockitoMock$... exception, you might need to use :

    import org.mockito.mock.SerializableMode
    
    mock = Mockito.mock(
        classOf[ImmutableBytesWritable],
        Mockito.withSettings().serializable(SerializableMode.ACROSS_CLASSLOADERS)
    )
    

    Finally, you should have something like :

    import org.apache.spark.SparkContext 
    import org.apache.spark.SparkConf    
    
    import org.apache.hadoop.hbase.io.ImmutableBytesWritable
    import org.apache.hadoop.hbase.client.Result
    
    import org.mockito.Mockito
    import org.mockito.mock.SerializableMode
    
    object Test extends App {
    
        val conf = new SparkConf()
            .setMaster("local[2]")
            .setAppName("test")
        lazy val sc = new SparkContext(conf)
    
        val mockImmutableBytesWritable = Mockito.mock(
            classOf[ImmutableBytesWritable],
            Mockito.withSettings().serializable(
                SerializableMode.ACROSS_CLASSLOADERS
            )
        )
    
        val mockResult = Mockito.mock(
            classOf[Result],
            Mockito.withSettings().serializable(
                SerializableMode.ACROSS_CLASSLOADERS
            )
        )
    
        val myRDD = sc.parallelize(Seq((mockImmutableBytesWritable, mockResult)))
    
        println(myRDD.count)
    
        sc.stop()
    
    }