Search code examples
pythontensorflowdata-scienceprobabilitytensorflow-probability

Selecting one normal from a tensor based on another random variable in TensorFlow Probability


I'm attempting to select a single sample from a range of Normal distributions based upon the output of a categorical distribution, however can't seem to come up with quite the right way to do it. Using something along the lines of:

tfp.distributions.JointDistributionSequential([
        tfp.distributions.Categorical(probs=[0, 0, 1/2, 1/2]),
        lambda c: tfp.distributions.Normal([0, 1, -10, 30], 1)[..., c]
    ])

Returns exactly what I want for the single case, however if I want multiple samples at once this breaks (as c becomes a numpy array rather than an integer. Is this possible and if so, how should I go about it?

(I also attempted using OneHotCategorical and multiplying but that didn't work at all!)


Solution

  • You could do this, if you don't want to use MixtureSameFamily as Brian suggests:

    tfp.distributions.JointDistributionSequential([
            tfp.distributions.Categorical(probs=[0, 0, 1/2, 1/2]),
            lambda c: tfp.distributions.Normal(tf.gather([0., 1, -10, 30], c), 1)
        ])
    

    Note I needed to add a . to the locs in the gather to avoid a dtype error.

    Here, what we end up doing is

    1. drawing n samples from the Categorical
    2. constructing a batch of n Normals, whose locs are obtained by indexing n times into the 4-vector of locs
    3. sampling from that n-batch of Normals.

    The previous approach doesn't work because Distribution slicing doesn't support this kind of "fancy indexing" It would be cool if we did! TF doesn't support it in general, for reasons.