Search code examples
pythontorchheapq

Error inserting torch tensors into a heapq using duplicated priorities


How to avoid RuntimeError: bool value of Tensor with more than one value is ambiguous in this code?

import torch
import heapq

h = []
heapq.heappush(h, (1, torch.Tensor([[1,2]])))
heapq.heappush(h, (1, torch.Tensor([[3,4]])))

It happens because the comparison between tuples compares the second elements when the first ones are equal


Solution

  • It is necessary to prevent the heapq from trying to compare the second elements of the tuple when it finds duplicate priorities and only needs to redefine the < operator for my elements.

    import torch
    import heapq
    
    class HeapItem:
        def __init__(self, p, t):
            self.p = p
            self.t = t
    
        def __lt__(self, other):
            return self.p < other.p
    
    h = []
    heapq.heappush(h, HeapItem(1, torch.Tensor([[1,2]])))
    heapq.heappush(h, HeapItem(1, torch.Tensor([[3,4]])))