Search code examples
smlsmlnj

verifying size of binary trees?


I have a datatype in this way

datatype 'a bin_tree = 
    Leaf of 'a
  | Node of 'a bin_tree    (* left tree *)
           * int           (* size of left tree *)
           * int           (* size of right tree *)
           * 'a bin_tree   (* right tree *)

so an example for correct tree would be:

val tree1 =
  Node(Node(Node(Leaf 47, 1, 1, Leaf 38),
            2,1,
            Leaf 55),
       3,2,
       Node(Leaf 27, 1, 1, Leaf 96))

and an example for violating tree would be

val tree1false =
  Node(Node(Node(Leaf 47, 1, 1, Leaf 38),
            2,1,
            Leaf 55),
       4,2,
       Node(Leaf 27, 1, 1, Leaf 96))

How can I write a predicate test such that

- test tree1;
val it = true : bool
- test tree1false;
val it = false : bool

Solution

  • This is a recursive problem. Before solving recursive problems on trees, it is a good idea to have a firm grasp on recursion on lists. You could say that trees are generalisations of lists, or that lists are special-cases of trees: lists have one tail, trees can have any number of tails depending on the type of tree. So here is how you could reconstruct and solve the problem using lists:

    If, instead of the typical list definition, you have a list that also memoizes its own length:

    (* datatype 'a list = [] | :: of 'a * 'a list *)
    
    datatype 'a lenlist = Nil | Cons of int * 'a * 'a lenlist
    

    Then you can test that the stored length is in accordance with the actual number of values.

    I'll start by creating a function that counts to illustrate the part of the function that performs recursion:

    (* For regular built-in lists *)
    fun count0 [] = 0
      | count0 (x::xs) = 1 + count0 xs
    
    (* Counting the memoized list type disregarding the n *)
    fun count1 Nil = 0
      | count1 (Cons (n, x, xs)) = 1 + count1 xs
    

    The next part is that I'd like, in each recursive step, to test that the stored number n is also in accordance with the actual counting. What is the return type of this function? Well, the test function that you want should be 'a lenlist -> bool and the count function that I made is 'a lenlist -> int.

    I will suggest that you make a testcount that kinda does both. But you can do so in many ways, e.g. by giving it "extra arguments", or by giving it "extra return values". I will demonstrate both, just to show that sometimes one is better than the other and experience will tell you which.

    Here is a val testcount1 : 'a lenlist -> bool * int function:

    fun testcount1 Nil = (true, 0)
      | testcount1 (Cons (n, x, xs)) =
        let val (good_so_far, m) = testcount1 xs
            val still_good = good_so_far andalso n = m + 1
        in (still_good, m + 1)
        end
    
    val goodList = Cons (4, #"c", Cons (3, #"o", Cons (2, #"o", Cons (1, #"l", Nil))))
    val badList = Cons (3, #"d", Cons (2, #"e", Cons (1, #"r", Cons (0, #"p", Nil))))
    

    Testing this,

    - testcount1 goodList;
    > val it = (true, 4) : bool * int
    - testcount1 badList;
    > val it = (false, 4) : bool * int
    

    This shows that testcount1 returns whether the numbers add up and the list's actual length, which was necessary during recursion to test that the numbers add up in each step, but in the end is no longer necessary. You could wrap this testcount function up in a simpler test function that only cares about the bool:

    fun test xs = #1 (testcount1 xs)
    

    Here is another way to go about: There is something not so satisfying with the way testcount1 recurses. It keeps calculating the m + 1s even though it is able to determine that a list (e.g. at Cons (0, #"p", Nil)) is broken.

    Here is an alternate val testcount2 : 'a lenlist * int -> bool that stores a number in an extra argument instead:

    fun testcount2 (Nil, 0) = true
      | testcount2 (Nil, _) = false
      | testcount2 (Cons (n, x, xs), m) =
          n = m andalso testcount2 (xs, m - 1)
    

    This seems a lot neater to me: The function is tail-recursive, and it stops immediately when it senses that something is fishy. So it doesn't need to traverse the entire list if it's broken at the head. The downside is that it needs to know the length, which we don't know yet. But we can compensate by assuming that whatever is advertised is correct until it's clearly the case, or not.

    Testing this function, you need to not only give it a goodList or a badList but also an m:

    - testcount2 (goodList, 4);
    > val it = true : bool
    - testcount2 (badList, 4);
    > val it = false : bool
    - testcount2 (badList, 3);
    > val it = false : bool
    

    It's important that this function doesn't just compare n = m, since in badList, that'd result in agreeing that badList is 3 elements long, since n = m is true for each iteration in all Cons cases. This is helped in the two Nil cases that require us to have reached 0 and not e.g. ~1 as is the case for badList.

    This function can also be wrapped inside test to hide the fact that we feed it an extra argument derived from the 'a lenlist itself:

    fun size Nil = 0
      | size (Cons (n, _, _)) = n
    
    fun test xs = testcount2 (xs, size xs)
    

    Some morals so far:

    • Sometimes it is necessary to create helper functions to solve your initial problem.
    • Those helper functions are not restricted to have the same type signature as the function you deliver (whether this is for an exercise in school, or for an external API/library).
    • Sometimes it helps to extend the type that your function returns.
    • Sometimes it helps to extend the arguments of your functions.
    • Just because your task is "Write a function foo -> bar", this does not mean that you cannot create this by composing functions that return a great deal more or less than foo or bar.

    Now, for some hints for solving this on binary trees:

    Repeating the data type,

    datatype 'a bin_tree = 
        Leaf of 'a
      | Node of 'a bin_tree    (* left tree *)
               * int           (* size of left tree *)
               * int           (* size of right tree *)
               * 'a bin_tree   (* right tree *)
    

    You can start by constructing a skeleton for your function based on the ideas above:

    fun testcount3 (Leaf x, ...) = ...
      | testcount3 (Leaf x, ...) = ...
      | testcount3 (Node (left, leftC, rightC, right), ...) = ...
    

    I've embedded som hints here:

    • Your solution should probably contain pattern matches against Leaf x and Node (left, leftC, rightC, right). And given the "extra argument" type of solution (which at least proved nice for lists, but we'll see) needed two Leaf x cases. Why was that?
    • If, in the case of lists, the "extra argument" m represents the expected length of the list, then what would an "extra argument" represent in the case of trees? You can't say "it's the length of the list", since it's a tree. How do you capture the part where a tree branches?
    • If this is still too difficult, consider solving the problem for lists without copy-pasting. That is, you're allowed to look at the solutions in this answer (but try not to), but you're not allowed to copy-paste code. You have to type it if you want to copy it.
    • As a start, try to define the helper function size that was used to produce test from testcount2, but for trees. So maybe call it sizeTree to avoid the name overlap, but other than that, try and make it resemble. Here's a skeleton:
    fun sizeTree (Leaf x) = ...
      | sizeTree (Node (left, leftC, rightC, right)) = ...
    

    Sticking testcount3 and sizeTree together, once written, should be easy as tau.