Search code examples
performancehaskellrecursionequation-solving

More efficient algorithm preforms worse in Haskell


A friend of mine showed me a home exercise in a C++ course which he attend. Since I already know C++, but just started learning Haskell I tried to solve the exercise in the "Haskell way".

These are the exercise instructions (I translated from our native language so please comment if the instructions aren't clear):

Write a program which reads non-zero coefficients (A,B,C,D) from the user and places them in the following equation: A*x + B*y + C*z = D The program should also read from the user N, which represents a range. The program should find all possible integral solutions for the equation in the range -N/2 to N/2.

For example:

Input: A = 2,B = -3,C = -1, D = 5, N = 4
Output: (-1,-2,-1), (0,-2, 1), (0,-1,-2), (1,-1, 0), (2,-1,2), (2,0, -1)

The most straight-forward algorithm is to try all possibilities by brute force. I implemented it in Haskell in the following way:

triSolve :: Integer -> Integer -> Integer -> Integer -> Integer -> [(Integer,Integer,Integer)]
triSolve a b c d n =
  let equation x y z = (a * x + b * y + c * z) == d
      minN = div (-n) 2
      maxN = div n 2
  in [(x,y,z) | x <- [minN..maxN], y <- [minN..maxN], z <- [minN..maxN], equation x y z]

So far so good, but the exercise instructions note that a more efficient algorithm can be implemented, so I thought how to make it better. Since the equation is linear, based on the assumption that Z is always the first to be incremented, once a solution has been found there's no point to increment Z. Instead, I should increment Y, set Z to the minimum value of the range and keep going. This way I can save redundant executions. Since there are no loops in Haskell (to my understanding at least) I realized that such algorithm should be implemented by using a recursion. I implemented the algorithm in the following way:

solutions :: (Integer -> Integer -> Integer -> Bool) -> Integer -> Integer -> Integer -> Integer -> Integer ->     [(Integer,Integer,Integer)]
solutions f maxN minN x y z
  | solved = (x,y,z):nextCall x (y + 1) minN
  | x >= maxN && y >= maxN && z >= maxN = []
  | z >= maxN && y >= maxN = nextCall (x + 1) minN minN
  | z >= maxN = nextCall x (y + 1) minN
  | otherwise = nextCall x y (z + 1)
  where solved = f x y z
        nextCall = solutions f maxN minN

triSolve' :: Integer -> Integer -> Integer -> Integer -> Integer -> [(Integer,Integer,Integer)]
triSolve' a b c d n =
  let equation x y z = (a * x + b * y + c * z) == d
      minN = div (-n) 2
      maxN = div n 2
  in solutions equation maxN minN minN minN minN

Both yield the same results. However, trying to measure the execution time yielded the following results:

*Main> length $ triSolve' 2 (-3) (-1) 5 100
3398
(2.81 secs, 971648320 bytes)
*Main> length $ triSolve 2 (-3) (-1) 5 100
3398
(1.73 secs, 621862528 bytes)

Meaning that the dumb algorithm actually preforms better than the more sophisticated one. Based on the assumption that my algorithm was correct (which I hope won't turn as wrong :) ), I assume that the second algorithm suffers from an overhead created by the recursion, which the first algorithm isn't since it's implemented using a list comprehension. Is there a way to implement in Haskell a better algorithm than the dumb one? (Also, I'll be glad to receive general feedbacks about my coding style)


Solution

  • Of course there is. We have:

    a*x + b*y + c*z = d
    

    and as soon as we assume values for x and y, we have that

    a*x + b*y = n
    

    where n is a number we know. Hence

    c*z = d - n
    z = (d - n) / c
    

    And we keep only integral zs.