I'm attempting to select a single sample from a range of Normal distributions based upon the output of a categorical distribution, however can't seem to come up with quite the right way to do it. Using something along the lines of:
tfp.distributions.JointDistributionSequential([
tfp.distributions.Categorical(probs=[0, 0, 1/2, 1/2]),
lambda c: tfp.distributions.Normal([0, 1, -10, 30], 1)[..., c]
])
Returns exactly what I want for the single case, however if I want multiple samples at once this breaks (as c becomes a numpy array rather than an integer. Is this possible and if so, how should I go about it?
(I also attempted using OneHotCategorical and multiplying but that didn't work at all!)
You could do this, if you don't want to use MixtureSameFamily
as Brian suggests:
tfp.distributions.JointDistributionSequential([
tfp.distributions.Categorical(probs=[0, 0, 1/2, 1/2]),
lambda c: tfp.distributions.Normal(tf.gather([0., 1, -10, 30], c), 1)
])
Note I needed to add a .
to the locs in the gather to avoid a dtype error.
Here, what we end up doing is
n
samples from the Categorical
n
Normal
s, whose locs are obtained by indexing n
times into the 4-vector of locsn
-batch of Normal
s.The previous approach doesn't work because Distribution
slicing doesn't support this kind of "fancy indexing" It would be cool if we did! TF doesn't support it in general, for reasons.