I'm trying to implement this paper, and stuck with this simple step. Although this is to do with attention, the thing I'm stuck with is just how to implement a permutation of a vector added to a matrix without using for loops.
The attention scores have a learned bias vector added to them, the theory is that it encodes relative position (j-i) of the two tokens the score represents
so alpha is a T x T matrix,T depends on the batch being forwarded, and B is a learned bias vector whose length has to be fixed and as large as 2T. My current implementation which I believe does what the paper suggests is:
def __init__(...):
...
self.bias = torch.nn.Parameter(torch.randn(config.n),requires_grad = True)
stdv = 1. / math.sqrt(self.bias.data.size(0))
self.bias.data.uniform_(-stdv, stdv)
def forward(..)
...
#n = 201 (2* max_seq_len + 1)
B_matrix = torch.zeros(self.T, self.T) # 60 x 60
for i in range(self.T):
B_matrix[i] = self.bias[torch.arange(start=n//2-i, end=n//2-i+T)])]
attention_scores = attention_scores + B_matrix.unsqueeze(0)
# 64 x 60 x 60
...
This is the only relevant part
B_matrix = torch.zeros(self.T, self.T) # 60 x 60
for i in range(self.T):
B_matrix[i] = self.bias[torch.arange(start=n//2-i, end=n//2-i+T)])]
basically trying to not use a for loop to go over each row.
but I know this must be really inefficient, and costly when this model is very large. I'm doing an explicit for loop over each row to get a permutation of the learned bias vector.
Can anyone help me out with a better way, through smart broadcasting perhaps?
After thinking about it, I don't need to instantiate a zero matrix, but still can't get rid of the for loop? and can't use gather as the B_matrix is a different size than a tiled b vector.
functor = lambda i : bias[torch.arange(start=n//2-i, end=n//2-i+T)]
B_matrix = torch.stack([functor(i) for i in torch.arange(T)])
I couldn't figure out what n
was supposed to be in your code but I think the following example using torch.meshgrid
provides what you're looking for.
Supposing
n, m = 10, 20 # arbitrary
a = torch.randn(n, m)
b = torch.randn(n + m)
then
for i in range(n):
for j in range(m):
a[i, j] = a[i, j] + b[n - i + j]
is equivalent to
ii, jj = torch.meshgrid(torch.arange(n), torch.arange(m))
a = a + b[n - ii + jj]
though the latter is an out-of-place operation, which is usually a good thing. If you actually wanted an in-place operation then replace a =
with a[...] =
.
Note that this is an example of integer array indexing where we index b
using a tensor that is the same shape as a
.