Search code examples
rconfusion-matrix

Confusion matrix for training and validation sets


What is the purpose of the "newdata" argument? Why don't we need to specify newdata = tlFLAG.t in the first case?

pred <- predict(tree1, type = "class")
confusionMatrix(pred, factor(tlFLAG.t$TERM_FLAG)) 

pred.v <- predict(tree1, type = "class", newdata = tlFLAG.v)
confusionMatrix(pred.v, factor(tlFLAG.v$TERM_FLAG)) 

Solution

  • In every machine learning process (in this case a classification problem), you have to split your data in a train and a test set.

    This is useful because you can train your algorithm in the first set, and test it on the second.

    This has to be done, otherwise (if you use all the data) you're exposing yourself to overfitting, because almost every algorithm will try to fit best the data you feed.

    You'll end up with even a perfect model for your data, but that will predict very poorly on new data, that it has not yet seen.

    The predict function, because of this, lets you pick new data to "test" the goodness of your model on unseen data by the newdata= arg.

    In your first case so, you "test" you performance on the already trained data by not specifying the newdata= arg, so the confusionMatrix could be over-ottimistic.

    In the second case you should specify newdata=test_set, and with this your prediction will be based on test data, so the performance will be more accurate, and even more interesting on this second case.

    I'll build here an example for you to see a classic approach:

    data <- iris # iris dataset
    
    # first split the data
    set.seed(123) # for reproducibility
    pos <- sample(100)
    
    train <- data[pos, ] # random pick of 100 obs
    test <- data[-pos, ] # remaining 50
    
    # now you can start with your model - please not that this is a dummy example
    library(rpart)
    
    tree <- rpart(Species ~ ., data=train) # fit tree on train data
    
    # make prediction on train data (no need to specify newclass= ) # NOT very useful
    pred <- predict(tree, type = "class")
    caret::confusionMatrix(pred, train$Species)
    
    # make prediction on test data (remove the response)
    pred <- predict(tree, type = "class", newdata = test[, -5]) # I removed Species (5th column in test)
    # build confusion from predictions against the truth (ie the test$Species)
    caret::confusionMatrix(pred, test$Species) 
    

    Note how the performance is awful on the test data, while it was almost perfect on train data.