Search code examples
pythonpytorchembeddingweighted-average

Why does Pytorch EmbeddingBag with mode "max" not accept `per_sample_weights`?


Pytorch's EmbeddingBag allows for efficient lookup + reduce operations on varying length collections of embedding indices. There are 3 modes: "sum", "average" and "max" for the reduce operation. With "sum", you can also provide per_sample_weights giving you a weighted sum.

Why is per_sample_weights not allowed for the "max" operation? Looking at how it's implemented, I can only assume there is an issue with performing a "ReduceMean" or "ReduceMax" operation after a "Mul" operation. Could that be something to do with calculating gradients??


p.s: It's easy enough to turn a weighted sum into a weighted average by dividing by the sum of the weights, but for "max" you can't get a weighted equivalent like that.


Solution

  • The argument per_sample_weights was only implemented for mode='sum', not due to technical limitations, but because the developers found no use cases for a "weighted max":

    I haven't been able to find use cases for "weighted mean" (which can be emulated via weighted sum) and "weighted max".