Search code examples
rpredictionglmnet

k-fold cross validation of glmnet group predictions


I have a set of predictors that are partially correlated and I would like to reduce them to a functional set and use the reduced model for prediction. I am able to find a good lambda to use by using the

> require(glmnet); require(glmnetUtils)
> cvfit <- cv.glmnet(
+         SpeakerGroup ~ Age +transient_mean +syllablerate+syllablerate_sd+intensitysfraction_mean + NucleusPercentVoiced_mean +NucleusPercentVoiced_sd +OnsetPercentVoiced_mean + OnsetPercentVoiced_sd +Shim + Jitt +intensityslope+rateslope + APQ3 +APQ5+DDP_A+RAP +PPQ5 +DDP, nfolds = 20
+         ,family="binomial",data=curr.df,type.measure = "class")
> 
> plot(cvfit)

(see image here: https://umu.box.com/s/9rt60v3btfo8qhz870vludv6whxlgfx0 imgur is not working for me).

> cbind(coef.cv.glmnet(cvfit, s = "lambda.1se"),coef.cv.glmnet(cvfit, s = "lambda.min"))
20 x 2 sparse Matrix of class "dgCMatrix"
                                  1           1
(Intercept)               -1.229948 -0.84290372
Age                        .         .         
transient_mean             .         .         
syllablerate               .        -0.31636610
syllablerate_sd            .         .         
intensitysfraction_mean    .         .         
NucleusPercentVoiced_mean  .         .         
NucleusPercentVoiced_sd    .         .         
OnsetPercentVoiced_mean    .         0.01119326
OnsetPercentVoiced_sd      .         .         
Shim                       .         .         
Jitt                       .         8.09912574
intensityslope             .        -1.68472631
rateslope                  .         .         
APQ3                       .         .         
APQ5                       .         .         
DDP_A                      .         .         
RAP                        .         .         
PPQ5                       .         .         
DDP                        .         .         
> 

Ok, this model gets me a set of predictors that I would like to evaluate the accuracy, specificity and so on in terms of predicting group membership (two possible groups).

> require(resamplr) # https://github.com/jrnold/resamplr
> # compute 5 folds that have the same balance between groups as the original data set
> curr.df %>% group_by(SpeakerGroup) %>% crossv_kfold(k=5,stratify=TRUE) -> folds
> 

I may compute the model

> folds <- folds %>% mutate(model = map(train, ~ glmnet(
+         SpeakerGroup ~ Age + transient_mean +syllablerate+syllablerate_sd+intensitysfraction_mean +NucleusPercentVoiced_sd +OnsetPercentVoiced_mean + OnsetPercentVoiced_sd  + Jitt +intensityslope + APQ3 +DDP_A,data=.,family="binomial")))

(you can find the resulting folds object here https://umu.box.com/s/ktxbba4ptzf3hke8g5ze6qgvt0rv42fp)

Now, I want to predict based on each model and the test data set up by the 5-fold procedure.

> 
> predicted <- folds %>% mutate(predicted =map2(model, test, ~ predict(.x, data = .y,type="response",s=cvfit$lambda.min)))

my I gen an error:

Error in mutate_impl(.data, dots) : 
  Evaluation error: argument "data" is missing, with no default.

I am confused by this, as I have provided a data argument.

Any ideas of what could have gone wrong here?

Is there a simpler way to get a standard 2x2 confusion matrix of a 5-fold cross validation of a GLMNET model??

Thanks!

Fredrik


Solution

  • As pointed out in the comment above, the cv.glmnet procedure is a cross validation already. The issue is that I saw no way to extract the fitted values from the model.

    For future reference, this works

    cvfit <- cv.glmnet(
            <description of model>... ,type.measure = "auc",keep=TRUE)
    

    The main point is the keep=TRUE which allows you to extract parameters later.

    currInd <- match(cvfit$lambda.min,cvfit$glmnet.fit$lambda)
    # There is also a 'cvfit$lambda.1se' to have a look at 
      cutoff <- 0.5
      predicted <- cut(as.numeric(cvfit$fit.preval[,currInd]),c(-1000,cutoff,1000),labels=<your labels> )
    

    which gives you a vector of predictions, using your cutoff value, that you can then compared to the actual value.

    I do wish that there would be a standardized way of doing this without extracting parameters by hand, but there it is. At least this works.