Search code examples
pythonpytorchconv-neural-networkresnetimage-classification

How can I use larger input images when using a pre-trained CNN without resizing?


I have a ResNet18 model trained on the Places365 image dataset, and I'd like to use this pre-trained model to expedite the training needed to identify distressed houses. My dataset is images of size 800x800x3, but the inputs are of size 224x224x3. I do not want to resize the image to 224x224 since I'm worried it will lose house distress indicators (chipped paint and loose shingles) during conversion.

My idea was to add extra layers that can handle the larger images before feeding them into ResNet. I have the following pytorch model:

import torch
from torch import nn
from torchvision import models

class NewModel(nn.Model):
    def __init__(self, pretrain_model_path) -> None:
        # Not sure here
        self.pre_layers = None
        # import the trained model
        model = models.resnet18(num_classes=365)
        checkpoint = torch.load(pretrain_model_path, map_location=lambda storage, loc: storage)
        state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
        model.load_state_dict(state_dict)
        # change prediction class count
        model.fc = nn.Linear(model.fc.in_features, 4)
        self.backbone = model

    def forward(self, x):
        x = self.pre_layers(x)
        x = self.backbone(x)
        return x

Is this a common practice or is it better to make one from scratch that is built for this size image specifically? How would I go about implementing it if so?


Solution

    1. You can use an ordinary resnet18 model and pass 800x800 images to it. But it may be slow and consumes more memory.
    import torch
    from torchvision import models
    model = models.resnet18(num_classes=4)
    print(model(torch.zeros((1, 3, 800, 800))).shape)  # (1, 4)
    
    1. You can add any lightweight module that reduces image resolution:
    import torch
    from torch import nn
    from torchvision import models
    
    
    class NewModel(nn.Module):
        def __init__(self, intermediate_features=64) -> None:
            super().__init__()
            model = models.resnet18(num_classes=4)
            self.backbone = model
    
            self.pre_model = nn.Sequential(
                nn.Conv2d(3, intermediate_features, 3, stride=2, padding=1),
                nn.ReLU(),
            )
    
            conv1 = self.backbone.conv1
            self.backbone.conv1 = nn.Conv2d(
                intermediate_features, conv1.out_channels,
                conv1.kernel_size, conv1.stride, conv1.padding)
    
        def forward(self, x):
            # 3x800x800
            x = self.pre_model(x)
            # 3x400x400
            x = self.backbone(x)
            # 4
            return x
    
    
    model = NewModel()
    x = torch.zeros((1, 3, 800, 800))
    print(model(x).shape)
    

    Depending on your data the different approaches may perform better or worse so you may need to experiment with model architectures.