Search code examples
pythonneural-networkdeep-learningcaffepycaffe

What are solvers callback functions in pycaffe, and how can I use them?


Looking at this PR, I see that one can define on_start and on_gradient callbacks for caffe.Solver object.

import caffe
solver = caffe.AdamSolver('solver.prototxt')
solver.add_callback(on_start, on_gradient)  # <- ??

What type of objects are on_start and on_gradient?
What are these callbacks for?
How can one use them (an example would be nice...)?


Solution

  • 1. Where and how are the callbacks defined?

    The callbacks are part of the Solver, and are thus defined in the solver.hpp file. To be exact, there is a Callback class, which looks like this:

      // Invoked at specific points during an iteration
      class Callback {
       protected:
        virtual void on_start() = 0;
        virtual void on_gradients_ready() = 0;
    
        template <typename T>
        friend class Solver;
      };
      const vector<Callback*>& callbacks() const { return callbacks_; }
      void add_callback(Callback* value) {
        callbacks_.push_back(value);
      }
    

    and a protected vector of such callbacks, which is a member of the Solver class.

      vector<Callback*> callbacks_;
    

    So, this basically provides an add_callback function to the Solver class, which allows one to add an object of type Callback to a vector. This is to make sure, that each callback has two methods: on_start() and on_gradients_ready().

    2. Where are the callbacks called?

    This happens in the solver.cpp file, in the step() function, which contains the main worker loop. Here's that main loop part (with lots of things stripped out for simplicity):

    while (iter_ < stop_iter) {
    
        for (int i = 0; i < callbacks_.size(); ++i) {
            callbacks_[i]->on_start();
        }
    
        // accumulate the loss and gradient
        Dtype loss = 0;
        for (int i = 0; i < param_.iter_size(); ++i) {
            loss += net_->ForwardBackward();
        }
        loss /= param_.iter_size();
    
        for (int i = 0; i < callbacks_.size(); ++i) {
          callbacks_[i]->on_gradients_ready();
        }
    
        ApplyUpdate();
    
        ++iter_;
    }
    

    3. Where is this used?

    This callback feature was implemented when multi-GPU support was added. The only place (that I know of), where callbacks are used, is to synchronize the solver between multiple GPUs:

    The P2PSync class in parallel.hpp inherits from the Solver::Callback class, and implements an on_start() and on_gradients_ready() method, which synchronize the GPUs and finally accumulate the all gradient updates.

    4. How to use callbacks from Python?

    As the pull request #3020 explains,

    on_start and on_gradient are python functions.

    so it should be straight-forward to use. A full, runnable example is shown in this Github Gist I created.

    5. How is this useful?

    As the two callback functions do not take any arguments, you can't simply use them to keep track of the loss or similar things. To do that, you have to create a wrapper function around the Solver class, and call add_callback with two methods as callback functions. This allows you to access the net from within the callback, by using self.solver.net. In the following example, I use the on_start callback to load data into the net, and the on_gradients_ready callback to print the loss function.

    class SolverWithCallback:
        def __init__(self, solver_file):
            self.solver = caffe.SGDSolver(solver_file)
            self.solver.add_callback(self.load, self.loss)
    
        def solve(self):
            self.solver.solve()
    
        def load(self):
            inp = np.random.randint(0, 255)
            self.solver.net.blobs['data'].data[...] = inp
            self.solver.net.blobs['labels'].data[...] = 2 * inp
    
        def loss(self):
            print "loss: " + str(self.solver.net.blobs['loss'].data)
    
    if __name__=='__main__':
        solver = SolverWithCallback('solver.prototxt')
        solver.solve()