Search code examples
pythonnumpypytorcharray-broadcasting

Efficient PyTorch or NumPy broadcasting not found to avoid bottleneck operations


I have the following implementation in my PyTorch-based code which involves a nested for loop. The nested for loop along with the if condition makes the code very slow to execute. I attempted to avoid the nested loop to involve the broadcasting concepts in NumPy and PyTorch but that did not yield any result. Any help regarding avoiding the for loops will be appreciated.

Here are the links I have read PyTorch, NumPy.

#!/usr/bin/env python
# coding: utf-8

import torch

batch_size=32
mask=torch.FloatTensor(batch_size).uniform_() > 0.8

teacher_count=510
student_count=420
feature_dim=750
student_output=torch.zeros([batch_size,student_count])
teacher_output=torch.zeros([batch_size,teacher_count])

student_adjacency_mat=torch.randint(0,1,(student_count,student_count))
teacher_adjacency_mat=torch.randint(0,1,(teacher_count,teacher_count))

student_feat=torch.rand([batch_size,feature_dim])
student_graph=torch.rand([student_count,feature_dim])
teacher_feat=torch.rand([batch_size,feature_dim])
teacher_graph=torch.rand([teacher_count,feature_dim])


for m in range(batch_size):
    if mask[m]==1:
        for i in range(student_count):
            for j in range(student_count):
                student_output[m][i]=student_output[m][i]+student_adjacency_mat[i][j]*torch.dot(student_feat[m],student_graph[j])
    if mask[m]==0:
        for i in range(teacher_count):
            for j in range(teacher_count):
                teacher_output[m][i]=teacher_output[m][i]+teacher_adjacency_mat[i][j]*torch.dot(teacher_feat[m],teacher_graph[j])   

Solution

  • Problem statement

    The operation you are looking to perform is quite straightforward. If you look closely at your loop:

    for m in range(batch_size):
      if mask[m]==1:
        for i in range(student_count):
          for j in range(student_count):
             student_output[m][i] += student_adjacency_mat[i][j]*torch.dot(student_feat[m],student_graph[j])
    
       if mask[m]==0:
         for i in range(teacher_count):
           for j in range(teacher_count):
             teacher_output[m][i] += teacher_adjacency_mat[i][j]*torch.dot(teacher_feat[m],teacher_graph[j])
    

    Elements relevant to us:

    • You have two operations separated based on a mask which can ultimately be computed separately.

    • Each operation is looping through the adjacent matrices, ie. student_count².

    • The assignment operation comes down to

      output[m,i] += adj_matrix[i,j] * <feats[m] / graph[j]>
      

      where adj_matrix[i,j] is a scalar.


    Using torch.einsum

    This is a typical use case for torch.einsum. You can read more on this thread where I also happen to have written an answer.

    If we keep away from all implementation details, the formulation with torch.einsum is rather self-explanatory:

    o = torch.einsum('ij,mf,jf->mi', adj_matrix, feats, graph)
    

    In pseudo-code, this comes down to:

    o[m,i] += adj_matrix[i,j]*feats[m,f]*graph[j,f]
    

    For all i, j, m, and f in your domain of interest. Which is precisely the desired operation.

    Combined with the mask expanded to the appropriate form with M = mask[:,None], this gives you for the student tensor:

    >>> student = M*torch.einsum('ij,mf,jf->mi', student_adjacency_mat, student_feat, student_graph)
    

    For the teacher result, you can invert the mask with ~M:

    >>> teacher = ~M*torch.einsum('ij,mf,jf->mi', teacher_adjacency_mat, teacher_feat, teacher_graph)
    

    Using torch.matmul

    Alternatively, since this is a rather simple application of torch.einsum, you can also get away with two calls to torch.matmul. Given A and B, two matrices indexed by ik, and kj respectively, you get A@B which corresponds to ik@kj -> ij. Therefore you can get the desired result with:

    >>> g = student_feat@student_graph.T # [email protected] -> mf@fj -> mj
    >>> g@student_adjacency_mat.T        # [email protected] -> mj@ji -> mi
    

    See how the two steps relate to the torch.einsum call with 'ij,mf,jf->mi'. First mf,jf->mj, followed by mj,ij->mi.


    Side note your current dummy student and teacher adjacenty matrices are initialized with zeros. Maybe you meant to have:

    student_adjacency_mat=torch.randint(0,2,(student_count,student_count)).float()
    teacher_adjacency_mat=torch.randint(0,2,(teacher_count,teacher_count)).float()