Search code examples
rone-hot-encoding

How to turn one-hot encoded variables to a single factor in R


In this post HERE they discuss how to one-hot encode a single factor variable in R. I wonder how to reverse to the problem and get a single factor from variables that one-hot encode certain properties?


Solution

  • Here's a solution ...

    First one hot encode carb

    mtcars$carb <- factor(mtcars$carb)
    df <- as.data.frame(model.matrix(~ carb - 1, mtcars))
    head(df)
    
    #>                   carb1 carb2 carb3 carb4 carb6 carb8
    #> Mazda RX4             0     0     0     1     0     0
    #> Mazda RX4 Wag         0     0     0     1     0     0
    #> Datsun 710            1     0     0     0     0     0
    #> Hornet 4 Drive        1     0     0     0     0     0
    #> Hornet Sportabout     0     1     0     0     0     0
    #> Valiant               1     0     0     0     0     0
    

    We could of course select out the hot encode variables

    library(dplyr)
    
    df %>% 
       rowwise() %>% 
       mutate(remade = which.max(c_across(starts_with("carb")))) %>%
       ungroup %>%
       mutate(remade = factor(remade))
    
    #> # A tibble: 32 x 7
    #>    carb1 carb2 carb3 carb4 carb6 carb8 remade
    #>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct> 
    #>  1     0     0     0     1     0     0 4     
    #>  2     0     0     0     1     0     0 4     
    #>  3     1     0     0     0     0     0 1     
    #>  4     1     0     0     0     0     0 1     
    #>  5     0     1     0     0     0     0 2     
    #>  6     1     0     0     0     0     0 1     
    #>  7     0     0     0     1     0     0 4     
    #>  8     0     1     0     0     0     0 2     
    #>  9     0     1     0     0     0     0 2     
    #> 10     0     0     0     1     0     0 4     
    #> # … with 22 more rows
    

    Here it is as a function with the option to keep or delete the one hot encoded columns a la @KM_83

    cold_encode <- function(df, encoded_prefix, keep_dummies = FALSE) {
       var <- sym(encoded_prefix)
       df <- 
          df %>%
          rowwise() %>%
          mutate({{ var }} := which.max(c_across(starts_with(encoded_prefix)))) %>%
          ungroup %>%
          mutate({{ var }} := factor({{ var }})) 
       if (!keep_dummies) {
          df <- 
          df %>% select(-matches(paste0(encoded_prefix,1:9)))
       }
       return(df)
    }
    
    cold_encode(df, "carb")
    #> # A tibble: 32 x 11
    #>      mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear carb 
    #>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct>
    #>  1  21       6  160    110  3.9   2.62  16.5     0     1     4 4    
    #>  2  21       6  160    110  3.9   2.88  17.0     0     1     4 4    
    #>  3  22.8     4  108     93  3.85  2.32  18.6     1     1     4 1    
    #>  4  21.4     6  258    110  3.08  3.22  19.4     1     0     3 1    
    #>  5  18.7     8  360    175  3.15  3.44  17.0     0     0     3 2    
    #>  6  18.1     6  225    105  2.76  3.46  20.2     1     0     3 1    
    #>  7  14.3     8  360    245  3.21  3.57  15.8     0     0     3 4    
    #>  8  24.4     4  147.    62  3.69  3.19  20       1     0     4 2    
    #>  9  22.8     4  141.    95  3.92  3.15  22.9     1     0     4 2    
    #> 10  19.2     6  168.   123  3.92  3.44  18.3     1     0     4 4    
    #> # … with 22 more rows