Search code examples
tensorflowdeep-learningtensorflow-probability

Tensorflow: Sampling a tensor according to another tensor?


I have a tensor T of shape Batch_Size x Num_Items x Item_Dimension and another tensor P of shape Batch_Size x Num_Items, where the Num_Items values in each batch of P sum to 1 (a probability distribution of items for each batch). I want to sample without replacement N items from T according to probability distribution P. The resulting tensor should be of shape Batch_Size x N x Item_Dimension. How would I do this?


Solution

  • Take a look at https://github.com/tensorflow/tensorflow/issues/9260

    Though note I believe you need logits instead of probs for Gumbel max sampling.