Search code examples
haskellaverageternary-tree

Haskell: ternary tree average, with nested `where`


I was trying to calculate the average of a ternary tree. It seems not possible to finish it inside one function. Is there any way to solve this question, or it's necessary to use two functions? Thanks.

-- define a tree
data Ttree t = Nil | Node3 t (Ttree t) (Ttree t) (Ttree t)

-- get the Ternary tree and return the average
treeAverage :: Ttree Double -> Double
treeAverage Nil = 0.0    -- empty tree
treeAverage tree = treeAverage' tree (0.0) 
           -- try to use accumulator and another function 
  where
    treeAverage' Nil _ =  0.0    -- empty tree
    treeAverage' (Node3 n left mid right) (sum/count) = 
        ((n+sumL+sumM+sumR) / (1+countL+countM+countR))  -- average
      where
        (sumL,countL) = treeAverage' left 
            -- calculate left subtree with sum and count
        (sumM,countM) = treeAverage' mid 
        (sumR,countR) = treeAverage' right

Solution

  • In order to compute an average value, you have to perform a single division at the very end of the process, something like (allSum / allCount). As the division cannot be part of the recursive tree traversal process, it seems difficult to achieve what you want within a single function.

    Let's start by providing a little fix for your code. It is unclear whether your auxiliary treeAverage' function returns a pair or a single numeric value. We can rewrite the whole thing like this, where unambiguously it returns a pair:

    -- define a tree structure:
    data Ttree t = Nil | Node3 t (Ttree t) (Ttree t) (Ttree t)
                      deriving (Eq, Show)
    
    treeAverage1 :: Ttree Double -> Double
    treeAverage1 Nil = 0.0 -- empty tree
    treeAverage1 tree =
      let   (sum1, count1) = treeAverage' tree
      in    sum1 / count1
        where
          treeAverage'  Nil  =  (0,0) -- empty tree
          treeAverage'  (Node3 n left mid right) =
              let  (sumL,countL) = treeAverage' left   -- calculate left subtree
                   (sumM,countM) = treeAverage' mid 
                   (sumR,countR) = treeAverage' right
              in
                  ((n+sumL+sumM+sumR) , (1+countL+countM+countR)) -- (sum, count)
    

    and that code appears to work:

    $ ghci
     GHCi, version 8.8.4: https://www.haskell.org/ghc/  :? for help
     λ> 
     λ> :load  q67816203.hs
     Ok, one module loaded.
     λ> 
     λ> leaf x = Node3 x Nil Nil Nil
     λ> 
     λ> tr0 = Node3 4 (leaf 5) (leaf 6) (leaf 7) :: Ttree Double
     λ> tr1 = Node3 2 (leaf 10) tr0 (leaf 15)
     λ> 
     λ> tr1
     Node3 2.0 (Node3 10.0 Nil Nil Nil) (Node3 4.0 (Node3 5.0 Nil Nil Nil) (Node3 6.0 Nil Nil Nil) (Node3 7.0 Nil Nil Nil)) (Node3 15.0 Nil Nil Nil)
     λ> 
     λ> treeAverage1 tr1
     7.0
     λ> 
    
    

    However, in this code, tree traversal is inextricably intertwined with arithmetics.

    Decoupling ...

    The common Haskell practice would be to improve matters by subcontracting tree traversal to general purpose functions, that is, functions we (or the language library) would provide anyway in order to support our tree structure, regardless of any numeric concerns.

    About plain lists ...

    At that point, we can look at a simpler problem: how do we compute an average for a plain list of numbers ?

    As mentioned in a comment by chepner, you can use:

    listAverage xs = (sum xs) / (length xs)
    

    We could adapt this approach to Ttree objects, coming up with treeSum and treeLeafCount functions. But that would be suboptimal. In modern hardware, memory access is way more expensive than arithmetics, and listAverage needlessly traverses the list twice.

    How do we get to traverse the list just once ? Well, computing an average is obviously a fold operation, that is you traverse a complex structure in order to produce a single value. See the classic paper by Graham Hutton about the merits of fold operations.

    Lists have an instance of the Foldable class mentioned in the comment by chepner. So the library provides, among other things, a foldr function for lists:

    foldr :: (a -> b -> b) -> b -> [a] -> b
    

    The first argument of foldr is a combining function, which takes an accumulator value and a scalar value from the input list, and returns an updated accumulator value. The second argument is just an initial value for the accumulator.

    So we can write a single-traversal list average like this:

    listAverage :: [Double] -> Double
    listAverage xs  =  sum1 / count1
      where
        cbf x (sum0, count0) = (sum0+x, count0+1)  --  combining function
        (sum1, count1) = foldr cbf (0,0) xs
    

    This works fine:

     λ> 
     λ> :type listAverage
     listAverage :: [Double] -> Double
     λ> 
     λ> listAverage [1,2,3,4,5]
     3.0
     λ> 
    

    Now, can we adapt this approach to trees ?

    Tree traversal

    So we need to somehow get a version of foldr for our trees.

    We can write it manually, working our way thru the structure from right to left:

    treeFoldr  ::  (v -> r -> r) -> r -> Ttree v -> r
    treeFoldr cbf r0  Nil  =  r0
    treeFoldr cbf r0  (Node3 v left mid right)  =
        let  rr = treeFoldr cbf  r0  right
             rm = treeFoldr cbf  rr  mid
             rl = treeFoldr cbf  rm  left
        in
             cbf v rl
    

    Note that it is critical here to be able to specify the initial accumulator value.

    So we now have a tree traversal mechanism that is fully general purpose and detached from any numeric concerns.

    For example, we can use it to flatten any sort of (possibly non-numeric) tree:

    toListFromTree:: Ttree v -> [v]
    toListFromTree tr  =  let  cbf = \v vs -> v:vs
                          in   treeFoldr cbf [] tr
    

    This can be further simplified:

    toListFromTree tr  =  treeFoldr (:) [] tr
    

    Testing:

     λ> 
     λ> treeFoldr (:) [] tr1
     [2.0,10.0,4.0,5.0,6.0,7.0,15.0]
     λ> 
    

    At that point, we can define the Foldable instance for trees:

    instance Foldable Ttree  where  foldr = treeFoldr
    

    and the pretty short code of the list averager above can now be used unmodified to average trees, essentially by adapting its type signature.

    treeAverage :: Ttree Double -> Double
    treeAverage tr  =  sum1 / count1
      where
        cbf x (sum0, count0) = (sum0+x, count0+1)  --  combining function
        (sum1, count1) = foldr cbf (0,0) tr
    

    Now, we can do something easier. The GHC compiler happens to provide an extension, DeriveFoldable, that allows us to ask the compiler to write treeFoldr automagically. This leads directly to our:

    Shortest solution:

    {-#  LANGUAGE  DeriveFoldable    #-}
    
    -- define a tree structure:
    data Ttree t  =  Nil  |  Node3 t  (Ttree t)  (Ttree t)  (Ttree t)
                      deriving  (Eq, Show, Foldable)
    
    treeAverage :: Ttree Double -> Double
    treeAverage tr = sum1 / count1
        where
            cbf x (s,c)    =  (s+x,c+1)
            (sum1, count1) =  foldr cbf (0,0) tr
    

    And I think most Haskell programmers would agree that this counts as a single function :-)

    Note that it is also possible to provide Functor instances, hence an fmap function, using the DeriveFunctor GHC extension.