Search code examples
pythontensorflowobject-detectionobject-detection-apitransfer-learning

How to reset classes while retaining class specific weights in TensorFlow Object Detection API


I am currently using the TensorFlow Object Detection API and am attempting to fine tune a pre-trained Faster-RCNN from the model zoo. Currently, if I choose a different number of classes to the number used in the original network, it will simply not initialise the weights and biases from the SecondStageBoxPredictor/ClassPredictor as this now has different dimensions from the original ClassPredictor. However, as all of the classes I would like to train the network on are classes the original network has been trained to identify, I would like to retain the weights and biases associated with the classes I want to use in SecondStageBoxPredictor/ClassPredictor and prune all the others, rather than simply initialising these values from scratch (similar to the behaviour of this function).

Is this possible, and if so, how would I go about modifying the structure of this layer in the Estimator?

n.b. This question asks a similar thing, and their response is to ignore irrelevant classes from the network output - in this situation, however, I am attempting to fine tune the network and I assume the presence of these redundant classes would complicate the training / evaluation process?


Solution

  • If all the classes you would like to train the network on are the ones the network has been trained to identify, you could simply use the network to detect, isn't it?

    However, if you have extra classes and you would like to do transfer-learning, you can have as many variables restored from checkpoint as possible by setting:

    fine_tune_checkpoint_type: 'detection'
    load_all_detection_checkpoint_vars: True
    

    in field train_config from the pipeline config file.

    Finally, by looking at the computation graph, it can be seen that the shape of SecondStageBoxPredictor/ClassPredictor/weights is dependent on the number of output classes. enter image description here

    Note that in tensorflow you can only restore in variables level, if two variables have different shapes, one can not use one to initialize the other. So in your case the idea of preserving some values of the weights variable is not feasible.