Search code examples
rrandom-forestpredictrecipetidymodels

Predict with step_naomit and retain ID using tidymodels


I am trying to retain an ID on the row when predicting using a Random Forest model to merge back on to the original dataframe. I am using step_naomit in the recipe that removes the rows with missing data when I bake the training data, but also removes the records with missing data on the testing data. Unfortunately, I don't have an ID to easily know which records were removed so I can accurately merge back on the predictions.

I have tried to add an ID column to the original data, but bake will remove any variable not included in the formula (and I don't want to include ID in the formula). I also thought I may be able to retain the row.names from the original table to merge on, but it appears the row.name is reset upon baking as well.

I realize I can remove the NA values prior to the recipe to solve this problem, but then what is the point of step_naomit in the recipe? I also tried skip=TRUE in the step_naomit, but then I get an error for missing data when fitting the model (only for random forest). I feel I am missing something here in tidymodels that would allow me to retain all the rows prior to baking?

See example:


## R 3.6.1 ON WINDOWS 10 MACHINE

require(tidyverse)
require(tidymodels)
require(ranger)

set.seed(123)

temp <- iris %>%
    dplyr::mutate(Petal.Width = case_when(
        round(Sepal.Width) %% 2 == 0 ~ NA_real_, ## INTRODUCE NA VALUES
        TRUE ~ Petal.Width))

mySplit <- rsample::initial_split(temp, prop = 0.8)

myRecipe <- function(dataFrame) {
    recipes::recipe(Petal.Width ~ ., data = dataFrame) %>%
        step_naomit(all_numeric()) %>%
        prep(data = dataFrame)
}

myPred <- function(mySplit,myRecipe) {

    train_set <- training(mySplit)
    test_set <- testing(mySplit)

    train_prep <- myRecipe(train_set)

    analysis_processed <- bake(train_prep, new_data = train_set)

    model <- rand_forest(
            mode = "regression",
            mtry = 3,
            trees = 50) %>%
        set_engine("ranger", importance = 'impurity') %>%
        fit(Sepal.Width ~ ., data=analysis_processed)

    test_processed <- bake(train_prep, new_data = test_set)

    test_processed %>%
        bind_cols(myPrediction = unlist(predict(model,new_data=test_processed))) 

}

getPredictions <- myPred(mySplit,myRecipe)

nrow(getPredictions)

##  21 ROWS

max(as.numeric(row.names(getPredictions)))

##  21

nrow(testing(mySplit))

##  29 ROWS

max(as.numeric(row.names(testing(mySplit))))

##  150

Solution

  • To be able to keep track of which observations were removed we need to give the original dataset an id variable.

    temp <- iris %>%
        dplyr::mutate(Petal.Width = case_when(
            round(Sepal.Width) %% 2 == 0 ~ NA_real_, ## INTRODUCE NA VALUES
            TRUE ~ Petal.Width),
            id = row_number()) #<<<<
    

    Then we use update_role() to first designate it as an "id variable", then remove it as a predictor so it doesn't become part of the modeling process. And that is it. Everything else should work like before. Below is the fully updated code with #<<<< to denote my changes.

    require(tidyverse)
    #> Loading required package: tidyverse
    require(tidymodels)
    #> Loading required package: tidymodels
    #> Registered S3 method overwritten by 'xts':
    #>   method     from
    #>   as.zoo.xts zoo
    #> ── Attaching packages ───────────────────── tidymodels 0.0.3 ──
    #> ✔ broom     0.5.2     ✔ recipes   0.1.7
    #> ✔ dials     0.0.3     ✔ rsample   0.0.5
    #> ✔ infer     0.5.0     ✔ yardstick 0.0.4
    #> ✔ parsnip   0.0.4
    #> ── Conflicts ──────────────────────── tidymodels_conflicts() ──
    #> ✖ scales::discard() masks purrr::discard()
    #> ✖ dplyr::filter()   masks stats::filter()
    #> ✖ recipes::fixed()  masks stringr::fixed()
    #> ✖ dplyr::lag()      masks stats::lag()
    #> ✖ dials::margin()   masks ggplot2::margin()
    #> ✖ dials::offset()   masks stats::offset()
    #> ✖ yardstick::spec() masks readr::spec()
    #> ✖ recipes::step()   masks stats::step()
    require(ranger)
    #> Loading required package: ranger
    
    set.seed(1234)
    
    temp <- iris %>%
        dplyr::mutate(Petal.Width = case_when(
            round(Sepal.Width) %% 2 == 0 ~ NA_real_, ## INTRODUCE NA VALUES
            TRUE ~ Petal.Width),
            id = row_number()) #<<<<
    
    mySplit <- rsample::initial_split(temp, prop = 0.8)
    
    myRecipe <- function(dataFrame) {
        recipes::recipe(Petal.Width ~ ., data = dataFrame) %>%
            update_role(id, new_role = "id variable") %>%  #<<<<
            update_role(-id, new_role = 'predictor') %>%   #<<<<
            step_naomit(all_numeric()) %>%
            prep(data = dataFrame)
    }
    
    myPred <- function(mySplit,myRecipe) {
    
        train_set <- training(mySplit)
        test_set <- testing(mySplit)
    
        train_prep <- myRecipe(train_set)
    
        analysis_processed <- bake(train_prep, new_data = train_set)
    
        model <- rand_forest(
                mode = "regression",
                mtry = 3,
                trees = 50) %>%
            set_engine("ranger", importance = 'impurity') %>%
            fit(Sepal.Width ~ ., data=analysis_processed)
    
        test_processed <- bake(train_prep, new_data = test_set)
    
        test_processed %>%
            bind_cols(myPrediction = unlist(predict(model,new_data=test_processed))) 
    
    }
    
    getPredictions <- myPred(mySplit, myRecipe)
    
    getPredictions
    #> # A tibble: 23 x 7
    #>    Sepal.Length Sepal.Width Petal.Length Petal.Width Species     id myPrediction
    #>           <dbl>       <dbl>        <dbl>       <dbl> <fct>    <int>        <dbl>
    #>  1          4.6         3.1          1.5         0.2 setosa       4         3.24
    #>  2          4.3         3            1.1         0.1 setosa      14         3.04
    #>  3          5.1         3.4          1.5         0.2 setosa      40         3.22
    #>  4          5.9         3            4.2         1.5 versico…    62         2.98
    #>  5          6.7         3.1          4.4         1.4 versico…    66         2.92
    #>  6          6           2.9          4.5         1.5 versico…    79         3.03
    #>  7          5.7         2.6          3.5         1   versico…    80         2.79
    #>  8          6           2.7          5.1         1.6 versico…    84         3.12
    #>  9          5.8         2.6          4           1.2 versico…    93         2.79
    #> 10          6.2         2.9          4.3         1.3 versico…    98         2.88
    #> # … with 13 more rows
    
    # removed ids
    setdiff(testing(mySplit)$id, getPredictions$id)
    #> [1]   5  28  47  70  90 132
    

    Created on 2019-11-26 by the reprex package (v0.3.0)