Search code examples
rmlriml

R, iml, mlr. Feature Importance always returns 1 for every feature


I'm doing something with the mlr framework that causes FeatureImp to return 1 for every feature and I can't put my finger on it. Here's an exemple:

library(caret)
#> Carregando pacotes exigidos: lattice
#> Carregando pacotes exigidos: ggplot2
library(mlr)
#> Carregando pacotes exigidos: ParamHelpers
#> 
#> Attaching package: 'mlr'
#> The following object is masked from 'package:caret':
#> 
#>     train
library(iml)

data("iris")
iris = iris[iris$Species != 'setosa',]
iris$Species = ifelse(iris$Species == 'virginica', 1, 0)
iris$Species = as.factor(iris$Species)

ind=createDataPartition(iris$Species, times=1, p=0.8, list=FALSE)
train=iris[ind,]
test=iris[-ind,]
remove(ind)

train.task=makeClassifTask(data=train, target = 'Species', positive = 1)
test.task=makeClassifTask(data=test, target = 'Species', positive = 1)

learner=list(
  xgboost = makeLearner("classif.xgboost",predict.type = "prob"),
  ksvm = makeLearner("classif.ksvm",predict.type = "prob"),
  nnet = makeLearner("classif.nnet",predict.type = "prob"),
  randomForest = makeLearner("classif.randomForest",predict.type = "prob")
)

model = lapply(learner, function(x) train(x, train.task))
#> # weights:  19
#> initial  value 57.506055 
#> iter  10 value 52.109027
#> iter  20 value 7.798098
#> iter  30 value 5.401193
#> iter  40 value 4.707935
#> iter  50 value 4.702049
#> final  value 4.701710 
#> converged
prediction = lapply(model, function(x) predict(x, test.task))

ensemble = makeStackedLearner(learner, super.learner = 'classif.randomForest', predict.type = 'prob',
                              method = "stack.cv", use.feat = FALSE)
model$ensemble = train(ensemble, train.task)
#> # weights:  19
#> initial  value 43.712841 
#> iter  10 value 5.444287
#> iter  20 value 4.536990
#> iter  30 value 4.527489
#> iter  40 value 4.481401
#> iter  50 value 4.481221
#> iter  50 value 4.481221
#> iter  50 value 4.481221
#> final  value 4.481221 
#> converged
#> # weights:  19
#> initial  value 52.864011 
#> iter  10 value 33.347827
#> iter  20 value 2.926847
#> iter  30 value 0.011104
#> final  value 0.000055 
#> converged
#> # weights:  19
#> initial  value 44.627604 
#> iter  10 value 31.360597
#> iter  20 value 5.798769
#> iter  30 value 4.290623
#> iter  40 value 3.751202
#> iter  50 value 3.547856
#> iter  60 value 3.469366
#> iter  70 value 3.373487
#> iter  80 value 3.317680
#> iter  90 value 3.310354
#> iter 100 value 3.301115
#> final  value 3.301115 
#> stopped after 100 iterations
#> # weights:  19
#> initial  value 46.410266 
#> iter  10 value 29.975896
#> iter  20 value 1.266423
#> iter  30 value 0.004667
#> final  value 0.000052 
#> converged
#> # weights:  19
#> initial  value 52.665930 
#> final  value 44.361399 
#> converged
#> # weights:  19
#> initial  value 60.471973 
#> iter  10 value 50.475349
#> iter  20 value 7.580138
#> iter  30 value 4.828646
#> iter  40 value 4.543112
#> iter  50 value 2.995374
#> iter  60 value 2.636710
#> iter  70 value 2.539857
#> iter  80 value 2.497281
#> iter  90 value 2.427158
#> iter 100 value 2.370383
#> final  value 2.370383 
#> stopped after 100 iterations
prediction$ensemble = predict(model$ensemble, test.task)

predictor = Predictor$new(model$ensemble,
                          data = train.task$env$data[which(names(train.task$env$data) != "Species")],
                          y = as.numeric(train.task$env$data$Species)-1)

imp = FeatureImp$new(predictor, loss = "ce")
imp$results
#>        feature importance.05 importance importance.95 permutation.error
#> 1 Sepal.Length             1          1             1                 1
#> 2  Sepal.Width             1          1             1                 1
#> 3 Petal.Length             1          1             1                 1
#> 4  Petal.Width             1          1             1                 1

Created on 2020-01-23 by the reprex package (v0.3.0)


Solution

  • Seems like this is fixed with the dev version of {iml}.

    I could reproduce your issues with the current CRAN version.

    library(caret)
    #> Loading required package: lattice
    #> Loading required package: ggplot2
    library(mlr)
    #> Loading required package: ParamHelpers
    #> 'mlr' is in maintenance mode since July 2019. Future development
    #> efforts will go into its successor 'mlr3' (<https://mlr3.mlr-org.com>).
    #> 
    #> Attaching package: 'mlr'
    #> The following object is masked from 'package:caret':
    #> 
    #>     train
    library(iml)
    
    data("iris")
    iris = iris[iris$Species != "setosa", ]
    iris$Species = ifelse(iris$Species == "virginica", 1, 0)
    iris$Species = as.factor(iris$Species)
    
    ind = createDataPartition(iris$Species, times = 1, p = 0.8, list = FALSE)
    train = iris[ind, ]
    test = iris[-ind, ]
    remove(ind)
    
    train.task = makeClassifTask(data = train, target = "Species", positive = 1)
    test.task = makeClassifTask(data = test, target = "Species", positive = 1)
    
    learner = list(
      xgboost = makeLearner("classif.xgboost", predict.type = "prob"),
      ksvm = makeLearner("classif.ksvm", predict.type = "prob"),
      nnet = makeLearner("classif.nnet", predict.type = "prob"),
      randomForest = makeLearner("classif.randomForest", predict.type = "prob")
    )
    
    model = lapply(learner, function(x) train(x, train.task))
    #> # weights:  19
    #> initial  value 59.040647 
    #> iter  10 value 54.908003
    #> iter  20 value 8.784817
    #> iter  30 value 2.906017
    #> iter  40 value 0.187334
    #> iter  50 value 0.000610
    #> final  value 0.000059 
    #> converged
    prediction = lapply(model, function(x) predict(x, test.task))
    
    ensemble = makeStackedLearner(learner,
      super.learner = "classif.randomForest", predict.type = "prob",
      method = "stack.cv", use.feat = FALSE)
    model$ensemble = train(ensemble, train.task)
    #> # weights:  19
    #> initial  value 44.537254 
    #> iter  10 value 6.716784
    #> iter  20 value 4.750452
    #> iter  30 value 4.487501
    #> iter  40 value 4.481250
    #> final  value 4.481222 
    #> converged
    #> # weights:  19
    #> initial  value 54.135701 
    #> iter  10 value 13.081961
    #> iter  20 value 1.676063
    #> iter  30 value 0.002261
    #> final  value 0.000044 
    #> converged
    #> # weights:  19
    #> initial  value 42.621635 
    #> iter  10 value 5.201573
    #> iter  20 value 2.878946
    #> iter  30 value 1.133911
    #> iter  40 value 0.002784
    #> iter  50 value 0.000726
    #> final  value 0.000037 
    #> converged
    #> # weights:  19
    #> initial  value 43.795663 
    #> iter  10 value 4.478310
    #> iter  20 value 1.811306
    #> iter  30 value 0.027775
    #> iter  40 value 0.004873
    #> iter  50 value 0.001480
    #> iter  60 value 0.000230
    #> iter  70 value 0.000221
    #> final  value 0.000089 
    #> converged
    #> # weights:  19
    #> initial  value 44.433321 
    #> iter  10 value 7.252874
    #> iter  20 value 1.200457
    #> iter  30 value 0.001668
    #> final  value 0.000063 
    #> converged
    #> # weights:  19
    #> initial  value 67.012204 
    #> final  value 55.451774 
    #> converged
    prediction$ensemble = predict(model$ensemble, test.task)
    
    predictor = Predictor$new(model$ensemble,
      data = train.task$env$data[which(names(train.task$env$data) != "Species")],
      y = as.numeric(train.task$env$data$Species) - 1)
    
    imp = FeatureImp$new(predictor, loss = "ce")
    imp$results
    #>        feature importance.05 importance importance.95 permutation.error
    #> 1  Petal.Width          11.1       12.0          14.2            0.3000
    #> 2 Petal.Length          10.3       11.5          13.1            0.2875
    #> 3 Sepal.Length           3.3        4.5           6.3            0.1125
    #> 4  Sepal.Width           2.1        3.5           4.0            0.0875
    

    Created on 2020-01-23 by the reprex package (v0.3.0)

    Session info

    devtools::session_info()
    #> ─ Session info ───────────────────────────────────────────────────────────────
    #>  setting  value                                      
    #>  version  R version 3.6.2 Patched (2019-12-12 r77564)
    #>  os       macOS Mojave 10.14.6                       
    #>  system   x86_64, darwin15.6.0                       
    #>  ui       X11                                        
    #>  language (EN)                                       
    #>  collate  en_US.UTF-8                                
    #>  ctype    en_US.UTF-8                                
    #>  tz       Europe/Berlin                              
    #>  date     2020-01-23                                 
    #> 
    #> ─ Packages ───────────────────────────────────────────────────────────────────
    #>  package      * version     date       lib
    #>  assertthat     0.2.1       2019-03-21 [1]
    #>  backports      1.1.5       2019-10-02 [1]
    #>  BBmisc         1.11        2017-03-10 [1]
    #>  callr          3.4.0       2019-12-09 [1]
    #>  caret        * 6.0-85      2020-01-07 [1]
    #>  checkmate      1.9.4       2019-07-04 [1]
    #>  class          7.3-15      2019-01-01 [2]
    #>  cli            2.0.1.9000  2020-01-12 [1]
    #>  codetools      0.2-16      2018-12-24 [2]
    #>  colorspace     1.4-1       2019-03-18 [1]
    #>  crayon         1.3.4       2017-09-16 [1]
    #>  data.table     1.12.8      2019-12-09 [1]
    #>  desc           1.2.0       2018-05-01 [1]
    #>  devtools       2.2.1       2019-09-24 [1]
    #>  digest         0.6.23      2019-11-23 [1]
    #>  dplyr          0.8.3       2019-07-04 [1]
    #>  ellipsis       0.3.0       2019-09-20 [1]
    #>  evaluate       0.14        2019-05-28 [1]
    #>  fansi          0.4.1       2020-01-08 [1]
    #>  fastmatch      1.1-0       2017-01-28 [1]
    #>  foreach        1.4.7       2019-07-27 [1]
    #>  fs             1.3.1       2019-05-06 [1]
    #>  generics       0.0.2       2018-11-29 [1]
    #>  ggplot2      * 3.2.1       2019-08-10 [1]
    #>  glue           1.3.1       2019-03-12 [1]
    #>  gower          0.2.1       2019-05-14 [1]
    #>  gridExtra      2.3         2017-09-09 [1]
    #>  gtable         0.3.0       2019-03-25 [1]
    #>  highr          0.8         2019-03-20 [1]
    #>  htmltools      0.4.0       2019-10-04 [1]
    #>  iml          * 0.9.0       2020-01-23 [1]
    #>  ipred          0.9-9       2019-04-28 [1]
    #>  iterators      1.0.12      2019-07-26 [1]
    #>  kernlab        0.9-29      2019-11-12 [1]
    #>  knitr          1.27        2020-01-16 [1]
    #>  lattice      * 0.20-38     2018-11-04 [2]
    #>  lava           1.6.6       2019-08-01 [1]
    #>  lazyeval       0.2.2       2019-03-15 [1]
    #>  lifecycle      0.1.0       2019-08-01 [1]
    #>  lubridate      1.7.4       2018-04-11 [1]
    #>  magrittr       1.5         2014-11-22 [1]
    #>  MASS           7.3-51.4    2019-03-31 [1]
    #>  Matrix         1.2-18      2019-11-27 [2]
    #>  memoise        1.1.0       2017-04-21 [1]
    #>  Metrics        0.1.4       2018-07-09 [1]
    #>  mlr          * 2.17.0.9000 2020-01-13 [1]
    #>  ModelMetrics   1.2.2.1     2020-01-13 [1]
    #>  munsell        0.5.0       2018-06-12 [1]
    #>  nlme           3.1-143     2019-12-10 [2]
    #>  nnet           7.3-12      2016-02-02 [2]
    #>  parallelMap    1.4         2019-05-17 [1]
    #>  ParamHelpers * 1.13.0.9000 2019-12-11 [1]
    #>  pillar         1.4.3       2019-12-20 [1]
    #>  pkgbuild       1.0.6       2019-10-09 [1]
    #>  pkgconfig      2.0.3       2019-09-22 [1]
    #>  pkgload        1.0.2       2018-10-29 [1]
    #>  plyr           1.8.5       2019-12-10 [1]
    #>  prediction     0.3.14      2019-06-17 [1]
    #>  prettyunits    1.1.0       2020-01-09 [1]
    #>  pROC           1.16.1      2020-01-14 [1]
    #>  processx       3.4.1       2019-07-18 [1]
    #>  prodlim        2019.11.13  2019-11-17 [1]
    #>  ps             1.3.0       2018-12-21 [1]
    #>  purrr          0.3.3       2019-10-18 [1]
    #>  R6             2.4.1       2019-11-12 [1]
    #>  randomForest   4.6-14      2018-03-25 [1]
    #>  Rcpp           1.0.3       2019-11-08 [1]
    #>  recipes        0.1.9       2020-01-07 [1]
    #>  remotes        2.1.0       2019-06-24 [1]
    #>  reshape2       1.4.3       2017-12-11 [1]
    #>  rlang          0.4.3       2020-01-22 [1]
    #>  rmarkdown      2.1         2020-01-20 [1]
    #>  rpart          4.1-15      2019-04-12 [1]
    #>  rprojroot      1.3-2       2018-01-03 [1]
    #>  scales         1.1.0       2019-11-18 [1]
    #>  sessioninfo    1.1.1       2018-11-05 [1]
    #>  stringi        1.4.5       2020-01-11 [1]
    #>  stringr        1.4.0       2019-02-10 [1]
    #>  survival       3.1-8       2019-12-03 [2]
    #>  testthat       2.3.1       2019-12-01 [1]
    #>  tibble         2.1.3       2019-06-06 [1]
    #>  tidyselect     0.2.5       2018-10-11 [1]
    #>  timeDate       3043.102    2018-02-21 [1]
    #>  usethis        1.5.1.9000  2020-01-17 [1]
    #>  withr          2.1.2       2018-03-15 [1]
    #>  xfun           0.12        2020-01-13 [1]
    #>  xgboost        0.90.0.2    2019-08-01 [1]
    #>  XML            3.99-0.3    2020-01-20 [1]
    #>  yaml           2.2.0       2018-07-25 [1]
    #>  source                                   
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.2)                           
    #>  Github (r-lib/cli@f786d87)               
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.0)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  Github (christophM/iml@54b2ce2)          
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.0)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.0)                           
    #>  local                                    
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.2)                           
    #>  Github (berndbischl/ParamHelpers@c2d989c)
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.0)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.0)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  Github (r-lib/rlang@624c5c3)             
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.1)                           
    #>  Github (pat-s/usethis@0251102)           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.1)                           
    #>  CRAN (R 3.6.2)                           
    #>  CRAN (R 3.6.0)                           
    #> 
    #> [1] /Users/pjs/Library/R/3.6/library
    #> [2] /Library/Frameworks/R.framework/Versions/3.6/Resources/library