Pytorch's EmbeddingBag
allows for efficient lookup + reduce operations on varying length collections of embedding indices. There are 3 modes: "sum", "average" and "max" for the reduce operation. With "sum", you can also provide per_sample_weights
giving you a weighted sum.
Why is per_sample_weights
not allowed for the "max" operation? Looking at how it's implemented, I can only assume there is an issue with performing a "ReduceMean" or "ReduceMax" operation after a "Mul" operation. Could that be something to do with calculating gradients??
p.s: It's easy enough to turn a weighted sum into a weighted average by dividing by the sum of the weights, but for "max" you can't get a weighted equivalent like that.
The argument per_sample_weights
was only implemented for mode='sum'
, not due to technical limitations, but because the developers found no use cases for a "weighted max":
I haven't been able to find use cases for "weighted mean" (which can be emulated via weighted sum) and "weighted max".