Search code examples
rtidyversemultinomialtidymodels

broom::tidy fails on multinomial regression


I'm trying to run a multinomial logistic regression in R using tidymodels but I can't convert my results to a tidy object. Here's a sample using the iris data set.

# Multinomial  -----------------------------------------------------------------
# recipe
multinom_recipe <-
  recipe(Species ~ Sepal.Length + Sepal.Width + Sepal.Length + Petal.Width, data = iris) %>% 
  step_relevel(Species, ref_level = "setosa")

# model 
multinom_model <-  multinom_reg() %>% 
  set_engine("nnet")

# make workflow
multinom_wf <- 
  workflow() %>% 
  add_model(multinom_model) %>% 
  add_recipe(multinom_recipe) %>% 
  fit(data = iris) %>% 
  tidy()

multinom_wf

The last step throws the following error:

Error in eval(predvars, data, env) : object '..y' not found

I thought it was bc the output of the fit(data = iris) is a workflow object, but this code seems to work fine when I don't use workflow (which is the whole point of using tidymodels) or if I fit a linear model.

# recipe
linear_recipe <-
  recipe(Sepal.Length ~ Sepal.Width + Sepal.Length + Petal.Width, data = iris) 

# model 
linear_model <-  linear_reg() %>% 
  set_engine("lm")

# make workflow
linear_wf <- 
  workflow() %>% 
  add_model(linear_model) %>% 
  add_recipe(linear_recipe) %>% 
  fit(data = iris) %>% 
  tidy()

linear_wf

Anyone have an idea as to what I'm missing or is this a bug?


Solution

  • It could be a clash with the call. We could change it to

    multinom_wf$fit$fit$fit$call <- quote(nnet::multinom(formula = Species ~ ., data = iris, trace = FALSE))
    multinom_wf  %>%
         tidy
    

    -output

    # A tibble: 8 x 6
      y.level    term         estimate std.error statistic p.value
      <chr>      <chr>           <dbl>     <dbl>     <dbl>   <dbl>
    1 versicolor (Intercept)      4.17      12.0    0.348  0.728  
    2 versicolor Sepal.Length     1.08      42.0    0.0258 0.979  
    3 versicolor Sepal.Width     -9.13      81.5   -0.112  0.911  
    4 versicolor Petal.Width     20.9       14.0    1.49   0.136  
    5 virginica  (Intercept)    -16.0       12.1   -1.33   0.185  
    6 virginica  Sepal.Length     2.37      42.0    0.0563 0.955  
    7 virginica  Sepal.Width    -13.9       81.5   -0.171  0.864  
    8 virginica  Petal.Width     36.8       14.1    2.61   0.00916
    

    where

    multinom_wf <- 
      workflow() %>% 
      add_model(multinom_model) %>% 
      add_recipe(multinom_recipe) %>% 
      fit(data = iris)