Search code examples
pythondrake

Add a progress bar for SNOPT solver in pydrake


I am using the AddVisualizationCallback functionality for pydrake MathematicalProgram objects to try and make a progress bar that I can update while the solver is running, so that I can see how my solver is doing without having to look at the overly verbose output coming from SNOPT (which BTW I haven't been able to print to console, only to file). My approach would be something like this:

import pydrake.solvers.mathematicalprogram as mp
from tqdm import tqdm

prog = mp.MathematicalProgram()

pbar = tqdm(total=max_iterations)
def update(x):
    pbar.update(1)
prog.AddVisualizationCallback(update, x)

However, my problem is SNOPT has both major and minor iterations, and it seems my update callback gets called irrespective of it being major or minor. Ideally, I would like it to be called only on a major iteration, otherwise my progress bar will fill sooner than it should.

My question: is there a way to tell whether a particular iteration is a major or minor iteration of SNOPT, such that I could provide that information to my update callback and only update my progress bar on the major iteration? Alternatively, could I set it so that the visualization callback is ONLY called on major iterations?


Solution

  • I've managed to hack together a solution using the log file that is written by SNOPT:

    from tqdm import tqdm
    import pydrake.solvers.mathematicalprogram as mp
    from pydrake.solvers.snopt import SnoptSolver
    
    prog = mp.MathematicalProgram()
    
    fmax_iterations = 20
    log_filename = '/tmp/snopt.out'
    log = open(log_filename, 'w+')
    pbar = tqdm(total=max_iterations)
    
    def update(x):
        lines = log.read()
        if 'Itns Major Minors' in lines:
            idx = lines.rfind('Itns Major Minors')
            header, info = lines[idx:].split('\n')[:2]
            for name, val in zip(header.split(), info.split()):
                if name == 'Major':
                    pbar.n = int(val)
                    pbar.refresh()
                    break
    
    prog.AddVisualizationCallback(update, x)
    
    solver_options = mp.SolverOptions()
    solver_options.SetOption(mp.CommonSolverOption.kPrintFileName, log_filename)
    solver_options.SetOption(SnoptSolver.id(), "Major iterations limit", max_iterations)
    solver_options.SetOption(SnoptSolver.id(), "Print frequency", 1)
    
    result = mp.Solve(prog, solver_options=solver_options)
    pbar.n = max_iterations
    pbar.refresh()
    pbar.close()
    log.close()
    

    This gives me a nice progress bar while the SNOPT solver is running without all the clutter it's actually printing to its log file:

    40%|██████████████▍                     | 8/20 [00:07<00:11,  1.07it/s]
    

    It'd be nice to have a programmatic hook into the solver in the callback function, maybe that's possible? Short of that, this seems to do the job.