Search code examples
pythonpytorchtorch

Gradient Checkpointing returning values


I have a checkpoint callback function (i.e, custom_dec) that returns a Tensor, and a dictionary. But it seems like this function does not return dictionaries (or other data types), but only tensors. What is the workaround to this, as the module that I want to checkpoint is returning a tensor, plus a data type as dictionary:

def custom_dec(self, module):
        def custom_forward(*inputs):
            output = module(inputs[0], inputs[1],
                            encoder_attn_mask=inputs[2],
                            decoder_padding_mask=inputs[3],
                            layer_state=inputs[4],
                            causal_mask=inputs[5],
                            output_attentions=inputs[6],
                            )
            # output[2] is a python dictionary
            return output[0], output[2]

The following is the checkpoint call:

x, layer_past = \
                checkpoint.checkpoint(
                    self.custom_dec(decoder_layer),
                    x,
                    encoder_hidden_states,
                    encoder_padding_mask,
                    decoder_padding_mask,
                    layer_state,
                    decoder_causal_mask,
                    output_attentions,
                )

The error:

TypeError: CheckpointFunctionBackward.forward: expected Variable (got dictionary) for return value 1


Solution

  • A similar situation was discussed here.

    What you can do is to convert the dictionary into some tensor form. I faced an error where it was caused by an input list which is not accepted by torch.utils.checkpoint. My solution was to pass the tensors in the list as independent tensors and form a list out of them in custom_forward.

    I don't know the form of your dictionary (e.g. if every key will always have a value), but you can come up with a Dictionary-Tensor inter-change scheme that works for your dictionary.