Search code examples
rdplyrvectorizationmutate

Vectorized function in dplyr::mutate and logical operators


I am trying to vectorize a function for the use in dplyr::mutate. For the life of me, I can't get it working. This is what I have been doing:

str_to_seq <- Vectorize(function(x) {
  
  # This function converts text format year ranges (e.g. "1970 - 1979") to 
  # numeric ranges. Handily works with single values and edge cases such as 
  # "- 1920".
  
  res <- stringr::str_extract_all(x, "\\d+") %>% 
    unlist() %>% 
    {seq(dplyr::first(.), dplyr::last(.))}
  
  return(res)
  
}, vectorize.args = "x", SIMPLIFY = F)

year <- c(1970, 1980, 1990, 2000, 2010, 2020)
agegroup <- "1950 - 1959"

testt <- expand.grid(agegroup = agegroup, year = year, stringsAsFactors = F)

testt %>% 
  as_tibble() %>% 
  dplyr::mutate(
    yearminus50 = year - 50,
    statement = all(yearminus50 >= str_to_seq(agegroup)))

The statement column fails with the error message

Error in `dplyr::mutate()`:
ℹ In argument: `statement = all(yearminus50 >= str_to_seq(agegroup))`.
Caused by error:
! 'list' object cannot be coerced to type 'double'
Run `rlang::last_trace()` to see where the error occurred.

I can't get my function str_to_seq to create plain vectors. Output seems to be a list.

statement should be c(FALSE, FALSE, FALSE, FALSE, TRUE, TRUE) as we can see with this brute code:

all(year[1] - 50 >= unlist(str_to_seq(agegroup)[[1]]))
all(year[2] - 50 >= unlist(str_to_seq(agegroup)[[1]]))
all(year[3] - 50 >= unlist(str_to_seq(agegroup)[[1]]))
all(year[4] - 50 >= unlist(str_to_seq(agegroup)[[1]]))
all(year[5] - 50 >= unlist(str_to_seq(agegroup)[[1]]))
all(year[6] - 50 >= unlist(str_to_seq(agegroup)[[1]]))

How can I improve my code so that the line statement = all(yearminus50 >= str_to_seq(agegroup)) would work?

Many thanks.


Solution

  • The problem is not with your function, it's with the expectation that all(..) is going to work with a list-column. We need to sapply (or similar) on the return from str_to_seq.

    However, in case this is "all" that you need, we can extract just the max from agegroup and compare that:

    testt |>
      mutate(
        yearminus50 = year - 50,
        statement = yearminus50 >=
          sapply(strsplit(agegroup, "[- ]+"), function(z) max(as.integer(z)))
      )
    #      agegroup year yearminus50 statement
    # 1 1950 - 1959 1970        1920     FALSE
    # 2 1950 - 1959 1980        1930     FALSE
    # 3 1950 - 1959 1990        1940     FALSE
    # 4 1950 - 1959 2000        1950     FALSE
    # 5 1950 - 1959 2010        1960      TRUE
    # 6 1950 - 1959 2020        1970      TRUE
    

    (Technically in this case, since all numbers are four-digits, one could skip the as.integer and go with string-comparison max which will return the same results here, ala yearminus50 >= sapply(strsplit(agegroup, "[- ]+"), max), but I prefer to keep my number-operations within the number realm for those rare occasions we're looking at age groups from before 1000AD ;-)

    But if you need str_to_seq for other purposes, then

    testt |>
      mutate(
        yearminus50 = year-50,
        statement = yearminus50 >= sapply(str_to_seq(agegroup), max)
      )
    #      agegroup year yearminus50 statement
    # 1 1950 - 1959 1970        1920     FALSE
    # 2 1950 - 1959 1980        1930     FALSE
    # 3 1950 - 1959 1990        1940     FALSE
    # 4 1950 - 1959 2000        1950     FALSE
    # 5 1950 - 1959 2010        1960      TRUE
    # 6 1950 - 1959 2020        1970      TRUE
    

    FWIW, I suggest this is a slightly faster variant:

    str_to_seq2 <- function(x) strsplit(x, "[- ]+") |> lapply(function(z) do.call(seq, as.list(z)))
    bench::mark(yours = str_to_seq(testt$agegroup), mine = str_to_seq2(testt$agegroup), check = FALSE)
    # # A tibble: 2 × 13
    #   expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory time               gc                  
    #   <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list> <list> <list>             <list>              
    # 1 yours       382.3µs  431.6µs     2271.        NA     10.8  1048     5      461ms <NULL> <NULL> <bench_tm [1,053]> <tibble [1,053 × 3]>
    # 2 mine         43.1µs   49.8µs    19360.        NA     20.4  8561     9      442ms <NULL> <NULL> <bench_tm [8,570]> <tibble [8,570 × 3]>
    

    (This relative performance holds even when testt has 60Ki rows.)