Here is the forward()
method of my pytorch model:
def forward(self, x, output_type, *unused_args, **unused_kwargs):
gru_output, gru_hn = self.gru(x)
# Decoder (Graph Adjacency Reconstruction)
for data_batch_idx in range(x.shape[0]):
pred = self.decoder(gru_output[data_batch_idx, -1, :]) # gru_output[-1] => only take last time-step
pred_graph_adj = pred.reshape(1, -1) if data_batch_idx == 0 else torch.cat((pred_graph_adj, pred.reshape(1, -1)), dim=0)
if output_type == "discretize":
bins = torch.tensor(self.model_cfg['output_bins']).reshape(-1, 1)
num_bins = len(bins)-1
bins = torch.concat((bins[:-1], bins[1:]), dim=1)
discretize_values = np.linspace(2, 4, num_bins)
for lower, upper, discretize_value in zip(bins[:, 0], bins[:, 1], discretize_values):
pred_graph_adj = torch.where((pred_graph_adj <= upper) & (pred_graph_adj > lower), discretize_value, pred_graph_adj)
pred_graph_adj = torch.where(pred_graph_adj < bins.min(), 2, pred_graph_adj)
return pred_graph_adj
And here is the snippet of training:
pred = self.forward(x, output_type=self.model_cfg['output_type'])
batch_loss = self.loss_fn(pred, y)
self.optimizer.zero_grad()
batch_loss.backward()
self.optimizer.step()
self.scheduler.step()
output_type
is not "discretize"
(not using torch.where
), sum([p.grad.sum() for p in self.decoder.parameters()])
will be non-zero.
output_type
is "discretize"
(using torch.where
), sum([p.grad.sum() for p in self.decoder.parameters()])
will be zero.batch_loss
, it's not zero.require_grad
of weight of model, they are True.pred
and batch_loss
are connect to model's weight.My questions are:
torch.where
cause the model's parameter gradients to become zero?torch.where
won't cause that, what's other possible reasons?pred_graph_adj
are between -1 ~ 1.
pred_graph_adj
(after torch.where
), they are between 2 ~ 4.torch.where
(lower
and upper
and discretize_value
) in each for-loop are:(lower, upper] -> discrete_values:
(-1, -0.25] -> 2
(-0.25, 0.25] -> 3
(0.25, 1] -> 4
Ok, I think I have enough information about your issue to answer it at this point. Some prelimiaries:
torch.where
statement before being used to compute loss.torch.where
statement replaces some model outputs that fall outside of a certain range with some other values (lower
,upper
,discretize_values
) which are generated independently of the model outputs.Let's use some simpler variable names to avoid the confusion of inplace assignments. Let output
be your model outputs. Let replace
be the replacement values. Let condition
be the condition statement. The simplified expression is something like:
result = torch.where(condition, output, replace)
loss = f(result) # some arbitrary function
It is true that the partial derivative of loss w.r.t. output is not ALWAYS zero. However, torch.where
is a simple indexing statement, meaning that the partial derivative of the output with respect to each of the inputs is 1 for all selected elements and zero for unselected elements. Since all of output
is unselected, partial of result
with respect to output
is zero. Partial of result
with respect to replace
is non-zero, but these elements are not a function of the model parameters or outputs (and furthermore are defined with a numpy statement, a dead giveaway that the function graph will not flow backwards past discretize_values
.)
So, while your problem is not structural (e.g. it is possible for loss to be backpropogated to the model parameters) your initial conditions are such that your optimization routine is stuck in a local minimum where all elements are always replaced and the gradient with respect to the model parameters is always zero.
But wait!! The conditional expresssion condition
is itself a function of the model outputs output
. So the gradient should still be backpropogated due to this. While technically true, torch.where
is not differentiable with respect to condition (think about the fact that any, arbitrarily complex, horrendously non-differentiable expression can be used as condition
provided it is of type bool
, so this derivative would be noisy, messy, and often undefined). For this reason, torch.where
is only differentiable with respect to the other inputs (i.e. output
and replace
here).
To solve your issue you could try random parameter initialization, or else choose not to replace values with torch.where
at the beginning to allow for initial convergence, or come up with some other clever solution that is more heavily dependent on your actual use case.
Hope this helps!