I was looking for Scala's equivalent code or underlying theory for pythons np.random.choice (Numpy as np). I have a similar implementation that uses Python's np.random.choice method to select the random moves from the probability distribution.
Input list: ['pooh', 'rabbit', 'piglet', 'Christopher'] and probabilies: [0.5, 0.1, 0.1, 0.3]
I want to select one of the value from the input list given the associated probability of each input element.
The Scala standard library has no equivalent to np.random.choice
but it shouldn't be too difficult to build your own, depending on which options/features you want to emulate.
Here, for example, is a way to get an infinite Stream
of submitted items, with the probability of any one item weighted relative to the others.
def weightedSelect[T](input :(T,Int)*): Stream[T] = {
val items :Seq[T] = input.flatMap{x => Seq.fill(x._2)(x._1)}
def output :Stream[T] = util.Random.shuffle(items).toStream #::: output
output
}
With this each input item is given with a multiplier. So to get an infinite pseudorandom selection of the characters c
and v
, with c
coming up 3/5ths of the time and v
coming up 2/5ths of the time:
val cvs = weightedSelect(('c',3),('v',2))
Thus the rough equivalent of the np.random.choice(aa_milne_arr,5,p=[0.5,0.1,0.1,0.3])
example would be:
weightedSelect("pooh"-> 5
,"rabbit" -> 1
,"piglet" -> 1
,"Christopher" -> 3).take(5).toArray
Or perhaps you want a better (less pseudo) random distribution that might be heavily lopsided.
def weightedSelect[T](items :Seq[T], distribution :Seq[Double]) :Stream[T] = {
assert(items.length == distribution.length)
assert(math.abs(1.0 - distribution.sum) < 0.001) // must be at least close
val dsums :Seq[Double] = distribution.scanLeft(0.0)(_+_).tail
val distro :Seq[Double] = dsums.init :+ 1.1 // close a possible gap
Stream.continually(items(distro.indexWhere(_ > util.Random.nextDouble())))
}
The result is still an infinite Stream
of the specified elements but the passed-in arguments are a bit different.
val choices :Stream[String] = weightedSelect( List("this" , "that")
, Array(4998/5000.0, 2/5000.0))
// let's test the distribution
val (choiceA, choiceB) = choices.take(10000).partition(_ == "this")
choiceA.length //res0: Int = 9995
choiceB.length //res1: Int = 5 (not bad)