Search code examples
rr-marginaleffects

marginaleffects::plot_predictions, returning geom_point when I want geom_line


I have the following code:

#R version 4.2.2

library(dplyr)
library(marginaleffects) #version 0.12.0

model <- df %>% 
  filter(!is.na(election)) %>% 
  lm(as.numeric(election) ~ as.numeric(cat)*age_cat, data = .)

marginaleffects::plot_predictions(model, condition = c("cat", "age_cat")) + theme_bw()

Which produces a graph something like the below (i removed a bunch of aesthetic changes in the code above for clarity):

enter image description here

This is essentially a geom_point layered on top of a geom_errorbar, whereas what I want is the equivalent of geom_smooth. All the documentation and tutorials/examples I can find online show plot_predictions producing the equivalent of geom_smooth. Does anyone know where i'm going wrong?

I'd be happy to share the data that goes into this plot but in its raw form its 400k observations, so I figure the best way is to share the intermediate output that marginaleffects makes before plotting. If anyone can advise how I produce that, or a better way of sharing data, i'd be very happy to oblige!

EDIT:

I've filtered the data down to two categories of age_cat and only kept two observations per value of cat and age_cat. The below will produce a silly-looking plot, but hopefully it makes my problem minimally reproducible:

structure(list(election = c("0", "1", "1", "0", "1", 
"1", "1", "0", "1", "0", "0", "1", "1", "1", "0", "1", "1", "0", 
"1", "0", "1", "0", "1", "1", "0", "1", "0", "1", "1", "1", "0", 
"1", "1", "1", "0", "0", "1", "1", "0", "0", "1", "0", "0", "1", 
"1", "0", "0", "1", "1", "1", "1", "0", "0", "0", "0", "0", "0", 
"0", "1", "1", "1", "1", "0", "1", "0", "1", "0", "1", "1", "1", 
"0", "1", "0", "0", "1", "0", "0", "1", "1", "0", "1", "1", "0", 
"0", "1", "0", "0", "0", "0", "0", "0", "1", "1", "0", "0", "0", 
"0", "0", "1", "0", "0", "0", "0", "0", "1", "0", "0", "1", "1", 
"0", "0", "1", "0", "0", "0", "0", "0", "1", "0", "0", "1", "0", 
"0", "0"), cat = structure(c(0, 0, 0, 0, 1, 1, 
1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 
6, 7, 7, 7, 7, 8, 8, 8, 8, 9, 9, 9, 9, 10, 10, 10, 10, 11, 11, 
11, 11, 12, 12, 12, 12, 13, 13, 13, 13, 14, 14, 14, 14, 15, 15, 
15, 15, 16, 16, 16, 16, 17, 17, 17, 17, 18, 18, 18, 18, 19, 19, 
19, 19, 20, 20, 20, 20, 21, 21, 21, 21, 22, 22, 22, 22, 23, 23, 
23, 23, 24, 24, 24, 24, 25, 25, 25, 25, 26, 26, 26, 26, 27, 27, 
27, 27, 28, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30), class = "difftime", units = "days"), 
    age_cat = structure(c(1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 
    1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 
    1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 
    2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 
    2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 
    1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 
    1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 
    2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 
    2L, 2L, 1L, 1L, 2L, 2L, 1L, 1L, 2L, 2L), levels = c("18 - 24", 
    "25 - 34", "35 - 44", "45 - 54", "55 - 64", "65+"), class = "factor")), class = c("grouped_df", 
"tbl_df", "tbl", "data.frame"), row.names = c(NA, -124L), groups = structure(list(
    cat = structure(c(0, 0, 1, 1, 2, 2, 3, 3, 4, 
    4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 
    13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 
    20, 21, 21, 22, 22, 23, 23, 24, 24, 25, 25, 26, 26, 27, 27, 
    28, 28, 29, 29, 30, 30), class = "difftime", units = "days"), 
    age_cat = structure(c(1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L, 1L, 
    2L, 1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L, 
    1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L, 1L, 
    2L, 1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L, 
    1L, 2L, 1L, 2L, 1L, 2L, 1L, 2L), levels = c("18 - 24", "25 - 34", 
    "35 - 44", "45 - 54", "55 - 64", "65+"), class = "factor"), 
    .rows = structure(list(1:2, 3:4, 5:6, 7:8, 9:10, 11:12, 13:14, 
        15:16, 17:18, 19:20, 21:22, 23:24, 25:26, 27:28, 29:30, 
        31:32, 33:34, 35:36, 37:38, 39:40, 41:42, 43:44, 45:46, 
        47:48, 49:50, 51:52, 53:54, 55:56, 57:58, 59:60, 61:62, 
        63:64, 65:66, 67:68, 69:70, 71:72, 73:74, 75:76, 77:78, 
        79:80, 81:82, 83:84, 85:86, 87:88, 89:90, 91:92, 93:94, 
        95:96, 97:98, 99:100, 101:102, 103:104, 105:106, 107:108, 
        109:110, 111:112, 113:114, 115:116, 117:118, 119:120, 
        121:122, 123:124), ptype = integer(0), class = c("vctrs_list_of", 
    "vctrs_vctr", "list"))), row.names = c(NA, -62L), class = c("tbl_df", 
"tbl", "data.frame"), .drop = TRUE))

Solution

  • plot_predictions() uses the following geoms by default:

    • geom_line() when the x-axis variable is recognized as numeric
    • geom_pointrange() when the x-axis is categorical (factor, character) or not recognized as numeric

    In your example dataset, the cat variable (on the x-axis) is of type difftime, which is not recognized as numeric:

    is.numeric(df$cat)
    # [1] FALSE
    
    class(df$cat)
    # [1] "difftime"
    

    We have a number of solutions in this case:

    1. Convert cat to numeric with as.numeric() in df before fitting the model.
    2. Call plot_predictions() with the draw=FALSE argument, and feed the output data frame to ggplot().

    BTW, I am the marginaleffects maintainer. If you feel that the package behavior should be different for difftime columns, please open a GitHub issue for discussion.