Search code examples
pytorchnormal-distribution

How to make a Truncated normal distribution in pytorch?


I want to create a Truncated normal distribution(that is Gaussian distribution with a range) in PyTorch.
I want to be able to change the mean, std, and range.
Is there a PyTorch method for that?


Solution

  • Use torch.nn.init.trunc_normal_.

    Description as given Here:

    Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:\mathcal{N}(\text{mean}, \text{std}^2) with values outside :math:[a, b] redrawn until they are within the bounds. The method used for generating the random values works best when :math:a \leq \text{mean} \leq b.