Search code examples
pythonnumbatry-except

Numba try: if array.shape[1] - error: tuple index out of range. works without numba, doesn't work with @njit(fastmath=True, nogil=True, cache=True)


Numba 0.53.1, Python 3.7.9, Windows 10 64bit

This doctest works fine:

import numpy as np

def example_numba_tri(yp):
    """
    >>> example_numba_tri(np.array([0.1, 0.5, 0.2, 0.3, 0.1, 0.7, 0.6, 0.4, 0.1]))
    array([[0.1, 0.3, 0.6],
           [0.5, 0.1, 0.4],
           [0.2, 0.7, 0.1]])
    """
    try:
        if yp.shape[1] == 3:
            pass
    except:
        yp = yp.reshape(int(len(yp) / 3), -1, order='F')

    return yp

Just add @njit(fastmath=True, nogil=True, cache=True):

from numba import njit
import numpy as np

@njit(fastmath=True, nogil=True, cache=True)
def example_numba_tri(yp):
    """
    >>> example_numba_tri(np.array([0.1, 0.5, 0.2, 0.3, 0.1, 0.7, 0.6, 0.4, 0.1]))
    array([[0.1, 0.3, 0.6],
           [0.5, 0.1, 0.4],
           [0.2, 0.7, 0.1]])
    """
    try:
        if yp.shape[1] == 3:
            pass
    except:
        yp = yp.reshape(int(len(yp) / 3), -1, order='F')

    return yp

and get an error:

    numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
    Internal error at <numba.core.typeinfer.StaticGetItemConstraint object at 0x0000020CB55A7608>.
    tuple index out of range
    During: typing of static-get-item at C:/U1/main.py (1675)
    Enable logging at debug level for details.
    
    File "main.py", line 1675:
    def example_numba_tri(yp):
        <source elided>
        try:
            if yp.shape[1] == 3:
            ^

How to fix it and why this happen? Or this is a bug? I read https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#pysupported-exception-handling , but seems I did everything as it written there.

Updates:

  1. https://github.com/numba/numba/issues/6872
  2. Also this was useful for me https://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile
  3. profiling and restructing helped me convert to numba only most cpu-intensive parts
  4. numba.pydata.org/numba-doc/dev/reference/numpysupported.html seems no support for order (order='F')

Solution

  • Please consider to rewrite your code another way, since it looks like your Numba code checks types of input data so block if yp.shape[1] == 3: is checked in compilation stage, that is why it is not handled by try except

    Please try the code below, it is identical to your code, but there is no order='F' which does not want to work with Numba in any way.

    from numba import njit
    import numpy as np
    
    
    @njit(fastmath=True, nogil=True, cache=True)
    def example_numba_tri(yp):
        return yp.reshape(int(len(yp) / 3), -1)
    
    
    def wrapper_example_numba_tri(yp):
        if len(yp.shape) > 1:
            if yp.shape[1] == 3:
                return yp
        return example_numba_tri(yp)
    
    
    if name == 'main':
        x = np.array([0.1, 0.5, 0.2, 0.3, 0.1, 0.7, 0.6, 0.4, 0.1])
        wrapper_example_numba_tri(x)