Search code examples
pythonnumpypython-typing

Python type hint for objects that have "@" (matrix-multiply)


I have a function fun() that accepts a NumPy ArrayLike and a "matrix", and returns a numpy array.

from numpy.typing import ArrayLike
import numpy as np

def fun(A, x: ArrayLike) -> np.ndarray:
    return (A @ x) ** 2 - 27.0

What's the correct type for entities that have an @ operation? Note that fun() could also accept a scipy.sparse; perhaps more.


Solution

  • You can use typing.Protocol to assert that the type implements __matmul__.

    class SupportsMatrixMultiplication(typing.Protocol):
        def __matmul__(self, x):
            ...
    
    
    def fun(A: SupportsMatrixMultiplication, x: ArrayLike) -> np.ndarray:
        return (A @ x) ** 2 - 27.0
    

    You can, I believe, further refine this by providing type hints for x and a return type hint, if you want more than just supporting @ as an operator.