Search code examples
pythonpytorchgradient-descentautograd

Problem in Backpropagation through a sample in Beta distribution in pytorch


Say I have obtained some alphas and betas as parameters from a neural network, which will be parameters of the Beta distribution. Now, I sample from the Beta distribution and then calculate some loss and back-propagate via the samples obtained. Is it possible to do that? Given that after the sampling process, I do .requires_grad_(True) to the sample and then compute the loss? This surely works, but it looks like the loss is not converging, is there any other way to do this in PyTorch?

Say, I get the following variables via some neural network:

mu, sigma, pred = model.forward(input)

Where say, mu is the (batch_size x 30) shaped tensor, similarly sigma is (batch_size x 30) shaped tensor. I compute the alphas and betas using the mu and sigma obtained from a Neural Network (both of the same shape (batch_size x 30)), and then sample it via a beta distribution as follows:

def sample_from_beta_distribution(alpha, beta, eps=1e-6):
    # Clamp alpha and beta to be positive
    alpha_positive = torch.clamp(alpha, min=eps)
    beta_positive = torch.clamp(beta, min=eps)
    
    # Create a Beta distribution
    # This will automatically broadcast to handle the batch dimension
    beta_dist = torch.distributions.beta.Beta(alpha_positive, beta_positive)
    
    # Sample from the distribution
    # This will return samples of shape [38, 30]
    samples = beta_dist.sample()
    
    return samples

Here, I take the samples which is of the same shape as (batch_size x 30), perform some operations on it, and then calculate the loss. I expected the gradient to propagate through this, but looks like the loss is not converging.

Any leads would help. Please note, this is not as simple as the reparameterization trick in the standard Normal distribution.


Solution

  • Looks like .rsample() is doing the trick here, which is keeping the computational graph alive...