Search code examples
pythonpytorchdistributionnormal-distribution

What is a tensor argument to Normal supposed to mean in Distributions Package of Pytorch?


I understand torch.Normal(loc, scale) is a class corresponding to univariate normal distribution in pytorch. I understand how it works when loc and scale are numbers. The problem is when the inputs to torch.Normal are tensors as opposed to numbers. In that case I do not understand it well. What is the exact interpretation/usage of such tensor arguments? See for example y_dist in code below. loc and scale are tensors for y_dist. What does this exactly mean? I do not think this converts the univariate distribution to multivariate, does it? Does it instead form a group of univariate distributions?

import torch as pt
ptd = pt.distributions
x_dist = ptd.Normal(loc = 2, scale = 3)
x_samples = x_dist.sample()

batch_size = 256
y_dist = ptd.Normal(loc = 0.25 * pt.ones(batch_size, dtype=pt.float32), scale = pt.ones(batch_size, dtype=pt.float32))

Solution

  • As you said, if loc (a.k.a. mu) and scale (a.k.a. sigma) are floats then it will sample from a normal distribution, with loc as the mean, and scale as the standard deviation.

    Providing tensors instead of floats will just make it sample from more than one normal distribution independently (unlike torch.distributions.MultivariateNormal of course)

    If you look at the source code you will see loc and scale are broadcasted to the same shape on __init__.


    Here's an example to show this behavior:

    >>> mu = torch.tensor([-10, 10], dtype=torch.float)
    >>> sigma = torch.ones(2, 2)
    >>> y_dist = Normal(loc=mu, scale=sigma)
    

    Above mu is 1D, while sigma is 2D, yet:

    >>> y_dist.loc
    tensor([[-10.,  10.],
            [-10.,  10.]])
    

    So it will get two samples from N(-10, 1) and two samples from N(10, 1)

    >>> y_dist.sample()
    tensor([[ -9.1686,  10.6062],
            [-10.0974,   8.5439]])
    

    Similarly:

    >>> mu = torch.zeros(2, 2)
    >>> sigma = torch.tensor([0.001, 1000], dtype=torch.float)
    >>> y_dist = Normal(loc=mu, scale=sigma)
    

    Will broadcast scale to be:

    >>> y_dist.scale
    tensor([[1.0000e-03, 1.0000e+01],
            [1.0000e-03, 1.0000e+01]])
    
    >>> y_dist.sample()
    tensor([[-8.0329e-04,  1.4213e+01],
            [-1.4907e-03,  3.1190e+02]])