Search code examples
lua

How to pass dynamic statement into function?


This what I try:

local data = {
    {cond1 = 'yes', cond2 = 'yes', nums = 7, results = 'no', average = 9.5},
    {cond1 = 'yes', cond2 = 'no', nums = 12, results = 'no', average = 15},
    {cond1 = 'no', cond2 = 'yes', nums = 18, results = 'yes', average = 26.5},
    {cond1 = 'no', cond2 = 'yes', nums = 35, results = 'yes', average = 36.5},
    {cond1 = 'yes', cond2 = 'yes', nums = 38, results = 'yes', average = 44},
    {cond1 = 'yes', cond2 = 'no', nums = 50, results = 'no', average = 66},
    {cond1 = 'no', cond2 = 'no', nums = 83, results = 'no'},
}

function split_data(data, root_node, branch_node, root_condition_function, branch_condition_function)
    local true_node = {true_value = 0, false_value = 0}
    local false_node = {true_value = 0, false_value = 0}

    for _, entry in ipairs(data) do
        if root_condition_function(entry, root_node) then
            if branch_condition_function(entry, branch_node) then
                true_node.true_value = true_node.true_value + 1
            else
                true_node.false_value = true_node.false_value + 1
            end
        else
            if branch_condition_function(entry, branch_node) then
                false_node.true_value = false_node.true_value + 1
            else
                false_node.false_value = false_node.false_value + 1
            end
        end
    end

    return true_node, false_node
end

-- Statement for root node/string node
local function root_string_condition(entry, root_node)
    return entry[root_node] == 'yes'
end

-- Statement for numeric root node
local function root_numeric_condition(entry, root_node, branch_node)
    local nums = entry[root_node]
    local average = entry[branch_node]

    -- Check for nil values
    if nums == nil or average == nil then
        return false
    end

    return tonumber(nums) < tonumber(average), tonumber(nums) >= tonumber(average)
end

-- Statement for branch node
local function branch_condition(entry, branch_node)
    return entry[branch_node] == 'yes'
end

-- Example usage for string condition
local trueNodeStr, falseNodeStr = split_data(data, 'cond1', 'results', root_string_condition, branch_condition)

-- Print the results for string condition
print("String Condition Results:")
print(string.format("Cond1 True  - True Value: %d, False Value: %d", trueNodeStr.true_value, trueNodeStr.false_value))
print(string.format("Cond1 False - True Value: %d, False Value: %d", falseNodeStr.true_value, falseNodeStr.false_value))

-- Example usage for numeric condition
local trueNodeNum, falseNodeNum = split_data(data, 'nums', 'results', root_numeric_condition, branch_condition)

-- Print the results for numeric condition
print("\nNumeric Condition Results:")
print(string.format("Nums True  - True Value: %d, False Value: %d", trueNodeNum.true_value, trueNodeNum.false_value))
print(string.format("Nums False - True Value: %d, False Value: %d", falseNodeNum.true_value, falseNodeNum.false_value))

The results:

String Condition Results:
Cond1 True  - True Value: 1, False Value: 3
Cond1 False - True Value: 2, False Value: 1

Numeric Condition Results:
Nums True  - True Value: 0, False Value: 0
Nums False - True Value: 3, False Value: 4

Why for Numeric results give incorrect result?

Expected results:

Numeric Condition Results:
Nums True  - True Value: 0, False Value: 1
Nums False - True Value: 3, False Value: 3

Solution

  • Given this call:

    local trueNodeNum, falseNodeNum = split_data(data, 'nums', 'results', root_numeric_condition, branch_condition)
    

    split_data only passes two arguments to root_numeric_condition (root_condition_function)

    if root_condition_function(entry, root_node) then
    

    where three are expected:

    local function root_numeric_condition(entry, root_node, branch_node)
    

    So in root_numeric_condition

    local average = entry[branch_node]
    

    will always be nil, and the function will always return false.

    Given the variable name, it looks like you are actually trying to access the average field of the entry table, but

    return tonumber(nums) < tonumber(average), tonumber(nums) >= tonumber(average)
    

    the multiple return values of root_numeric_condition and the use of tonumber on what one would expect to already be numbers makes it rather unclear what your intentions are.


    As a start, split_data can be refactored into:

    local function split_data(data, root, branch, root_condition, branch_condition)
        local results = {
            { true_value = 0, false_value = 0 },
            { true_value = 0, false_value = 0 }
        }
    
        for _, entry in ipairs(data) do
            local b1 = root_condition(entry, root) and 1 or 2
            local b2 = branch_condition(entry, branch) and 'true_value' or 'false_value'
    
            results[b1][b2] = 1 + results[b1][b2]
        end
    
        return results[1], results[2] -- or table.unpack(results)
    end
    

    And purely based on your expectations, one would assume something like

    local function root_numeric_condition(entry, root)
        return not entry.average or entry[root] > entry.average 
    end