Search code examples
pythonmachine-learningmxnet

What is the recommended operation to protect some weights from being changed by the trainer in MxNet?


What is the recommended operation to protect some weights from being changed by the trainer in MxNet?

As far as I know, if I want to protect some weights in TenserFlow, I should prevent them from being passed to the optimizer. So, I do the same in MxNet with following codes.

all_params = net.collect_params()

 while True:

    firstKey = next(iter(all_params._params))

    if 'resnet' not in firstKey:

        break

    all_params._params.popitem(last = False)
trainer = mx.gluon.Trainer(all_params,'sgd')

The variable all_params._params belongs to a rare type called OrderedDict. I think it means that the order in this dictionary is very important. I should not change the order. As shown above, I can only remove some parameters from the beginning of the network. It is very inconvenient. The ”params” gets a ”underline _” at the beginning, which means it should not be charged by the general user.

I do not receive any errors, but I wonder this is not the recommended operation.


Solution

  • As far as I understand, you want to freeze some layers (so their parameters remains unchanged during training) and you are using Gluon.

    In that case you can set grad_req attribute to 'null' (it is a string) to prevent changes of this parameter. Here is the example. I define a set of parameter names I want to freeze and freeze them after creating my model, but before the initialization.

    num_hidden = 10
    net = gluon.nn.Sequential()
    with net.name_scope():
        net.add(gluon.nn.Dense(num_hidden, activation="relu"))
        net.add(gluon.nn.Dense(num_hidden, activation="relu"))
        net.add(gluon.nn.Dense(num_outputs))
    
    layers_to_freeze = set(['sequential1_dense0_weight', 'sequential1_dense0_bias', 'sequential1_dense1_weight', 'sequential1_dense1_bias'])    
    
    for p in net.collect_params().items():
        if p[0] in layers_to_freeze:
            p[1].grad_req = 'null'
    
    net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
    

    If you run training, these parameters shouldn't change. You can find names of parameters by printing p[0] in the loop.