I have a ResNet18 model trained on the Places365 image dataset, and I'd like to use this pre-trained model to expedite the training needed to identify distressed houses. My dataset is images of size 800x800x3, but the inputs are of size 224x224x3. I do not want to resize the image to 224x224 since I'm worried it will lose house distress indicators (chipped paint and loose shingles) during conversion.
My idea was to add extra layers that can handle the larger images before feeding them into ResNet. I have the following pytorch model:
import torch
from torch import nn
from torchvision import models
class NewModel(nn.Model):
def __init__(self, pretrain_model_path) -> None:
# Not sure here
self.pre_layers = None
# import the trained model
model = models.resnet18(num_classes=365)
checkpoint = torch.load(pretrain_model_path, map_location=lambda storage, loc: storage)
state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
# change prediction class count
model.fc = nn.Linear(model.fc.in_features, 4)
self.backbone = model
def forward(self, x):
x = self.pre_layers(x)
x = self.backbone(x)
return x
Is this a common practice or is it better to make one from scratch that is built for this size image specifically? How would I go about implementing it if so?
model and pass 800x800 images to it. But it may be slow and consumes more memory.import torch
from torchvision import models
model = models.resnet18(num_classes=4)
print(model(torch.zeros((1, 3, 800, 800))).shape) # (1, 4)
import torch
from torch import nn
from torchvision import models
class NewModel(nn.Module):
def __init__(self, intermediate_features=64) -> None:
model = models.resnet18(num_classes=4)
self.backbone = model
self.pre_model = nn.Sequential(
nn.Conv2d(3, intermediate_features, 3, stride=2, padding=1),
conv1 = self.backbone.conv1
self.backbone.conv1 = nn.Conv2d(
intermediate_features, conv1.out_channels,
conv1.kernel_size, conv1.stride, conv1.padding)
def forward(self, x):
# 3x800x800
x = self.pre_model(x)
# 3x400x400
x = self.backbone(x)
# 4
return x
model = NewModel()
x = torch.zeros((1, 3, 800, 800))
Depending on your data the different approaches may perform better or worse so you may need to experiment with model architectures.