I have a classe containing several np.array
:
class VECMParams(ModelParams):
def __init__(
self,
ecm_gamma: np.ndarray,
ecm_mu: Optional[np.ndarray],
ecm_lambda: np.ndarray,
ecm_beta: np.ndarray,
intercept_coint: bool,
):
self.ecm_gamma = ecm_gamma
self.ecm_mu = ecm_mu
self.ecm_lambda = ecm_lambda
self.ecm_beta = ecm_beta
self.intercept_coint = intercept_coint
I want to override ==
operator. Basically, a VECMParam
is equal to another when all of their arrays are equal to rhs
one :
def __eq__(self, rhs: object) -> bool:
if not isinstance(rhs, VECMParams):
raise NotImplementedError()
return (
np.all(self.ecm_gamma == rhs.ecm_gamma) and
np.all(self.ecm_mu == rhs.ecm_mu) and
np.all(self.ecm_lambda == rhs.ecm_lambda) and
np.all(self.ecm_beta == rhs.ecm_beta)
)
Still, mypy keeps saying that Incompatible return value type (got "Union[bool_, bool]", expected "bool") [return-value]
because np.all
returns bool_
and __eq__
needs to return native bool
. I search for hours it looks like there is no way to convert these bool_ to native bool. Anyone having the same problem ?
PS: doing my_bool_ is True
is not evaluated to the right native bool value
Look at the numpy.all()
:
A new boolean or array is returned unless out is specified, in which case a reference to out is returned.
This is the Union[ndarray, bool]
.
How to fix:
def __eq__(self, rhs: 'VECMParams') -> bool:
if not isinstance(rhs, VECMParams):
raise NotImplementedError()
return bool(
np.all(self.ecm_gamma == rhs.ecm_gamma) and
np.all(self.ecm_mu == rhs.ecm_mu) and
np.all(self.ecm_lambda == rhs.ecm_lambda) and
np.all(self.ecm_beta == rhs.ecm_beta)
)