Search code examples
juliamathematical-optimizationjulia-jump

How to check if a user-defined function is already registered in Julia/JuMP


I want to check if a user-defined function is already registered in JuMP/julia. Here's an example:

function foo( f, f1, f2 )

  if !function_is_registered(:f)  # This is what I'm looking for
    JuMP.register(:f,1,f1,f2)
  end
  ####
    # Optimization problem here using f
    # Leads to some return statement
  ####
end

f(x) = exp( A * x )
f1(x) = A * exp( A * x )
f2(x) = A * A * exp( A * x )
    # Function to register

A = 2
use1 = foo(f, f1, f2)
use2 = foo(f, f1, f2)
    # This second usage would fail without the check.  Can't re-register f.

As should be obvious from the comments, the check is needed for the second usage. As far as I can tell, JuMP registers functions at a global level - once registered they can't be re-defined locally (right? If they can, this solves my problem too!).


Solution

  • This will do what you want.

    using JuMP
    using Ipopt
    

    function set_A_sol( A )
      f = (x) -> exp( A * x ) - x
      f1 = (x) -> A * exp( A * x ) - 1.0
      f2 = (x) -> A * A * exp( A * x )
      # Local redefinition of f
      try
        JuMP.register(:f, 1, f, f1, f2)
      catch e
        if e.msg == "Operator f has already been defined"
          ind = pop!( ReverseDiffSparse.univariate_operator_to_id, :f);
          deleteat!( ReverseDiffSparse.univariate_operators, ind);
          pop!( ReverseDiffSparse.user_univariate_operator_f, ind);
          pop!( ReverseDiffSparse.user_univariate_operator_fprime, ind);
          pop!( ReverseDiffSparse.user_univariate_operator_fprimeprime, ind);
          JuMP.register(:f, 1, f, f1, f2);
        end
      end
      mod = Model(solver=Ipopt.IpoptSolver(print_level=0))
      @variable(mod, - Inf <= x <= Inf )
      @NLobjective(mod, Min, f(x) )
      status=solve(mod)
      return getvalue(x)
    end
    

    julia> ans1 = set_A_sol(0.5)
    1.3862943611200509
    

    julia> ans2 = set_A_sol(1.0)
    0.0
    

    julia> ans3 = set_A_sol(2.0)
    -0.34657359027997264
    

    Explanation:

    If you look at the register function, defined in nlp.jl, "Registering" involves adding the symbol to a dictionary, held in ReverseDiffSparse. Register a function and check those dictionaries manually to see what they look like.

    So "de-registering" simply involves removing all traces of :f and its derivatives from all the places where it has been recorded.