Search code examples
pythonpython-typingmypyscipy.statsprobability-distribution

Appropriate type-hints for Generic/Frozen Multivariate Distributions in scipy.stats


I need to differentiate between 4 different sets of probability distributions from the scipy.stats module: generic univariate, frozen univariate, generic multivariate, & frozen multivariate. Throughout the application, I would like to add type hints for these 4 sets.

For the univariate cases, mypy has no problems with type hints like in this MWE:

from typing import Any

from scipy.stats import norm
from scipy.stats._distn_infrastructure import rv_generic, rv_frozen


def sample_frozen_univariate(n_samples: int, dist: rv_frozen):
    return dist.rvs(n_samples)


def sample_generic_univariate(n_samples: int, dist: rv_generic,
                              *distparams: Any):
    return dist(*distparams).rvs(n_samples)


n_samples = 4
frozen_dist = norm()
print(sample_frozen_univariate(n_samples, frozen_dist))

n_samples = 4
generic_dist = norm
loc, scale = 1, 2
print(sample_generic_univariate(n_samples, generic_dist, loc, scale))

Now here's an MWE of the multivariate cases. If executed it provides the correct result:

from typing import Any

from scipy.stats import multivariate_normal
from scipy.stats._multivariate import multi_rv_generic, multi_rv_frozen


def sample_frozen_multivariate(n_samples: int, n_variates: int, dist: multi_rv_frozen):
    if dist.dim != n_variates:
        msg = 'distribution dimension %s inconsistent with n_variates=%s'
        raise ValueError(msg % (dist.dim, n_variates))
    sample = dist.rvs(n_samples)
    return sample


def sample_generic_multivariate(n_samples: int, n_variates: int, dist: multi_rv_generic,
                                *distparams: Any):
    sample = dist(*distparams).rvs(n_samples)
    if sample.shape[1] != n_variates:
        msg = 'sample dimension %s inconsistent with n_variates=%s'
        raise ValueError(msg % (sample.shape[1], n_variates))
    return sample


n_samples = 4
n_variates = 2
frozen_dist = multivariate_normal(mean=[0, 0], cov=[[1, 0], [0, 1]])
print(sample_frozen_multivariate(n_samples, n_variates, frozen_dist))

n_samples = 4
n_variates = 2
generic_dist = multivariate_normal
mean, cov = [-1, 1], [[1, 0], [0, 1]]
print(sample_generic_multivariate(n_samples, n_variates, generic_dist, mean, cov))

However, mypy reports the following issues:

example_2.py:8: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "dim"  [attr-defined]
example_2.py:10: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "dim"  [attr-defined]
example_2.py:11: error: "multi_rv_frozen[multi_rv_generic]" has no attribute "rvs"  [attr-defined]
example_2.py:17: error: "multi_rv_generic" not callable  [operator]

I understand the issues that mypy is reporting, just not how to properly address them. What approach(es) to type hints for generic and frozen multivariate distributions from scipy.stats would appease type-checkers? (As an aside, I am aware of the problem with relying on types from "private" modules in scipy.stats. Though I understand that is a separate issue, I welcome solutions that simultaneously address that problem.)


Solution

  • multi_rv_generic and multi_rv_frozen are (minimal) basses for random variables; which do not extend to the interface that you need. Scipy possibly lacks a more common interface as not all distributions have the same characteristics, this is why you cannot just use these two in this case.

    If you only want to use multivariate normal its kind of easy:

    from scipy.stats._multivariate import (
        multivariate_normal_frozen,
        multivariate_normal_gen,
    )
    
    def sample_frozen_multivariate(n_samples: int, n_variates: int, dist: multivariate_normal_frozen):
        if dist.dim != n_variates:
            msg = "distribution dimension %s inconsistent with n_variates=%s"
            raise ValueError(msg % (dist.dim, n_variates))
        sample = dist.rvs(n_samples)
        return sample
    
    
    def sample_generic_multivariate(n_samples: int, n_variates: int, dist: multivariate_normal_gen, *distparams: Any):
        sample = dist(*distparams).rvs(n_samples)
        if sample.shape[1] != n_variates:
            msg = "sample dimension %s inconsistent with n_variates=%s"
            raise ValueError(msg % (sample.shape[1], n_variates))
        return sample
    

    Extend to all multivariates

    As said not distributions provide the same functionality, two examples: dirichlet_multinomial and normal_inverse_gamma

    Cover all existing distributions

    To be on the safe side (as long as no new distribution is added) you can try to cover all types explicitly. You will see quite a few errors on this code which shows which distributions your current code cannot handle yet. So you either have to exclude certain distributions or need to update your code.

    from scipy.stats import multivariate_normal
    from scipy.stats._multivariate import (
        multivariate_normal_frozen,
        matrix_normal_frozen,
        dirichlet_frozen,
        wishart_frozen,
        invwishart_frozen,
        multinomial_frozen,
        special_ortho_group_frozen,
        ortho_group_frozen,
        random_correlation_frozen,
        unitary_group_frozen,
        multivariate_t_frozen,
        multivariate_hypergeom_frozen,
        random_table_frozen,
        uniform_direction_frozen,
        dirichlet_multinomial_frozen,
        vonmises_fisher_frozen,
        normal_inverse_gamma_frozen,
        
        multivariate_normal_gen,
        matrix_normal_gen,
        dirichlet_gen,
        wishart_gen,
        multinomial_gen,
        special_ortho_group_gen,
        ortho_group_gen,
        random_correlation_gen,
        unitary_group_gen,
        multivariate_t_gen,
        multivariate_hypergeom_gen,
        random_table_gen,
        uniform_direction_gen,
        dirichlet_multinomial_gen,
        vonmises_fisher_gen,
        normal_inverse_gamma_gen,
    )
    
    # Use TypeAlias or TypeAliasType for Python < 3.12
    type FrozenDistType = (
        multivariate_normal_frozen
        | matrix_normal_frozen
        | dirichlet_frozen
        | wishart_frozen
        | invwishart_frozen
        | multinomial_frozen
        | special_ortho_group_frozen
        | ortho_group_frozen
        | random_correlation_frozen
        | unitary_group_frozen
        | multivariate_t_frozen
        | multivariate_hypergeom_frozen
        | random_table_frozen
        | uniform_direction_frozen
        | dirichlet_multinomial_frozen
        | vonmises_fisher_frozen
        | normal_inverse_gamma_frozen
    )
    
    type GenericDistType = (
        multivariate_normal_gen
        | matrix_normal_gen
        | dirichlet_gen
        | wishart_gen
        | multinomial_gen
        | special_ortho_group_gen
        | ortho_group_gen
        | random_correlation_gen
        | unitary_group_gen
        | multivariate_t_gen
        | multivariate_hypergeom_gen
        | random_table_gen
        | uniform_direction_gen
        | dirichlet_multinomial_gen
        | vonmises_fisher_gen
        | normal_inverse_gamma_gen
    )
    
    def sample_frozen_multivariate(n_samples: int, n_variates: int, dist: FrozenDistType):
        if dist.dim != n_variates:
            msg = "distribution dimension %s inconsistent with n_variates=%s"
            raise ValueError(msg % (dist.dim, n_variates))
        sample = dist.rvs(n_samples)
        return sample
    
    
    def sample_generic_multivariate(n_samples: int, n_variates: int, dist: GenericDistType, *distparams: Any):
        sample = dist(*distparams).rvs(n_samples)
        if sample.shape[1] != n_variates:
            msg = "sample dimension %s inconsistent with n_variates=%s"
            raise ValueError(msg % (sample.shape[1], n_variates))
        return sample
    
    Cover only distributions you currently can cover (Type-checkers are happy; does not support all distributions)

    Instead of covering all the distributions which might not cover your use case you can limit them with a Protocol to only the cases you want to support. Only distributions that satisfy your current code are supported.

    class FrozenDistProtocol(Protocol):
        def rvs(self, size: int, *args, **kwargs) -> np.ndarray:
            """
            Returns
            -------
            rvs : ndarray or scalar
            """
            ...
    
        def __init__(self, *args: Any, **kwargs: Any) -> None:
            self.dim: int
    
    
    class GenericDistProtocol(Protocol):
        def rvs(self, *args: Any, **kwargs: Any) -> np.ndarray | Any:
            """
            Returns
            -------
            rvs : ndarray or scalar
            """
            ...
    
        def __init__(self, *args: Any, **kwargs: Any) -> None: ...
    
    def sample_frozen_multivariate(n_samples: int, n_variates: int, dist: FrozenDistProtocol):
        if dist.dim != n_variates:
            msg = "distribution dimension %s inconsistent with n_variates=%s"
            raise ValueError(msg % (dist.dim, n_variates))
        sample = dist.rvs(n_samples)
        return sample
    
    
    def sample_generic_multivariate(n_samples: int, n_variates: int, dist: GenericDistProtocol, *distparams: Any):
        sample = dist(*distparams).rvs(n_samples)
        if sample.shape[1] != n_variates:
            msg = "sample dimension %s inconsistent with n_variates=%s"
            raise ValueError(msg % (sample.shape[1], n_variates))
        return sample
    

    You should test with some other distributions of the signatures of the Protocol classes I provided are sufficient.