I have a checkpoint callback function (i.e, custom_dec) that returns a Tensor, and a dictionary. But it seems like this function does not return dictionaries (or other data types), but only tensors. What is the workaround to this, as the module that I want to checkpoint is returning a tensor, plus a data type as dictionary:
def custom_dec(self, module):
def custom_forward(*inputs):
output = module(inputs[0], inputs[1],
encoder_attn_mask=inputs[2],
decoder_padding_mask=inputs[3],
layer_state=inputs[4],
causal_mask=inputs[5],
output_attentions=inputs[6],
)
# output[2] is a python dictionary
return output[0], output[2]
The following is the checkpoint call:
x, layer_past = \
checkpoint.checkpoint(
self.custom_dec(decoder_layer),
x,
encoder_hidden_states,
encoder_padding_mask,
decoder_padding_mask,
layer_state,
decoder_causal_mask,
output_attentions,
)
The error:
TypeError: CheckpointFunctionBackward.forward: expected Variable (got dictionary) for return value 1
A similar situation was discussed here.
What you can do is to convert the dictionary into some tensor form. I faced an error where it was caused by an input list which is not accepted by torch.utils.checkpoint
. My solution was to pass the tensors in the list as independent tensors and form a list out of them in custom_forward
.
I don't know the form of your dictionary (e.g. if every key will always have a value), but you can come up with a Dictionary-Tensor inter-change scheme that works for your dictionary.