Search code examples
rmultidimensional-array

Looking for faster way to implement logSumExp across multidimensional array


I have a line in some R code I am writing that is quite slow. It applies logSumExp across a 4 dimensional array using the apply command. I'm wondering are there ways to speed it up!

Reprex: (this might take 10seconds or more to run)

library(microbenchmark)
library(matrixStats)

array4d <- array( runif(5*500*50*5 ,-1,0),
                  dim = c(5, 500, 50, 5) )
microbenchmark(
    result <- apply(array4d, c(1,2,3), logSumExp)
)

Any advice appreciated!


Solution

  • The otherwise great solution from @Miff was causing my code to crash with certain datasets as infinities were being produced which I eventually figured out was due to an underflow problem which can be avoided by using the 'logSumExp trick': https://www.xarg.org/2016/06/the-log-sum-exp-trick-in-machine-learning/

    Taking inspiration from @Miff 's code, and the R apply() function, I made a new function to gives faster calculations while avoiding the underflow issue. Not quite as fast as @Miff 's solution however. Posting in case it helps others

    apply_logSumExp <- function (X) {
        MARGIN <- c(1, 2, 3) # fixing the margins as have not tested other dims
        dl <- length(dim(X)) # get length of dim
        d <- dim(X) # get dim
        dn <- dimnames(X) # get dimnames
        ds <- seq_len(dl) # makes sequences of length of dims
        d.call <- d[-MARGIN]    # gets index of dim not included in MARGIN
        d.ans <- d[MARGIN]  # define dim for answer array
        s.call <- ds[-MARGIN] # used to define permute
        s.ans <- ds[MARGIN]     # used to define permute
        d2 <- prod(d.ans)   # length of results object
        
        newX <- aperm(X, c(s.call, s.ans)) # permute X such that dims omitted from calc are first dim
        dim(newX) <- c(prod(d.call), d2) # voodoo. Preserves ommitted dim dimension but collapses the rest into 1
        
        maxes <- colMaxs(newX)
        ans <- maxes + log(colSums(exp( sweep(newX, 2, maxes, "-"))) )
        ans <- array(ans, d.ans)
        
        return(ans)
    }
    
     > microbenchmark(
    +     res1 <- apply(array4d, c(1,2,3), logSumExp),
    +     res2 <- log(rowSums(exp(array4d), dims=3)),
    +     res3 <- apply_logSumExp(array4d)
    + )
    Unit: milliseconds
                                              expr        min         lq       mean    median        uq       max
     res1 <- apply(array4d, c(1, 2, 3), logSumExp) 176.286670 213.882443 247.420334 236.44593 267.81127 486.41072
      res2 <- log(rowSums(exp(array4d), dims = 3))   4.664907   5.821601   7.588448   5.97765   7.47814  30.58002
                  res3 <- apply_logSumExp(array4d)  12.119875  14.673011  19.635265  15.20385  18.30471  90.59859
     neval cld
       100   c
       100 a  
       100  b