Search code examples
scalarandomsampling

What is the scala equivalent of Python's Numpy np.random.choice?(Random weighted selection in scala)


I was looking for Scala's equivalent code or underlying theory for pythons np.random.choice (Numpy as np). I have a similar implementation that uses Python's np.random.choice method to select the random moves from the probability distribution.

Python's code

Input list: ['pooh', 'rabbit', 'piglet', 'Christopher'] and probabilies: [0.5, 0.1, 0.1, 0.3]

I want to select one of the value from the input list given the associated probability of each input element.


Solution

  • The Scala standard library has no equivalent to np.random.choice but it shouldn't be too difficult to build your own, depending on which options/features you want to emulate.

    Here, for example, is a way to get an infinite Stream of submitted items, with the probability of any one item weighted relative to the others.

    def weightedSelect[T](input :(T,Int)*): Stream[T] = {
      val items  :Seq[T]    = input.flatMap{x => Seq.fill(x._2)(x._1)}
      def output :Stream[T] = util.Random.shuffle(items).toStream #::: output
      output
    }
    

    With this each input item is given with a multiplier. So to get an infinite pseudorandom selection of the characters c and v, with c coming up 3/5ths of the time and v coming up 2/5ths of the time:

    val cvs = weightedSelect(('c',3),('v',2))
    

    Thus the rough equivalent of the np.random.choice(aa_milne_arr,5,p=[0.5,0.1,0.1,0.3]) example would be:

    weightedSelect("pooh"-> 5
                  ,"rabbit" -> 1
                  ,"piglet" -> 1
                  ,"Christopher" -> 3).take(5).toArray
    

    Or perhaps you want a better (less pseudo) random distribution that might be heavily lopsided.

    def weightedSelect[T](items :Seq[T], distribution :Seq[Double]) :Stream[T] = {
      assert(items.length == distribution.length)
      assert(math.abs(1.0 - distribution.sum) < 0.001) // must be at least close
    
      val dsums  :Seq[Double] = distribution.scanLeft(0.0)(_+_).tail
      val distro :Seq[Double] = dsums.init :+ 1.1 // close a possible gap
      Stream.continually(items(distro.indexWhere(_ > util.Random.nextDouble())))
    }
    

    The result is still an infinite Stream of the specified elements but the passed-in arguments are a bit different.

    val choices :Stream[String] = weightedSelect( List("this"     , "that")
                                               , Array(4998/5000.0, 2/5000.0))
    
    // let's test the distribution
    val (choiceA, choiceB) = choices.take(10000).partition(_ == "this")
    
    choiceA.length  //res0: Int = 9995
    choiceB.length  //res1: Int = 5  (not bad)