Search code examples
pythonpython-typingpyright

How to tell type checker the type of an external function output


I am trying to add type hints to a package I am maintaining for more effective checks. However, we are using some external packages like numpy, scipy, etc., which do not have type hints everywhere. Thus, it happens quite often, that the type checkers do not know what type a variable is that is generated by these other packages, but I do. Is there a way to tell the type checkers without changing the external packages.

I am trying to type check code like this:

import numpy as np
from scipy.sparse import csr_matrix, spmatrix

def run() -> spmatrix:
    a = csr_matrix(np.eye(3))
    b = csr_matrix(np.eye(3))
    c = a @ b
    c = c.multiply(3.0)
    return c

For this code, pyright and similar type checkers complain about the c.multiply line and sometimes about the return because they don't understand that c is still going to be a sparse matrix and thus believe it might not have a .multiply function or mismatch the return type.

I am searching for some magic comment I could add to tell these checkers that c will be a sparse matrix. Like c: spmatrix = a @ b or c = a @ b # pyright c: spmatrix or whatever it might be.


Solution

  • The proper way to do this is with stub files to inform pyright about the types of objects in the libraries. These exist for pandas and numpy. However, I am not aware of any of for scipy. One thing you can often do - which seems to work in this case - is use assert() statements to help the static type checker:

    def run() -> spmatrix:
        a = csr_matrix(np.eye(3))
        b = csr_matrix(np.eye(3))
        c = a @ b
        assert isinstance(c, csr_matrix) # this is new
        c = c.multiply(3.0)
        return c
    

    Without this line this is my pyright output:

    $ pyright test.py
    ~/test.py
      ~/test.py:9:11 - error: Cannot access attribute "multiply" for class "ndarray[Any, dtype[Unknown]]"
        Attribute "multiply" is unknown (reportAttributeAccessIssue)
      ~/test.py:9:11 - error: Cannot access attribute "multiply" for class "ndarray[Any, Unknown]"
        Attribute "multiply" is unknown (reportAttributeAccessIssue)
      ~/test.py:9:11 - error: Cannot access attribute "multiply" for class "matrix[Unknown, Unknown]"
        Attribute "multiply" is unknown (reportAttributeAccessIssue)
    3 errors, 0 warnings, 0 informations 
    

    However, with the assert() statement it reads:

    $ pyright test.py
    0 errors, 0 warnings, 0 informations