Search code examples
rggplot2animationgganimate

Can you have a rolling window filter in gganimate?


I am looking to have each frame of a scatter plot be filtered by another vector with a certain bin width and have it it roll through those. For example I can do this by:

library(ggplot2)
library(gganimate)

#example data
iris <- datasets::iris

#plot x and y
g <- ggplot(iris) + geom_point(aes(x = Petal.Width,y = Petal.Length))

#filter x and y by a third value with a bin width of 2 steping through by 0.5
g + transition_filter(transition_length = 1,
                      filter_length = 1,
                      4 < Sepal.Length & Sepal.Length < 6,
                      4.5 < Sepal.Length & Sepal.Length < 6.5,
                      5 < Sepal.Length & Sepal.Length < 7,
                      5.5 < Sepal.Length & Sepal.Length < 7.5,
                      6 < Sepal.Length & Sepal.Length < 8)

However - writing out each filter condition is tedious, and I would like to filter a different dataset with a ~20 binwidth steping through by 1 over a 300 point range so writing 100+ filters is not practical.

Is there another way to do this?enter image description here


Solution

  • A while ago I wanted this exact function but didn't actually see anything in gganimate to do it, so I wrote something that would get the job done. Below is what I came up with, so I ended up rebuilding gganimate with this function included to avoid using :::.

    I wrote this a while ago so I don't recall the exact intention of each argument at the moment of writing it (ALWAYS REMEMBER TO DOCUMENT YOUR CODE).

    Here is what I recall

    • span : expression that can be evaluated within the data layers
    • size : how much data to be shown at once
    • enter_length/exit_length : Don't exactly recall how it works in relation to each other or size/span
    • range : a subset range
    • retain_data_order : logical - don't remember why this is here (sorry!)
    library(gganimate)
    #> Loading required package: ggplot2
    library(rlang)
    library(tweenr)
    library(stringi)
    
    get_row_event <- gganimate:::get_row_event
    is_placeholder <- gganimate:::is_placeholder
    recast_event_times <- gganimate:::recast_event_times
    recast_times <- gganimate:::recast_times
    
    TransitionSpan <- ggplot2::ggproto('TransitionSpan',
                                       TransitionEvents,
                                       finish_data = function (self, data, params)
                                       {
                                         lapply(data, function(d) {
                                           split_panel <- stri_match(d$group, regex = "^(.+)<(.*)>(.*)$")
                                           if (is.na(split_panel[1]))
                                             return(list(d))
                                           d$group <- match(d$group, unique(d$group))
                                           empty_d <- d[0, , drop = FALSE]
                                           d <- split(d, as.integer(split_panel[, 3]))
                                           frames <- rep(list(empty_d), params$nframes)
                                           frames[as.integer(names(d))] <- d
                                           frames
                                         })
                                       },
                                       setup_params = function(self, data, params) {
                                         # browser()
                                         params$start <- get_row_event(data, params$span_quo, "start")
                                         time_class <- if (is_placeholder(params$start))
                                           NULL
                                         else params$start$class
                                         end_quo <- expr(!!params$span_quo + diff(range(!!params$span_quo))*!!params$size_quo)
                                         params$end <- get_row_event(data, end_quo, "end",
                                                                     time_class)
                                         params$enter_length <- get_row_event(data, params$enter_length_quo,
                                                                              "enter_length", time_class)
                                         params$exit_length <- get_row_event(data, params$exit_length_quo,
                                                                             "exit_length", time_class)
                                         params$require_stat <- is_placeholder(params$start) || is_placeholder(params$end) ||
                                           is_placeholder(params$enter_length) || is_placeholder(params$exit_length)
                                         static = lengths(params$start$values) == 0
                                         params$row_id <- Map(function(st, end, en, ex, s) if (s)
                                           character(0)
                                           else paste(st, end, en, ex, sep = "_"), st = params$start$values,
                                           end = params$end$values, en = params$enter_length$values,
                                           ex = params$exit_length$values, s = static)
                                         params
                                       },
                                       setup_params2 = function(self, data, params, row_vars) {
                                         late_start <- FALSE
                                         if (is_placeholder(params$start)) {
                                           params$start <- get_row_event(data, params$start_quo, 'start', after = TRUE)
                                           late_start <- TRUE
                                         } else {
                                           params$start$values <- lapply(row_vars$start, as.numeric)
                                         }
                                         size <- expr(!!params$size_quo)
                                         
                                         time_class <- params$start$class
                                         if (is_placeholder(params$end)) {
                                           params$end <- get_row_event(data, params$end_quo, 'end', time_class, after = TRUE)
                                         } else {
                                           params$end$values <- lapply(row_vars$end, as.numeric)
                                         }
                                         if (is_placeholder(params$enter_length)) {
                                           params$enter_length <- get_row_event(data, params$enter_length_quo, 'enter_length', time_class, after = TRUE)
                                         } else {
                                           params$enter_length$values <- lapply(row_vars$enter_length, as.numeric)
                                         }
                                         if (is_placeholder(params$exit_length)) {
                                           params$exit_length <- get_row_event(data, params$exit_length_quo, 'exit_length', time_class, after = TRUE)
                                         } else {
                                           params$exit_length$values <- lapply(row_vars$exit_length, as.numeric)
                                         }
                                         times <- recast_event_times(params$start, params$end, params$enter_length, params$exit_length)
                                         params$span_size <- diff(times$start$range)*eval_tidy(size)
                                         
                                         
                                         range <- if (is.null(params$range)) {
                                           low <- min(unlist(Map(function(start, enter) {
                                             start - (if (length(enter) == 0) 0 else enter)
                                           }, start = times$start$values, enter = times$enter_length$values)))
                                           high <- max(unlist(Map(function(start, end, exit) {
                                             (if (length(end) == 0) start else end) + (if (length(exit) == 0) 0 else exit)
                                           }, start = times$start$values, end = times$end$values, exit = times$exit_length$values)))
                                           range  <- c(low, high)
                                         } else {
                                           if (!inherits(params$range, time_class)) {
                                             stop('range must be given in the same class as time', call. = FALSE)
                                           }
                                           as.numeric(params$range)
                                         }
                                         full_length <- diff(range)
                                         frame_time <- recast_times(
                                           seq(range[1], range[2], length.out = params$nframes),
                                           time_class
                                         )
                                         
                                         frame_length <- full_length / params$nframes
                                         rep_frame <- round(params$span_size/frame_length)
                                         lowerl <- c(rep(frame_time[1],rep_frame), frame_time[2:(params$nframes-rep_frame+1)])
                                         upperl <- c(frame_time[1:(params$nframes-rep_frame)], rep(frame_time[params$nframes-rep_frame+1], rep_frame))
                                         start <- lapply(times$start$values, function(x) {
                                           round((params$nframes - 1) * (x - range[1])/full_length) + 1
                                         })
                                         end <- lapply(times$end$values, function(x) {
                                           if (length(x) == 0) return(numeric())
                                           round((params$nframes - 1) * (x - range[1])/full_length) + 1
                                         })
                                         enter_length <- lapply(times$enter_length$values, function(x) {
                                           if (length(x) == 0) return(numeric())
                                           round(x / frame_length)
                                         })
                                         exit_length <- lapply(times$exit_length$values, function(x) {
                                           if (length(x) == 0) return(numeric())
                                           round(x / frame_length)
                                         })
                                         
                                         params$range <- range
                                         params$frame_time <- frame_time
                                         static = lengths(start) == 0
                                         params$row_id <- Map(function(st, end, en, ex, s) if (s) character(0) else paste(st, end, en, ex, sep = '_'),
                                                              st = start, end = end, en = enter_length, ex = exit_length, s = static)
                                         params$lowerl <- lowerl
                                         params$upperl <- upperl
                                         params$frame_span <- upperl - lowerl
                                         params$frame_info <- data.frame(
                                           frame_time = frame_time,
                                           lowerl = lowerl,
                                           upperl = upperl,
                                           frame_span = upperl - lowerl
                                         )
                                         params$nframes <- nrow(params$frame_info)
                                         params
                                       },
                                       expand_panel = function(self, data, type, id, match, ease, enter, exit, params, layer_index) {
                                         #browser()
                                         row_vars <- self$get_row_vars(data)
                                         if (is.null(row_vars))
                                           return(data)
                                         data$group <- paste0(row_vars$before, row_vars$after)
                                         start <- as.numeric(row_vars$start)
                                         end <- as.numeric(row_vars$end)
                                         if (is.na(end[1]))
                                           end <- NULL
                                         enter_length <- as.numeric(row_vars$enter_length)
                                         if (is.na(enter_length[1]))
                                           enter_length <- NULL
                                         exit_length <- as.numeric(row_vars$exit_length)
                                         if (is.na(exit_length[1]))
                                           exit_length <- NULL
                                         data$.start <- start
                                         all_frames <- tween_events(data, c(ease,"linear"),
                                                                    params$nframes, !!start, !!end, c(1, params$nframes),
                                                                    enter, exit, !!enter_length, !!exit_length)
                                         if(params$retain_data_order){
                                           all_frames <- all_frames[order(as.numeric(all_frames$.id)),]
                                         } else {
                                           all_frames <- all_frames[order(all_frames$.start, as.numeric(all_frames$.id)),]
                                         }
                                         all_frames$group <- paste0(all_frames$group, '<', all_frames$.frame, '>')
                                         all_frames$.frame <- NULL
                                         all_frames$.start <- NULL
                                         all_frames
                                       })
    
    transition_span <- function(span, size = 0.5, enter_length = NULL, exit_length = NULL, range = NULL, retain_data_order = T){
      
      span_quo <- enquo(span)
      size_quo <- enquo(size)
      enter_length_quo <- enquo(enter_length)
      exit_length_quo <- enquo(exit_length)
      gganimate:::require_quo(span_quo, "span")
      ggproto(NULL, TransitionSpan,
              params = list(span_quo = span_quo,
                            size_quo = size_quo, range = range, enter_length_quo = enter_length_quo,
                            exit_length_quo = exit_length_quo,
                            retain_data_order = retain_data_order))
      
    }
    g <- ggplot(iris) + 
      geom_point(aes(x = Petal.Width,y = Petal.Length, color = Sepal.Length)) +
      viridis::scale_color_viridis()
    a <- g + transition_span(Sepal.Length, .1, 1, 1)
    animate(a, renderer = gganimate::gifski_renderer())
    

    Created on 2021-08-11 by the reprex package (v2.0.0)