Search code examples
pythonpython-3.xcachingdata-structureslru

LRU cache implementation in Python using priority queue


I have the following code for implementing LRU cache.

from __future__ import annotations

from time import time

import heapq

from typing import List, Dict, TypeVar, Generic, Optional, Tuple

# LRU Cache
T = TypeVar('T')


class Element:
    def __init__(self, key: str) -> None:
        self.key = key
        self.unixtime = time()

    def __lt__(self, other: Element) -> bool:
        return self.unixtime < other.unixtime

    def __eq__(self, other: Element) -> bool:
        return self.unixtime == other.unixtime

    def __gt__(self, other: Element) -> bool:
        return (not self.unixtime < other.unixtime) and self.unixtime != other.unixtime

    def __repr__(self) -> str:
        return f'({self.key}, {self.unixtime})'


class PriorityQueue(Generic[T], list):
    def __init__(self) -> None:
        self._data: List[Optional[T]] = []
        super().__init__()

    @property
    def is_empty(self) -> bool:
        return not self._data

    def push(self, v: T) -> None:
        heapq.heappush(self._data, v)

    def popq(self) -> Optional[T]:
        if not self.is_empty:
            return heapq.heappop(self._data)
        else:
            return None

    def __repr__(self) -> str:
        return repr(self._data)


class LRUCache:
    def __init__(self, limit: int) -> None:
        self._data: Dict[str, int] = {}
        self.limit = limit
        self._keyqueue: PriorityQueue[Element] = PriorityQueue()

    def put(self, key: str, value: T) -> None:
        if len(self._data) < self.limit:    # there is still room in the cache
            if key not in self._data:
                self._keyqueue.push(Element(key))
            else:
                correct_key = [item for item in self._keyqueue._data if item.key == key][0]
                ind = self._keyqueue._data.index(correct_key)
                self._keyqueue._data[ind].unixtime = time()
            self._data[key] = value
        else:                               # cache is full
            if key not in self._data:
                out_key = self._keyqueue.popq()
                self._data.pop(out_key.key)
                self._keyqueue.push(Element(key))
            else:
                correct_key = [item for item in self._keyqueue._data if item.key == key][0]
                ind = self._keyqueue._data.index(correct_key)
                self._keyqueue._data[ind].unixtime = time()
            self._data[key] = value

    def get(self, key: str) -> Optional[T]:
        if key in self._data:
            correct_key = [item for item in self._keyqueue._data if item.key == key][0]
            ind = self._keyqueue._data.index(correct_key)
            self._keyqueue._data[ind].unixtime = time()
            return self._data[key]
        else:
            raise KeyError('Key not found in cache')

    def __repr__(self) -> str:
        return repr(self._data)

cache = LRUCache(3)
cache.put('owen', 45)
cache.put('john', 32)
cache.put('box', 4556)

cache.get('owen')
cache.get('owen')

cache.put('new', 9)
cache

I use the unixtime attribute of the Element class to decide the priority. I am using the heapq module together with a list to implement the priority queue. Maybe it is not the most efficient way to implement LRU cache in Python but this is what I came up with.

My problem is that after I access the owen key twice with .get() and then issue cache.put('new', 9) - It should remove john because it is the least recently used. Instead it removes owen.
I have checked _keyqueue and owen has the highest unixtime and john has the lowest, and as I understand, the heapq module in Python uses min_heap so the john record should be replaced by the new value. What am I missing here?


Solution

  • I finally discovered what was the problem: Whenever updating the times, we need to call heapq.heapify() on the heap data after the update. I have also written a slightly more efficient implementation, if anyone needs it:

    from typing import List, Optional, TypeVar, Tuple, Dict, Generic
    
    from time import time
    
    import heapq
    
    T = TypeVar('T')
    
    
    class LRUTuple(tuple):
        def __init__(self, key: Tuple[str]) -> None:
            self.key = key
            self.time = time()
    
        def __lt__(self, other) -> bool:
            return self.time < other.time
    
        def __gt__(self, other) -> bool:
            return not self.time < other.time
    
    
    # test class
    a = LRUTuple(('owen',))
    b = LRUTuple(('aheek',))
    assert b > a
    assert a < b
    
    
    class PriorityQueue(Generic[T]):
        def __init__(self) -> None:
            self._data: List[T] = []
    
        @property
        def is_empty(self) -> bool:
            return not self._data
    
        def add(self, v: T) -> None:
            heapq.heappush(self._data, v)
    
        def pop_queue(self) -> Optional[T]:
            if not self.is_empty:
                return heapq.heappop(self._data)
            else:
                print('Empty Queue')
                return None
    
        def _heapify(self) -> None:
            heapq.heapify(self._data)
    
        def peek(self) -> Optional[T]:
            if not self.is_empty:
                return self._data[0]
            else:
                print('Empty Queue')
                return None
    
        def __repr__(self) -> str:
            return repr(self._data)
    
    
    class LRUCache:
        def __init__(self, limit: int) -> None:
            self._data: Dict[str, T] = {}
            self.limit = limit
            self._keyqueue: PriorityQueue[LRUTuple] = PriorityQueue()
    
        def _update_key_time(self, key: str) -> None:
            self._keyqueue._data[self._keyqueue._data.index((key,))].time = time()
            self._keyqueue._heapify()
    
        def put(self, key: str, value: T) -> None:
            if len(self._keyqueue._data) < self.limit:
                if key not in self._data:
                    self._data[key] = value
                    self._keyqueue.add(LRUTuple((key,)))
                else:
                    self._data[key] = value
                    self._update_key_time(key)
            else:
                # remove lru key
                poped_key = self._keyqueue.pop_queue()
                self._data.pop(poped_key[0])
                self.put(key, value)
    
        def get(self, key: str) -> Optional[T]:
            if key in self._data:
                self._update_key_time(key)
                return self._data[key]
            else:
                print('KeyError: key not found')
                return None
    
        def __repr__(self) -> str:
            return repr([(k[0], k.time) for k in self._keyqueue._data])
    
    
    # test LRUCache usage
    lr = LRUCache(3)
    lr.put('owen', 54)
    lr.put('arwen', 4)
    lr.put('jiji', 98)
    lr._keyqueue.peek()
    lr.get('owen')
    lr._keyqueue.peek()
    lr
    lr.put('bone', 7)   # should replace arwen!
    lr
    lr._keyqueue.peek()