Search code examples
pythonnumpyindexingtype-inferencenumba

Why does numba jit report error in numpy indexing?


Below is my code. It runs WELL w/o numba and throws converting to object warning w/o (nopython=True).

@numba.jit(nopython=True)
def computation_numba(array_cmp: np.ndarray, y_hat: np.ndarray, y_df: np.ndarray):
    cold_start_ratio = 0.05
    opportunities = 0
    if array_cmp.shape[1] <= 100:
        idx = np.flip(np.argsort(y_hat))
        y_df = y_df[idx]
        opportunities = max(int(cold_start_ratio*len(y_df)), 1)
        pos = y_df[:opportunities]
        neg = y_df[-opportunities:]
    else:
        ys = np.append(array_cmp,np.stack((y_hat,y_df),axis=0),axis=1)
        # to remove outlier 5-sigma
        means = ys[0,:].mean()
        stds = ys[0,:].std()
        ys = ys[:,np.array(ys[0,:]<=means+5*stds)]
        ys = ys[:,np.array(ys[0,:]>=means-5*stds)]
        idx = np.flip(np.argsort(y_hat))
        ys = ys[:,idx]
        opportunities = int(0.1*ys.shape[1])
        pos_th = ys[0,opportunities]
        neg_th = ys[0,-opportunities]
        pos = ys[1,np.array(ys[0,:]>=pos_th)]
        neg = ys[1,np.array(ys[0,:]<=neg_th)]
    return ( pos.sum()-neg.sum() )

array_cmp = np.random.random([2,200])
y_hat = np.random.random(50)
y_df = pd.DataFrame.from_dict({'A': y_hat+0.3, 'B': y_hat*3})
print(computation_numba(array_cmp, y_hat, y_df['A'].to_numpy()))

It returns this error:

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Cell In[21], line 30
     28 y_hat = np.random.random(50)
     29 y_df = pd.DataFrame.from_dict({'A': y_hat+0.3, 'B': y_hat*3})
---> 30 print(computation_numba(array_cmp, y_hat, y_df['A'].to_numpy()))

File /usr/local/lib64/python3.9/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File /usr/local/lib64/python3.9/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function array>) found for signature:
 
 >>> array(array(bool, 1d, C))
 
There are 4 candidate implementations:
      - Of which 4 did not match due to:
      Overload in function '_OverloadWrapper._build.<locals>.ol_generated': File: numba/core/overload_glue.py: Line 129.
        With argument(s): '(array(bool, 1d, C))':
       Rejected as the implementation raised a specific error:
         TypingError: array(bool, 1d, C) not allowed in a homogeneous sequence
  raised from /usr/local/lib64/python3.9/site-packages/numba/core/typing/npydecl.py:488

During: resolving callee type: Function(<built-in function array>)
During: typing of call at /tmp/ipykernel_1655694/2732175936.py (16)


File "../../../../tmp/ipykernel_1655694/2732175936.py", line 16:
<source missing, REPL/exec in use?>

I tried this post, but it did not solve the problem. I also read this post, but did not figure out how to adapt to my case.

---------------- Update 1-------------------

If I remove np.array type cast, which is opposite to this post, the code returns a similar error. The main difference is that 4 candidates becomes 22 candidates.

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Cell In[26], line 30
     28 y_hat = np.random.random(50)
     29 y_df = pd.DataFrame.from_dict({'A': y_hat+0.3, 'B': y_hat*3})
---> 30 print(computation_numba(array_cmp, y_hat, y_df['A'].to_numpy()))

File /usr/local/lib64/python3.9/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File /usr/local/lib64/python3.9/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function getitem>) found for signature:
 
 >>> getitem(array(float64, 2d, C), Tuple(Literal[int](1), array(bool, 1d, C)))
 
There are 22 candidate implementations:
      - Of which 20 did not match due to:
      Overload of function 'getitem': File: <numerous>: Line N/A.
        With argument(s): '(array(float64, 2d, C), Tuple(int64, array(bool, 1d, C)))':
       No match.
      - Of which 2 did not match due to:
      Overload in function 'GetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 166.
        With argument(s): '(array(float64, 2d, C), Tuple(int64, array(bool, 1d, C)))':
       Rejected as the implementation raised a specific error:
         NumbaNotImplementedError: only one advanced index supported
  raised from /usr/local/lib64/python3.9/site-packages/numba/core/typing/arraydecl.py:69

During: typing of intrinsic-call at /tmp/ipykernel_1655694/1304761145.py (23)

File "../../../../tmp/ipykernel_1655694/1304761145.py", line 23:
<source missing, REPL/exec in use?>

----------------------Update 2-------------------------

The code post on July 14 is in deed what I was using, as can be seen from this screenshot: this screen shot

It still returned that "22 candidate" error after I removed np.array(). (errors in the screenshot is almost the same as in the code blok in Update 1. I'm using Python 3.9 w/ numba==0.56.4, numpy==1.23.5, pandas==2.0.1.


Solution

  • Thanks to @Rutger Kassies in the comment. The problem comes from version conflict (bug?). The code works fine with numba==0.57.1.