I need to expand a model from RGB to RGBA. I can handle the code rewrite on the model, but instead of retraining the entire model from scratch, I would love to start it off with it's 3 channel weights + zeros.
Is there an easy way to change torch's save of 3 channel weights into 4?
Yes, you can do a little bit of "model surgery". Assuming the input to the model is only processed directly by a convolutional layer then you can just replace that conv layer with another that has in_channels
set to 4
. Then you can set weights to zero and copy over the old weights (and biases if applicable) from the original conv layer.
For example, say we had a simple model that looked like this
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 10, kernel_size=3, padding=1, bias=True)
self.conv2 = nn.Conv2d(10, 5, kernel_size=3, padding=1, bias=True)
self.linear = nn.Linear(125, 1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
return self.linear(x.flatten(start_dim=1))
model = SimpleModel()
Supposing that the model is trained at this point, we could perform the surgery as follows
y_rgb = torch.randn(1, 3, 5, 5)
# get performance on initial z_rgb
z_rgb = model(y_rgb)
# perform model surgery
with torch.no_grad():
new_conv1 = nn.Conv2d(4, 10, kernel_size=3, padding=1, bias=True)
new_conv1.weight.zero_()
new_conv1.weight[:,:3,...]=model.conv1.weight
new_conv1.bias.copy_(model.conv1.bias)
model.conv1 = new_conv1
# add a random alpha channel to y_rgba
y_alpha = torch.randn(1,1,5,5)
y_rgba = torch.cat([y_rgb, y_alpha], dim=1)
# get results on rgba model
z_rgba = model(y_rgba)
# compare z_rgb and z_rgba, print mean-square difference
z_err = ((z_rgba-z_rgb)**2).mean().item()
print('Err:', z_err)
# save results to a new file
torch.save(model.state_dict(), 'checkpoint_rgba.pt')
which should give you an error of zero or very close to zero.
Of course if you don't have a bias
term in your first conv layer then you don't need to copy that over.
Assuming you've saved the new state dictionary, then you will probably want to update the model class definition so that your input convolution layer takes 4 channel input instead of 3. Then next time you can directly load the new state dictionary without additional steps.
Now it's not strictly necessary to do the surgery on the model directly. Though I tend to prefer it as I find it easier to verify correctness.
Assuming you saved off the state dictionary for the RGB model, you could also just directly modify the state dictionary.
# assuming you saved RGB model using torch.save(model.state_dict(), 'checkpoint_rgb.pt')
state_dict = torch.load('checkpoint_rgb.pt')
old_weight = state_dict['conv1.weight']
state_dict['conv1.weight'] = torch.zeros(
old_weight.shape[0],
old_weight.shape[1]+1,
old_weight.shape[2],
old_weight.shape[3]
).type_as(old_weight)
state_dict['conv1.weight'][:,:3,...] = old_weight
torch.save(state_dict, 'checkpoint_rgba.pt')