Search code examples
pythonheapmin-heap

Min-heap insert function doesn't work properly


Can anyone help me with my insert function, please? I don't know why it doesn't insert the data correctly.

In the test file, the expected array is: [None, -12, -11, -6, -9, -3, -5, -2, -1, -4]

But the function returns: [None, -12, -9, -6, -11, -3, -5, -2, -1, -4]

from functools import total_ordering
import math
class MinHeap:
    def __init__(self):
        self.arr_heap = [None]

    def __str__(self):
        return str(self.arr_heap[1:])

    def __repr__(self):
        return str(self)
        
    def get_left_pos(self, i:int) ->int:
        return 2*i

    def get_right_pos(self, i:int) ->int:
        return 2*i+1

    def get_parent_pos(self, i) ->int:
        return math.floor(i/2)

    def swap(self, pos_1, pos_2):
        aux = self.arr_heap[pos_1]
        self.arr_heap[pos_1] = self.arr_heap[pos_2]
        self.arr_heap[pos_2] = aux

    def is_a_leaf(self, posicao):
        return posicao >= len(self.arr_heap)//2 and posicao <= len(self.arr_heap)

    def heapify(self, pos_raiz_sub_arvore:int):
        parent_pos = pos_raiz_sub_arvore
        left_child_pos = self.get_left_pos(parent_pos)
        right_child_pos = self.get_right_pos(parent_pos)
        if not self.is_a_leaf(parent_pos):
            if (self.arr_heap[parent_pos] > self.arr_heap[left_child_pos] or self.arr_heap[parent_pos] > self.arr_heap[right_child_pos]):
                if self.arr_heap[left_child_pos] < self.arr_heap[right_child_pos]:
                    self.swap(parent_pos, left_child_pos)
                    self.heapify(left_child_pos)
                else:
                    self.swap(parent_pos, right_child_pos)
                    self.heapify(right_child_pos)
    
    def insert(self, element):
        self.arr_heap.append(element)
        current = self.arr_heap.index(element)
        while self.arr_heap[current] < self.arr_heap[self.get_parent_pos(current)-1]:
            self.swap(current, self.get_parent_pos(current))
            current = self.get_parent_pos(current)

    def remove(self):
        element = self.arr_heap[1]
        element_pos = len(self.arr_heap)-1
        self.arr_heap[1] = self.arr_heap[element_pos]
        self.arr_heap.pop(element_pos)
        self.heapify(1)
        return element

    def __str__(self):
        return str(self.arr_heap)

    def __repr__(self):
        return str(self)

The test file is:

import unittest
from typing import List, Dict
from heap import MinHeap
class TestHeap(unittest.TestCase):
    def test_heapify(self):
        obj_heap = MinHeap()

        obj_heap.arr_heap = [None,-12,-9,-6]
        obj_heap.heapify(1)
        self.assertListEqual(obj_heap.arr_heap, [None,-12,-9,-6], f"The heapify operation was not performed correctly. Input list: {[None,-12,-9,-6]}")

        obj_heap.arr_heap = [None,-12,-15,-6]
        obj_heap.heapify(1)
        self.assertListEqual(obj_heap.arr_heap, [None,-15,-12,-6], f"The heapify operation was not performed correctly. Input list: {[None,-12,-15,-6]}")

        obj_heap.arr_heap = [None,-12,-9,-15]
        obj_heap.heapify(1)
        self.assertListEqual(obj_heap.arr_heap, [None,-15,-9,-12], f"The heapify operation was not performed correctly. Input list: {[None,-12,-9,-15]}")

        obj_heap.arr_heap = [None,-12,-2,-6,-4,-5,3,0,-1,1, -3, 2]
        obj_heap.heapify(2)
        self.assertListEqual(obj_heap.arr_heap, [None,-12,-5,-6,-4,-3,3,0,-1,1, -2, 2], f"The heapify operation was not performed correctly. Expected outcome: {[None,-12,-5,-6,-4,-3,3,0,-1,1, -2, 2]} result obtained: {obj_heap.arr_heap}. Input list: {[None,-12,-2,-6,-4,-5,3,0,-1,1, -3, 2]}")

        obj_heap.arr_heap = [None,-12,-2,-4,-6,-5,3,0,1,-3, -3, 2]
        obj_heap.heapify(2)
        self.assertListEqual(obj_heap.arr_heap, [None,-12,-6,-4,-3,-5,3,0,1,-2, -3, 2], f"The heapify operation was not performed correctly. Expected outcome: {[None,-12,-6,-4,-3,-5,3,0,1,-2, -3, 2]} result obtained: {obj_heap.arr_heap}. Input list: {[None,-12,-2,-4,-6,-5,3,0,1,-3, -3, 2]}")

    def test_insert(self):
        arr_test = [1,-8,-11,-14]
        arr_heap_expected = [[None,-12,-9,-6,-4,-3,-5,-2,-1,1],
                             [None,-12,-9,-6,-8,-3,-5,-2,-1,-4],
                             [None,-12,-11,-6,-9,-3,-5,-2,-1,-4],
                             [None,-14,-12,-6,-9,-3,-5,-2,-1,-4],
                            ]
    
        for val_inserir in arr_test:
            objHeap = MinHeap()
            objHeap.insert(val_inserir)
            self.assertListEqual([None,val_inserir],objHeap.arr_heap,f"Incorrect insertion when inserting the value {val_inserir} in the heap {[None,-12,-9,-6,-4,-3,-5,-2,-1]}, expected: {[None,val_inserir]} result: {objHeap.arr_heap}")

        for i,val_inserir in enumerate(arr_test):
            objHeap = MinHeap()
            objHeap.arr_heap = [None,-12,-9,-6,-4,-3,-5,-2,-1]
            objHeap.insert(val_inserir)
            self.assertListEqual(arr_heap_expected[i],objHeap.arr_heap,f"Incorrect insertion when inserting the value {val_inserir} in the heap {[None,-12,-9,-6,-4,-3,-5,-2,-1]}, expected: {arr_heap_expected[i]} result: {objHeap.arr_heap}")


    def test_remove(self):
        obj_heap = MinHeap()
        obj_heap.arr_heap = [None,-12,-9,-4,-7,-5,3,0,1,-2, -3, 2]

        min_val = obj_heap.remove()
        self.assertEqual(min_val, -12, f"Incorrect insertion when inserting the value (-12) but {min_val} ")
        self.assertListEqual(obj_heap.arr_heap, [None,-9,-7,-4,-2,-5,3,0,1,2, -3], f"The test_remove operation did not end with the expected heap.")

        obj_heap.arr_heap = [None,-12]
        min_val = obj_heap.remove()
        self.assertEqual(min_val, -12, f"Incorrect insertion when inserting the value (-12) but {min_val} ")
        self.assertListEqual(obj_heap.arr_heap, [None], f"The test_remove operation did not end with the expected heap.")

if __name__ == "__main__":
    unittest.main()

All the other tests work just fine.


Solution

  • I learned about MinHeaps so I could work out the problem(s) with your code (yes...I have no life ;) ). There are two problems, and they're in the same line. That line is this one in the insert method:

    while self.arr_heap[current] < self.arr_heap[self.get_parent_pos(current)-1]:
    

    The first problem is the -1. You want to compare the value of the parent element against the value of the current element. The parent element is at self.get_parent_pos(current). I don't know why you think you want to look at the element just before that. You dont. So with this line, your third insert test succeeds, but your fourth test fails:

    while self.arr_heap[current] < self.arr_heap[self.get_parent_pos(current)]:
    

    The reason for the fourth test failing is that the new element reaches the top of the tree and your algorithm keeps going. Since the element now has no parent, the code crashes trying to get the parent's value. The answer is to recognize when the new element has reached the top of the tree, and stop. To do that, you modify the line to this:

    while current > 0 and self.arr_heap[current] < self.arr_heap[self.get_parent_pos(current)]:
    

    With that version of the line, all of your insert tests pass.

    One little tid-bit for you. Your method:

    def get_parent_pos(self, i) ->int:
        return math.floor(i/2)
    

    can be written more cleanly as:

    def get_parent_pos(self, i) ->int:
        return i // 2
    

    / produces a float result, whereas '//' produces an integer result by rounding down. Keep this in mind, as you often want to do an integer divide, and '//' is a cleaner and easier way to go about it.