I have a tensor T
of shape Batch_Size x Num_Items x Item_Dimension
and another tensor P
of shape Batch_Size x Num_Items
, where the Num_Items values in each batch of P sum to 1 (a probability distribution of items for each batch). I want to sample without replacement N
items from T according to probability distribution P. The resulting tensor should be of shape Batch_Size x N x Item_Dimension
. How would I do this?
Take a look at https://github.com/tensorflow/tensorflow/issues/9260
Though note I believe you need logits instead of probs for Gumbel max sampling.