Search code examples
pythonpython-typingiterable

How to check a class/type is iterable (uninstantiated)


I am inspecting type hints such as list[int] which is a GenericAlias.

If I get the origin using typing.get_origin(list[int]) or list[int].__origin__ it returns the class type list, as expected: <class 'list'>

How can I check if the class is iterable without instantiating it, or is that the only way?
The usual iter() and isinstance(object, collections.abc.Iterable) obviously don't work as they expect the instantiated object, not the class.

I saw this answer, but doesn't seem to work correctly in Python 3.10 (even when i_type variable is substituted for t).


Solution

  • This depends a bit on what you define as iterable.

    The Collections Abstract Base Classes module considers a class to implement the Iterable protocol once it defines the __iter__ method. Note that you do not need to define the __next__ method. This is only needed, if you want to implement an Iterator. (Those two often get confused.)

    A slightly broader definition in accordance with the the general notion of an iterable in the documentation also includes classes that implement __getitem__ (with integer indexes starting at 0 as Sequences do).

    In practice this means that you have an iterable class, if and only if you can call the built-in iter() function with an instance of that class. That function merely calls the instance's __iter__ method, if it finds one.

    If that is what you consider to be iterable as well, the most reliable way to check that I can think of is the following. We first find out if one of the classes in the method resolution order implements the desired instance method:

    (Thanks @user2357112 for reminding me of checking inherited methods.)

    def _implements_instance_method(cls: type, name: str) -> type | None:
        """
        Checks whether a class implements a certain instance method.
    
        Args:
            cls: Class to check; superclasses (except `object`) are also checked
            name: The name of the instance method that `cls` should have
    
        Returns:
            The earliest class in the MRO of `cls` implementing the instance
            method with the provided `name` or `None` if none of them do.
        """
        for base in cls.__mro__[:-1]:  # do not check `object`
            if name in base.__dict__ and callable(base.__dict__[name]):
                return base
        return None
    

    That first check is self-explanatory; if it fails, we obviously don't have that method. But this is where it gets a little pedantic.

    The second check actually does more than one thing. First off it ensures that that the name on our cls is defined as a method, i.e. callable. But it also ensures us against any descriptor shenanigans (to an extent). This is why we check callable(cls.__dict__[name]) and not simply callable(getattr(cls, name)).

    If someone were to (for whatever reason) have a @classmethod or a @property called name, that would not fly here.

    Next we write our actual iterable checking function:

    def is_iterable_class(cls: type, strict: bool = True) -> bool:
        """
        Returns `True` only if `cls` implements the iterable protocol.
    
        Args:
            cls:
                The class to check for being iterable
            strict (optional):
                If `True` (default), only classes that implement (or inherit)
                the `__iter__` instance method are considered iterable;
                if `False`, classes supporting `__getitem__` subscripting
                will be considered iterable.
                -> https://docs.python.org/3/glossary.html#term-iterable
    
        Returns:
            `True` if `cls` is to be considered iterable; `False` otherwise.
        """
        if not isinstance(cls, type):
            return False
        if _implements_instance_method(cls, "__iter__") is None:
            if strict:
                return False
            return _implements_instance_method(cls, "__getitem__") is not None
        return True
    

    There are still a number of pitfalls here though.

    A little demo:

    from collections.abc import Iterable, Iterator
    from typing import Generic, TypeVar
    
    
    T = TypeVar("T")
    
    
    class MyIter(Iterable[T]):
        def __init__(self, *items: T) -> None:
            self._items = items
    
        def __iter__(self) -> Iterator[T]:
            return iter(self._items)
    
    
    class SubIter(MyIter[T]):
        pass
    
    
    class IdxIter(Generic[T]):
        def __init__(self, *items: T) -> None:
            self._items = items
    
        def __getitem__(self, idx: int) -> T:
            return self._items[idx]
    
    
    class Foo:
        __iter__ = "bar"
    
    
    class Bar:
        @classmethod
        def __iter__(cls) -> Iterator[int]:
            return iter(range(5))
    
    
    class Baz:
        def __iter__(self) -> int:
            return 1
    
    
    def _implements_instance_method(cls: type, name: str) -> type | None:
        """
        Checks whether a class implements a certain instance method.
    
        Args:
            cls: Class to check; base classes (except `object`) are also checked
            name: The name of the instance method that `cls` should have
    
        Returns:
            The earliest class in the MRO of `cls` implementing the instance
            method with the provided `name` or `None` if none of them do.
        """
        for base in cls.__mro__[:-1]:  # do not check `object`
            if name in base.__dict__ and callable(base.__dict__[name]):
                return base
        return None
    
    
    def is_iterable_class(cls: type, strict: bool = True) -> bool:
        """
        Returns `True` only if `cls` implements the iterable protocol.
    
        Args:
            cls:
                The class to check for being iterable
            strict (optional):
                If `True` (default), only classes that implement (or inherit)
                the `__iter__` instance method are considered iterable;
                if `False`, classes supporting `__getitem__` subscripting
                will be considered iterable.
                -> https://docs.python.org/3/glossary.html#term-iterable
    
        Returns:
            `True` if `cls` is to be considered iterable; `False` otherwise.
        """
        if not isinstance(cls, type):
            return False
        if _implements_instance_method(cls, "__iter__") is None:
            if strict:
                return False
            return _implements_instance_method(cls, "__getitem__") is not None
        return True
    
    
    if __name__ == '__main__':
        import numpy as np
        print(f"{is_iterable_class(MyIter)=}")
        print(f"{is_iterable_class(SubIter)=}")
        print(f"{is_iterable_class(IdxIter)=}")
        print(f"{is_iterable_class(IdxIter, strict=False)=}")
        print(f"{is_iterable_class(Foo)=}")
        print(f"{is_iterable_class(Bar)=}")
        print(f"{is_iterable_class(Baz)=}")
        print(f"{is_iterable_class(np.ndarray)=}")
        try:
            iter(np.array(1))
        except TypeError as e:
            print(repr(e))
    

    The output:

    is_iterable_class(MyIter)=True
    is_iterable_class(SubIter)=True
    is_iterable_class(IdxIter)=False
    is_iterable_class(IdxIter, strict=False)=True
    is_iterable_class(Foo)=False
    is_iterable_class(Bar)=False
    is_iterable_class(Baz)=True
    is_iterable_class(np.ndarray)=True
    TypeError('iteration over a 0-d array')
    

    You should immediately notice that my function returns True for Baz even though it clearly messes up and delivers an integer instead of an Iterator. This is to demonstrate that the contract of the Iterable protocol ends at the definition of __iter__ and does not cover what it returns. Even though one would reasonably assume that it must return an Iterator, it is technically still an Iterable even if it doesn't.

    Another great practical example of this was pointed out by @user2357112: The numpy.ndarray is certainly iterable, by contract and in practice in most situations. However, when it is a 0D-array (i.e. a scalar), the __iter__ method raises a TypeError because iterating over a scalar makes little sense.

    The non-strict version of the function is even less practical since a class could easily and sensibly implement __getitem__, but not in the way expected by iter().

    I see no way around those issues and even the Python documentation tells you that

    the only reliable way to determine whether an object is iterable is to call iter(obj).


    If it is actually the Iterator that you are interested in, you could of course expand the function to do the same checks done for the __iter__ method also for the __next__ method. But keep in mind that this will immediately exclude all the built-in collection types like list, dict etc. because they actually don't implement __next__. Again, referring to collections.abc, you can see that all Collection subtypes only inherit from Iterable, not from Iterator.

    Hope this helps.