Search code examples
scalapattern-matchingcase-classscala-3

Pattern matching case classes, all cases have same return


I'm new to Scala and the only way I saw people access case classes was with pattern matching.

I recently solved a Huffman coding problem with the following code


  abstract class TreeNodes 
  case class Leaf(weight: Int, value: String) extends TreeNodes
  case class Node(weights: Int, left: TreeNodes, right: TreeNodes) extends TreeNodes

  // P50
  def huffman(freq: List[(String,Int)]): List[(String,String)] = {
    def makeNode(left: TreeNodes, right: TreeNodes): Node = 
      (left,right) match 
        case (Leaf(w1,_),Leaf(w2,_)) => Node(w1 + w2, left,right)
        case (Node(w1,_,_),Leaf(w2,_)) => Node(w1 + w2, left ,right)
        case (Leaf(w1,_), Node(w2,_,_)) => Node(w1 + w2, left, right)
        case (Node(w1,_,_), Node(w2,_,_)) => Node(w1 + w2, left, right)

    def makeLeaves(freq: List[(String,Int)]): List[TreeNodes] = freq.map((s: String, i: Int) => Leaf(i,s))

    def makeTree(nodes: List[TreeNodes]): List[TreeNodes] = {
      if nodes.size == 1 then nodes else {
        val sortedNodes = nodes.sortBy(_ match {case Leaf(w,_) => w; case Node(w,_,_) => w})
        makeTree(sortedNodes.appended(makeNode(sortedNodes.head,sortedNodes.tail.head)).drop(2))
      }
    }

    def traverseTree(node: TreeNodes, acc: String): List[(String, String)] = {
      node match  
        case Leaf(w, s) => List((s,acc))
        case Node(w, left, right) => traverseTree(left, acc + "0") ::: traverseTree(right, acc + "1")
    }
    traverseTree(makeTree(makeLeaves(freq)).head, "").sortBy(_._1)
  }

where the input and output should be like so:

scala> huffman(List(("a", 45), ("b", 13), ("c", 12), ("d", 16), ("e", 9), ("f", 5)))
res0: List[String, String] = List((a,0), (b,101), (c,100), (d,111), (e,1101), (f,1100))

The code is correct and produces the desired output. However, I want to refactor the makeNode function to a cleaner version, since all cases have the same return.

I want a more concise version. Is there a way to create a common field for different case classes or any other way to access case classes without pattern matching?


Solution

  • Just declare the common methods in the common interface, in this case, add def weight: Int or val weight: Int to the base class.

    In Scala 2.x:

    sealed abstract class TreeNodes {
      def weight: Int
    }
    case class Leaf(weight: Int, value: String) extends TreeNodes
    case class Node(weight: Int, left: TreeNodes, right: TreeNodes) extends TreeNodes
    
    //...
    
      def makeNode(left: TreeNodes, right: TreeNodes): Node =
        Node(left.weight + right.weight, left,right)
    

    In Scala 3.x:

    enum TreeNodes:
      case Leaf(weight: Int, value: String)
      case Node(weight: Int, left: TreeNodes, right: TreeNodes)
      def weight: Int
    
    

    but remember to import the Leaf and Node if you want that the rest of the code works as-is: import TreeNodes._.