Search code examples
scalafibonacciscalazmemoization

Fibonacci memoization in Scala with Memo.mutableHashMapMemo


I am trying implement the fibonacci function in Scala with memoization

One example given here uses a case statement: Is there a generic way to memoize in Scala?

import scalaz.Memo
lazy val fib: Int => BigInt = Memo.mutableHashMapMemo {
  case 0 => 0
  case 1 => 1
  case n => fib(n-2) + fib(n-1)
}

It seems the variable n is implicitly defined as the first argument, but I get a compilation error if I replace n with _

Also what advantage does the lazy keyword have here, as the function seems to work equally well with and without this keyword.

However I wanted to generalize this to a more generic function definition with appropriate typing

import scalaz.Memo
def fibonachi(n: Int) : Int = Memo.mutableHashMapMemo[Int, Int] {
    var value : Int = 0
    if( n <= 1 ) { value = n; }
    else         { value = fibonachi(n-1) + fibonachi(n-2) }
    return value
}

but I get the following compilation error

cmd10.sc:4: type mismatch;
 found   : Int => Int
 required: Int
def fibonachi(n: Int) : Int = Memo.mutableHashMapMemo[Int, Int] {
                                                                         ^Compilation Failed
Compilation Failed

So I am trying to understand the generic way of adding adding a memoization annotation to a scala def function


Solution

  • One way to achieve a Fibonacci sequence is via a recursive Stream.

    val fib: Stream[BigInt] = 0 #:: fib.scan(1:BigInt)(_+_)
    

    An interesting aspect of streams is that, if something holds on to the head of the stream, the calculation results are auto-memoized. So, in this case, because the identifier fib is a val and not a def, the value of fib(n) is calculated only once and simply retrieved thereafter.

    However, indexing a Stream is still a linear operation. If you want to memoize that away you could create a simple memo-wrapper.

    def memo[A,R](f: A=>R): A=>R =
      new collection.mutable.WeakHashMap[A,R] {
        override def apply(a: A) = getOrElseUpdate(a,f(a))
      }
    
    val fib: Stream[BigInt] = 0 #:: fib.scan(1:BigInt)(_+_)
    val mfib = memo(fib)
    mfib(99)  //res0: BigInt = 218922995834555169026