Search code examples
matlabrandom-forestcross-validationdecision-treek-fold

Matlab's TreeBagger and k-fold cross validation


I am trying to get the 5-fold cross validation error of a model created with TreeBagger using the function crossval but I keep getting an error

Error using crossval>evalFun The function 'regrTree' generated the following error: Too many input arguments.

My code is below. Can anyone point me in the right direction? Thanks

%Random Forest
%%XX is training data matrix, Y is training labels vector
XX=X_Tbl(:,2:end);
Forest_Mdl = TreeBagger(1000,XX,Y,'Method','regression');

err_std = crossval('mse',XX,Y,'Predfun',@regrTree, 'kFold',5);


function yfit_std = regrTree(Forest_Mdl,XX) 
yfit_std = predict(Forest_Mdl,XX);
end

Solution

  • Reading the documentation helps a lot!:

    The function has to be defined as:

    (note that it takes 3 arguments, not 2)

    function yfit = myfunction(Xtrain,ytrain,Xtest)
    % Calculate predicted response
    ...
    end
    

    Xtrain — Subset of the observations in X used as training predictor data. The function uses Xtrain and ytrain to construct a classification or regression model.

    ytrain — Subset of the responses in y used as training response data. The rows of ytrain correspond to the same observations in the rows of Xtrain. The function uses Xtrain and ytrain to construct a classification or regression model.

    Xtest — Subset of the observations in X used as test predictor data. The function uses Xtest and the model trained on Xtrain and ytrain to compute the predicted values yfit.

    yfit — Set of predicted values for observations in Xtest. The yfit values form a column vector with the same number of rows as Xtest.