Search code examples
scalazstate-monadst-monadstarrayscala-cats

How can I implement a Fisher-Yates shuffle in Scala without side effects?


I want to implement the Fisher-Yates algorithm (an in-place array shuffle) without side effects by using an STArray for the local mutation effects, and a functional random number generator

type RNG[A] = State[Seed,A]

to produce the random integers needed by the algorithm.

I have a method def intInRange(max: Int): RNG[Int] which I can use to produce a random Int in [0,max).

From Wikipedia:

To shuffle an array a of n elements (indices 0..n-1):
    for i from n − 1 downto 1 do
        j ← random integer such that 0 ≤ j ≤ i
        exchange a[j] and a[i]

I suppose I need to stack State with ST somehow, but this is confusing to me. Do I need a [S]StateT[ST[S,?],Seed,A]? Do I have to rewrite RNG to use StateT as well?

(Edit) I don't want to involve IO, and I don't want to substitute Vector for STArray because the shuffle wouldn't be performed in-place.

I know there is a Haskell implementation here, but I'm not currently capable of understanding and porting this to Scalaz. But maybe you can? :)

Thanks in advance.


Solution

  • Here is a more or less direct translation from the Haskell version you linked that uses a mutable STArray. The Scalaz STArray doesn't have an exact equivalent of the listArray function, so I've made one up. Otherwise, it's a straightforward transliteration:

    import scalaz._
    import scalaz.effect.{ST, STArray}
    import ST._
    import State._
    import syntax.traverse._
    import std.list._
    
    def shuffle[A:Manifest](xs: List[A]): RNG[List[A]] = {
      def newArray[S](n: Int, as: List[A]): ST[S, STArray[S, A]] =
        if (n <= 0) newArr(0, null.asInstanceOf[A])
        else for {
          r <- newArr[S,A](n, as.head)
          _ <- r.fill((_, a: A) => a, as.zipWithIndex.map(_.swap))
        } yield r
      for {
        seed <- get[Seed]
        n = xs.length
        r <- runST(new Forall[({type λ[σ] = ST[σ, RNG[List[A]]]})#λ] {
          def apply[S] = for {
            g <- newVar[S](seed)
            randomRST = (lo: Int, hi: Int) => for {
              p <- g.read.map(intInRange(hi - lo).apply)
              (a, sp) = p
              _ <- g.write(sp)
            } yield a + lo
            ar  <- newArray[S](n, xs)
            xsp <- Range(0, n).toList.traverseU { i => for {
              j  <- randomRST(i, n)
              vi <- ar read i
              vj <- ar read j
              _  <- ar.write(j, vi)
            } yield vj }
            genp <- g.read
          } yield put(genp).map(_ => xsp)
        })
      } yield r
    }
    

    Although the asymptotics of using a mutable array might be good, do note that the constant factors of the ST monad in Scala are quite large. You may be better off just doing this in a monolithic block using regular mutable arrays. The overall shuffle function remains pure because all of your mutable state is local.