I got the following class, where:
iterable
is the passed argument, like for example range(20)
, n_max
is an optional value, which limits the numbers of elements the cache should have, iterator
is a field that gets initiated with the iterable, cache
is the list I am trying to fill and finished
is a bool
which signals if the iterator is "empty" or not. Here is an example input:
>>> iterable = range(20)
>>> cachedtuple = CachedTuple(iterable)
>>> print(cachedtuple[0])
0
>>> print(len(cachedtuple.cache))
1
>>> print(cachedtuple[10])
10
>>> print(len(cachedtuple.cache))
11
>>> print(len(cachedtuple))
20
>>> print(len(cachedtuple.cache))
20
>>> print(cachedtuple[25])
@dataclass
class CachedTuple:
iterable: Iterable = field(init=True)
n_max: Optional[int] = None
iterator: Iterator = field(init=False)
cache: list = field(default_factory=list)
finished: bool = False
def __post_init__(self):
self.iterator = iter(self.iterable)
def cache_next(self):
if self.n_max and self.n_max <= len(self.cache):
self.finished = True
else:
try:
nxt = next(self.iterator)
self.cache.append(nxt)
except StopIteration:
self.finished = True
def __getitem__(self, item: int):
match item:
case item if type(item) != int:
raise IndexError
case item if item < 0:
raise IndexError
case item if self.finished or self.n_max and item > self.n_max:
raise IndexError(f"Index {item} out of range")
case item if item >= len(self.cache):
while item - len(self.cache) >= 0:
self.cache_next()
return self.__getitem__(item)
case item if item < len(self.cache):
return self.cache[item]
def __len__(self):
while not self.finished:
self.cache_next()
return len(self.cache)
Although this code is certainly not good, at least it works for almost every scenario, but using the range function of Python for example. If I use for example
cachedtuple = CachedTuple(range(20))
for element in cachedtuple:
print(element)
I get the element until 19
and then the program loops infinitely. I think one problem might be that I have no raise StopIteration
in my code. So I am kind of lost how to fix this mess.
Your bug is due to these lines:
case item if item >= len(self.cache):
while item - len(self.cache) >= 0:
self.cache_next()
Basically, CachedTuple((1,2,3))[50]
will loop indefinitely, as 50
is larger than the length of the cache, and self.cache_next()
won't generate any new values.
A simple change adding a self.finished
check will work:
case item if item >= len(self.cache):
while item - len(self.cache) >= 0 and not self.finished:
self.cache_next()
I do however believe you have numerous other issues with the code, and I think you can improve it tremendously:
__iter__
instead of relying on the old iteration mechanism of __getitem__
.collections.abc.Sequence
and adhere to the Sequence
protocol.Remember, simple readable code is infinitely more important than using new language features.
I took the liberty and spent a few hours creating an example code complying to collections.abc.Sequence
. Enjoy!
from collections.abc import Sequence
import itertools
from typing import Iterable, Iterator, Optional, TypeVar, overload
_T_co =TypeVar("_T_co", covariant=True)
class CachedIterable(Sequence[_T_co]):
def __init__(self, iterable: Iterable[_T_co], *, max_length: int = None) -> None:
self._cache: list[_T_co] = []
if max_length is not None:
if max_length <= 0:
raise ValueError('max_length must be > 0')
iterable = itertools.islice(iterable, max_length)
else:
try:
# Attempt to optimize and get a length.
max_length = len(iterable) # type: ignore
except TypeError:
max_length = None
self._max_length = max_length
self._iterator: Optional[Iterator] = iter(iterable)
def __repr__(self) -> str:
return (f'<{self.__class__.__name__} {self._cache!r}'
f'{"+" if self._iterator else ""}>')
def _exhaust_iterator(self) -> None:
"""Fully exhaust the iterator."""
assert self._iterator
try:
self._cache.extend(self._iterator)
finally:
self._iterator = None
def _advance_iterator(self, n: int) -> None:
"""Attempt to advance the iterator by n steps.
May advance by less than n steps if the iterator is exhausted.
"""
assert self._iterator
pre_advance_length = len(self._cache)
try:
self._cache.extend(itertools.islice(self._iterator, n))
except Exception:
# Iterator threw an exception.
self._iterator = None
raise
# If iterator exhausted, clear it.
if pre_advance_length + n > len(self._cache):
self._iterator = None
def _grow_cache(self, size: int) -> None:
"""Atttempt grow the cache to be at least size.
May grow to less than size if the iterator is exhausted.
"""
if size <= len(self._cache):
return
if self._max_length and size >= self._max_length:
self._exhaust_iterator()
return
self._advance_iterator(size - len(self._cache))
@overload
def __getitem__(self, i: int) -> _T_co: ...
@overload
def __getitem__(self, s: slice) -> Sequence[_T_co]: ...
def __getitem__(self, index):
if not isinstance(index, (slice, int)):
raise TypeError(f'index must be int or slice, not {index!r}')
if not self._iterator:
return self._cache[index]
if isinstance(index, slice):
# Stop might be less than start if step is negative.
max_index = max(index.stop or 0, index.start or 0)
# If we're counting from the end, exaust the iterator.
if (index.stop is not None and index.stop < 0 or
index.start is not None and index.start < 0):
self._exhaust_iterator()
else:
self._grow_cache(max_index + 1)
return self._cache[index]
# Asking for a number beyond the limit.
if self._max_length and index > self._max_length:
raise IndexError(f'index {index} out of range')
# If we're counting from the end, exaust the iterator.
if index < 0:
self._exhaust_iterator()
else:
self._grow_cache(index + 1)
return self._cache[index]
def __iter__(self) -> Iterator[_T_co]:
if not self._iterator:
yield from self._cache
return
yield from self._cache
while True:
try:
item = next(self._iterator)
# Iterator threw an exception.
except StopIteration:
self._iterator = None
return
except BaseException:
self._iterator = None
raise
self._cache.append(item)
# Prevent capturing GeneratorExit and other gen.throw() exceptions.
yield item
def __len__(self) -> int:
# TODO: Can optimize for known lengths.
if not self._iterator:
return len(self._cache)
self._exhaust_iterator()
return len(self._cache)