Search code examples
pythontime-complexitymergesort

Time complexity of merge sort: function appears to be called 2*n-1 times rather than O(log n) times


I'm teaching a coding class and need an intuitive and obvious way to explain the time complexity of merge sort. I tried including a print statement at the start of my merge_sort() function, anticipating that the print statement would execute O(log n) times. However, as best as I can tell, it executes 2*n-1 times instead (Python code below):

merge_sort() function:

def merge_sort(my_list):
    print("hi") #prints 2*n-1 times??
    if(len(my_list) <= 1):
        return
    mid = len(my_list)//2
    l = my_list[:mid]
    r = my_list[mid:]
    merge_sort(l)
    merge_sort(r)
    i = 0
    j = 0
    k = 0
    while(i < len(l) or j < len(r)):
        #print("hey") #prints nlogn times as expected
        if(i >= len(l)):
            my_list[k] = r[j]
            j += 1
        elif(j >= len(r)):
            my_list[k] = l[i]
            i += 1
        elif(l[i] < r[j]):
            my_list[k] = l[i]
            i += 1
        elif(l[i] > r[j]):
            my_list[k] = r[j]
            j += 1
        k += 1

Driver code:

#print("Enter a list")
my_list = list(map(int, input().split()))
#print("Sorted list:")
#merge_sort(my_list)
print(my_list)

Input:

1 2 3 4 5 6 7 8

Expected output:

hi
hi
hi

or some variation thereof which varies proportional to log n.

Actual output:

hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi #15 times, i.e. 2*n-1

A few more iterations of this with different input sizes have given me the impression that this is 2*n-1, which makes no sense to me. Does anyone have an explanation for this?


Solution

  • It is not true that there are only O(logn) recursive calls. The thing that is O(logn) is the depth of the recursion tree, not the number of nodes in the recursion tree.

    When we look at one level of the recursion tree, then we can note that each call in that level deals with a distinct partition of the array. Together, the "nodes" in that recursion level, deal with all elements of the array, which gives that level a O(n) time complexity. This is true for each level.

    As there are O(logn) levels, the total complexity comes down to O(nlogn).

    Here is a suggestion on how to illustrate this:

    statistics = []
    
    def merge_sort(my_list, depth=0):
        if len(my_list) <= 1:
            return
        # manage statistics
        if depth >= len(statistics):
            statistics.append(0)  # for each depth we count operations
        mid = len(my_list)//2
        l = my_list[:mid]
        r = my_list[mid:]
        merge_sort(l, depth+1)
        merge_sort(r, depth+1)
        i = 0
        j = 0
        k = 0
        while i < len(l) or j < len(r):
            statistics[depth] += 1  # count this as a O(1) unit of work
            if i >= len(l):
                my_list[k] = r[j]
                j += 1
            elif j >= len(r):
                my_list[k] = l[i]
                i += 1
            elif l[i] < r[j]:
                my_list[k] = l[i]
                i += 1
            elif l[i] > r[j]:
                my_list[k] = r[j]
                j += 1
            k += 1
    
    import random
    
    my_list = list(range(32))
    random.shuffle(my_list)
    merge_sort(my_list)
    print(my_list)
    print(statistics)
    

    The statistics will output the number of units of work done at each level. In the example of an input size of 32, you'll get a list with 5 such numbers.

    NB: In Python, if conditions don't need parentheses