Search code examples
pythonmachine-learningdeep-learningpytorch

Does using torch.where cause the model's parameter gradients to become zero?


Here is the forward() method of my pytorch model:

    def forward(self, x, output_type, *unused_args, **unused_kwargs):
        gru_output, gru_hn = self.gru(x)
        # Decoder (Graph Adjacency Reconstruction)
        for data_batch_idx in range(x.shape[0]):
            pred = self.decoder(gru_output[data_batch_idx, -1, :])  # gru_output[-1] => only take last time-step
            pred_graph_adj = pred.reshape(1, -1) if data_batch_idx == 0 else torch.cat((pred_graph_adj, pred.reshape(1, -1)), dim=0)
        if output_type == "discretize":
            bins = torch.tensor(self.model_cfg['output_bins']).reshape(-1, 1)
            num_bins = len(bins)-1
            bins = torch.concat((bins[:-1], bins[1:]), dim=1)
            discretize_values = np.linspace(2, 4, num_bins)
            for lower, upper, discretize_value in zip(bins[:, 0], bins[:, 1], discretize_values):
                pred_graph_adj = torch.where((pred_graph_adj <= upper) & (pred_graph_adj > lower), discretize_value, pred_graph_adj)
            pred_graph_adj = torch.where(pred_graph_adj < bins.min(), 2, pred_graph_adj)

        return pred_graph_adj

And here is the snippet of training:

                pred = self.forward(x, output_type=self.model_cfg['output_type'])
                batch_loss = self.loss_fn(pred, y)
                self.optimizer.zero_grad()
                batch_loss.backward()
                self.optimizer.step()
                self.scheduler.step()
  1. When output_type is not "discretize" (not using torch.where), sum([p.grad.sum() for p in self.decoder.parameters()]) will be non-zero.
    • But When output_type is "discretize" (using torch.where), sum([p.grad.sum() for p in self.decoder.parameters()]) will be zero.
  2. I've check the batch_loss, it's not zero.
  3. I've check all the require_grad of weight of model, they are True.
  4. I've check computational graph, pred and batch_loss are connect to model's weight.

My questions are:

  1. Does using torch.where cause the model's parameter gradients to become zero?
  2. If torch.where won't cause that, what's other possible reasons?

Update info:

  • The initial values of pred_graph_adj are between -1 ~ 1.
    • But I've check the values range of final pred_graph_adj (after torch.where), they are between 2 ~ 4.
  • The specific values of args of torch.where(lower and upper and discretize_value) in each for-loop are:
(lower, upper] -> discrete_values:
(-1, -0.25] -> 2
(-0.25, 0.25] -> 3
(0.25, 1] -> 4

Solution

  • Ok, I think I have enough information about your issue to answer it at this point. Some prelimiaries:

    • You are confused why the gradient for loss with respect to your model parameters is 0.
    • The outputs from your model pass through a torch.where statement before being used to compute loss.
    • That torch.where statement replaces some model outputs that fall outside of a certain range with some other values (lower,upper,discretize_values) which are generated independently of the model outputs.
    • Initially, the outputs of your model fall in a range such that ALL model outputs are replaced by these other values.

    Let's use some simpler variable names to avoid the confusion of inplace assignments. Let output be your model outputs. Let replace be the replacement values. Let condition be the condition statement. The simplified expression is something like:

    result = torch.where(condition, output, replace)
    loss = f(result) # some arbitrary function
    

    It is true that the partial derivative of loss w.r.t. output is not ALWAYS zero. However, torch.where is a simple indexing statement, meaning that the partial derivative of the output with respect to each of the inputs is 1 for all selected elements and zero for unselected elements. Since all of output is unselected, partial of result with respect to output is zero. Partial of result with respect to replace is non-zero, but these elements are not a function of the model parameters or outputs (and furthermore are defined with a numpy statement, a dead giveaway that the function graph will not flow backwards past discretize_values.)

    So, while your problem is not structural (e.g. it is possible for loss to be backpropogated to the model parameters) your initial conditions are such that your optimization routine is stuck in a local minimum where all elements are always replaced and the gradient with respect to the model parameters is always zero.

    But wait!! The conditional expresssion condition is itself a function of the model outputs output. So the gradient should still be backpropogated due to this. While technically true, torch.where is not differentiable with respect to condition (think about the fact that any, arbitrarily complex, horrendously non-differentiable expression can be used as condition provided it is of type bool, so this derivative would be noisy, messy, and often undefined). For this reason, torch.where is only differentiable with respect to the other inputs (i.e. output and replace here).

    To solve your issue you could try random parameter initialization, or else choose not to replace values with torch.where at the beginning to allow for initial convergence, or come up with some other clever solution that is more heavily dependent on your actual use case. Hope this helps!