Search code examples
rglmnet

how to plot the correct labels in glmnet?


Consider this example

library(dplyr)
library(tibble)
library(glmnet)
library(quanteda)

dtrain <- data_frame(text = c("Chinese Beijing Chinese",
                              "Chinese Chinese Shanghai",
                              "this is china",
                              "china is here",
                              'hello china',
                              "Chinese Beijing Chinese",
                              "Chinese Chinese Shanghai",
                              "this is china",
                              "china is here",
                              'hello china',
                              "Kyoto Japan",
                              "Tokyo Japan Chinese",
                              "Kyoto Japan",
                              "Tokyo Japan Chinese",
                              "Kyoto Japan",
                              "Tokyo Japan Chinese",
                              "Kyoto Japan",
                              "Tokyo Japan Chinese",
                              'japan'),
                     class = c(1, 1, 1, 1, 1,1,1,1,1,1,1,0,0,0,0,0,0,0,0))

I use quanteda to get a document term matrix from this dataframe

dtm <- quanteda::dfm(dtrain$text)
> dtm
Document-feature matrix of: 19 documents, 11 features (78.5% sparse).
19 x 11 sparse Matrix of class "dfm"
        features
docs     chinese beijing shanghai this is china here hello kyoto japan tokyo
  text1        2       1        0    0  0     0    0     0     0     0     0
  text2        2       0        1    0  0     0    0     0     0     0     0
  text3        0       0        0    1  1     1    0     0     0     0     0
  text4        0       0        0    0  1     1    1     0     0     0     0
  text5        0       0        0    0  0     1    0     1     0     0     0

I can fit a lasso regression using glmnet easily:

fit <- glmnet(dtm, y = as.factor(dtrain$class), alpha = 1, family = 'binomial')

However, plotting the fit does not show the labels of the dtm matrix (and I only see three curves). What is wrong here?

enter image description here


Solution

  • As far as I understand it, what the plot is giving you is the value of the coefficients associated to the words that are significant. In your case, words 9-11, which are Kyoto, Japan and Tokyo (I can see that from the dtm table). This normal plot library does not have I think what you say you would like to do. Instead, you can use library(plotmo) as following:

    library(dplyr)
    library(tibble)
    library(glmnet)
    library(quanteda)
    library(plotmo)
    dtrain <- data_frame(text = c("Chinese Beijing Chinese",
                                  "Chinese Chinese Shanghai",
                                  "this is china",
                                  "china is here",
                                  'hello china',
                                  "Chinese Beijing Chinese",
                                  "Chinese Chinese Shanghai",
                                  "this is china",
                                  "china is here",
                                  'hello china',
                                  "Kyoto Japan",
                                  "Tokyo Japan Chinese",
                                  "Kyoto Japan",
                                  "Tokyo Japan Chinese",
                                  "Kyoto Japan",
                                  "Tokyo Japan Chinese",
                                  "Kyoto Japan",
                                  "Tokyo Japan Chinese",
                                  'japan'),
                         class = c(1, 1, 1, 1, 1,1,1,1,1,1,1,0,0,0,0,0,0,0,0))
    
    
    dtm <- quanteda::dfm(dtrain$text)
    fit <- glmnet(dtm, y = as.factor(dtrain$class), alpha = 1, family = 'binomial')
    plot_glmnet(fit, label=3)            # label the 3 biggest final coefs
    

    The image is I hope what you were asking. Cheers !

    Cheers !