Search code examples
rtidymodelsr-recipesmlr3

Is there a way to group rows (especially dummy variables) in the recipes package in R (or ml3)


# Packages
library(dplyr)
library(recipes)

# toy dataset, with A being multicolored
df <- tibble(name = c("A", "A", "A", "B", "C"), color = c("green", "yellow", "purple", "green", "blue"))


    #> # A tibble: 5 x 2
    #>   name  color 
    #>   <chr> <chr> 
    #> 1 A     green 
    #> 2 A     yellow
    #> 3 A     purple
    #> 4 B     green 
    #> 5 C     blue

The recipes step works nicely

dummified_df <- recipe(. ~ ., data = df) %>%
        step_dummy(color, one_hot = TRUE) %>%
        prep(training = df) %>%
        juice()


    #> # A tibble: 5 x 5
    #>   name  color_blue color_green color_purple color_yellow
    #>   <fct>      <dbl>       <dbl>        <dbl>        <dbl>
    #> 1 A              0           1            0            0
    #> 2 A              0           0            0            1
    #> 3 A              0           0            1            0
    #> 4 B              0           1            0            0
    #> 5 C              1           0            0            0

But the result I truly want to obtain is the one below, with one observation per row now that the multicolored item does not need several rows anymore.

summarized_dummified_df <- dummified_df %>% 
     group_by(name) %>% 
     summarise_all(~ifelse(max(.) > 0, 1, 0)) %>% 
     ungroup()


    #> # A tibble: 3 x 5
    #>   name  color_blue color_green color_purple color_yellow
    #>   <fct>      <dbl>       <dbl>        <dbl>        <dbl>
    #> 1 A              0           1            1            1
    #> 2 B              0           1            0            0
    #> 3 C              1           0            0            0

Obviously, I can do it this way. But to fully integrate my recipe step in the tidymodels ecosystem, for instance with workflow, it would be much better if I could group the rows that don't have to be duplicated anymore thanks to the dummy variables directly inside the recipe.

Is there any tidymodels-sanctioned way to obtain this result ?


I also tried to do this with mlr3, to no avail because I can't find any suitable PipeOp to aggregate rows.

library("mlr3")
library("mlr3pipelines")


task = TaskClassif$new("task",
                       data.table::data.table(
                           name = c("A", "A", "A", "B", "C"),
                           color = as.factor(c("green", "yellow", "purple", "green", "blue")),
                           price = as.factor(c("low", "low", "low", "high", "low"))),
                           "price"
                       )
                       
poe = po("encode")

poe$train(list(task))[[1]]$data()

#>    price name color.blue color.green color.purple color.yellow
#> 1:   low    A          0           1            0            0
#> 2:   low    A          0           0            0            1
#> 3:   low    A          0           0            1            0
#> 4:  high    B          0           1            0            0
#> 5:   low    C          1           0            0            0

I'm looking into the creation of custom step_ functions or custom PipeOp but I still feel like I'm missing something because my type of data doesn't feel that uncommon to me.


Solution

  • Dummy or indicator variables are conceptually mapped as one-to-one everywhere I have seen, not one-to-many, and I think this is why you are running into this. Like you, though, I have wanted to map them one-to-many sometime in the real world. I typically do this in a data tidying step before starting my model preprocessing workflow, something like this:

    library(tidyverse)
    
    # toy dataset, with A being multicolored
    df <- tibble(name = c("A", "A", "A", "B", "C"), color = c("green", "yellow", "purple", "green", "blue"))
    
    df %>%
      mutate(value = 1) %>%
      pivot_wider(names_from = "color", names_prefix = "color_", values_from = "value", values_fill = 0)
    #> # A tibble: 3 x 5
    #>   name  color_green color_yellow color_purple color_blue
    #>   <chr>       <dbl>        <dbl>        <dbl>      <dbl>
    #> 1 A               1            1            1          0
    #> 2 B               1            0            0          0
    #> 3 C               0            0            0          1
    

    Created on 2020-08-18 by the reprex package (v0.3.0.9001)