Search code examples
python-3.xpytorchtorchvision

How does a pytorch function (such as RoIPool) work?


For example, I'm trying to view the implementation of RoI Pooling in pytorch.

Here is a code fragment showing how to use RoIPool in pytorch

import torch
from torchvision.ops.roi_pool import RoIPool

device = torch.device('cuda')

# create feature layer, proposals and targets
num_proposals = 10
feature_map = torch.randn(1, 64, 32, 32)

proposals = torch.zeros((num_proposals, 4))
proposals[:, 0] = torch.randint(0, 16, (num_proposals,))
proposals[:, 1] = torch.randint(0, 16, (num_proposals,))
proposals[:, 2] = torch.randint(16, 32, (num_proposals,))
proposals[:, 3] = torch.randint(16, 32, (num_proposals,))

roi_pool_obj = RoIPool(3, 2**-1)
roi_pool = roi_pool_obj(feature_map, [proposals])

I'm using pychram, so when I follow RoIPool from the second line, it opens a file located at ~/anaconda3/envs/CV/lib/python3.8/site-package/torchvision/ops/roi_pool.py, which is exactly the same as codes in the documentation.

I pasted the code below without documentations.

from typing import List, Union

import torch
from torch import nn, Tensor
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops

from ..utils import _log_api_usage_once
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape


def roi_pool(
    input: Tensor,
    boxes: Union[Tensor, List[Tensor]],
    output_size: BroadcastingList2[int],
    spatial_scale: float = 1.0,
) -> Tensor:

    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(roi_pool)
    _assert_has_ops()
    check_roi_boxes_shape(boxes)
    rois = boxes
    output_size = _pair(output_size)
    if not isinstance(rois, torch.Tensor):
        rois = convert_boxes_to_roi_format(rois)
    output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output_size[0], output_size[1])
    return output


class RoIPool(nn.Module):

    def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float):
        super().__init__()
        _log_api_usage_once(self)
        self.output_size = output_size
        self.spatial_scale = spatial_scale

    def forward(self, input: Tensor, rois: Tensor) -> Tensor:
        return roi_pool(input, rois, self.output_size, self.spatial_scale)

    def __repr__(self) -> str:
        s = f"{self.__class__.__name__}(output_size={self.output_size}, spatial_scale={self.spatial_scale})"
        return s

So, in the code example:

When running roi_pool_obj = RoIPool(3, 2**-1) it will create an instance of RoIPool by calling its __init__ method, which only initialized two instance variables;

When running roi_pool = roi_pool_obj(feature_map, [proposals]), it must have called the forward() method (but I don't know how) which then called the roi_pool() function above;

When running the roi_pool() function, it did some checking first and then computed output with the line output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output_size[0], output_size[1]).

But this doesn't show details of how roi_pool is implemented and pycharm showed Cannot find declaration to go to when I tried to follow torch.ops.torchvision.roi_pool.

To summarize, I have two questions:

  1. How does the forward() called by running roi_pool = roi_pool_obj(feature_map, [proposals])?
  2. How can I view the source code of torch.ops.torchvision.roi_pool or where is the file containing it's implementaion located?

Last but not least, I've just started reading source code which is pretty difficult for me. I'd appreciate it if you can also provide some advice or tutorials.


Solution

    1. RoIPool is a subclass of torch.nn.Module. Source code:

    https://github.com/pytorch/vision/blob/07ae61bf9c21ddd1d5f65d326aa9636849b383ca/torchvision/ops/roi_pool.py#L56

    1. nn.Module defines __call__ method which in turn calls forward method. Source code:

    https://github.com/pytorch/pytorch/blob/b2311192e6c4745aac3fdd774ac9d56a36b396d4/torch/nn/modules/module.py#L1234

    1. When you executing roi_pool = roi_pool_obj(feature_map, [proposals]) statement the __call__ method uses the forward() of RoiPool. Source code:

    https://github.com/pytorch/vision/blob/07ae61bf9c21ddd1d5f65d326aa9636849b383ca/torchvision/ops/roi_pool.py#L67

    1. RoiPool.forward calls torch.ops.torchvision.roi_pool.

    https://github.com/pytorch/vision/blob/07ae61bf9c21ddd1d5f65d326aa9636849b383ca/torchvision/ops/roi_pool.py#L52

    1. ops is a object which loads native libraries implemented in c++:

    https://github.com/pytorch/pytorch/blob/b2311192e6c4745aac3fdd774ac9d56a36b396d4/torch/_ops.py#L537

    so when you call torch.ops.torchvision it will use torchvision library.

    1. Here the roi_pool function is registered:

    https://github.com/pytorch/vision/blob/7947fc8fb38b1d3a2aca03f22a2e6a3caa63f2a0/torchvision/csrc/ops/roi_pool.cpp#L53

    1. Here you can find the actual implementation of rol_pool

    CPU:

    https://github.com/pytorch/vision/blob/7947fc8fb38b1d3a2aca03f22a2e6a3caa63f2a0/torchvision/csrc/ops/cpu/roi_pool_kernel.cpp

    GPU: https://github.com/pytorch/vision/blob/7947fc8fb38b1d3a2aca03f22a2e6a3caa63f2a0/torchvision/csrc/ops/cuda/roi_pool_kernel.cu