Search code examples
rdplyrfiltergroup-by

Why filter() doesn't work properly when there are multiple categories in one column?


I have the following data frame:

set.seed(3994)
val <- seq(1:60)
cat <- rep(c("A", "B", "C"), times=c(30, 20, 10))
date <- as.Date(sample(seq(as.Date('2000/01/01'), as.Date('2020/01/01'), by="day"), 60))

df <- data.frame(val, cat, date)
df <- df %>%  
  arrange(cat, val)

I want to remove the top and bottom x percent of the data for each category (when data is sorted). For instance, top 2% and bottom 2% of category "A","B" and "C". I have the following code:

remove_percentile_by_category <- function(data, measurement_column, category_column, date_column, percent_to_remove) {
  # Sort the data by category and measurement (assuming it's already sorted by category)
  data_sorted <- data %>%
    arrange({{category_column}}, {{measurement_column}})

  # Calculate the number of rows for each category
  summary_data <- data_sorted %>%
    group_by({{category_column}}) %>%
    summarize(n = n())
  print(n)
  # Calculate the number of rows to remove for each category
  summary_data <- summary_data %>%
    mutate(
      n_to_remove = ceiling(n * percent_to_remove / 100)
    )
  print(summary_data)
  # Remove the top and bottom rows for each category
  data_filtered <- data_sorted %>%
    group_by({{category_column}}) %>%
    filter(row_number() > summary_data$n_to_remove & row_number() <= n() - summary_data$n_to_remove) %>%
    ungroup()

  return(data_filtered %>%
           arrange({{category_column}}, {{date_column}}))
}

I know that the calculation of the number of rows to be deleted is correct, but when it comes to the following line of code, the issue happens.

filter(row_number() > summary_data$n_to_remove & row_number() <= n() - summary_data$n_to_remove)

Question: The above line of code doesn't assign the correct n_to_remove to each category. For example, it should remove the top and bottom 2 rows (4 rows total) of category A, instead, it only removes one row from top and bottom (2 rows total). What am I doing wrong?

PS. I have already asked this question. The solution provided for that question is correct, but gives me some weird errors due to my data structure. I gave up on that solution and also my original code provided in that question. I developed this new code which works fine with my data, but it doesn't correctly accomplish its intended task. Any help much be much appreciated.


Solution

  • Here is one approach to make your code work which avoids creating any additional dataframes:

    library(dplyr, warn = FALSE)
    remove_percentile_by_category <- function(data,
                                              measurement_column,
                                              category_column,
                                              date_column,
                                              percent_to_remove) {
      data %>%
        arrange({{ category_column }}, {{ measurement_column }}) %>%
        add_count({{ category_column }}) %>%
        mutate(n_to_remove = ceiling(n * percent_to_remove / 100)) %>%
        filter(row_number() > n_to_remove,
          row_number() <= n() - n_to_remove,
          .by = {{ category_column }}
        ) %>%
        select(-n, -n_to_remove)
    }
    
    count(df, cat)
    #>   cat  n
    #> 1   A 30
    #> 2   B 20
    #> 3   C 10
    df_filtered <- remove_percentile_by_category(df, val, cat, percent_to_remove = 10)
    count(df_filtered, cat)
    #>   cat  n
    #> 1   A 24
    #> 2   B 16
    #> 3   C  8