Search code examples
pythonpytorchpytorch-geometricgraph-neural-network

Speeding up pytorch operations for custom message dropout


I am trying to implement message dropout in my custom MessagePassing convolution in PyTorch Geometric. Message dropout consists of randomly ignoring p% of the edges in the graph. My idea was to randomly remove p% of them from the input edge_index in forward().

The edge_index is a tensor of shape (2, num_edges) where the 1st dimension is the "from" node ID and the 2nd is the "to" node ID". So what I thought I could do is select a random sample of range(N) and then use it to mask out the rest of the indices:

    def forward(self, x, edge_index, edge_attr=None):
        if self.message_dropout is not None:
            # TODO: this is way too slow (4-5 times slower than without it)
            # message dropout -> randomly ignore p % of edges in the graph i.e. keep only (1-p) % of them
            random_keep_inx = random.sample(range(edge_index.shape[1]), int((1.0 - self.message_dropout) * edge_index.shape[1]))
            edge_index_to_use = edge_index[:, random_keep_inx]
            edge_attr_to_use = edge_attr[random_keep_inx] if edge_attr is not None else None
        else:
            edge_index_to_use = edge_index
            edge_attr_to_use = edge_attr

        ...

However, it is way too slow, it makes an epoch go for 5' instead of 1' without (5 times slower). Is there a faster way to do this in PyTorch?

Edit: The bottleneck seems to be the random.sample() call, not the masking. So I guess what I should be asking is for faster alternatives to that.


Solution

  • I managed to create a boolean mask using PyTorch's Dropout from Functional which is much faster. Now an epoch takes ~1' again. Better than other solutions with permutations that I found elsewhere.

        def forward(self, x, edge_index, edge_attr=None):
            if self.message_dropout is not None:
                # message dropout -> randomly ignore p % of edges in the graph
                mask = F.dropout(torch.ones(edge_index.shape[1]), self.message_dropout, self.training) > 0
                edge_index_to_use = edge_index[:, mask]
                edge_attr_to_use = edge_attr[mask] if edge_attr is not None else None
            else:
                edge_index_to_use = edge_index
                edge_attr_to_use = edge_attr
    
            ...