I am currently making a Decision tree classifier using Gini and Information Gain and splitting the tree based on the the best attribute with the most gain each time. However, it is sticking the same attribute every time and simply adjusting the value for its question. This results in a very low accuracy of usually around 30% as it is only taking into account the very first attribute.
# Used to find the best split for data among all attributes
def split(r):
max_ig = 0
max_att = 0
max_att_val = 0
i = 0
curr_gini = gini_index(r)
n_att = len(att)
for c in range(n_att):
if c == 3:
continue
c_vals = get_column(r, c)
while i < len(c_vals):
# Value of the current attribute that is being tested
curr_att_val = r[i][c]
true, false = fork(r, c, curr_att_val)
ig = gain(true, false, curr_gini)
if ig > max_ig:
max_ig = ig
max_att = c
max_att_val = r[i][c]
i += 1
return max_ig, max_att, max_att_val
# Used to compare and test if the current row is greater than or equal to the test value
# in order to split up the data
def compare(r, test_c, test_val):
if r[test_c].isdigit():
return r[test_c] == test_val
elif float(r[test_c]) >= float(test_val):
return True
else:
return False
# Splits the data into two lists for the true/false results of the compare test
def fork(r, c, test_val):
true = []
false = []
for row in r:
if compare(row, c, test_val):
true.append(row)
else:
false.append(row)
return true, false
def rec_tree(r):
ig, att, curr_att_val = split(r)
if ig == 0:
return Leaf(r)
true_rows, false_rows = fork(r, att, curr_att_val)
true_branch = rec_tree(true_rows)
false_branch = rec_tree(false_rows)
return Node(att, curr_att_val, true_branch, false_branch)
The working solution i have was to change the split function as follows. To be completly honest i amnt able to see whats wrong but it might be obvious The working function is as follows
def split(r):
max_ig = 0
max_att = 0
max_att_val = 0
# calculates gini for the rows provided
curr_gini = gini_index(r)
no_att = len(r[0])
# Goes through the different attributes
for c in range(no_att):
# Skip the label column (beer style)
if c == 3:
continue
column_vals = get_column(r, c)
i = 0
while i < len(column_vals):
# value we want to check
att_val = r[i][c]
# Use the attribute value to fork the data to true and false streams
true, false = fork(r, c, att_val)
# Calculate the information gain
ig = gain(true, false, curr_gini)
# If this gain is the highest found then mark this as the best choice
if ig > max_ig:
max_ig = ig
max_att = c
max_att_val = r[i][c]
i += 1
return max_ig, max_att, max_att_val