Search code examples
scalaapache-sparkgenericspysparkpy4j

Converting Scala case class to PySpark schema


Given a simple Scala case class like this:

package com.foo.storage.schema   
case class Person(name: String, age: Int)

it's possible to create a Spark schema from a case class as follows:

import org.apache.spark.sql._
import com.foo.storage.schema.Person  

val schema = Encoders.product[Person].schema

I wonder if it's possible to access the schema from a case class in Python/PySpark. I would hope to do something like this [Python]:

jvm = sc._jvm
py4j_class = jvm.com.foo.storage.schema.Person 
jvm.org.apache.spark.sql.Encoders.product(py4j_class)

This throws an error com.foo.storage.schema.Person._get_object_id does not exist in the JVM. The Encoders.product is a generic in Scala, and I'm not entirely sure how to specify the type using Py4J. Is there a way to use the case class to create a PySpark schema?


Solution

  • I've found there's no clean / easy way to do this using generics, also not as a pure Scala function. What I ended up doing is making a companion object for the case class that can fetch the schema.

    Solution

    package com.foo.storage.schema
    case class Person(name: String, age: Int)
    object Person {
      def getSchema = Encoders.product[Person].schema
    }
    

    This function can be called from Py4J, but will return a JavaObject. It can be converted with a helper function like this:

    from pyspark.sql.types import StructType
    import json
    def java_schema_to_python(j_schema):
      json_schema = json.loads(ddl.json())
      return StructType.fromJson(json_schema)
    

    Finally, we can extract our schema:

    j_schema = jvm.com.foo.storage.Person.getSchema()
    java_schema_to_python(j_schema)
    

    Alternative solution

    I found there is one more way to do this, but I like the first one better. You can make a generic function that infers the type of the argument in Scala, and uses that to infer the type:

    object SchemaConverter {
      def getSchemaFromType[T <: Product: TypeTag](obj: T): StructType = {
         Encoders.product[T].schema
      }
    }
    

    Which can be called like this:

    val schema = SchemaConverter.getSchemaFromType(Person("Joe", 42))
    

    I didn't like this method since it requires you to create a dummy instance of the case class. Haven't tested it, but I think the function above could be called using Py4J too.