Search code examples
pythonpytorchrgba

Retraining a Model from 3 Channels (RGB) to 4 Channels (RGBA), can I use the 3 channel weights?


I need to expand a model from RGB to RGBA. I can handle the code rewrite on the model, but instead of retraining the entire model from scratch, I would love to start it off with it's 3 channel weights + zeros.

Is there an easy way to change torch's save of 3 channel weights into 4?


Solution

  • Yes, you can do a little bit of "model surgery". Assuming the input to the model is only processed directly by a convolutional layer then you can just replace that conv layer with another that has in_channels set to 4. Then you can set weights to zero and copy over the old weights (and biases if applicable) from the original conv layer.

    For example, say we had a simple model that looked like this

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class SimpleModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(3, 10, kernel_size=3, padding=1, bias=True)
            self.conv2 = nn.Conv2d(10, 5, kernel_size=3, padding=1, bias=True)
            self.linear = nn.Linear(125, 1)
        def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.relu(self.conv2(x))
            return self.linear(x.flatten(start_dim=1))
    
    model = SimpleModel()
    

    Supposing that the model is trained at this point, we could perform the surgery as follows

    y_rgb = torch.randn(1, 3, 5, 5)
    
    # get performance on initial z_rgb
    z_rgb = model(y_rgb)
    
    # perform model surgery
    with torch.no_grad():
        new_conv1 = nn.Conv2d(4, 10, kernel_size=3, padding=1, bias=True)
        new_conv1.weight.zero_()
        new_conv1.weight[:,:3,...]=model.conv1.weight
        new_conv1.bias.copy_(model.conv1.bias)
        model.conv1 = new_conv1
    
    # add a random alpha channel to y_rgba
    y_alpha = torch.randn(1,1,5,5)
    y_rgba = torch.cat([y_rgb, y_alpha], dim=1)
    
    # get results on rgba model
    z_rgba = model(y_rgba)
    
    # compare z_rgb and z_rgba, print mean-square difference
    z_err = ((z_rgba-z_rgb)**2).mean().item()
    print('Err:', z_err)
    
    # save results to a new file
    torch.save(model.state_dict(), 'checkpoint_rgba.pt')
    

    which should give you an error of zero or very close to zero.

    Of course if you don't have a bias term in your first conv layer then you don't need to copy that over.

    Assuming you've saved the new state dictionary, then you will probably want to update the model class definition so that your input convolution layer takes 4 channel input instead of 3. Then next time you can directly load the new state dictionary without additional steps.


    Now it's not strictly necessary to do the surgery on the model directly. Though I tend to prefer it as I find it easier to verify correctness.

    Assuming you saved off the state dictionary for the RGB model, you could also just directly modify the state dictionary.

    # assuming you saved RGB model using torch.save(model.state_dict(), 'checkpoint_rgb.pt')
    state_dict = torch.load('checkpoint_rgb.pt')
    old_weight = state_dict['conv1.weight']
    state_dict['conv1.weight'] = torch.zeros(
        old_weight.shape[0],
        old_weight.shape[1]+1,
        old_weight.shape[2],
        old_weight.shape[3]
    ).type_as(old_weight)
    state_dict['conv1.weight'][:,:3,...] = old_weight
    torch.save(state_dict, 'checkpoint_rgba.pt')