Search code examples
rstatisticsclassificationdata-analysistidymodels

How to obtain the coefficients of a parsnip multinomial logistic regression model?


I fit a multinomial logistic regression model to predict species in the iris dataset using the tidymodels framework.

library(tidymodels)

iris.lr = multinom_reg(
  mode="classification",
  penalty=NULL,
  mixture=NULL
) %>%
  set_engine("glmnet")

iris.fit = iris.lr %>%
  fit(Species ~. , data = iris)

I would then like to look at the coefficients of my model and write out the formula. My understanding is that I should get this from iris.fit.

The output of iris.fit has a 100row table with Df, %Dev ,Lambda. The iris dataset only has 4 predictors. How do I translate this output into coefficients?


Solution

  • You can get all the coefficients (for each lambda tested) in a dataframe using the tidy() function.

    library(tidymodels)
    #> ── Attaching packages ────────────────────────────────────────── tidymodels 0.1.0 ──
    #> ✓ broom     0.5.6      ✓ recipes   0.1.12
    #> ✓ dials     0.0.6      ✓ rsample   0.0.6 
    #> ✓ dplyr     0.8.5      ✓ tibble    3.0.1 
    #> ✓ ggplot2   3.3.0      ✓ tune      0.1.0 
    #> ✓ infer     0.5.1      ✓ workflows 0.1.1 
    #> ✓ parsnip   0.1.1      ✓ yardstick 0.0.6 
    #> ✓ purrr     0.3.4
    #> ── Conflicts ───────────────────────────────────────────── tidymodels_conflicts() ──
    #> x purrr::discard()  masks scales::discard()
    #> x dplyr::filter()   masks stats::filter()
    #> x dplyr::lag()      masks stats::lag()
    #> x ggplot2::margin() masks dials::margin()
    #> x recipes::step()   masks stats::step()
    
    iris_lr <- multinom_reg(
      mode = "classification",
      penalty = NULL,
      mixture = NULL
    ) %>%
      set_engine("glmnet")
    
    iris_fit = iris_lr %>%
      fit(Species ~ . , data = iris)
    
    library(broom)
    
    tidy(iris_fit)
    #> # A tibble: 839 x 6
    #>    class      term            step  estimate lambda dev.ratio
    #>    <chr>      <chr>          <dbl>     <dbl>  <dbl>     <dbl>
    #>  1 setosa     ""                 1  6.41e-16  0.435 -1.21e-15
    #>  2 versicolor ""                 1 -1.62e-15  0.435 -1.21e-15
    #>  3 virginica  ""                 1  9.81e-16  0.435 -1.21e-15
    #>  4 setosa     ""                 2  2.44e- 1  0.396  6.56e- 2
    #>  5 setosa     "Petal.Length"     2 -9.85e- 2  0.396  6.56e- 2
    #>  6 versicolor ""                 2 -1.22e- 1  0.396  6.56e- 2
    #>  7 virginica  ""                 2 -1.22e- 1  0.396  6.56e- 2
    #>  8 setosa     ""                 3  4.62e- 1  0.361  1.20e- 1
    #>  9 setosa     "Petal.Length"     3 -1.89e- 1  0.361  1.20e- 1
    #> 10 versicolor ""                 3 -2.31e- 1  0.361  1.20e- 1
    #> # … with 829 more rows
    

    Created on 2020-05-14 by the reprex package (v0.3.0)