Search code examples
pythonpython-typing

How to define a nested generic type in Python?


I have a function that (should) flatten arbitrarily many times nested list

T = TypeVar("T")
type Nested[T] = Sequence[T | Sequence[Nested]]

def flatten(seq: Nested[T]) -> list[T]:
    flattened: list[T] = []
    for elem in seq:
        if isinstance(item, Sequence):
            flattened.extend(flatten(cast(Nested[T], elem)))
        else:
            flattened.append(elem)
    return flattened

Now if I pass in for example a list of list, the result seems to still be list of list. Why's that?


test: list[list[str]]

# (parameter) flattened: list[list[str]]
flattened = flatten(test)

It looks like it is always returning the same type that is passed in. Why's that? Looks like whatever is inside the first list is thought to be the generic type. How can I define this nested (recursive) type and have the flatten function to work and show the type hints correctly like this?


test: list[list[str]]
test2: list[list[list[list[int]]]]

# (parameter) flattened: list[str]
flattened = flatten(test)

# (parameter) flattened2: list[int]
flattened2 = flatten(test2)

Python used is 3.12


Edit:

Just when I posted this I found out there was a little mistake in my definitions. If the function is defined like this

T = TypeVar("T")
type Nested[T] = Sequence[T | Nested[T]] # <--- Fix here!!

def flatten(seq: Nested[T]) -> list[T]:
    flattened: list[T] = []
    for elem in seq:
        if isinstance(item, Sequence):
            flattened.extend(flatten(cast(Nested[T], elem)))
        else:
            flattened.append(elem)
    return flattened

Also another problem was that in my actual code I was assigning the flattened list back to the variable that was already defined to be a list of list (or more). Assigning the flattened list to a new variable shows the type correctly... EXCEPT in case of bytes. I guess that is because under the hood the type bytes is actually some kind of Iterable[int]

test: list[list[str]]
bytes_test: list[list[bytes]]

# test is already defined as list[list[str]]
# (parameter) test: list[list[str]]
test = flatten(test)

# type for flattened is inferred so it's list[str]
# (parameter) flattened: list[str]
flattened = flatten(test)

# Apparently bytes type is equal to Iterable[int], so it is flattened as well
# (parameter) flattened_bytes: list[int]
flattened_bytes = flatten(bytes_test)

Now I wonder how can I preserve the nesting in case of bytes, so that the result would be correctly nested list[bytes]


Edit2:

There seems to be a bug of infinite recursion when using str or bytes as the type, because both of them are iterables. Looks like now the best way to do this kind of generic flattening is to invert the order of the isinstance if like this

T = TypeVar("T")
type Nested[T] = Sequence[T | Nested[T]]

def flatten(seq: Nested[T]) -> list[T]:
    flattened: list[T] = []
    for elem in seq:
        if isinstance(item, T): # <-- This doesn't work. Need custom generic isinstance checker
            flattened.append(elem)
        else:
            flattened.extend(flatten(elem))
    return flattened

But then this would need some custom generic isinstance checker


Solution

  • Because Python types are not well-ordered (in the mathematical sense) with respect to iteration, i.e.

    int < Sequence[int] < Sequence[Sequence[int]] < ...
    

    but

    ... < str < str < str < Sequence[str] < Sequence[Sequence[str]] < ...
    

    you need to manually check for str (and bytes) in addition to testing for iterability, rather than testing if a value has the non-iterable type expected.

    Something like

    def flatten(seq: Nested[T]) -> list[T]:
        flattened: list[T] = []
    
        for elem in seq:
            try:
                # Special case because str and bytes
                # are iterable, but you want to treat them
                # as non-iterables.
                if isinstance(elem, (str, bytes)):
                    raise TypeError()
    
                i = iter(elem)
            except TypeError:
                # Non-iterable types, str, and bytes
                flattened.append(elem)
            else:
                # Iterable types except str and bytes
                flattened.extend(flatten(elem))
    
        return flattened