Search code examples
pythonperformancejitnumba

Compiling njit nopython version of function fails due to data types


I'm writing a function in njit to speed up a very slow reservoir operations optimization code. The function is returning the maximum value for spill releases based on the reservoir level and gate availability. I am passing in a parameter size that specifies the number of flows to calculate (in some calls it's one and in some its many). I'm also passing in a numpy.zeros array that I can then fill with the function output. A simplified version of the function is written as follows:

import numpy as np
from numba import njit

@njit(cache=True)
def fncMaxFlow(elev, flag, size, MaxQ):
    if (flag == 1): # SPOG2 running
        if size==0:
            if (elev>367.28):
                return 861.1 
            else: return 0
        else:
            for i in range(size):
                if((elev[i]>367.28) & (elev[i]<385)):
                    MaxQ[i]=861.1
            return MaxQ
    else:
        if size==0: return 0
        else: return MaxQ

fncMaxFlow(np.random.randint(368, 380, 3), 1, 3, np.zeros(3))

The error I'm getting:

Can't unify return type from the following types: array(float64, 1d, C), float64, int32

What is the reason for this? Is there any workaround or some step I'm missing so I can use numba to speed things up? This function and others like it are being called millions of times so they are a major factor in the computational efficiency. Any advice would help - I'm pretty new to python.


Solution

  • A variable within a numba function must have consistent type including the return variable. In your code you can either return MaxQ (an array), 861.1 (a float) or 0 (an int).

    You need to refactor this code so that it always returns a consistent type regardless of code path.

    Also note that in several places where you are comparing a numpy array to a scalar (elev > 367.28), what you are getting back is an array of boolean values, which is going to cause you issues. Your example function doesn't run as a pure python function (dropping the numba decorator) because of this.