Search code examples
optimizationmathematical-optimizationnumerical-methodsdifferential-equationsgekko

Getting adjoint state of solution in Gekko


After solving an optimal control problem in Gekko (IMODE = 6) is there any way to access or reconstruct the adjoint state p ? Since the documentation does not provide any resource for this, I hopping that there is an way to retrieve some information that may led to reconstruction of the adjoint state.

A bonus question, is there any optimal control solver (with python API) that returns the adjoint state?


Solution

  • There is the m.options.SENSITIVITY = 1 option in gekko to produce sensitivity.txt in the run directory m.path. However, this only works if there are zero DOF for simulation or an optimization problem that has zero DOF when the MV status is off.

    The other alternative is to add the adjoint equations to the problems. Here is an example dynamic optimization problem.

    import numpy as np
    from gekko import GEKKO
    m = GEKKO()
    nt = 101
    m.time = np.linspace(0,1,nt)
    x = m.Var(value=1)
    u = m.Var(value=0,lb=-1,ub=1)
    m.Equation(x.dt()==-x + u)
    m.Minimize(x**2)
    m.options.IMODE = 6
    m.solve()
    

    Adding the adjoint equation gives the value of lam, as the sensitivity of the objective function to changes in x. The transversality condition for this problem is that lam(T)=0 (final time constraint) and the initial condition is calculated. This is achieved with fixed_initial=False when declaring lam and setting m.fix_final(lam,0) to fix the final value at zero.

    costate results

    import numpy as np
    from gekko import GEKKO
    import matplotlib.pyplot as plt
    
    m = GEKKO()
    nt = 101
    m.time = np.linspace(0, 1, nt)
    x = m.Var(value=1)
    u = m.Var(value=0, lb=-1, ub=1)
    lam = m.Var(value=0,fixed_initial=False) # Adjoint variable
    m.fix_final(lam,0)
    m.Equation(x.dt() == -x + u)
    m.Equation(lam.dt() == 2*x + lam)  # Adjoint state equation
    m.Minimize(x**2)
    m.options.IMODE = 6
    m.solve(disp=False)
    
    # Plotting
    plt.figure(figsize=(7,4))
    plt.plot(m.time, x.value, 'b-', lw=2, label=r'$x$')
    plt.plot(m.time, u.value, 'r--', lw=2, label=r'$u$')
    plt.plot(m.time, lam.value, 'g-.', lw=2, label=r'$\lambda$')
    plt.xlabel('Time'); plt.legend()
    plt.grid(); plt.tight_layout()
    plt.show()