Search code examples
performancetensorflowsamplingmultinomial

Tensorflow: Efficient multinomial sampling (Theano x50 faster?)


I want to be able to sample from a multinomial distribution very efficiently and apparently my TensorFlow code is very... very slow...

The idea is that, I have:

  • A vector: counts = [40, 50, 26, ..., 19] for example
  • A matrix of probabilities: probs = [[0.1, ..., 0.5], ... [0.3, ..., 0.02]] such that np.sum(probs, axis=1) = 1

Let's say len(counts) = N and len(probs) = (N, 50). What I want to do is (in our example):

  • sample 40 times from the first probability vector of the matrix probs
  • sample 50 times from the second probability vector of the matrix probs
  • ...
  • sample 19 times from the Nth probability vector of the matrix probs

such that my final matrix looks like (for example): A = [[22, ... 13], ..., [12, ..., 3]] where np.sum(A, axis=1) == counts (i.e the sum over each row = the number in the corresponding row of counts vector)

Here is my TensorFlow code sample:

import numpy as np
import tensorflow as tf
import tensorflow.contrib.distributions as ds
import time

nb_distribution = 100 # number of probability distributions

counts = np.random.randint(2000, 3500, size=nb_distribution) # define number of counts (vector of size 100 with int in 2000, 3500)
# print(u[:40]) # should be the same as the output of print(np.sum(res, 1)[:40]) in the tf.Session()

# probsn is a matrix of probability:
# each row of probsn contains a vector of size 30 that sums to 1
probsn = np.random.uniform(size=(nb_distribution, 30))
probsn /= np.sum(probsn, axis=1)[:, None]

counts = tf.Variable(counts, dtype=tf.float32)
probs = tf.Variable(tf.convert_to_tensor(probsn.astype(np.float32)))

# sample from the multinomial
dist = ds.Multinomial(total_count=counts, probs=probs)
out = dist.sample()

start = time.time()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    res = sess.run(out)
    # print(np.sum(res, 1)[:40])
print(time.time() - start)

elapsed time: 0.12 seconds

My equivalent code in Theano:

import numpy as np
import theano
from theano.tensor import _shared

nb_distribution = 100 # number of probability distributions

counts = np.random.randint(2000, 3500, size=nb_distribution)
#print(u[:40]) # should be the same as the output of print(np.sum(v_sample(), 1)[:40])

counts = _shared(counts) # define number of counts (vector of size 100 with int in 2000, 3500)

# probsn is a matrix of probability:
# each row of probsn contains a vector that sums to 1
probsn = np.random.uniform(size=(nb_distribution, 30)) 
probsn /= np.sum(probsn, axis=1)[:, None]
probsn = _shared(probsn)

from theano.tensor.shared_randomstreams import RandomStreams

np_rng = np.random.RandomState(12345)
theano_rng = RandomStreams(np_rng.randint(2 ** 30))

v_sample = theano.function(inputs=[], outputs=theano_rng.multinomial(n=counts, pvals=probsn))

start_t = time.time()
out = np.sum(v_sample(), 1)[:40]
# print(out)
print(time.time() - start_t)

elapsed time: 0.0025 seconds

Theano is like 100x faster... Is there something wrong with my TensorFlow code? How can I sample from a multinomial distribution efficiently in TensorFlow?


Solution

  • The problem is that the TensorFlow multinomial sample() method actually uses the method calls _sample_n(). This method is defined here. As we can see in the code to sample from the multinomial the code produces a matrix of one_hot for each row and then reduce the matrix into a vector by summing over the rows:

    math_ops.reduce_sum(array_ops.one_hot(x, depth=k), axis=-2)

    It is inefficient because it uses extra memory. To avoid this I have used the tf.scatter_nd function. Here is a fully runnable example:

    import tensorflow as tf
    import numpy as np
    import tensorflow.contrib.distributions as ds
    import time
    
    tf.reset_default_graph()
    
    nb_distribution = 100 # number of probabilities distribution
    
    u = np.random.randint(2000, 3500, size=nb_distribution) # define number of counts (vector of size 100 with int in 2000, 3500)
    
    # probsn is a matrix of probability:
    # each row of probsn contains a vector of size 30 that sums to 1
    probsn = np.random.uniform(size=(nb_distribution, 30))
    probsn /= np.sum(probsn, axis=1)[:, None]
    
    counts = tf.Variable(u, dtype=tf.float32)
    probs = tf.Variable(tf.convert_to_tensor(probsn.astype(np.float32)))
    
    # sample from the multinomial
    dist = ds.Multinomial(total_count=counts, probs=probs)
    out = dist.sample()
    
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        res = sess.run(out) # if remove this line the code is slower...
        start = time.time()
        res = sess.run(out)
        print(time.time() - start)
        print(np.all(u == np.sum(res, axis=1)))
    

    This code took 0.05 seconds to compute

    def vmultinomial_sampling(counts, pvals, seed=None):
        k = tf.shape(pvals)[1]
        logits = tf.expand_dims(tf.log(pvals), 1)
    
        def sample_single(args):
            logits_, n_draw_ = args[0], args[1]
            x = tf.multinomial(logits_, n_draw_, seed)
            indices = tf.cast(tf.reshape(x, [-1,1]), tf.int32)
            updates = tf.ones(n_draw_) # tf.shape(indices)[0]
            return tf.scatter_nd(indices, updates, [k])
    
        x = tf.map_fn(sample_single, [logits, counts], dtype=tf.float32)
    
        return x
    
    xx = vmultinomial_sampling(u, probsn)
    # check = tf.expand_dims(counts, 1) * probs
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        res = sess.run(xx) # if remove this line the code is slower...
        start_t = time.time()
        res = sess.run(xx)
        print(time.time() -start_t)
        #print(np.sum(res, axis=1))
        print(np.all(u == np.sum(res, axis=1)))
    

    This code took 0.016 seconds

    The drawback is that my code doesn't actually parallelize the computation (even though parallel_iterations parameter is set to 10 by default in map_fn, putting it to 1 doesn't change anything...)

    Maybe someone will find something better because it is still very slow as compare to Theano's implementation (due to the fact that it doesn't take advantage of the parallelization... and yet, here, parallelization makes sense because sampling one row is indenpendent from sampling another one...)