Search code examples
scalafold

How to sum up all nodes in this scala fold_tree function


I am a beginner with scala. I have been given a fold_tree_preorder function that implements the higher order function fold on a binary tree. The tree, node and leaf definitions are below

 abstract class Tree[+A]
  case class Leaf[A](value: A) extends Tree[A]
  case class Node[A](value: A, left: Tree[A], right: Tree[A]) extends Tree[A]

This is the function I have been given

  def fold_tree_preorder [Z,A](f:(Z,A)=>Z) (z:Z) (t:Tree[A]) : Z =
    t match {
      case Leaf(value) => f(z, value)
      case Node(value , lt, rt) => {
            val z1 = f(z,value)
            val z2 = fold_tree_preorder (f) (z1) (lt)
            fold_tree_preorder (f) (z2) (rt)

          }
      }

I am not sure how to actually call this function. I am trying to do something like the following:

def count_tree [A](t:Tree[A]) : Int =
    fold_tree_preorder[A,A=>A]((z,a)=>(z+a))(0)(t)

But I am getting errors like type mismatch error. I don't think the parameters themselves are correct either, but I'm not even sure how to test what the output would look like because I can't figure out the correct way of calling the fold_tree_preorder function. How can I input the correct syntax to call this function?


Solution

  • z is the fold_tree_preorder function is the output type you are expecting which is Int

    Use the function like below

    assuming that count_tree counts number of nodes of the tree

    def count_tree [A](t:Tree[A]) : Int =
        fold_tree_preorder[Int, A]((z,a) => z + 1 )(0)(t)
    

    Just add 1 to the z on visiting a node to count number of nodes