Search code examples
scalatraits

How to collect field values from case classes in scala


I've got various case classes with different fields inherit some trait. All are mixed in a List. What is the way to collect (or group by) specific field's values?

sealed trait Template

object Template {
  case class TemplateA(field: String) extends Template
  case class TemplateB extends Template
}

object Runner {
  def main(args: String*) {
    val list = List(TemplateA("abc"), TemplateB, Template("cde"))

    // need to output something like "abc;1", "cde;1"

  }
}


Solution

  • Totally agree with @LuisMiguel, just to show one way of doing this, here's what I can think of:

    trait Template { val field: Option[String] } 
    case class TemplateA(field: Option[String]) extends Template 
    case class TemplateB() extends Template { override val field: Option[String] = None }
    
    val list: List[Template] = List(
      TemplateA(Some("abc")),
      TemplateB(),
      TemplateA(Some("cde"))
    )
    
    list.collect { 
      case template if template.field.nonEmpty =>
        template.field.get
    }.groupMapReduce(identity)(_ => 1)(_ + _)  
    
    // res8: Map[String, Int] = Map("abc" -> 1, "cde" -> 1)
    

    Or if you want to get rid of the Optional argument when instantiating TemplateA instances, you can also do this:

    case class TemplateA(private val value: String) extends Template {
      override val field: Option[String] = Option(value)
    }
    
    val list: List[Template] = List(TemplateA("abc"), TemplateB(), TemplateA("cde"))
    

    As @DmytroMitin mentioned, we can do a bit of refactoring to avoid using ifs in our collect function, I'd rather use some sort of unapply function, that can extract the field value of TemplateA instances:

    object Template { // or any name as you wish
      def unapply(t: Template): Option[String] = t match {
        case TemplateA(Some(value)) => Option(value)
        case _ => None
      }
    } 
    

    And then, we can use pattern matching:

    list.collect {
      case Template(field) => field
    }.groupMapReduce(identity)(_ => 1)(_ + _)