In this post HERE they discuss how to one-hot encode a single factor variable in R. I wonder how to reverse to the problem and get a single factor from variables that one-hot encode certain properties?
Here's a solution ...
First one hot encode carb
mtcars$carb <- factor(mtcars$carb)
df <- as.data.frame(model.matrix(~ carb - 1, mtcars))
head(df)
#> carb1 carb2 carb3 carb4 carb6 carb8
#> Mazda RX4 0 0 0 1 0 0
#> Mazda RX4 Wag 0 0 0 1 0 0
#> Datsun 710 1 0 0 0 0 0
#> Hornet 4 Drive 1 0 0 0 0 0
#> Hornet Sportabout 0 1 0 0 0 0
#> Valiant 1 0 0 0 0 0
We could of course select out the hot encode variables
library(dplyr)
df %>%
rowwise() %>%
mutate(remade = which.max(c_across(starts_with("carb")))) %>%
ungroup %>%
mutate(remade = factor(remade))
#> # A tibble: 32 x 7
#> carb1 carb2 carb3 carb4 carb6 carb8 remade
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct>
#> 1 0 0 0 1 0 0 4
#> 2 0 0 0 1 0 0 4
#> 3 1 0 0 0 0 0 1
#> 4 1 0 0 0 0 0 1
#> 5 0 1 0 0 0 0 2
#> 6 1 0 0 0 0 0 1
#> 7 0 0 0 1 0 0 4
#> 8 0 1 0 0 0 0 2
#> 9 0 1 0 0 0 0 2
#> 10 0 0 0 1 0 0 4
#> # … with 22 more rows
Here it is as a function with the option to keep or delete the one hot encoded columns a la @KM_83
cold_encode <- function(df, encoded_prefix, keep_dummies = FALSE) {
var <- sym(encoded_prefix)
df <-
df %>%
rowwise() %>%
mutate({{ var }} := which.max(c_across(starts_with(encoded_prefix)))) %>%
ungroup %>%
mutate({{ var }} := factor({{ var }}))
if (!keep_dummies) {
df <-
df %>% select(-matches(paste0(encoded_prefix,1:9)))
}
return(df)
}
cold_encode(df, "carb")
#> # A tibble: 32 x 11
#> mpg cyl disp hp drat wt qsec vs am gear carb
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct>
#> 1 21 6 160 110 3.9 2.62 16.5 0 1 4 4
#> 2 21 6 160 110 3.9 2.88 17.0 0 1 4 4
#> 3 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1
#> 4 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1
#> 5 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2
#> 6 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1
#> 7 14.3 8 360 245 3.21 3.57 15.8 0 0 3 4
#> 8 24.4 4 147. 62 3.69 3.19 20 1 0 4 2
#> 9 22.8 4 141. 95 3.92 3.15 22.9 1 0 4 2
#> 10 19.2 6 168. 123 3.92 3.44 18.3 1 0 4 4
#> # … with 22 more rows