Search code examples
rmachine-learningneural-network

Package ‘neuralnet’ in R, rectified linear unit (ReLU) activation function?


I am trying to use activation functions other than the pre-implemented "logistic" and "tanh" in the R package neuralnet. Specifically, I would like to use rectified linear units (ReLU) f(x) = max{x,0}. Please see my code below.

I believe I can use custom functions if defined by (for example)

custom <- function(a) {x*2}

but if I set max(x,0) instead of x*2 then R tells me that 'max is not in the derivatives table', and same for '>' operator. So I am looking for a sensible workaround as I am thinking numerical integration of max in this case wouldn't be an issue.

nn <- neuralnet(
  as.formula(paste("X",paste(names(Z[,2:10]), collapse="+"),sep="~")),
  data=Z[,1:10], hidden=5, err.fct="sse",
  act.fct="logistic", rep=1,
  linear.output=TRUE)

Any ideas? I am a bit confused as I didn't think the neuralnet package would do analytical differentiation.


Solution

  • The internals of the neuralnet package will try to differentiate any function provided to act.fct. You can see the source code here.

    At line 211 you will find the following code block:

    if (is.function(act.fct)) {
        act.deriv.fct <- differentiate(act.fct)
        attr(act.fct, "type") <- "function"
    }
    

    The differentiate function is a more complex use of the deriv function which you can also see in the source code above. Therefore, it is currently not possible to provide max(0,x) to the act.fct. It would require an exception placed in the code to recognize the ReLU and know the derivative. It would be a great exercise to get the source code, add this in and submit to the maintainers to expand (but that may be a bit much).

    However, regarding a sensible workaround, you could use softplus function which is a smooth approximation of the ReLU. Your custom function would look like this:

    custom <- function(x) {log(1+exp(x))}
    

    You can view this approximation in R as well:

    softplus <- function(x) log(1+exp(x))
    relu <- function(x) sapply(x, function(z) max(0,z))
    
    x <- seq(from=-5, to=5, by=0.1)
    library(ggplot2)
    library(reshape2)
    
    fits <- data.frame(x=x, softplus = softplus(x), relu = relu(x))
    long <- melt(fits, id.vars="x")
    ggplot(data=long, aes(x=x, y=value, group=variable, colour=variable))+
      geom_line(size=1) +
      ggtitle("ReLU & Softplus") +
      theme(plot.title = element_text(size = 26)) +
      theme(legend.title = element_blank()) +
      theme(legend.text = element_text(size = 18))
    

    enter image description here