I was trying to calculate the average of a ternary tree. It seems not possible to finish it inside one function. Is there any way to solve this question, or it's necessary to use two functions? Thanks.
-- define a tree
data Ttree t = Nil | Node3 t (Ttree t) (Ttree t) (Ttree t)
-- get the Ternary tree and return the average
treeAverage :: Ttree Double -> Double
treeAverage Nil = 0.0 -- empty tree
treeAverage tree = treeAverage' tree (0.0)
-- try to use accumulator and another function
where
treeAverage' Nil _ = 0.0 -- empty tree
treeAverage' (Node3 n left mid right) (sum/count) =
((n+sumL+sumM+sumR) / (1+countL+countM+countR)) -- average
where
(sumL,countL) = treeAverage' left
-- calculate left subtree with sum and count
(sumM,countM) = treeAverage' mid
(sumR,countR) = treeAverage' right
In order to compute an average value, you have to perform a single division at the very end of the process, something like (allSum / allCount)
. As the division cannot be part of the recursive tree traversal process, it seems difficult to achieve what you want within a single function.
Let's start by providing a little fix for your code. It is unclear whether your auxiliary treeAverage'
function returns a pair or a single numeric value. We can rewrite the whole thing like this, where unambiguously it returns a pair:
-- define a tree structure:
data Ttree t = Nil | Node3 t (Ttree t) (Ttree t) (Ttree t)
deriving (Eq, Show)
treeAverage1 :: Ttree Double -> Double
treeAverage1 Nil = 0.0 -- empty tree
treeAverage1 tree =
let (sum1, count1) = treeAverage' tree
in sum1 / count1
where
treeAverage' Nil = (0,0) -- empty tree
treeAverage' (Node3 n left mid right) =
let (sumL,countL) = treeAverage' left -- calculate left subtree
(sumM,countM) = treeAverage' mid
(sumR,countR) = treeAverage' right
in
((n+sumL+sumM+sumR) , (1+countL+countM+countR)) -- (sum, count)
and that code appears to work:
$ ghci
GHCi, version 8.8.4: https://www.haskell.org/ghc/ :? for help
λ>
λ> :load q67816203.hs
Ok, one module loaded.
λ>
λ> leaf x = Node3 x Nil Nil Nil
λ>
λ> tr0 = Node3 4 (leaf 5) (leaf 6) (leaf 7) :: Ttree Double
λ> tr1 = Node3 2 (leaf 10) tr0 (leaf 15)
λ>
λ> tr1
Node3 2.0 (Node3 10.0 Nil Nil Nil) (Node3 4.0 (Node3 5.0 Nil Nil Nil) (Node3 6.0 Nil Nil Nil) (Node3 7.0 Nil Nil Nil)) (Node3 15.0 Nil Nil Nil)
λ>
λ> treeAverage1 tr1
7.0
λ>
However, in this code, tree traversal is inextricably intertwined with arithmetics.
The common Haskell practice would be to improve matters by subcontracting tree traversal to general purpose functions, that is, functions we (or the language library) would provide anyway in order to support our tree structure, regardless of any numeric concerns.
At that point, we can look at a simpler problem: how do we compute an average for a plain list of numbers ?
As mentioned in a comment by chepner, you can use:
listAverage xs = (sum xs) / (length xs)
We could adapt this approach to Ttree
objects, coming up with treeSum
and treeLeafCount
functions. But that would be suboptimal. In modern hardware, memory access is way more expensive than arithmetics, and listAverage
needlessly traverses the list twice.
How do we get to traverse the list just once ? Well, computing an average is obviously a fold operation, that is you traverse a complex structure in order to produce a single value. See the classic paper by Graham Hutton about the merits of fold operations.
Lists have an instance of the Foldable
class mentioned in the comment by chepner. So the library provides, among other things, a foldr
function for lists:
foldr :: (a -> b -> b) -> b -> [a] -> b
The first argument of foldr
is a combining function, which takes an accumulator value and a scalar value from the input list, and returns an updated accumulator value. The second argument is just an initial value for the accumulator.
So we can write a single-traversal list average like this:
listAverage :: [Double] -> Double
listAverage xs = sum1 / count1
where
cbf x (sum0, count0) = (sum0+x, count0+1) -- combining function
(sum1, count1) = foldr cbf (0,0) xs
This works fine:
λ>
λ> :type listAverage
listAverage :: [Double] -> Double
λ>
λ> listAverage [1,2,3,4,5]
3.0
λ>
Now, can we adapt this approach to trees ?
So we need to somehow get a version of foldr
for our trees.
We can write it manually, working our way thru the structure from right to left:
treeFoldr :: (v -> r -> r) -> r -> Ttree v -> r
treeFoldr cbf r0 Nil = r0
treeFoldr cbf r0 (Node3 v left mid right) =
let rr = treeFoldr cbf r0 right
rm = treeFoldr cbf rr mid
rl = treeFoldr cbf rm left
in
cbf v rl
Note that it is critical here to be able to specify the initial accumulator value.
So we now have a tree traversal mechanism that is fully general purpose and detached from any numeric concerns.
For example, we can use it to flatten any sort of (possibly non-numeric) tree:
toListFromTree:: Ttree v -> [v]
toListFromTree tr = let cbf = \v vs -> v:vs
in treeFoldr cbf [] tr
This can be further simplified:
toListFromTree tr = treeFoldr (:) [] tr
Testing:
λ>
λ> treeFoldr (:) [] tr1
[2.0,10.0,4.0,5.0,6.0,7.0,15.0]
λ>
At that point, we can define the Foldable
instance for trees:
instance Foldable Ttree where foldr = treeFoldr
and the pretty short code of the list averager above can now be used unmodified to average trees, essentially by adapting its type signature.
treeAverage :: Ttree Double -> Double
treeAverage tr = sum1 / count1
where
cbf x (sum0, count0) = (sum0+x, count0+1) -- combining function
(sum1, count1) = foldr cbf (0,0) tr
Now, we can do something easier. The GHC compiler happens to provide an extension, DeriveFoldable
, that allows us to ask the compiler to write treeFoldr
automagically. This leads directly to our:
{-# LANGUAGE DeriveFoldable #-}
-- define a tree structure:
data Ttree t = Nil | Node3 t (Ttree t) (Ttree t) (Ttree t)
deriving (Eq, Show, Foldable)
treeAverage :: Ttree Double -> Double
treeAverage tr = sum1 / count1
where
cbf x (s,c) = (s+x,c+1)
(sum1, count1) = foldr cbf (0,0) tr
And I think most Haskell programmers would agree that this counts as a single function :-)
Note that it is also possible to provide Functor
instances, hence an fmap
function, using the DeriveFunctor
GHC extension.