Search code examples
numpyfor-loopmatrixpytorchattention-model

Pytorch, get rid of a for loop when adding permutation of one vector to entries of a matrix?


I'm trying to implement this paper, and stuck with this simple step. Although this is to do with attention, the thing I'm stuck with is just how to implement a permutation of a vector added to a matrix without using for loops.

The attention scores have a learned bias vector added to them, the theory is that it encodes relative position (j-i) of the two tokens the score represents enter image description here

so alpha is a T x T matrix,T depends on the batch being forwarded, and B is a learned bias vector whose length has to be fixed and as large as 2T. My current implementation which I believe does what the paper suggests is:

    def __init__(...):
       ...
        self.bias = torch.nn.Parameter(torch.randn(config.n),requires_grad = True)
        stdv = 1. / math.sqrt(self.bias.data.size(0))
        self.bias.data.uniform_(-stdv, stdv)
     def forward(..)
        ...
        #n = 201  (2* max_seq_len + 1)

        B_matrix = torch.zeros(self.T, self.T) # 60 x 60
        for i in range(self.T):
          B_matrix[i] = self.bias[torch.arange(start=n//2-i, end=n//2-i+T)])]

        attention_scores = attention_scores + B_matrix.unsqueeze(0)
        # 64 x 60 x 60   
        ...

This is the only relevant part

B_matrix = torch.zeros(self.T, self.T) # 60 x 60
        for i in range(self.T):
          B_matrix[i] = self.bias[torch.arange(start=n//2-i, end=n//2-i+T)])]

basically trying to not use a for loop to go over each row.

but I know this must be really inefficient, and costly when this model is very large. I'm doing an explicit for loop over each row to get a permutation of the learned bias vector.

Can anyone help me out with a better way, through smart broadcasting perhaps?

After thinking about it, I don't need to instantiate a zero matrix, but still can't get rid of the for loop? and can't use gather as the B_matrix is a different size than a tiled b vector.

functor = lambda i : bias[torch.arange(start=n//2-i, end=n//2-i+T)]
B_matrix = torch.stack([functor(i) for i in torch.arange(T)])

Solution

  • I couldn't figure out what n was supposed to be in your code but I think the following example using torch.meshgrid provides what you're looking for.

    Supposing

    n, m = 10, 20   # arbitrary
    a = torch.randn(n, m)
    b = torch.randn(n + m)
    

    then

    for i in range(n):
        for j in range(m):
            a[i, j] = a[i, j] + b[n - i + j]
    

    is equivalent to

    ii, jj = torch.meshgrid(torch.arange(n), torch.arange(m))
    a = a + b[n - ii + jj]
    

    though the latter is an out-of-place operation, which is usually a good thing. If you actually wanted an in-place operation then replace a = with a[...] =.

    Note that this is an example of integer array indexing where we index b using a tensor that is the same shape as a.