Search code examples
pythongekko

Why does it provide two different outputs with if2/if3?


I am testing the m.if3 function in gekko by using conditional statements of if-else but I get two different outputs. The optimal number i get from below code is 12. I plug that in the next code with if-else statement to ensure that the cost matches up but it does not. Am I using if3/if2 incorrectly? The rate is 0.1 for the first 5 days and it switches to 0.3 for the remaining 45 days.

I get different outputs even though I am doing the same thing in both of them.

I tried everything from using if-else statements to using if2 statements.


Solution

  • The if2 or if3 function isn't needed because the switching argument duration-5 is a constant value that is not a function of a Gekko variable. Just like the validation script, the two segments can be calculated separately and added together to get a total cost and patient count.

    from gekko import GEKKO
    m = GEKKO(remote=False)
    
    # parameters
    cost_p = 9  
    cost_s = 12 
    var1 = 50 
    duration = 50  
    x = m.Var(integer=True, lb=1)  
    
    rate1 = 0.3
    rate2 = 0.1
    
    cost1 = m.Intermediate((rate1 * cost_p * 5 + cost_s) * x)
    cost2 = m.Intermediate((rate2 * cost_p * (duration-5) + cost_s) * x)
    cost = m.Intermediate(cost1+cost2)
    
    countp1 = m.Intermediate(rate1 * 5 * x)
    countp2 = m.Intermediate(rate2 * (duration-5) * x)
    p_count = m.Intermediate(countp1+countp2)
    
    m.Minimize(cost)
    m.Equation(p_count >= var1)
    m.options.SOLVER = 1 # for MINLP solution
    m.solve(disp=False)
    num_sites = x.value[0]
    print(f'num_s = {num_s}')
    print(f'cost: {cost.value[0]}')
    print(f'p_count: {p_count.value[0]}')
    

    The optimal solution is:

    num_s = 9.0
    cost: 810.0
    p_count: 54.0
    

    The solution validation agrees with this answer:

    # Solution validation
    # Parameters
    cost_s = 9  
    cost_p = 12  
    num_p = 50  
    duration = 50 
    
    if duration > 5:
        rate = 0.1
    else:
        rate = 0.3
    x = 9 
    
    
    cost1 = (0.1 * cost_p * 45 + cost_s) * x
    cost2 = (0.3 * cost_p * 5 + cost_s) * x
    cost = cost1 + cost2
    
    
    countp1 = 0.3 * 5 * x
    countp2 = 0.1 * 45 * x
    countp = countp1 + countp2
    
    print(f'cost (validation): {cost}')
    print(f'count (validation): {countp}')