I have coded a Decision Tree classifier in Matlab. To the best of my knowledge everything should work, the logic checks out. When I try to call the fit method it breaks on one of my functions telling me I haven't got the right input arguments but I'm sure I do! Been trying to solve this and similar errors to do with functions and input arguments for a day or two now. I wondered if it had something to do from calling them from within the constructor but calling them from the main script still doesn't work. Pls help!
classdef my_ClassificationTree < handle
properties
X % training examples
Y % training labels
MinParentSize % minimum parent node size
MaxNumSplits % maximum number of splits
Verbose % are we printing out debug as we go?
% MinLeafSize
CutPoint
CutPredictorIndex
Children
numSplits
root
end
methods
% constructor: implementing the fitting phase
function obj = my_ClassificationTree(X, Y, MinParentSize, MaxNumSplits, Verbose)
obj.X = X;
obj.Y = Y;
obj.MinParentSize = MinParentSize;
obj.MaxNumSplits = MaxNumSplits;
obj.Verbose = Verbose;
% obj.Children = zeros(1, 2);
% obj.CutPoint = 0;
% obj.CutPredictorIndex = 0;
% obj.MinLeafSize = MinLeafSize;
obj.numSplits = 0;
obj.root = Node(1, size(obj.X,1));
root = Node(1, size(obj.X,1));
fit(obj,root);
end
function node = Node(sIndex,eIndex)
node.startIndex = sIndex;
node.endIndex = eIndex;
node.leaf = false;
node.Children = 0;
node.size = eIndex - sIndex + 1;
node.CutPoint = 0;
node.CutPredictorIndex = 0;
node.NodeClass = 0;
end
function fit(obj,node)
if node.size < obj.MinParentSize || obj.numSplits >= obj.MaxNumSplits
% Mark the node as a leaf node
node.Leaf = true;
% Calculate the majority class label for the examples at this node
labels = obj.Y(node.startIndex:node.endIndex); %gather all the labels for the data in the nodes range
node.NodeClass = mode(labels); %find the most frequent label and classify the node as such
return;
end
bestCutPoint = findBestCutPoint(node, obj.X, obj.Y);
leftChild = Node(node.startIndex, bestCutPoint.CutIndex - 1);
rightChild = Node(bestSplit.splitIndex, node.endIndex);
obj.numSplits = obj.numSplits + 1;
node.CutPoint = bestSplit.CutPoint;
node.CutPredictorIndex = bestSplit.CutPredictorIndex;
%Attach the child nodes to the parent node
node.Children = [leftChild, rightChild];
% Recursively build the tree for the left and right child nodes
fit(obj, leftChild);
fit(obj, rightChild);
end
function bestCutPoint = findBestCutPoint(node, X, labels)
bestCutPoint.CutPoint = 0;
bestCutPoint.CutPredictorIndex = 0;
bestCutPoint.CutIndex = 0;
bestGDI = Inf; % Initialize the best GDI to a large value
% Loop through all the features
for i = 1:size(X, 2)
% Loop through all the unique values of the feature
values = unique(X(node.startIndex:node.endIndex, i));
for j = 1:length(values)
% Calculate the weighted impurity of the two resulting
% cut
leftLabels = labels(node.startIndex:node.endIndex, 1);
rightLabels = labels(node.startIndex:node.endIndex, 1);
leftLabels = leftLabels(X(node.startIndex:node.endIndex, i) < values(j));
rightLabels = rightLabels(X(node.startIndex:node.endIndex, i) >= values(j));
leftGDI = weightedGDI(leftLabels, labels);
rightGDI = weightedGDI(rightLabels, labels);
% Calculate the weighted impurity of the split
cutGDI = leftGDI + rightGDI;
% Update the best split if the current split has a lower GDI
if cutGDI < bestGDI
bestGDI = cutGDI;
bestCutPoint.CutPoint = values(j);
bestCutPoint.CutPredictorIndex = i;
bestCutPoint.CutIndex = find(X(:, i) == values(j), 1, 'first');
end
end
end
end
% the prediction phase:
function predictions = predict(obj, test_examples)
% get ready to store our predicted class labels:
predictions = categorical;
% Iterate over each example in X
for i = 1:size(test_examples, 1)
% Set the current node to be the root node
currentNode = obj.root;
% While the current node is not a leaf node
while ~currentNode.leaf
% Check the value of the predictor feature specified by the CutPredictorIndex property of the current node
value = test_examples(i, currentNode.CutPredictorIndex);
% If the value is less than the CutPoint of the current node, set the current node to be the left child of the current node
if value < currentNode.CutPoint
currentNode = currentNode.Children(1);
% If the value is greater than or equal to the CutPoint of the current node, set the current node to be the right child of the current node
else
currentNode = currentNode.Children(2);
end
end
% Once the current node is a leaf node, add the NodeClass of the current node to the predictions vector
predictions(i) = currentNode.NodeClass;
end
end
% add any other methods you want on the lines below...
end
end
This is the function that calls myClassificationTree
function m = my_fitctree(train_examples, train_labels, varargin)
% take an extra name-value pair allowing us to turn debug on:
p = inputParser;
addParameter(p, 'Verbose', false);
%addParameter(p, 'MinLeafSize', false);
% take an extra name-value pair allowing us to set the minimum
% parent size (10 by default):
addParameter(p, 'MinParentSize', 10);
% take an extra name-value pair allowing us to set the maximum
% number of splits (number of training examples-1 by default):
addParameter(p, 'MaxNumSplits', size(train_examples,1) - 1);
p.parse(varargin{:});
% use the supplied parameters to create a new my_ClassificationTree
% object:
m = my_ClassificationTree(train_examples, train_labels, ...
p.Results.MinParentSize, p.Results.MaxNumSplits, p.Results.Verbose);
end
that is my code from the main block of code
mym2_dt = my_fitctree(train_examples, train_labels, 'MinParentSize', 10)
These are the errors these are the errors
I'm expecting it to build a decision tree and fill it. However it breaks on the findBestCutPoint function and I cannot fix it
The first argument of class methods (except the constructor) should be an instance of the class (i.e obj
). Your definition of Node
and findBestCutPoint
should have obj
as the first argument.
Moreover, calls to class methods from within other methods should have the syntax obj.theMethod
which seems not to be the case in your code.
So, for instance, the call to Node
should be:
obj.root = obj.Node(1, size(obj.X,1));
and Node
should be defined as follows:
function node = Node(obj,sIndex,eIndex)
Same applies to findBestCutPoint
. Note that, in the calls, the reference to the class instance is passed implicitly, so you don't need to actually include it in the call.