I'm learning about policy gradients and I'm having hard time understanding how does the gradient passes through a random operation. From here: It is not possible to directly backpropagate through random samples. However, there are two main methods for creating surrogate functions that can be backpropagated through
.
They have an example of the score function
:
probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()
Which I tried to create an example of:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import matplotlib.pyplot as plt
from tqdm import tqdm
softplus = torch.nn.Softplus()
class Model_RL(nn.Module):
def __init__(self):
super(Model_RL, self).__init__()
self.fc1 = nn.Linear(1, 20)
self.fc2 = nn.Linear(20, 30)
self.fc3 = nn.Linear(30, 2)
def forward(self, x):
x1 = self.fc1(x)
x = torch.relu(x1)
x2 = self.fc2(x)
x = torch.relu(x2)
x3 = softplus(self.fc3(x))
return x3, x2, x1
# basic
net_RL = Model_RL()
features = torch.tensor([1.0])
x = torch.tensor([1.0])
y = torch.tensor(3.0)
baseline = 0
baseline_lr = 0.1
epochs = 3
opt_RL = optim.Adam(net_RL.parameters(), lr=1e-3)
losses = []
xs = []
for _ in tqdm(range(epochs)):
out_RL = net_RL(x)
mu, std = out_RL[0]
dist = Normal(mu, std)
print(dist)
a = dist.sample()
log_p = dist.log_prob(a)
out = features * a
reward = -torch.square((y - out))
baseline = (1-baseline_lr)*baseline + baseline_lr*reward
loss = -(reward-baseline)*log_p
opt_RL.zero_grad()
loss.backward()
opt_RL.step()
losses.append(loss.item())
This seems to work magically fine which again, I don't understand how the gradient passes through as they mentioned that it can't pass through the random operation (but then somehow it does).
Now since the gradient can't flow through the random operation I tried to replace
mu, std = out_RL[0]
with mu, std = out_RL[0].detach()
and that caused the error:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
. If the gradient doesn't pass through the random operation, I don't understand why would detaching a tensor before the operation matter.
It is indeed true that sampling is not a differentiable operation per se. However, there exist two (broad) ways to mitigate this - [1] The REINFORCE way and [2] The reparameterization way. Since your example is related to [1], I will stick my answer to REINFORCE.
What REINFORCE does is it entirely gets rid of sampling operation in the computation graph. However, the sampling operation remains outside the graph. So, your statement
.. how does the gradient passes through a random operation ..
isn't correct. It does not pass through any random operation. Let's see your example
mu, std = out_RL[0]
dist = Normal(mu, std)
a = dist.sample()
log_p = dist.log_prob(a)
Computation of a
does not involve creating a computation graph. It is technically equivalent to plugging in some offline data from a dataset (as in supervised learning)
mu, std = out_RL[0]
dist = Normal(mu, std)
# a = dist.sample()
a = torch.tensor([1.23, 4.01, -1.2, ...], device='cuda')
log_p = dist.log_prob(a)
Since we don't have offline data beforehand, we create them on the fly and the .sample()
method does merely that.
So, there is no random operation on the graph. The log_p
depends on mu
and std
deterministically, just like any standard computation graph. If you cut the connection like this
mu, std = out_RL[0].detach()
.. of course it is going to complaint.
Also, do not get confused by this operation
dist = Normal(mu, std)
log_p = dist.log_prob(a)
as it does not contain any randomness by itself. This is merely a shortcut for writing the tedious log-likelihood formula for Normal
distribution.