I am failing to implement a numba jitted priority queue.
Heavily plagiarized from the python docs, I am fairly happy with this class.
import itertools
import numba as nb
from numba.experimental import jitclass
from typing import List, Tuple, Dict
from heapq import heappush, heappop
class PurePythonPriorityQueue:
def __init__(self):
self.pq = [] # list of entries arranged in a heap
self.entry_finder = {} # mapping of indices to entries
self.REMOVED = -1 # placeholder for a removed item
self.counter = itertools.count() # unique sequence count
def put(self, item: Tuple[int, int], priority: float = 0.0):
"""Add a new item or update the priority of an existing item"""
if item in self.entry_finder:
self.remove_item(item)
count = next(self.counter)
entry = [priority, count, item]
self.entry_finder[item] = entry
heappush(self.pq, entry)
def remove_item(self, item: Tuple[int, int]):
"""Mark an existing item as REMOVED. Raise KeyError if not found."""
entry = self.entry_finder.pop(item)
entry[-1] = self.REMOVED
def pop(self):
"""Remove and return the lowest priority item. Raise KeyError if empty."""
while self.pq:
priority, count, item = heappop(self.pq)
if item is not self.REMOVED:
del self.entry_finder[item]
return item
raise KeyError("pop from an empty priority queue")
Now I would like to call this from a numba jitted function doing heavy numerical work, so I tried to make this a numba jitclass. Since entries are heterogeneous list in the vanilla python implementation, I figured I should implement other jitclasses as well. However, I am getting a Failed in nopython mode pipeline (step: nopython frontend)
(full trace below).
Here is my attempt:
@jitclass
class Item:
i: int
j: int
def __init__(self, i, j):
self.i = i
self.j = j
@jitclass
class Entry:
priority: float
count: int
item: Item
removed: bool
def __init__(self, p: float, c: int, i: Item):
self.priority = p
self.count = c
self.item = i
self.removed = False
@jitclass
class PriorityQueue:
pq: List[Entry]
entry_finder: Dict[Item, Entry]
counter: int
def __init__(self):
self.pq = nb.typed.List.empty_list(Entry(0.0, 0, Item(0, 0)))
self.entry_finder = nb.typed.Dict.empty(Item(0, 0), Entry(0, 0, Item(0, 0)))
self.counter = 0
def put(self, item: Item, priority: float = 0.0):
"""Add a new item or update the priority of an existing item"""
if item in self.entry_finder:
self.remove_item(item)
self.counter += 1
entry = Entry(priority, self.counter, item)
self.entry_finder[item] = entry
heappush(self.pq, entry)
def remove_item(self, item: Item):
"""Mark an existing item as REMOVED. Raise KeyError if not found."""
entry = self.entry_finder.pop(item)
entry.removed = True
def pop(self):
"""Remove and return the lowest priority item. Raise KeyError if empty."""
while self.pq:
priority, count, item = heappop(self.pq)
entry = heappop(self.pq)
if not entry.removed:
del self.entry_finder[entry.item]
return item
raise KeyError("pop from an empty priority queue")
if __name__ == "__main__":
queue1 = PurePythonPriorityQueue()
queue1.put((4, 5), 5.4)
queue1.put((5, 6), 1.0)
print(queue1.pop()) # Yay this works!
queue2 = PriorityQueue() # Nope
queue2.put(Item(4, 5), 5.4)
queue2.put(Item(5, 6), 1.0)
print(queue2.pop())
Is this type of data structure implementable with numba? What is wrong with my current implementation?
Full trace:
(5, 6)
Traceback (most recent call last):
File "/home/nicoco/src/work/work-research/scripts/thickness/priorityqueue.py", line 106, in <module>
queue2 = PriorityQueue() # Nope
File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/experimental/jitclass/base.py", line 122, in __call__
return cls._ctor(*bind.args[1:], **bind.kwargs)
File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/dispatcher.py", line 420, in _compile_for_args
error_rewrite(e, 'typing')
File "/home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/dispatcher.py", line 361, in error_rewrite
raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
No implementation of function Function(<function typeddict_empty at 0x7fead8c3f8b0>) found for signature:
>>> typeddict_empty(typeref[<class 'numba.core.types.containers.DictType'>], instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'typeddict_empty': File: numba/typed/typeddict.py: Line 213.
With argument(s): '(typeref[<class 'numba.core.types.containers.DictType'>], instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function new_dict at 0x7fead9002a60>) found for signature:
>>> new_dict(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'impl_new_dict': File: numba/typed/dictobject.py: Line 639.
With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>)':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: nopython mode backend)
No implementation of function Function(<built-in function eq>) found for signature:
>>> eq(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)
There are 30 candidate implementations:
- Of which 28 did not match due to:
Overload of function 'eq': File: <numerous>: Line N/A.
With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)':
No match.
- Of which 2 did not match due to:
Operator Overload in function 'eq': File: unknown: Line unknown.
With argument(s): '(instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>, instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>)':
No match for registered cases:
* (bool, bool) -> bool
* (int8, int8) -> bool
* (int16, int16) -> bool
* (int32, int32) -> bool
* (int64, int64) -> bool
* (uint8, uint8) -> bool
* (uint16, uint16) -> bool
* (uint32, uint32) -> bool
* (uint64, uint64) -> bool
* (float32, float32) -> bool
* (float64, float64) -> bool
* (complex64, complex64) -> bool
* (complex128, complex128) -> bool
During: lowering "$20call_function.8 = call $12load_global.4(dp, $16load_deref.6, $18load_deref.7, func=$12load_global.4, args=[Var(dp, dictobject.py:653), Var($16load_deref.6, dictobject.py:654), Var($18load_deref.7, dictobject.py:654)], kws=(), vararg=None)" at /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/dictobject.py (654)
raised from /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/types/functions.py:229
During: resolving callee type: Function(<function new_dict at 0x7fead9002a60>)
During: typing of call at /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/typeddict.py (219)
File "../../../../../.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/typed/typeddict.py", line 219:
def impl(cls, key_type, value_type):
return dictobject.new_dict(key_type, value_type)
^
raised from /home/nicoco/.cache/pypoetry/virtualenvs/work-research-r4deHn84-py3.8/lib/python3.8/site-packages/numba/core/typeinfer.py:1071
- Resolution failure for non-literal arguments:
None
During: resolving callee type: BoundFunction((<class 'numba.core.types.abstract.TypeRef'>, 'empty') for typeref[<class 'numba.core.types.containers.DictType'>])
During: typing of call at /home/nicoco/src/work/work-research/scripts/thickness/priorityqueue.py (72)
File "priorityqueue.py", line 72:
def __init__(self):
<source elided>
self.pq = nb.typed.List.empty_list(Entry(0.0, 0, Item(0, 0)))
self.entry_finder = nb.typed.Dict.empty(Item(0, 0), Entry(0, 0, Item(0, 0)))
^
During: resolving callee type: jitclass.PriorityQueue#7fead8ba2b20<pq:ListType[instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>],entry_finder:DictType[instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>]<iv=None>,counter:int64>
During: typing of call at <string> (3)
During: resolving callee type: jitclass.PriorityQueue#7fead8ba2b20<pq:ListType[instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>],entry_finder:DictType[instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,instance.jitclass.Entry#7feb3119d3d0<priority:float64,count:int64,item:instance.jitclass.Item#7fead907c1f0<i:int64,j:int64>,removed:bool>]<iv=None>,counter:int64>
During: typing of call at <string> (3)
File "<string>", line 3:
<source missing, REPL/exec in use?>
Process finished with exit code 1
This was not possible due to several issues in numba, but should be fixed for the next release (0.55) if I understood correctly. As a workaround for now, I could get it working by compiling llvmlite 0.38.0dev0 and the master branch of numba. I do not use conda but it is apparently easier to get pre-releases of llvmlite and numba this way.
Here is my implementation:
from heapq import heappush, heappop
from typing import List, Tuple, Dict, Any
import numba as nb
import numpy as np
from numba.experimental import jitclass
class UpdatablePriorityQueueEntry:
def __init__(self, p: float, i: Any):
self.priority = p
self.item = i
def __lt__(self, other: "UpdatablePriorityQueueEntry"):
return self.priority < other.priority
class UpdatablePriorityQueue:
def __init__(self):
self.pq = []
self.entries_priority = {}
def put(self, item: Any, priority: float = 0.0):
entry = UpdatablePriorityQueueEntry(priority, item)
self.entries_priority[item] = priority
heappush(self.pq, entry)
def pop(self) -> Any:
while self.pq:
entry = heappop(self.pq)
if entry.priority == self.entries_priority[entry.item]:
self.entries_priority[entry.item] = np.inf
return entry.item
raise KeyError("pop from an empty priority queue")
def clear(self):
self.pq.clear()
self.entries_priority.clear()
@jitclass
class PriorityQueueEntry(UpdatablePriorityQueueEntry):
priority: float
item: Tuple[int, int]
def __init__(self, p: float, i: Tuple[int, int]):
self.priority = p
self.item = i
@jitclass
class UpdatablePriorityQueue(UpdatablePriorityQueue):
pq: List[PriorityQueueEntry2d]
entries_priority: Dict[Tuple[int, int], float]
def __init__(self):
self.pq = nb.typed.List.empty_list(PriorityQueueEntry2d(0.0, (0, 0)))
self.entries_priority = nb.typed.Dict.empty((0, 0), 0.0)
def put(self, item: Tuple[int, int], priority: float = 0.0):
entry = PriorityQueueEntry2d(priority, item)
self.entries_priority[item] = priority
heappush(self.pq, entry)