Search code examples
machine-learningtorch

Parameter sharing in network with nn.SpatialBatchNormalization


I have a network with three parallel branches, and I want to share all their parameters so that they are identical at the end of the training. Let some_model be a standard nn.Sequential module made of cudnn.SpatialConvolution, nn.PReLU, nn.SpatialBatchNormalization. Additionally, there is a nn.SpatialDropout, but its probability is set to 0, so it has no effect.

ptb=nn.ParallelTable()
ptb:add(some_model) 
ptb:add(some_model:clone('weight','bias', 'gradWeight','gradBias'))
ptb:add(some_model:clone('weight','bias', 'gradWeight','gradBias'))

triplet=nn.Sequential()
triplet:add(ptb)

I don't think the loss function is relevant, but just in case, I use nn.DistanceRatioCriterion. To check that all weights are correctly shared, I pass a table of three identical examples {A,A,A} to the network. Obviously, if the weights are correctly shared, then the output of all three branches should be the same. This holds at the moment of network initialization, but once the paramerters have been updated (say, after one mini-batch iteration), the results of the three branches become different. Through layer by layer inspection, I have noticed that this discrepancy in the output comes from the nn.SpatialBatchNormalization layers in some_model. Therefore, it seems that the parameters from those layers are not properly shared. Following this, I have tried calling clone with the additional parameters running_mean and running_std, but the ouptut of the batchnorm layers still differ. Moreover, this seems to be cancelling the sharing of all other network parameters as well. What is the proper way of sharing parameters between nn.SpatialBatchNormalization modules?


Solution

  • Ok, I found the solution! It seems that the parameter running_std has been changed to running_var since the discussion I had linked to in the question. Calling the constructor with

    ptb:add(some_model:clone('weight','bias', 'gradWeight','gradBias','running_mean','running_var'))
    

    Solves the problem.