Search code examples
python-3.xtensorflowdeep-learningsoftmaxcross-entropy

How to batch compute cross entropy for pointer networks?


In pointer networks the output logits are over the length of the inputs. Working with such batches means padding the inputs to the maximum length of the batch inputs. Now, this is all fine till we have to compute loss. Currently what i am doing is :

logits = stabilize(logits(inputs))     #[batch, max_length]. subtract max(logits) to stabilize
masks = masks(inputs)     #[batch, max_length]. 1 for actual inputs, 0 for padded locations
exp_logits = exp(logits)
exp_logits_masked = exp_logits*masks
probs = exp_logits_masked/sum(exp_logits_masked)

Now i use these probabilities to compute cross entropy

cross_entropy = sum_over_batches(probs[correct_class])

Can i do better than this? Any ideas on how it is done generally by guys dealing with pointer networks?

If i didnt have variable size inputs this all could be achieved using callable tf.nn.softmax_cross_entropy_with_logits on logits and labels (which is highly optimized) but that in variable lengths would produce erroneous results as softmax computation has denominator larger by 1 for each padding in an input.


Solution

  • You look to be spot on with your approach, and to my knowledge this is how this is implemented in the RNN cells as well. Notice that the derivative of 1x = dx, and the derivative of 0x = 0. This produces the result you want because you are summing/averaging the gradients at the end of the network.

    The only thing you might consider is rescaling the loss based on the number of masked values. You might note that when there are 0 masked values your gradient will have a slightly different magnitude than than you have with many masked values. It's not clear to me that this will have a significant impact, but perhaps it will have a very small impact.

    Otherwise, I've used this same technique to great success myself, so I'm here to say you're on the right track.