I have the following implementation in my PyTorch-based code which involves a nested for loop. The nested for loop along with the if
condition makes the code very slow to execute. I attempted to avoid the nested loop to involve the broadcasting concepts in NumPy and PyTorch but that did not yield any result. Any help regarding avoiding the for
loops will be appreciated.
Here are the links I have read PyTorch, NumPy.
#!/usr/bin/env python
# coding: utf-8
import torch
batch_size=32
mask=torch.FloatTensor(batch_size).uniform_() > 0.8
teacher_count=510
student_count=420
feature_dim=750
student_output=torch.zeros([batch_size,student_count])
teacher_output=torch.zeros([batch_size,teacher_count])
student_adjacency_mat=torch.randint(0,1,(student_count,student_count))
teacher_adjacency_mat=torch.randint(0,1,(teacher_count,teacher_count))
student_feat=torch.rand([batch_size,feature_dim])
student_graph=torch.rand([student_count,feature_dim])
teacher_feat=torch.rand([batch_size,feature_dim])
teacher_graph=torch.rand([teacher_count,feature_dim])
for m in range(batch_size):
if mask[m]==1:
for i in range(student_count):
for j in range(student_count):
student_output[m][i]=student_output[m][i]+student_adjacency_mat[i][j]*torch.dot(student_feat[m],student_graph[j])
if mask[m]==0:
for i in range(teacher_count):
for j in range(teacher_count):
teacher_output[m][i]=teacher_output[m][i]+teacher_adjacency_mat[i][j]*torch.dot(teacher_feat[m],teacher_graph[j])
The operation you are looking to perform is quite straightforward. If you look closely at your loop:
for m in range(batch_size):
if mask[m]==1:
for i in range(student_count):
for j in range(student_count):
student_output[m][i] += student_adjacency_mat[i][j]*torch.dot(student_feat[m],student_graph[j])
if mask[m]==0:
for i in range(teacher_count):
for j in range(teacher_count):
teacher_output[m][i] += teacher_adjacency_mat[i][j]*torch.dot(teacher_feat[m],teacher_graph[j])
Elements relevant to us:
You have two operations separated based on a mask which can ultimately be computed separately.
Each operation is looping through the adjacent matrices, ie. student_count²
.
The assignment operation comes down to
output[m,i] += adj_matrix[i,j] * <feats[m] / graph[j]>
where adj_matrix[i,j]
is a scalar.
torch.einsum
This is a typical use case for torch.einsum
. You can read more on this thread where I also happen to have written an answer.
If we keep away from all implementation details, the formulation with torch.einsum
is rather self-explanatory:
o = torch.einsum('ij,mf,jf->mi', adj_matrix, feats, graph)
In pseudo-code, this comes down to:
o[m,i] += adj_matrix[i,j]*feats[m,f]*graph[j,f]
For all i
, j
, m
, and f
in your domain of interest. Which is precisely the desired operation.
Combined with the mask expanded to the appropriate form with M = mask[:,None]
, this gives you for the student tensor:
>>> student = M*torch.einsum('ij,mf,jf->mi', student_adjacency_mat, student_feat, student_graph)
For the teacher result, you can invert the mask with ~M
:
>>> teacher = ~M*torch.einsum('ij,mf,jf->mi', teacher_adjacency_mat, teacher_feat, teacher_graph)
torch.matmul
Alternatively, since this is a rather simple application of torch.einsum
, you can also get away with two calls to torch.matmul
. Given A
and B
, two matrices indexed by ik
, and kj
respectively, you get A@B
which corresponds to ik@kj -> ij
. Therefore you can get the desired result with:
>>> g = student_feat@student_graph.T # mf@jf.T -> mf@fj -> mj
>>> g@student_adjacency_mat.T # mj@ij.T -> mj@ji -> mi
See how the two steps relate to the torch.einsum call with 'ij,mf,jf->mi'. First mf,jf->mj
, followed by mj,ij->mi
.
Side note your current dummy student and teacher adjacenty matrices are initialized with zeros. Maybe you meant to have:
student_adjacency_mat=torch.randint(0,2,(student_count,student_count)).float()
teacher_adjacency_mat=torch.randint(0,2,(teacher_count,teacher_count)).float()