I have some batched input x
of shape [batch, time, feature]
, and some batched indices i
of shape [batch, new_time]
which I want to gather into the time dim of x
. As output of this operation I want a tensor y
of shape [batch, new_time, feature]
with values like this:
y[b, t', f] = x[b, i[b, t'], f]
In Tensorflow, I can accomplish this by using the batch_dims: int
argument of tf.gather
: y = tf.gather(x, i, axis=1, batch_dims=1)
.
In PyTorch, I can think of some functions which do similar things:
torch.gather
of course, but this does not have an argument similar to Tensorflow's batch_dims
. The output of torch.gather
will always have the same shape as the indices. So I would need to unbroadcast the feature
dim into i
before passing it to torch.gather
.
torch.index_select
, but here, the indices must be one-dimensional. So to make it work I would need to unbroadcast x
to add a "batch * new_time
" dim, and then after torch.index_select
reshape the output.
torch.nn.functional.embedding
. Here, the embedding matrices would correspond to x
. But this embedding function does not support the weights to be batched, so I run into the same issue as for torch.index_select
(looking at the code, tf.embedding
uses torch.index_select
under the hood).
Is it possible to accomplish such gather operation without relying on unbroadcasting which is inefficient for large dims?
This is actually the most frequent case: when input and index tensors don't perfectly match the number of dimensions. You can still utilize torch.gather
though since you can rewrite your expression:
y[b, t, f] = x[b, i[b, t], f]
as:
y[b, t, f] = x[b, i[b, t, f], f]
which ensures all three tensors have an equal number of dimensions. This reveals a third dimension on i
, which we can easily create for free by unsqueezing a dimension and expanding it to the shape of x
. You can do so with i[:,None].expand_as(x)
.
Here is a minimal example:
>>> b = 2; t = 3; f = 1
>>> x = torch.rand(b, t, f)
>>> i = torch.randint(0, t, (b, f))
>>> x.gather(1, i[:,None].expand_as(x))