Suppose I have the following embeddings emb_user = torch.randn(64, 128, 256)
. From the second dimension (of length 128), I wish to pick out 16 at random at each instance. I was wondering if there was a more efficient way of doing the following:
idx = torch.multinomial(torch.ones(64, 128), 16)
sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]
What I also find curios is that the above multinomial would not work if the weight matrix (torch.ones(64, 128)
) exceeded more than 2 dimensions.
Since in your case you want an uniform distribution you could speed it up with
idx = torch.sort(torch.randint(
0, 128 - 15, (64, 16), device=device
), axis=1).values + torch.arange(0, 16, device=device).reshape(1, -1)
sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]
Instead of
idx = torch.multinomial(torch.ones(64, 128, device=device), 16)
sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]
The runtimes on my machine are 427 µs
and 784 µs
with device='cpu'
; 135 µs
and 260 µs
and 469 µs
with device='cuda'
.
How it works?
The sorted randint gives the indices for a multinomial distribution with replacement. That is increasing, adding the arange
term makes it strictly increasing, thus eliminates the replacements.
Illustrating with a small case
idx = torch.sort(torch.randint(0, 7, (4,))).values
print('Indices with replacement in the range from 0 to 6: ', idx)
print('Indices without replacement in the slice: ', idx + torch.arange(4))
Indices with replacement in the range from 0 to 6: tensor([0, 5, 5, 6])
Indices without replacement in the slice: tensor([0, 6, 7, 9])
A possibly faster solution, but not from exactly the same distribution is the following:
idx = torch.cumsum(torch.diff(
torch.sort(torch.randint(
0, 128 - 16, (64, 17), device=device
), axis=1).values
, axis=1) + 1, axis=1) - 1
sampled_emb_user = emb_user[torch.arange(len(emb_user)).unsqueeze(-1), idx]
One more way, I expect to be closer to the exact method, not very rigorously analyzed.
# 1-rand() to include 1 and exclude zero.
d = torch.cumsum(1 - torch.rand(64, 17, device=device
), axis=1)
# this produces a sorted tensor with values in the range [0:128-16]
d = (((128 - 15) * d[:, :-1]) / d[:, -1:]).to(torch.long)
idx = d + torch.arange(0, 16, device=device).reshape(1, -1)
But in the end it tends to be slower than the method using sort.