Search code examples
pytorch

Pytorch - Selecting n indices without replacement from dimension x


Suppose I have the following embeddings emb_user = torch.randn(64, 128, 256). From the second dimension (of length 128), I wish to pick out 16 at random at each instance. I was wondering if there was a more efficient way of doing the following:

idx = torch.multinomial(torch.ones(64, 128), 16)
sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]

What I also find curios is that the above multinomial would not work if the weight matrix (torch.ones(64, 128)) exceeded more than 2 dimensions.


Solution

  • Since in your case you want an uniform distribution you could speed it up with

    idx = torch.sort(torch.randint(
        0, 128 - 15, (64, 16), device=device
    ), axis=1).values + torch.arange(0, 16, device=device).reshape(1, -1)
    sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]
    

    Instead of

    idx = torch.multinomial(torch.ones(64, 128, device=device), 16)
    sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]
    

    The runtimes on my machine are 427 µs and 784 µs with device='cpu'; 135 µs and 260 µs and 469 µs with device='cuda'.

    How it works?

    The sorted randint gives the indices for a multinomial distribution with replacement. That is increasing, adding the arange term makes it strictly increasing, thus eliminates the replacements.

    Illustrating with a small case

    idx = torch.sort(torch.randint(0, 7, (4,))).values
    print('Indices with replacement in the range from 0 to 6: ', idx)
    print('Indices without replacement in the slice: ', idx + torch.arange(4))
    
    Indices with replacement in the range from 0 to 6:  tensor([0, 5, 5, 6])
    Indices without replacement in the slice:  tensor([0, 6, 7, 9])
    

    A possibly faster solution, but not from exactly the same distribution is the following:

    idx = torch.cumsum(torch.diff(
    torch.sort(torch.randint(
        0, 128 - 16, (64, 17), device=device
    ), axis=1).values
    , axis=1) + 1, axis=1) - 1
    sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]
    

    One more way, I expect to be closer to the exact method, not very rigorously analyzed.

    # 1-rand() to include 1 and exclude zero.
    d = torch.cumsum(1 - torch.rand(64, 17, device=device
    ), axis=1)
    # this produces a sorted tensor with values in the range [0:128-16]
    d = (((128 - 15) * d[:, :-1]) / d[:, -1:]).to(torch.long)
    idx = d + torch.arange(0, 16, device=device).reshape(1, -1)
    

    But in the end it tends to be slower than the method using sort.