I have an array params
with shape (N, 2)
, where each pair corresponds to a Beta distribution, e.g. (alpha, beta)
. I need to sample a value from each of these distributions; is there a way to do this without looping in Python? I know I can do this:
u = [tfp.distributions.Beta(a, b).sample() for (a, b) in params]
# Or
u = [np.random.beta(a, b) for (a, b) in params]
but I imagine this to be very slow. I am using tensorflow
, tensorflow_probability
and numpy
btw.
EDIT:
Thanks to JST99's answer, I found out that it is possible to pass a vector of parameters to tfp.distributions.Beta
as well:
u = tfp.distributions.Beta(params[:, 0], params[:, 1]).sample()
You can separate out the alpha
and the beta
parameters using zip
, then use np.random.beta
in a vectorized fashion.
alpha, beta = zip(*params)
u = np.random.beta(alpha, beta)
Here is an example with some dummy parameters.
>>> params = [(i, i) for i in range(1, 100)]
>>> alpha, beta = zip(*params)
>>> u = np.random.beta(alpha, beta)
>>> u
array([0.88027947, 0.22507079, 0.46668932, 0.65091097, 0.62278597,
0.37450051, 0.53237829, 0.5589561 , 0.55190015, 0.64352003,
0.61396155, 0.58559066, 0.59525124, 0.49827492, 0.45065234,
0.53716919, 0.52950708, 0.41751582, 0.44912503, 0.57043946,
0.51909876, 0.34834858, 0.55753122, 0.41586101, 0.46762533,
0.4905744 , 0.53927006, 0.5234163 , 0.56215437, 0.38265575,
0.4940874 , 0.45066854, 0.53654453, 0.40955841, 0.49478651,
0.52974175, 0.43218663, 0.49791192, 0.47176042, 0.46717939,
0.45576387, 0.58941562, 0.44112651, 0.45401485, 0.48990107,
0.5640564 , 0.46720441, 0.439157 , 0.56098725, 0.43914691,
0.44654769, 0.5639682 , 0.41962566, 0.53689739, 0.46501042,
0.52775508, 0.55992535, 0.4948104 , 0.54856768, 0.4711496 ,
0.44694159, 0.54769584, 0.51792418, 0.48669042, 0.51969972,
0.51599904, 0.4818758 , 0.47555456, 0.47581746, 0.43417686,
0.49156854, 0.51359563, 0.52830314, 0.50988281, 0.47357164,
0.47619267, 0.52755645, 0.50141785, 0.48280575, 0.47817313,
0.47954096, 0.53885494, 0.5218641 , 0.50253071, 0.58804552,
0.50788384, 0.49429312, 0.47677202, 0.45542669, 0.47169082,
0.58838068, 0.4992328 , 0.5098452 , 0.44761298, 0.45971338,
0.4841432 , 0.47673295, 0.48205439, 0.4799415 ])