Search code examples

Find dynamic intervals per group with Sparklyr

I have a huge (~10 billion rows) data.frame that looks a bit like this :

data <- data.frame(Person = c(rep("John", 9), rep("Steve", 7), rep("Jane", 4)),
Year = c(1900:1908, 1902:1908, 1905:1908),
Grade = c(c(6,3,4,4,8,5,2,9,7), c(4,3,5,5,6,4,7), c(3,7,2,9)) )

It's a set of 3 Persons, observed at different Years and we have their Grade for the Year in question. I would like to create a variable which, for each grade, returns "a simplified grade". The simplified grade is simply the Grade cutted in different intervals. The difficulty is that the intervals are different by Person. To get the intervals thresholds by Person, I have the following list :

list.threshold <- list(John = c(5,7), Steve = 4, Jane = c(3,5,8))

So the grades of Steve will be cutted in 2 intervals but the ones of Jane in 4 intervals. Here are the results wanted (SimpleGrade) :

    Person  Year  Grade  SimpleGrade
1:   John   1900    6        1
2:   John   1901    3        0
3:   John   1902    4        0
4:   John   1903    4        0
5:   John   1904    8        2
6:   John   1905    5        1
7:   John   1906    2        0
8:   John   1907    9        2
9:   John   1908    7        2
10:  Steve  1902    4        1
11:  Steve  1903    3        0
12:  Steve  1904    5        1
13:  Steve  1905    5        1
14:  Steve  1906    6        1
15:  Steve  1907    4        1
16:  Steve  1908    7        1
17:  Jane   1905    3        1
18:  Jane   1906    7        2
19:  Jane   1907    2        0
20:  Jane   1908    9        3

I will have to find a solution in sparklyr because I'm working with a huge spark table.

In dplyr I would do something like this :


data <- group_by(data, Person) %>% 
mutate(SimpleGrade = cut(Grade, breaks = c(-Inf, list.threshold[[unique(Person)]], Inf), labels = FALSE, right = TRUE, include.lowest = TRUE) - 1)

It works but I'm having trouble converting this solution in sparklyr because of the fact that the thresholds are different per Person. I think I will have to use the ft_bucketizer function. Where I am so far with sparklyr :


spark_tbl <- group_by(spark_tbl, Person) %>%
ft_bucketizer(input_col  = "Grade",
            output_col = "SimpleGrade",
            splits     = c(-Inf, list.threshold[["John"]], Inf))

spark_tbl is only the spark table equivalent of data. It works if I don't change the thresholds and use only the ones of John for example.

Thanks a lot, Tom C.


  • Spark ML Bucketizer can be used only for global operations so it won't work for you. Instead you can create a reference table

    ref <- purrr::map2(names(list.threshold), 
       function(name, brks) purrr::map2(
         c("-Infinity", brks), c(brks, "Infinity"),
         function(low, high) list(
           name = name, 
           low = low,
           high = high))) %>%
       purrr::flatten() %>% 
       bind_rows() %>% 
       group_by(name) %>%
       arrange(low, .by_group = TRUE) %>%
       mutate(simple_grade = row_number() - 1) %>%
       copy_to(sc, .) %>%
       mutate_at(vars(one_of("low", "high")), as.numeric)
    # Source: spark<?> [?? x 4]
      name    low  high simple_grade
      <chr> <dbl> <dbl>        <dbl>
    1 Jane   -Inf     3            0
    2 Jane      3     5            1
    3 Jane      5     8            2
    4 Jane      8   Inf            3
    5 John   -Inf     5            0
    6 John      5     7            1
    7 John      7   Inf            2
    8 Steve  -Inf     4            0
    9 Steve     4   Inf            1

    and then left_join it with the data table:

    sdf <- copy_to(sc, data)
    simplified <- left_join(sdf, ref, by=c("Person" = "name")) %>%
      filter(Grade >= low & Grade < High) %>%
      select(-low, -high)
    # Source: spark<?> [?? x 4]
       Person  Year Grade simple_grade
       <chr>  <int> <dbl>        <dbl>
     1 John    1900     6            1
     2 John    1901     3            0
     3 John    1902     4            0
     4 John    1903     4            0
     5 John    1904     8            2
     6 John    1905     5            1
     7 John    1906     2            0
     8 John    1907     9            2
     9 John    1908     7            2
    10 Steve   1902     4            1
    # … with more rows
    simplified %>% dbplyr::remote_query_plan()
    == Physical Plan ==
    *(2) Project [Person#132, Year#133, Grade#134, simple_grade#15]
    +- *(2) BroadcastHashJoin [Person#132], [name#12], Inner, BuildRight, ((Grade#134 >= low#445) && (Grade#134 < high#446))
       :- *(2) Filter (isnotnull(Grade#134) && isnotnull(Person#132))
       :  +- InMemoryTableScan [Person#132, Year#133, Grade#134], [isnotnull(Grade#134), isnotnull(Person#132)]
       :        +- InMemoryRelation [Person#132, Year#133, Grade#134], StorageLevel(disk, memory, deserialized, 1 replicas)
       :              +- Scan ExistingRDD[Person#132,Year#133,Grade#134]
       +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, string, true]))
          +- *(1) Project [name#12, cast(low#13 as double) AS low#445, cast(high#14 as double) AS high#446, simple_grade#15]
             +- *(1) Filter ((isnotnull(name#12) && isnotnull(cast(high#14 as double))) && isnotnull(cast(low#13 as double)))
                +- InMemoryTableScan [high#14, low#13, name#12, simple_grade#15], [isnotnull(name#12), isnotnull(cast(high#14 as double)), isnotnull(cast(low#13 as double))]
                      +- InMemoryRelation [name#12, low#13, high#14, simple_grade#15], StorageLevel(disk, memory, deserialized, 1 replicas)
                            +- Scan ExistingRDD[name#12,low#13,high#14,simple_grade#15]