Search code examples
rcross-validationrpart

Why do I get different cross validation errors with rpart if I specify parms with default values?


I am puzzled by the following:

set.seed(144)
df = data.frame(outcome=as.factor(sample(c('a','b','c'), 1000, replace=T)), x=rnorm(1000), y=rnorm(1000), z=rnorm(1000))
library(rpart)
fit.default = rpart(outcome ~ x + y + z, data=df, method='class')
fit.specified = rpart(outcome ~ x + y + z, data=df, method='class', parms=list(split='gini', loss=matrix(c(0,1,1,1,0,1,1,1,0), nrow=3,ncol=3,byrow=T)))
fit.default$cptable
fit.specified$cptable

It produces different values in the xerror and xstd columns for the specified vs the default. But according to ?rpart the default split is 'gini' and the default loss matrix is the matrix of 1s (with zero diagonals) which I provided. So why would it behave differently? I noticed this because I was picking a different tree based on the minimum xerror and wanted to verify the baseline default case.


Solution

  • Illustrating my comment above, if you run them completely disentangled:

    set.seed(144)
    df = data.frame(outcome=as.factor(sample(c('a','b','c'), 1000, replace=T)), 
                    x=rnorm(1000), 
                    y=rnorm(1000), 
                    z=rnorm(1000))
    library(rpart)
    fit.default = rpart(outcome ~ x + y + z, 
                        data=df, 
                        method='class')
    fit.default$cptable  
    
    set.seed(144)
    df = data.frame(outcome=as.factor(sample(c('a','b','c'), 1000, replace=T)), 
                    x=rnorm(1000), 
                    y=rnorm(1000), 
                    z=rnorm(1000))
    library(rpart)
    fit.specified = rpart(outcome ~ x + y + z, 
                          data=df, 
                          method='class', 
                          parms=list(split='gini', 
                                    loss=matrix(c(0,1,1,1,0,1,1,1,0), 
                                    nrow=3,
                                    ncol=3,
                                    byrow=T)))
    
    fit.specified$cptable
    

    You get:

    > fit.default$cptable
             CP nsplit rel error    xerror       xstd
    1 0.0375000      0  1.000000 1.0000000 0.02371708
    2 0.0140625      1  0.962500 0.9640625 0.02401939
    3 0.0100000      3  0.934375 0.9921875 0.02378775
    

    and

    > fit.specified$cptable
             CP nsplit rel error    xerror       xstd
    1 0.0375000      0  1.000000 1.0000000 0.02371708
    2 0.0140625      1  0.962500 0.9640625 0.02401939
    3 0.0100000      3  0.934375 0.9921875 0.02378775