Search code examples
haskellequational-reasoning

Haskell - How to transform maximum (xs ++ map (x+) xs) to max (maximum xs) (x + maximum xs)


One of the excercises in "Thinking Functionally With Haskell" is about making a program more efficient using the fusion law. I am having some trouble trying to replicate the answer.

A part of the calculation requires that you transform maximum (xs ++ map (x+) xs) to max (maximum xs) (x + maximum xs) through equational reasoning.

maximum is defined as foldr1 max and as I don't know many rules surrounding foldr1 I'm stuck on even the first part which is to transform foldr1 max (xs ++ map (x+) xs) to max (foldr1 max xs) (foldr1 max (map (x+) xs)) so that's the first thing I'd like to understand.

Once we get past that, the next part seems harder i.e. transforming foldr1 max (map (x+) xs) to x + foldr1 max xs. Intuitively it makes sense; if you are finding the maximum value of a bunch of numbers that all have 'x' added to them then that's the same as finding the maximum of all the numbers before 'x' was added and adding 'x' to the result.

The only thing I've found to help me in this second stage is this stack overflow answer where the answer is basically given to you (if you assume p = q) with no individual easy to understand steps as you normally see with equational reasoning.

So please could someone show me the steps to do the transformation ?


Solution

  • This can be seen by induction.

    Suppose, xs == []. Both expressions are true, since both yield error.

    Suppose, xs == [y]

    maximum([y]++map(x+)[y]) == -- by definition of map
                             == maximum([y]++[x+y])
                                -- by definition of ++
                             == maximum([y,x+y])
                                -- by definition of maximum
                             == foldr1 max [y,x+y]
                                -- by definition of foldr1
                             == max y (foldr1 max [x+y])
                                -- by definition of foldr1
                             == max y (x+y)
                                -- by definition of foldr1 and maximum [y]
                             == max (maximum [y]) (x+maximum [y])
    

    Next, we will need a proof of commutativity of maximum: maximum (xs++(y:ys)) == max y (maximum (xs++ys)) - you will notice this is needed, if you skip this proof and go straight to the proof of maximum (y:ys ++ map(x+)(y:ys)) - one step there requires to move (x+y) from the middle of the list ys++(x+y):map(x+)ys.

    Suppose, xs==[]:

    maximum ([]++(y:ys)) == maximum (y:ys)
                         -- by definition of foldr1 and maximum
                         == max y (maximum ys)
                         == max y (maximum ([]++ys))
    

    Suppose, xs==x:xx:

    maximum(x:xx++(y:ys)) == maximum (x:(xx++(y:ys)))
                         -- by definition of foldr1 and maximum
                          == max x (maximum (xx++(y:ys)))
                         -- by induction
                          == max x (max y (maximum (xx++ys)))
     -- by commutativity of max, max a (max b c) == max b (max a c)
                          == max y (max x (maximum (xx++ys)))
                         -- by definition of foldr1 and maximum
                          == max y (maximum (x:(xx++ys)))
                         -- by definition of ++
                          == max y (maximum ((x:xx) ++ ys))
    

    Ok, now get back to proving the original statement. Now, suppose xs == y:ys

     maximum (y:ys ++ map(x+)(y:ys)) ==
     -- by definition of map
                             == maximum(y:ys ++ (x+y):map(x+)ys)
     -- by definition of foldr1 and maximum
                             == max y (maximum(ys ++ (x+y):map(x+)ys)
     -- by commutativity of maximum
                             == max y (max (x+y) (maximum (ys++map(x+)ys)))
     -- by induction, (maximum (ys++map(x+)ys)) == max (maximum ys) (x+maximum ys))
                             == max y (max (x+y) 
                                           (max (maximum ys) (x+maximum ys)))
     -- by commutativity of max (ie max a (max b c) == max b (max a c))
                             == max y (max (maximum ys) 
                                           (max (x+y) (x+maximum ys)))
     -- by associativity of max (is max a (max b c) == max (max a b) c)
                             == max (max y (maximum ys))
                                           (max (x+y) (x+maximum ys)))
                             -- by definition of max, max (x+y) (x+z) == x+(max y z)
                             == max (max y (maximum ys))
                                           (x + max y (maximum ys)))
                             -- by definition of foldr1 and maximum
                             == max (maximum (y:ys)) (x + maximum (y:ys))
    

    Since you asked also about induction and how to see a certain thing can be proven by induction, here's some more.

    You can see some of the steps are "by definition" - we know they are true by looking at how the function is written. For example, maximum = foldr1 max and foldr1 f (x:xs) = f x $ foldr1 f xs for non-empty xs. Definition of some other things are less clear - max y z the definition of max is not shown; yet, it can be shown by induction that max (x+y)(x+z) == x+max y z. Here one would start with the definition of max 0 y == y, then how to work out max for greater x. (Then you'd also need to cover the cases for negative x and y in a similar way.)

    For example, natural numbers are zero and any successor of a natural number. You see, here we don't have any comparison defined, nothing. So, the properties of addition, subtraction, max, etc, stem from the definition of the functions:

    data Nat = Z | S Nat -- zero and any successor of a natural number
    (+) :: Nat -> Nat -> Nat -- addition is...
    Z + x = x -- adding zero is neutral
    (S x) + y = S (x + y) -- recursive definition of (1+x)+y = 1+(x+y)
    -- here unwittingly we introduced associativity of addition:
    -- (x+y)+z=x+(y+z)
    -- so, let's see the simplest case:
    -- x == Z
    -- (Z+y)+z == -- by definition, Z+y=y -- see the first line of (+)
    --         == y+z
    --         == Z+(y+z) -- by definition, Z+(y+z)=(y+z)
    --
    -- ok, now try x == S m
    -- (S m + y) + z == -- by definition, (S m)+y=S(m+y) -- see the second line of(+)
    --               == S (m+y) + z
    --               == S ((m+y)+z) -- see the second line of (+)
    --                                 - S (m+y) + z = S ((m+y)+z)
    --               == S (m+(y+z)) -- by induction, the simpler
    --                                 case of (m+y)+z=m+(y+z)
    --                                 is known to be true
    --               == (S m)+(y+z) -- by definition, see second line of (+)
    -- proven
    

    Then, because we don't have comparison of Nats yet, have to define max in a particular way.

    max :: Nat -> Nat -> Nat
    max Z y = y -- we know Z is not the max
    max x Z = x -- and the other way around, too
    -- this inadvertently introduced commutativity of max already
    
    max (S x) (S y) = S (max x y) -- this inadvertently introduces the law
    -- that max (x+y) (x+z) == x + (max y z)
    

    Suppose, we want to prove the latter. Assume x == Z

    max (Z+y) (Z+z) == -- by definition of (+)
                    == max y z
                    == Z + (max y z) -- by definition of (+)
    

    ok, now assume x == S m

    max ((S m) + y) ((S m)+z) == -- by definition of (+)
                    == max (S(m+y)) (S(m+z))
                    -- by definition of max
                    == S (max (m+y) (m+z))
                    -- by induction
                    == S (m+(max y z))
                    -- by definition of (+)
                    == (S m)+(max y z)
    

    So, you see, it is important to know the definitions, important to prove the simplest case, and important to use the proof for the simpler case in the slightly more complex case.