Search code examples
rcorrelationr-caret

Remove highly correlated variables from multi correlated data


I have a very large data frame with more than 200 variables.
I am running a cor() to have a matrix with the correlation values and then I use the function caret::findCorrelation, calling findCorrelation(x, cutoff = 0.8) to find the highly correlated variables. After, I want to remove the variables that are highly correlated from my data one by one until I have no more highly correlated variables.

Because some variables are highly correlated with more than 40 other variables, I did run a feature importance analysis (using Boruta(), package Boruta) to determine the importance of these variables and base on the importance I started removing one variable at a time (started with the lowest mean importance variable (lowest meanImp).

This is the process I tried to code - the first iteration is with the variable with the lowest meanImp, check if that variable is highly correlated (abs(correlation indices) >= 0.8), if so remove from the data if not keep, then I update the correlation matrix by running cor() and then go for the next variable with the second lowest meanImp, so I did that until no more highly correlated variable is found in the matrix/data.

This is the code I used:

remove_highly_correlated <- function(data, confirmed_sorted, cutoff = 0.80) {
  removed_vars <- character(0)
  
  while (TRUE) {
    removed_this_iteration <- FALSE  
    
    for (i in 1:nrow(confirmed_sorted)) {
      var <- as.character(confirmed_sorted$variables[i])
      
      if (var %in% colnames(data)) {
        cor_matrix <- cor(data)  
        hc <- findCorrelation(cor_matrix, cutoff = cutoff)
        
        if (length(hc) > 0 && var %in% rownames(cor_matrix)[hc]) {
          message(paste("Variable", var, "is highly correlated, removing..."))
          removed_vars <- c(removed_vars, var)  
          data <- data[, -which(colnames(data) == var)]  
          removed_this_iteration <- TRUE  
          break 
        }
      }
    }
    
    if (!removed_this_iteration) {
      break  
    }
  }
  
  return(list(data = data, removed_vars = removed_vars))
}

result <- remove_highly_correlated(data5, confirmed_sorted)

data5_filtered <- as.data.frame(result$data)
removed_variables <- as.data.frame(result$removed_vars)

cor_d5 = round(cor(data5_filtered), 4)
hc = findCorrelation(cor_d5, cutoff = 0.80, verbose = TRUE)

where confirmed_sort is the data with the mean importance and data5 is my main data (I provided some short version of the data below).

This code worked somehow because after the code is done I tested if my final data still have highly correlated variables and I still have 6 variables that are in the confirmed_sort data.

The data provided below is just an example and it will not produce remaining highly correlated variables in the data.

blue <- c(0.57, 0.76, 0.78, 0.53, 0.26, 0.27, 0.32, 0.20, 0.63, 0.68, 0.69, 0.69, 0.35, 0.51, 0.39, 0.57, 0.67, 0.63, 0.66, 0.61, 0.54, 0.51, 0.56, 0.59, 0.52, 0.40, 0.39, 0.46, 0.82, 0.84, 0.83, 0.52, 0.59, 0.70, 0.61, 0.83)
red <- c(0.14, 0.11, 0.15, 0.17, 0.18, 0.17, 0.16, 0.07, 0.07, 0.11, 0.12, 0.10, 0.27, 0.19, 0.23, 0.19, 0.10, 0.11, 0.09, 0.10, 0.17, 0.23, 0.23, 0.22, 0.24, 1.00, 0.88, 0.64, 0.11, 0.12, 0.14, 0.56, 0.54, 0.36, 0.53, 0.13)
purple <- c(0.80, 0.84, 0.79, 0.76, 0.75, 0.76, 0.77, 0.59, 0.90, 0.84, 0.83, 0.86, 0.64, 0.73, 0.68, 0.73, 0.85, 0.83, 0.86, 0.85, 0.76, 0.69, 0.69, 0.71, 0.68, 0.00, 0.09, 0.28, 0.84, 0.83, 0.80, 0.35, 0.36, 0.55, 0.38, 0.81)
pink <- c(0.67, 0.73, 0.66, 0.62, 0.63, 0.63, 0.64, 0.49, 0.84, 0.74, 0.73, 0.78, 0.51, 0.58, 0.52, 0.59, 0.74, 0.72, 0.77, 0.75, 0.64, 0.54, 0.55, 0.57, 0.54, 0.00, 0.06, 0.18, 0.74, 0.73, 0.68, 0.25, 0.24, 0.41, 0.26, 0.69)
orange <- c(0.14, 0.11, 0.15, 0.17, 0.18, 0.17, 0.16, 0.07, 0.07, 0.11, 0.12, 0.10, 0.27, 0.19, 0.23, 0.19, 0.10, 0.11, 0.09, 0.10, 0.17, 0.23, 0.23, 0.22, 0.24, 1.00, 0.88, 0.64, 0.11, 0.12, 0.14, 0.56, 0.54, 0.36, 0.53, 0.13)
yellow <- c(0.20, 0.16, 0.21, 0.24, 0.25, 0.24, 0.23, 0.41, 0.10, 0.16, 0.17, 0.14, 0.36, 0.27, 0.32, 0.27, 0.15, 0.17, 0.14, 0.15, 0.24, 0.31, 0.31, 0.29, 0.32, 1.00, 0.91, 0.72, 0.16, 0.17, 0.20, 0.65, 0.64, 0.45, 0.62, 0.19)

data5 <- data.frame(blue, red, purple, pink, orange, yellow)

variables <- c("yellow", "purple", "blue", "green", "pink", "orange", "red")
meanImp <- c(10.07, 9.40, 9.31, 7.51, 7.49, 6.82, 6.65)

confirmed_sorted<- data.frame(variables, meanImp)

If I am not missing anything in the function code, any ideas why I still have highly correlated variables in my data. Any other method used to remove highly variables from a data is also welcome. Thank you.


Solution

  • Break the function into simpler functions. Each of them does one thing only.
    Like it is said in the discussion in comments to the question, first Boruta is called to determine the variable importance of the variables still in the data set. Then the highly correlated ones are found and the one with least mean importance removed. This repeats until o variable is removed.

    The tests are run with data sets taken from the examples in help("Boruta").

    library(caret)
    library(Boruta)
    
    # X - regressors, can be a data.frame
    # Y - response, a 1-dim vector
    runBoruta <- function(X, Y, verbose = FALSE, ...) {
      vars <- names(X)
      Brt <- Boruta(X, Y, ...)
      if(verbose) {
        message("Running 'Boruta' algorithm")
        print(Brt)
      }
      i_meanImp <- nrow(Brt$ImpHistory)
      meanImp <- Brt$ImpHistory[i_meanImp, vars]
      meanImp <- meanImp[is.finite(meanImp)] |> sort()
      meanImp |>
        as.data.frame() |> 
        cbind(variables = names(meanImp))
    }
    remove_one_var <- function(X, confirmed_sorted, cutoff = 0.80) {
      cor_matrix <- cor(X)
      hc <- findCorrelation(cor_matrix, cutoff = cutoff)
      if (length(hc) > 0) {
        i <- confirmed_sorted$meanImp[hc] |> which.min()
        v <- confirmed_sorted$variables[ hc[i] ]
        message(paste("Variable", v, "is highly correlated, removing..."))
        X <- X[, -hc[i] ]  
      } else v <- character(0L)
      return(list(data = X, removed = v))
    }
    remove_highly_correlated <- function(data, resp, cutoff, verbose = FALSE) {
      work <- data[names(data) != resp]
      Y <- data[[resp]]
      rmvd <- character(0L)
      repeat {
        confirmed_sorted <- runBoruta(work, Y, pValue = 0.05, verbose = verbose)
        work <- work[confirmed_sorted$variables]
        result <- remove_one_var(work, confirmed_sorted, cutoff)
        if(length(result$removed) == 0) break
        work <- result$data
        rmvd <- c(rmvd, result$removed)
      }
      list(data = result$data, removed_vars = rmvd)
    }
    

    Tests

    These are the tests I ran.

    1st test

    In the first test with data set srx,

    • Boruta discards 3 features as not important, N1, N2, N3;
    • Then findCorrelation gets one highly correlated feature to remove, nA:
    • and then the function iterates until no features are deemed important.
    data(srx, package = "Boruta")
    srx[] <- lapply(srx, as.integer)
    cutoff <- 0.8
    
    set.seed(2024)
    result1 <- remove_highly_correlated(srx, "Y", cutoff = cutoff, verbose = TRUE)
    #> Running 'Boruta' algorithm
    #> Boruta performed 18 iterations in 0.282912 secs.
    #>  5 attributes confirmed important: A, AnB, AoB, B, nA;
    #>  3 attributes confirmed unimportant: N1, N2, N3;
    #> Variable nA is highly correlated, removing...
    #> Running 'Boruta' algorithm
    #> Boruta performed 7 iterations in 0.0775032 secs.
    #>  4 attributes confirmed important: A, AnB, AoB, B;
    #>  No attributes deemed unimportant.
    

    2nd test

    
    iris.extended <- data.frame(iris, apply(iris[,-5], 2L, sample))
    names(iris.extended)[6:9] <- paste("Nonsense", 1:4, sep = "")
    
    set.seed(2024)
    result2 <- remove_highly_correlated(iris.extended, "Species", cutoff = cutoff)
    #> Variable Petal.Length is highly correlated, removing...
    #> Variable Petal.Width is highly correlated, removing...
    
    result2$data |> head()
    #>   Nonsense2 Sepal.Width Sepal.Length
    #> 1       3.1         3.5          5.1
    #> 2       4.1         3.0          4.9
    #> 3       3.8         3.2          4.7
    #> 4       3.0         3.1          4.6
    #> 5       3.1         3.6          5.0
    #> 6       2.9         3.9          5.4
    result2$removed_vars
    #> [1] "Petal.Length" "Petal.Width"
    

    Created on 2024-03-21 with reprex v2.1.0