I have an architecture as follow (built using nngraph):
require 'nn'
require 'nngraph'
input = nn.Identity()()
net1 = nn.Sequential():add(nn.SpatialConvolution(1, 5, 3, 3)):add(nn.ReLU(true)):add(nn.SpatialConvolution(5, 20, 4, 4))
net2 = nn.Sequential():add(nn.SpatialFullConvolution(20, 5, 4, 4)):add(nn.ReLU(true)):add(nn.SpatialFullConvolution(5, 1, 3, 3)):add(nn.Sigmoid())
net3 = nn.Sequential():add(nn.SpatialConvolution(1, 20, 3, 3)):add(nn.ReLU(true)):add(nn.SpatialConvolution(20, 40, 4, 4)):add(nn.ReLU(true)):add(nn.SpatialConvolution(40, 2, 3, 3)):add(nn.Sigmoid())
output1 = net1(input)
output2 = net2(output1)
output3 = net3(output2)
gMod = nn.gModule({input}, {output1, output3})
target1 = torch.rand(20, 51, 51)
target2 = torch.rand(2, 49, 49)
target2[target2:gt(0.5)] = 1
target2[target2:lt(0.5)] = 0
-- Do a forward pass
out1, out2 = unpack(gMod:forward(torch.rand(1, 56, 56)))
cr1 = nn.MSECriterion()
cr1:forward(out1, target1)
gradient1 = cr1:backward(out1, target1)
cr2 = nn.BCECriterion()
cr2:forward(out2, target2)
gradient2 = cr2:backward(out2, target2)
-- Now update the weights for the networks
LR = 0.001
gMod:backward(input, {gradient1, gradient2})
I wonder:
1) How one can stop gradient2 update the weights for net1, and only contribute to updating the weights for net2 and net3?
2) How is it possible to prevent gradient2 update net3 weights, but update other sub[network] weights?
I found the solution to the problems. Below I post the relevant codes for each one:
Question 1:
This is a litte bit tricky but totally doable. If net2
's first layer weights are not supposed to get updated with gradient2 one needs to modify the updateGradInput() function of the layer after that and make it output a zero tensor. This is done in the following code:
input = nn.Identity()()
net1 = nn.Sequential():add(nn.SpatialConvolution(1, 5, 3, 3)):add(nn.ReLU(true)):add(nn.SpatialConvolution(5, 20, 4, 4))
net2 = nn.Sequential():add(nn.SpatialFullConvolutionInputGrad0(20, 5, 4, 4)):add(nn.ReLU(true)):add(nn.SpatialFullConvolution(5, 1, 3, 3)):add(nn.Sigmoid())
net3 = nn.Sequential():add(nn.SpatialConvolution(1, 20, 3, 3)):add(nn.ReLU(true)):add(nn.SpatialConvolution(20, 40, 4, 4)):add(nn.ReLU(true)):add(nn.SpatialConvolution(40, 2, 3, 3)):add(nn.Sigmoid())
-- Modifying the updateGradInput function so that it will output a zeroed-out tensor at the first layer of net2
local tempLayer = net2:get(1)
function tempLayer:updateGradInput(input, gradOutput)
return self.gradInput
output1 = net1(input)
output2 = net2(output1)
output3 = net3(output2)
gMod = nn.gModule({input}, {output1, output3})
-- Everything else is the same ...
Question 2:
input = nn.Identity()()
net1 = nn.Sequential():add(nn.SpatialConvolution(1, 5, 3, 3)):add(nn.ReLU(true)):add(nn.SpatialConvolution(5, 20, 4, 4))
net2 = nn.Sequential():add(nn.SpatialFullConvolution(20, 5, 4, 4)):add(nn.ReLU(true)):add(nn.SpatialFullConvolution(5, 1, 3, 3)):add(nn.Sigmoid())
net3 = nn.Sequential():add(nn.SpatialConvolution(1, 20, 3, 3)):add(nn.ReLU(true)):add(nn.SpatialConvolution(20, 40, 4, 4)):add(nn.ReLU(true)):add(nn.SpatialConvolution(40, 2, 3, 3)):add(nn.Sigmoid())
net3.updateParameters = function() end -- Doing this prevents net3 weights get updated during the backward pass since the updateParameters function has been over-ridden
output1 = net1(input)
output2 = net2(output1)
output3 = net3(output2)
gMod = nn.gModule({input}, {output1, output3})
-- Everything else is the same ...