The goal is to calculate the weighted average by group with a window of 3 rows and weights 3, 2, 1 in order of the most recent rows. This is similar to the question here, but the weights are not given by a column. Also, I really would like to use frollsum()
because I'm working with a lot of data and need it to be performant.
I have a solution using frollapply()
:
library(data.table)
# Your data
set.seed(1)
DT <- data.table(group = rep(c(1, 2), each = 10), value = round(runif(n = 20, 1, 5)))
weights <- 1:3
k <- 3
weighted_average <- function(x) {
sum(x * weights[1:length(x)]) / sum(weights[1:length(x)])
}
# Apply rolling weighted average
DT[, wtavg := shift(frollapply(value, k, weighted_average, align = "right", fill = NA)),
by = group]
DT
#> group value wtavg
#> 1: 1 2 NA
#> 2: 1 2 NA
#> 3: 1 3 NA
#> 4: 1 5 2.500000
#> 5: 1 2 3.833333
#> 6: 1 5 3.166667
#> 7: 1 5 4.000000
#> 8: 1 4 4.500000
#> 9: 1 4 4.500000
#> 10: 1 1 4.166667
#> 11: 2 2 NA
#> 12: 2 2 NA
#> 13: 2 4 NA
#> 14: 2 3 3.000000
#> 15: 2 4 3.166667
#> 16: 2 3 3.666667
#> 17: 2 4 3.333333
#> 18: 2 5 3.666667
#> 19: 2 3 4.333333
#> 20: 2 4 3.833333
Created on 2023-11-27 with reprex v2.0.2
Probably not the optimal way (I would look into Rcpp) but you could get a significant speed up using simply frollsum thrice:
shift((frollsum(value, 3) + frollsum(value, 2) + frollsum(value, 1)) / 6)
Note that frollsum(value, 1)
could be replaced by value
.
Another (seemingly) faster and simpler alternative:
(c(shift(value, 3) + 2 * shift(value, 2) + 3 * shift(value, 1)) / 6,
Benchmarking
set.seed(1)
n = 1000000
groups = 1:1000
DT <- data.table(group = rep(groups, each = n/length(groups)), value = round(runif(n = n, 1, 5)))
bench::mark(
A = {
DT[, shift(frollapply(value, k, weighted_average, align = "right", fill = NA)),
by = group]
},
B = {
DT[, shift((frollsum(value, 3) + frollsum(value, 2) + value) / 6),
by = group]
}
)
# expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time
# 1 A 1.75s 1.75s 0.570 46.7MB 34.2 1 60 1.75s
# 2 B 80.38ms 85.72ms 9.24 77.4MB 12.9 5 7 541.13ms