Search code examples
pythoncachingpython-dataclasses

Custom cache with iterator does not work as intended


I got the following class, where:

iterable is the passed argument, like for example range(20), n_max is an optional value, which limits the numbers of elements the cache should have, iterator is a field that gets initiated with the iterable, cache is the list I am trying to fill and finished is a bool which signals if the iterator is "empty" or not. Here is an example input:

>>> iterable = range(20)
>>> cachedtuple = CachedTuple(iterable)
>>> print(cachedtuple[0])
0
>>> print(len(cachedtuple.cache))
1
>>> print(cachedtuple[10])
10
>>> print(len(cachedtuple.cache))
11
>>> print(len(cachedtuple))
20
>>> print(len(cachedtuple.cache))
20
>>> print(cachedtuple[25])


@dataclass
class CachedTuple:
    iterable: Iterable = field(init=True)
    n_max: Optional[int] = None
    iterator: Iterator = field(init=False)
    cache: list = field(default_factory=list)
    finished: bool = False

    def __post_init__(self):
        self.iterator = iter(self.iterable)

    def cache_next(self):
        
        if self.n_max and self.n_max <= len(self.cache):
            self.finished = True
        else:
            try:
                nxt = next(self.iterator)
                self.cache.append(nxt)

            except StopIteration:
                self.finished = True


    def __getitem__(self, item: int):

        match item:
            case item if type(item) != int:
                raise IndexError

            case item if item < 0:
                raise IndexError

            case item if self.finished or self.n_max and item > self.n_max:
                raise IndexError(f"Index {item} out of range")

            case item if item >= len(self.cache):
                while item - len(self.cache) >= 0:
                    self.cache_next()

                return self.__getitem__(item)

            case item if item < len(self.cache):
                return self.cache[item]


    def __len__(self):

        while not self.finished:
            self.cache_next()
        return len(self.cache)

Although this code is certainly not good, at least it works for almost every scenario, but using the range function of Python for example. If I use for example

cachedtuple = CachedTuple(range(20))
for element in cachedtuple:
    print(element)

I get the element until 19 and then the program loops infinitely. I think one problem might be that I have no raise StopIteration in my code. So I am kind of lost how to fix this mess.


Solution

  • Your bug is due to these lines:

    case item if item >= len(self.cache):
        while item - len(self.cache) >= 0:
            self.cache_next()
    

    Basically, CachedTuple((1,2,3))[50] will loop indefinitely, as 50 is larger than the length of the cache, and self.cache_next() won't generate any new values.

    A simple change adding a self.finished check will work:

    case item if item >= len(self.cache):
        while item - len(self.cache) >= 0 and not self.finished:
            self.cache_next()
    

    I do however believe you have numerous other issues with the code, and I think you can improve it tremendously:

    1. Drop the match statement. It does nothing.
    2. Implement iteration using __iter__ instead of relying on the old iteration mechanism of __getitem__.
    3. Inherit from the collections.abc.Sequence and adhere to the Sequence protocol.
    4. Drop the dataclass. This is not a dataclass. You seem to enjoy the delightful new language features, but unfortunately none of them are relevant and it's causing your code to be longer, less clear, and not working as intended.

    Remember, simple readable code is infinitely more important than using new language features.


    I took the liberty and spent a few hours creating an example code complying to collections.abc.Sequence. Enjoy!

    from collections.abc import Sequence
    import itertools
    from typing import Iterable, Iterator, Optional, TypeVar, overload
    
    _T_co =TypeVar("_T_co", covariant=True)
    
    class CachedIterable(Sequence[_T_co]):
        def __init__(self, iterable: Iterable[_T_co], *, max_length: int = None) -> None:
            self._cache: list[_T_co] = []
            
            if max_length is not None:
                if max_length <= 0:
                    raise ValueError('max_length must be > 0')
                iterable = itertools.islice(iterable, max_length)
            else:
                try:
                    # Attempt to optimize and get a length.
                    max_length = len(iterable)  # type: ignore
                except TypeError:
                    max_length = None
    
            self._max_length = max_length
            self._iterator: Optional[Iterator] = iter(iterable)
        
        def __repr__(self) -> str:
            return (f'<{self.__class__.__name__} {self._cache!r}'
                    f'{"+" if self._iterator else ""}>')
        
        def _exhaust_iterator(self) -> None:
            """Fully exhaust the iterator."""
            assert self._iterator
            try:
                self._cache.extend(self._iterator)
            finally:
                self._iterator = None
    
        def _advance_iterator(self, n: int) -> None:
            """Attempt to advance the iterator by n steps.
    
            May advance by less than n steps if the iterator is exhausted.
            """
            assert self._iterator
            
            pre_advance_length = len(self._cache)
    
            try:
                self._cache.extend(itertools.islice(self._iterator, n))
            except Exception:
                # Iterator threw an exception.
                self._iterator = None
                raise
    
            # If iterator exhausted, clear it.
            if pre_advance_length + n > len(self._cache):
                self._iterator = None
            
        def _grow_cache(self, size: int) -> None:
            """Atttempt grow the cache to be at least size.
            
            May grow to less than size if the iterator is exhausted.
            """
            if size <= len(self._cache):
                return
    
            if self._max_length and size >= self._max_length:
                self._exhaust_iterator()
                return
            
            self._advance_iterator(size - len(self._cache))
        
        @overload
        def __getitem__(self, i: int) -> _T_co: ...
    
        @overload
        def __getitem__(self, s: slice) -> Sequence[_T_co]: ...
            
        def __getitem__(self, index):
            if not isinstance(index, (slice, int)):
                raise TypeError(f'index must be int or slice, not {index!r}')
    
            if not self._iterator:
                return self._cache[index]
    
            if isinstance(index, slice):
                # Stop might be less than start if step is negative.
                max_index = max(index.stop or 0, index.start or 0)
                
                # If we're counting from the end, exaust the iterator.
                if (index.stop is not None and index.stop < 0 or
                        index.start is not None and index.start < 0):
                    self._exhaust_iterator()
                
                else:
                    self._grow_cache(max_index + 1)
    
                return self._cache[index]
    
            # Asking for a number beyond the limit.
            if self._max_length and index > self._max_length:
                raise IndexError(f'index {index} out of range')
    
            # If we're counting from the end, exaust the iterator.
            if index < 0:
                self._exhaust_iterator()
            else:
                self._grow_cache(index + 1)
    
            return self._cache[index]
        
        def __iter__(self) -> Iterator[_T_co]:
            if not self._iterator:
                yield from self._cache
                return
            
            yield from self._cache
            while True:
                try:
                    item = next(self._iterator)
                    # Iterator threw an exception.
                except StopIteration:
                    self._iterator = None
                    return
                except BaseException:
                    self._iterator = None
                    raise
                
                self._cache.append(item)
                # Prevent capturing GeneratorExit and other gen.throw() exceptions.
                yield item
    
    
        def __len__(self) -> int:
            # TODO: Can optimize for known lengths.
            if not self._iterator:
                return len(self._cache)
    
            self._exhaust_iterator()
            return len(self._cache)