Search code examples
pythonrecursionluamergesortlua-table

recursive merge sort in Lua showing wrong behavior; while almost same python code works well


Sorry if my English is bad, I speak Korean as mother tongue.

I tried to implement recursive merge sort with Lua with this pseudo code below:

void merge(int h, int m, const keytype U[], const keytype V[], keytype S[]) {
    index i, j, k;
    i = 1; j = 1; k = 1;
    while (i <= h && j <= m) {
        if (U[i] < V[j]) {
            S[k] = U[i];
            i++;
        } else {
            S[k] = V[j];
            j++;
        }
        k++;
    }
    if (i > h) {
        copy V[j] through V[m] to S[k] through S[h+m];
    } else {
        copy U[i] through U[h] to S[k] through S[h+m];
    }
}
void mergesort(int n, keytype S[]) {
  if (n > 1) {
    const int h = ⌊n/2⌋, m= n - h;
    keytype U[1..h], v[1..m];
    copy S[1] through S[h] to U[1] through U[h];
    copy S[h+1] through S[n] to V[1] through V[m];
    mergesort(h,U);
    mergesort(m,V);
    merge(h,m,U,V,S);
  }
}

and I wrote with Lua:

function merge(h, m, U, V, S)
    print(S[1], S[h+m])
    i, j, k = 1, 1, 1
    while i<=h and j<=m do
        if U[i] < V[j] then
            S[k] = U[i]
            i = i+1
        else
            S[k] = V[j]
            j = j+1
        end
        k = k+1
    end
    if i>h then
        while j<=m do
            S[k] = V[j]
            j, k = j+1, k+1
        end
    else
        while i<=h do
            S[k] = U[i]
            i, k = i+1, k+1
        end
    end
end

function mergeSort(n, S)
    if n>1 then
        h = math.floor(n/2)
        m = n - h
        local U = {}
        local V = {}
        i = 1
        while i<=h do
            U[i] = S[i]
            i = i + 1
        end
        while i<=n do
            V[i-h] = S[i]
            i = i + 1
        end
        print(h, m)
        mergeSort(h, U)
        mergeSort(m, V)
        merge(h, m, U, V, S)
    end
end

s = {52, 33, 14, 27, 8, 31, 24, 11, 5, 36, 44, 47, 20}
mergeSort(#s, s)
for k, v in ipairs(s) do
    print(v)
end

but it resulted 14, 33, 14, 27, 8, 31, 24, 11, 5, 36, 44, 47, 20, not sorted and element duplicated. however, I tried converting this Lua code into python code:

import math

def merge(h, m, U, V, S):
    print(S[1], S[h+m])
    i, j, k = 1, 1, 1
    while i<=h and j<=m:
        if U[i] < V[j]:
            S[k] = U[i]
            i = i+1
        else:
            S[k] = V[j]
            j = j+1
        k = k+1
    
    if i>h:
        while j<=m:
            S[k] = V[j]
            j, k = j+1, k+1
    else:
        while i<=h:
            S[k] = U[i]
            i, k = i+1, k+1

    
def mergeSort(n, S):
    if n>1:
        h = math.floor(n/2)
        m = n - h
        U = {}
        V = {}
        i = 1
        while i<=h:
            U[i] = S[i]
            i = i + 1
        while i<=n:
            V[i-h] = S[i]
            i = i + 1
        print(h, m)
        mergeSort(h, U)
        mergeSort(m, V)
        merge(h, m, U, V, S)

a = [52, 33, 14, 27, 8, 31, 24, 11, 5, 36, 44, 47, 20]
s = dict(enumerate(a, 1))
mergeSort(len(s), s)
print([s[i] for i in range(1, len(s)+1)])

I think this would be somehow far from pythonic code, but anyway this code sorted sequence well after all, resulting [5, 8, 11, 14, 20, 24, 27, 31, 33, 36, 44, 47, 52]. So I found that this bug is related to Lua table, but debugging by print() gave little information, and I couldn't find any lua reference related to this. What would be problem with my lua code? also, how can I make my python code more pythonic?


Solution

  • The problem stems from you using global variables instead of local variables. Function calls will alter the values of the calling "parent" function that way. If you localize all global variables declared inside the function, it both fixes the bug and improves performance. Consider using a linter to spot such global variable mistakes.

    Fixed code:

    function merge(h, m, U, V, S)
        print(S[1], S[h+m])
        local i, j, k = 1, 1, 1
        while i<=h and j<=m do
            if U[i] < V[j] then
                S[k] = U[i]
                i = i+1
            else
                S[k] = V[j]
                j = j+1
            end
            k = k+1
        end
        if i>h then
            while j<=m do
                S[k] = V[j]
                j, k = j+1, k+1
            end
        else
            while i<=h do
                S[k] = U[i]
                i, k = i+1, k+1
            end
        end
    end
    
    function mergeSort(n, S)
        if n>1 then
            local h = math.floor(n/2)
            local m = n - h
            local U = {}
            local V = {}
            local i = 1
            while i<=h do
                U[i] = S[i]
                i = i + 1
            end
            while i<=n do
                V[i-h] = S[i]
                i = i + 1
            end
            print(h, m)
            mergeSort(h, U)
            mergeSort(m, V)
            merge(h, m, U, V, S)
        end
    end
    
    local s = {52, 33, 14, 27, 8, 31, 24, 11, 5, 36, 44, 47, 20}
    mergeSort(#s, s)
    for k, v in ipairs(s) do
        print(v)
    end