Search code examples
rgbm

Trying to get gbm2sas package to work


I am experimenting with the R packages gbm2sas and gbm.

I am trying to create a gbm model object (using gbm() function) and generate SAS code that will implement the model (using gbm2sas() function). I am not able to get it to work. I get the following error.

Here is my R code:

library(gbm)
library(gbm2sas)
data(iris)
iris$setosaFlag = (iris$Species == "setosa")*1
iris.gbm = gbm(setosaFlag ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, 
                                data=iris, 
                                dist="bernoulli", 
                                n.tree = 3,
                                interaction.depth=3,
                                shrinkage = 0.01,
                                keep.data=TRUE,
                                verbose=TRUE,
                                n.cores=1)
print(iris.gbm)
pretty.gbm.tree(iris.gbm, i.tree=1)
pretty.gbm.tree(iris.gbm, i.tree=2)
pretty.gbm.tree(iris.gbm, i.tree=3)

gbm2sas(
                iris.gbm, # gbm object from above
                sasfile="studyGBM.R", # name to use for SAS code file
                ntrees=3, # number of trees
                mysasdata="sasdataset", 
                treeval="treevalue", 
                prefix="dobranch_" 
)

I get the following output and error:

> library(gbm)
> library(gbm2sas)
> data(iris)
> iris$setosaFlag = (iris$Species == "setosa")*1
> iris.gbm = gbm(setosaFlag ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, 
+                                 data=iris, 
+                                 dist="bernoulli", 
+                                 n.tree = 3,
+                                 interaction.depth=3,
+                                 shrinkage = 0.01,
+                                 keep.data=TRUE,
+                                 verbose=TRUE,
+                                 n.cores=1)
Iter   TrainDeviance   ValidDeviance   StepSize   Improve
     1        1.2531             nan     0.0100    0.0096
     2        1.2337             nan     0.0100    0.0093
     3        1.2148             nan     0.0100    0.0082

> print(iris.gbm)
gbm(formula = setosaFlag ~ Sepal.Length + Sepal.Width + Petal.Length + 
    Petal.Width, distribution = "bernoulli", data = iris, n.trees = 3, 
    interaction.depth = 3, shrinkage = 0.01, keep.data = TRUE, 
    verbose = TRUE, n.cores = 1)
A gradient boosted model with bernoulli loss function.
3 iterations were performed.
There were 4 predictors of which 3 had non-zero influence.
> pretty.gbm.tree(iris.gbm, i.tree=1)
  SplitVar SplitCodePred LeftNode RightNode MissingNode ErrorReduction Weight Prediction
0        2        2.4500        1         5           9    1.72800e+01     75     0.0012
1        0        5.0500        2         3           4    3.28692e-31     27     0.0300
2       -1        0.0300       -1        -1          -1    0.00000e+00     15     0.0300
3       -1        0.0300       -1        -1          -1    0.00000e+00     12     0.0300
4       -1        0.0300       -1        -1          -1    0.00000e+00     27     0.0300
5        0        6.8500        6         7           8    5.48890e-30     48    -0.0150
6       -1       -0.0150       -1        -1          -1    0.00000e+00     38    -0.0150
7       -1       -0.0150       -1        -1          -1    0.00000e+00     10    -0.0150
8       -1       -0.0150       -1        -1          -1    0.00000e+00     48    -0.0150
9       -1        0.0012       -1        -1          -1    0.00000e+00     75     0.0012
> pretty.gbm.tree(iris.gbm, i.tree=2)
  SplitVar SplitCodePred LeftNode RightNode MissingNode ErrorReduction Weight  Prediction
0        2    2.35000000        1         5           9   1.693529e+01     75  0.00103485
1        3    0.25000000        2         3           4   3.104314e-31     27  0.02940891
2       -1    0.02940891       -1        -1          -1   0.000000e+00     17  0.02940891
3       -1    0.02940891       -1        -1          -1   0.000000e+00     10  0.02940891
4       -1    0.02940891       -1        -1          -1   0.000000e+00     27  0.02940891
5        3    2.05000000        6         7           8   1.672221e-30     48 -0.01492556
6       -1   -0.01492556       -1        -1          -1   0.000000e+00     37 -0.01492556
7       -1   -0.01492556       -1        -1          -1   0.000000e+00     11 -0.01492556
8       -1   -0.01492556       -1        -1          -1   0.000000e+00     48 -0.01492556
9       -1    0.00103485       -1        -1          -1   0.000000e+00     75  0.00103485
> pretty.gbm.tree(iris.gbm, i.tree=3)
  SplitVar SplitCodePred LeftNode RightNode MissingNode ErrorReduction Weight   Prediction
0        2   2.700000000        1         5           9   1.762206e+01     75  0.003792325
1        0   5.050000000        2         3           4   1.479114e-30     32  0.028846427
2       -1   0.028846427       -1        -1          -1   0.000000e+00     20  0.028846427
3       -1   0.028846427       -1        -1          -1   0.000000e+00     12  0.028846427
4       -1   0.028846427       -1        -1          -1   0.000000e+00     32  0.028846427
5        0   6.750000000        6         7           8   8.513506e-31     43 -0.014852589
6       -1  -0.014852589       -1        -1          -1   0.000000e+00     33 -0.014852589
7       -1  -0.014852589       -1        -1          -1   0.000000e+00     10 -0.014852589
8       -1  -0.014852589       -1        -1          -1   0.000000e+00     43 -0.014852589
9       -1   0.003792325       -1        -1          -1   0.000000e+00     75  0.003792325
> 

> gbm2sas(
+                 iris.gbm, # gbm object from above
+                 sasfile="studyGBM.R", # name to use for SAS code file
+                 ntrees=3, # number of trees
+ mysasdata="sasdataset", 
+ treeval="treevalue", 
+ prefix="dobranch_" 
+ )
Error in data[, gbmobject$var.names] : 
  object of type 'closure' is not subsettable
> 
> 

Can anyone please point out what I am doing wrong?

Thanks.


Solution

  • The error isn't on your end (though I would change a few things). I dug into the source code for the gbm2sas function and there was a problem with the way that it was calling the var.names

    First, run this fixed version of the gbm2sas function:

    gbm2sas<-function(
      gbmobject,
      sasfile=NULL,
      ntrees=NULL,
      mysasdata="mysasdata",
      treeval="treeval",
      prefix="do_"
    ) {
      if(is.null(ntrees)) ntrees<-gbmobject$n.trees
      maxhmmt<-0
      hasprefix<-prefix!="do_"
      hasmysasdata<-mysasdata!="mysasdata"
      hastreeval<-treeval!="treeval"
      prepwords<-"data mysasdata; set mysasdata;"
      if(hasmysasdata) prepwords<-gsub("mysasdata", mysasdata, prepwords)
      write.table(prepwords, sasfile, row.names=FALSE, col.names=FALSE, quote=FALSE)
      numtrees<-ntrees
      for(treeloop in 1:numtrees) {
        pgt<-pretty.gbm.tree(gbmobject,i.tree = treeloop)[1:7]
        hmmt<-dim(pgt)[1]
        maxhmmt<-max(maxhmmt, hmmt)
        wordsa<-"do_x=0;"
        for(loop in 0:(hmmt-1)) {
          if(loop>0) {
            wordsb<-gsub("x", loop, wordsa)
          } else {
            wordsb<-"do_0=1;"
          }
          if(hasprefix) wordsb<-gsub("do_", prefix, wordsb)
          write.table(wordsb, sasfile, row.names=FALSE, col.names=FALSE, quote=FALSE, append=TRUE)
        } 
        words0<-"if missing(V A R1) then do_V A R5=1; else do;" 
        words1<-"if V A R1 lt V A R2 then do_V A R3=1; else do_V A R4=1; end;"
        words2<-"if V A R1 in (V A R2) then do_V A R3=1; else do_V A R4=1; end;"
        words2b<-"do_V A R4=1; end;"
        words3<-"end;"
        if(hasprefix) {
          words0<-gsub("do_", prefix, words0)
          words1<-gsub("do_", prefix, words1)
          words2<-gsub("do_", prefix, words2)
          words2b<-gsub("do_", prefix, words2b)
          words3<-gsub("do_", prefix, words3)
        }
        thevarnames<-gbmobject$var.names
        thevarnames2 <- as.list(gbmobject$var.names)
        types<-lapply(lapply(thevarnames2,class), function(i) ifelse (strsplit(i[1]," ")[1]=="ordered","ordered",i))
        levels<-lapply(thevarnames2,levels)
        for(loop in 1:hmmt) {
          prepwords<-paste("if do_", (loop-1), ">0 then do;", sep="")
          if(hasprefix) prepwords<-gsub("do_", prefix, prepwords)
          write.table(prepwords, sasfile, row.names=FALSE, col.names=FALSE, quote=FALSE, append=TRUE)
          splitvar<-1+as.numeric(as.vector(pgt[loop,]$SplitVar))
          splitcodepred<-as.numeric(as.vector(pgt[loop,]$SplitCodePred))
          leftnode<-as.numeric(as.vector(pgt[loop,]$LeftNode))
          rightnode<-as.numeric(as.vector(pgt[loop,]$RightNode))
          missingnode<-as.numeric(as.vector(pgt[loop,]$MissingNode))
          if(splitvar>0) {
            words0a<-gsub("V A R1", thevarnames[splitvar], words0)
            words1a<-gsub("V A R1", thevarnames[splitvar], words1)
            words2a<-gsub("V A R1", thevarnames[splitvar], words2)
            words0a<-gsub("V A R5", missingnode, words0a)
            words1a<-gsub("V A R3", leftnode, words1a)
            words2a<-gsub("V A R3", leftnode, words2a)
            words1a<-gsub("V A R4", rightnode, words1a)
            words2a<-gsub("V A R4", rightnode, words2a)
            words2ab<-gsub("V A R4", rightnode, words2b)
            thistype<-types[[splitvar]]
            leftstring<-" "
            rightstring<-" "
            write.table(words0a, sasfile, row.names=FALSE, col.names=FALSE, quote=FALSE, append=TRUE)
            if(thistype=="numeric") {
              words1a<-gsub("V A R2", splitcodepred, words1a)
              write.table(words1a, sasfile, row.names=FALSE, col.names=FALSE, quote=FALSE, append=TRUE)
            } else {
              if(thistype=="ordered") {
                splitcodepred<-ceiling(splitcodepred)
                if(splitcodepred>=1) {
                  theleft<-c(levels[[splitvar]][1:splitcodepred], NA)
                } else {
                  theleft<-rep(NA, 2)                   
                }
              } else {
                describer<-unlist(gbmobject$c.splits[1+splitcodepred])
                theleft<-c(levels[[splitvar]][describer==-1], NA)
              }
              logic<-!is.na(theleft)
              if(sum(as.numeric(logic))>0) {
                theleft<-theleft[logic]
                hmmt2<-length(theleft) 
                leftstring<-NULL
                for(loop2 in 1:hmmt2) {
                  leftstring<-paste(leftstring, "'", theleft[loop2], "'", sep="")
                  if(loop2<hmmt2) leftstring<-paste(leftstring, ", ", sep="")
                }
              } else {
                leftstring<-"blah"
              }
              if(leftstring!="blah") {
                words2a<-gsub("V A R2", leftstring, words2a)
                write.table(words2a, sasfile, row.names=FALSE, col.names=FALSE, quote=FALSE, append=TRUE)
              } else {
                write.table(words2ab, sasfile, row.names=FALSE, col.names=FALSE, quote=FALSE, append=TRUE)
              }
            }
          } else {
            prepwords<-paste("treeval", treeloop, "=", splitcodepred, ";", sep="")
            if(hastreeval) prepwords<-gsub("treeval", treeval, prepwords)
            write.table(prepwords, sasfile, row.names=FALSE, col.names=FALSE, quote=FALSE, append=TRUE)
          }
          write.table(words3, sasfile, row.names=FALSE, col.names=FALSE, quote=FALSE, append=TRUE)
        }
      }
      wordsa<-"drop do_x;"
      for(loop in 0:(maxhmmt-1)) {
        if(loop>0) {
          wordsb<-gsub("x", loop, wordsa)
        } else {
          wordsb<-"drop do_0;"
        }
        if(hasprefix) wordsb<-gsub("do_", prefix, wordsb)
        write.table(wordsb, sasfile, row.names=FALSE, col.names=FALSE, quote=FALSE, append=TRUE)
      }
      write.table("run;", sasfile, row.names=FALSE, col.names=FALSE, quote=FALSE, append=TRUE)
    }
    

    Here is an updated version of your code as well:

    library(gbm)
    library(gbm2sas)
    library(dplyr)
    data(iris)
    iris$setosaFlag = (iris$Species == "setosa")*1
    
    # remove '.' from variable names. SAS doesn't like anything but underscores.
    iris$septal_length <- iris$Sepal.Length
    iris$septal_width <- iris$Sepal.Width
    iris$petal_length <- iris$Petal.Length
    iris$petal_width <- iris$Petal.Width
    iris <- select(iris, setosaFlag, septal_length, septal_width, petal_length, petal_width) # I don't believe that dataset can include variables that aren't included in the gbm(), it's entirely possible I'm wrong but doesn't hurt to remove them just in case
    
    iris.gbm = gbm(setosaFlag ~ septal_length + septal_width + petal_length + petal_width,
                   data=iris, 
                   dist="bernoulli", 
                   n.tree = 3,
                   interaction.depth=3,
                   shrinkage = 0.01,
                   keep.data=TRUE,
                   verbose=TRUE,
                   n.cores=1)
    print(iris.gbm)
    pretty.gbm.tree(iris.gbm, i.tree=1)
    pretty.gbm.tree(iris.gbm, i.tree=2)
    pretty.gbm.tree(iris.gbm, i.tree=3)
    
    # change your sasfile name to one that ends in .sas!!
    gbm2sas(
      iris.gbm, # gbm object from above
      sasfile="studyGBM.R", # name to use for SAS code file
      ntrees=3, # number of trees
      mysasdata="sasdataset", 
      treeval="treevalue", 
      prefix="dobranch_" 
    )
    

    Test it out and let me know if you still encounter any issues.