Search code examples
scalacsvexport-to-csvcase-classimport-from-csv

Nested Scala case classes to/from CSV


There are many nice libraries for writing/reading Scala case classes to/from CSV files. I'm looking for something that goes beyond that, which can handle nested cases classes. For example, here a Match has two Players:

case class Player(name: String, ranking: Int)
case class Match(place: String, winner: Player, loser: Player)

val matches = List(
  Match("London", Player("Jane",7), Player("Fred",23)),
  Match("Rome", Player("Marco",19), Player("Giulia",3)),
  Match("Paris", Player("Isabelle",2), Player("Julien",5))
)

I'd like to effortlessly (no boilerplate!) write/read matches to/from this CSV:

place,winner.name,winner.ranking,loser.name,loser.ranking
London,Jane,7,Fred,23
Rome,Marco,19,Giulia,3
Paris,Isabelle,2,Julien,5

Note the automated header line using the dot "." to form the column name for a nested field, e.g. winner.ranking. I'd be delighted if someone could demonstrate a simple way to do this (say, using reflection or Shapeless).

[Motivation. During data analysis it's convenient to have a flat CSV to play around with, for sorting, filtering, etc., even when case classes are nested. And it would be nice if you could load nested case classes back from such files.]


Solution

  • Since a case-class is a Product, getting the values of the various fields is relatively easy. Getting the names of the fields/columns does require using Java reflection. The following function takes a list of case-class instances and returns a list of rows, each is a list of strings. It is using a recursion to get the values and headers of child case-class instances.

    def toCsv(p: List[Product]): List[List[String]] = {
      def header(c: Class[_], prefix: String = ""): List[String] = {
        c.getDeclaredFields.toList.flatMap { field =>
          val name = prefix + field.getName
          if (classOf[Product].isAssignableFrom(field.getType)) header(field.getType, name + ".")
          else List(name)
        }
      }
    
      def flatten(p: Product): List[String] =
        p.productIterator.flatMap {
          case p: Product => flatten(p)
          case v: Any => List(v.toString)
        }.toList
    
      header(classOf[Match]) :: p.map(flatten)
    }
    

    However, constructing case-classes from CSV is far more involved, requiring to use reflection for getting the types of the various fields, for creating the values from the CSV strings and for constructing the case-class instances. For simplicity (not saying the code is simple, just so it won't be further complicated), I assume that the order of columns in the CSV is the same as if the file was produced by the toCsv(...) function above. The following function starts by creating a list of "instructions how to process a single CSV row" (the instructions are also used to verify that the column headers in the CSV matches the the case-class properties). The instructions are then used to recursively produce one CSV row at a time.

    def fromCsv[T <: Product](csv: List[List[String]])(implicit tag: ClassTag[T]): List[T] = {
      trait Instruction {
        val name: String
        val header = true
      }
      case class BeginCaseClassField(name: String, clazz: Class[_]) extends Instruction {
        override val header = false
      }
      case class EndCaseClassField(name: String) extends Instruction {
        override val header = false
      }
      case class IntField(name: String) extends Instruction
      case class StringField(name: String) extends Instruction
      case class DoubleField(name: String) extends Instruction
    
      def scan(c: Class[_], prefix: String = ""): List[Instruction] = {
        c.getDeclaredFields.toList.flatMap { field =>
          val name = prefix + field.getName
          val fType = field.getType
    
          if (fType == classOf[Int]) List(IntField(name))
          else if (fType == classOf[Double]) List(DoubleField(name))
          else if (fType == classOf[String]) List(StringField(name))
          else if (classOf[Product].isAssignableFrom(fType)) BeginCaseClassField(name, fType) :: scan(fType, name + ".")
          else throw new IllegalArgumentException(s"Unsupported field type: $fType")
        } :+ EndCaseClassField(prefix)
      }
    
      def produce(instructions: List[Instruction], row: List[String], argAccumulator: List[Any]): (List[Instruction], List[String], List[Any]) = instructions match {
        case IntField(_) :: tail => produce(tail, row.drop(1), argAccumulator :+ row.head.toString.toInt)
        case StringField(_) :: tail => produce(tail, row.drop(1), argAccumulator :+ row.head.toString)
        case DoubleField(_) :: tail => produce(tail, row.drop(1), argAccumulator :+ row.head.toString.toDouble)
        case BeginCaseClassField(_, clazz) :: tail =>
          val (instructionRemaining, rowRemaining, constructorArgs) = produce(tail, row, List.empty)
          val newCaseClass = clazz.getConstructors.head.newInstance(constructorArgs.map(_.asInstanceOf[AnyRef]): _*)
          produce(instructionRemaining, rowRemaining, argAccumulator :+ newCaseClass)
        case EndCaseClassField(_) :: tail => (tail, row, argAccumulator)
        case Nil if row.isEmpty => (Nil, Nil, argAccumulator)
        case Nil => throw new IllegalArgumentException("Not all values from CSV row were used")
      }
    
      val instructions = BeginCaseClassField(".", tag.runtimeClass) :: scan(tag.runtimeClass)
      assert(csv.head == instructions.filter(_.header).map(_.name), "CSV header doesn't match target case-class fields")
    
      csv.drop(1).map(row => produce(instructions, row, List.empty)._3.head.asInstanceOf[T])
    }
    

    I've tested this using:

    case class Player(name: String, ranking: Int, price: Double)
    case class Match(place: String, winner: Player, loser: Player)
    
    val matches = List(
      Match("London", Player("Jane", 7, 12.5), Player("Fred", 23, 11.1)),
      Match("Rome", Player("Marco", 19, 13.54), Player("Giulia", 3, 41.8)),
      Match("Paris", Player("Isabelle", 2, 31.7), Player("Julien", 5, 16.8))
    )
    val csv = toCsv(matches)
    val matchesFromCsv = fromCsv[Match](csv)
    
    assert(matches == matchesFromCsv)
    

    Obviously this should be optimized and hardened if you ever want to use this for production...