Search code examples
pytorchexportcustomizationonnxmask-rcnn

Customize pytorch model export to ONNX


I am trying to export pretrained Mask R-CNN model to ONNX format. Since this model in basic configuration has following structure (here I added batch_size as dynamic axes):

enter image description here

I want to customize my model and add batch_size to output (it means I need to add new dim to each of the outputs).

I wrote following code to make it possible:

class MaskRCNNModel(torch.nn.Module):
  def __init__(self):
    super(MaskRCNNModel, self).__init__()
    self.model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights='DEFAULT')
    in_features = self.model.roi_heads.box_predictor.cls_score.in_features
    self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=7)
    self.model.load_state_dict(torch.load("saved_dict.torch"))

  def forward(self, input):
    outputs = self.model.forward(input)
    boxes = []
    labels = []
    scores = []
    masks = []
    for result in outputs:
        box, label, score, mask = result.values()
        boxes.append(box)
        labels.append(label)
        scores.append(score)
        masks.append(mask)
    
    return boxes, labels, scores, masks

maskrcnn_model = MaskRCNNModel()
maskrcnn_model.eval()
maskrcnn_model.to(device)

x = torch.rand(1, 3, 512, 512)
x = x.to(device)

maskrcnn_model(x)

torch.onnx.export(maskrcnn_model,
                  x,
                  "base_model_100_epochs.onnx",
                  opset_version=11,
                  input_names=["input"],
                  output_names=["boxes", "labels", "scores", "masks"])

but the code above doesn't change any export parameters. The structure of output stays the same:

enter image description here

What should I do to customize forward method to be able to add batch_size into ONNX model output?


Solution

  • Avoid doing this

    As per my original comment, I would discourage deploying most torchvision models with ONNX. It is all around a great module, just that it was not originally written with the intention to go well with static graphs.

    If throughput is a consideration, this implementation Mask R-CNN is not the way to go. With earlier ONNX opsets, I've had this model spend most of its execution time for h2d/d2h operations when falling back to CPU. I recommend checking YOLOv8 by ultralytics for a newer take on instance segmentation, or some of the many static implementations found on github.

    Torchvision Mask R-CNN outputs

    The model is designed with user-friendliness in mind, so for each image in the input batch it outputs a dictionary of tensors with accepted and post-processed results. For example if you have two images with ten detected objects in the first image and three in the second, the output would be

    batch = torch.randn((2, 3, 256, 256)) # Input two images
    output = mask_rcnn(batch) # run model
    results1, results2 = output # One dictionary per batch
    for key, value in results1:
        print(key, value.shape)
    >>> boxes [10, 4]
    >>> labels [10]
    >>> scores [10]
    >>> masks [10, 1, 256, 256]
    for key, value in results2:
        print(key, value.shape)
    >>> boxes [3, 4]
    >>> labels [3]
    >>> scores [3]
    >>> masks [3, 1, 256, 256]
    
    Why your approach does not work

    Is because ONNX does not understand python types. During torch.onnx.export, lists, dictionaries, tuples, etc. have no special meaning, and their entries are saved either as tensors or as constants. So the only thing your custom forward pass does is changes the order of the outputs, e.g. with the previous example outputs transform from

    >>> boxes1 [10, 4]
    >>> labels1 [10]
    >>> scores1 [10]
    >>> masks1 [10, 1, 256, 256]
    >>> boxes2 [3, 4]
    >>> labels2 [3]
    >>> scores2 [3]
    >>> masks2 [3, 1, 256, 256]
    

    to

    >>> boxes1 [10, 4]
    >>> boxes2 [3, 4]
    >>> labels1 [10]
    >>> labels2 [3]
    >>> scores1 [10]
    >>> scores2 [3]
    >>> masks1 [10, 1, 256, 256]
    >>> masks2 [3, 1, 256, 256]
    

    Torch ONNX documentation is worth reading as to how python and torch types are interpreted during export.

    Goal

    Is to have the model output batched results. E.g. you want the model to output tensors

    boxes [batch_size, num_detections, 4]
    labels [batch_size, num_detections]
    scores [batch_size, num_detections]
    masks [batch_size, num_detections, 1, 256, 256]
    

    We immediately see that this is impossible without applying any tricks. As different images in the batch will have a varying amount of predicted objects, we cannot create a tensor with 10 bounding boxes in the first index and 4 in the second.

    Solution - Padding

    To output batched results in this scenario, you can define constant shaped output tensors, and paste results for each image into them. For instance

    def forward(self, input):
        # Maximum number of detections the vision model will output per batch
        max_detections = self.model.roi_heads.detections_per_img
        # Variables for output tensor shapes
        # Use tensor.size instead of tensor.shape for dynamic inputs
        batch_size, _, input_height, input_width = input.shape
        # Create batched output tensors
        all_boxes = torch.zeros((batch_size, max_detections, 4))
        all_labels = torch.zeros((batch_size, max_detections))
        all_scores = torch.zeros((batch_size, max_detections))
        # Masks are returned with a redundant channel in the second dimension
        all_masks = torch.zeros((batch_size, max_detections, 1, input_height, input_width))
        # Number of detections per batch
        detections_per_batch = torch.zeros((batch_size, 1))
        # Run forward pass
        outputs = self.model.forward(input)
        for idx, result in enumerate(outputs):
            boxes, labels, scores, masks = result.values()
            # Number of detections for batch
            n_dets = boxes.size(0)
            detections_per_batch[idx] = n_dets
            # Paste batch results into output tensors
            all_boxes[idx, : n_dets] = boxes
            all_labels[idx, : n_dets] = labels
            all_scores[idx, : n_dets] = scores
            all_masks[idx, : n_dets] = masks
        return detections_per_batch, all_boxes, all_labels, all_scores, all_masks
    

    This forward pass creates output tensors which can potentially hold all object detections, and copies the realized object detections for each batch in to them. To keep track of which entries are zero-padding and which are actual detections, a tensor detections_per_batch is returned on top of the Mask R-CNN outputs. This is then used to extract the real predictions from ONNX outputs

    for preds, boxes, labels, scores, masks in zip(*outputs):
        detected_boxes = boxes[: preds]
        detected_labels = labels[: preds]
        ...
    
    Considerations

    This will have problems with I/O or memory bound applications, as the model always returns outputs with space for all potential detected masks. If you have a good upper bound for the amount of objects, you can limit this by reducing model.roi_heads.detections_per_img.