I understand torch.Normal(loc, scale)
is a class corresponding to univariate normal distribution in pytorch.
I understand how it works when loc and scale are numbers.
The problem is when the inputs to torch.Normal are tensors as opposed to numbers. In that case I do not understand it well. What is the exact interpretation/usage of such tensor arguments?
See for example y_dist in code below. loc and scale are tensors for y_dist. What does this exactly mean? I do not think this converts the univariate distribution to multivariate, does it? Does it instead form a group of univariate distributions?
import torch as pt
ptd = pt.distributions
x_dist = ptd.Normal(loc = 2, scale = 3)
x_samples = x_dist.sample()
batch_size = 256
y_dist = ptd.Normal(loc = 0.25 * pt.ones(batch_size, dtype=pt.float32), scale = pt.ones(batch_size, dtype=pt.float32))
As you said, if loc
(a.k.a. mu
) and scale
(a.k.a. sigma
) are floats then it will sample from a normal distribution, with loc
as the mean, and scale
as the standard deviation.
Providing tensors instead of floats will just make it sample from more than one normal distribution independently (unlike torch.distributions.MultivariateNormal
of course)
If you look at the source code you will see loc
and scale are broadcasted to the same shape on __init__
.
Here's an example to show this behavior:
>>> mu = torch.tensor([-10, 10], dtype=torch.float)
>>> sigma = torch.ones(2, 2)
>>> y_dist = Normal(loc=mu, scale=sigma)
Above mu
is 1D, while sigma
is 2D, yet:
>>> y_dist.loc
tensor([[-10., 10.],
[-10., 10.]])
So it will get two samples from N(-10, 1)
and two samples from N(10, 1)
>>> y_dist.sample()
tensor([[ -9.1686, 10.6062],
[-10.0974, 8.5439]])
Similarly:
>>> mu = torch.zeros(2, 2)
>>> sigma = torch.tensor([0.001, 1000], dtype=torch.float)
>>> y_dist = Normal(loc=mu, scale=sigma)
Will broadcast scale
to be:
>>> y_dist.scale
tensor([[1.0000e-03, 1.0000e+01],
[1.0000e-03, 1.0000e+01]])
>>> y_dist.sample()
tensor([[-8.0329e-04, 1.4213e+01],
[-1.4907e-03, 3.1190e+02]])