Search code examples
pythonnumpymypy

Can't get rid of mypy error about wrong type numpy.bool_ and bool


I have a classe containing several np.array :

class VECMParams(ModelParams):
    def __init__(
        self,
        ecm_gamma: np.ndarray,
        ecm_mu: Optional[np.ndarray],
        ecm_lambda: np.ndarray,
        ecm_beta: np.ndarray,
        intercept_coint: bool,
    ):
        self.ecm_gamma = ecm_gamma 
        self.ecm_mu = ecm_mu
        self.ecm_lambda = ecm_lambda 
        self.ecm_beta = ecm_beta
        self.intercept_coint = intercept_coint

I want to override == operator. Basically, a VECMParam is equal to another when all of their arrays are equal to rhs one :

def __eq__(self, rhs: object) -> bool:
    if not isinstance(rhs, VECMParams):
        raise NotImplementedError()

    return (
        np.all(self.ecm_gamma == rhs.ecm_gamma) and
        np.all(self.ecm_mu == rhs.ecm_mu) and
        np.all(self.ecm_lambda == rhs.ecm_lambda) and
        np.all(self.ecm_beta == rhs.ecm_beta) 
    )

Still, mypy keeps saying that Incompatible return value type (got "Union[bool_, bool]", expected "bool") [return-value] because np.all returns bool_ and __eq__ needs to return native bool. I search for hours it looks like there is no way to convert these bool_ to native bool. Anyone having the same problem ?

PS: doing my_bool_ is True is not evaluated to the right native bool value


Solution

  • Look at the numpy.all():

    A new boolean or array is returned unless out is specified, in which case a reference to out is returned.
    

    This is the Union[ndarray, bool].

    How to fix:

    def __eq__(self, rhs: 'VECMParams') -> bool:
        if not isinstance(rhs, VECMParams):
            raise NotImplementedError()
    
        return bool(
            np.all(self.ecm_gamma == rhs.ecm_gamma) and
            np.all(self.ecm_mu == rhs.ecm_mu) and
            np.all(self.ecm_lambda == rhs.ecm_lambda) and
            np.all(self.ecm_beta == rhs.ecm_beta) 
        )