Search code examples
roopr-s4

Define S4 class inheriting from function


I'm trying to write an S4 class that specifically returns a numeric vector of the same length as the input. I think I'm close; the problem I'm having now is that I can only create new classes from functions that live in my GlobalEnv.

library(S4Vectors)

setClass("TransFunc", contains = c("function"), prototype = function(x) x)

TransFunc <- function(x) {
  if (missing(x)) return(new("TransFunc"))
  new2("TransFunc", x)
}

.TransFunc.validity <- function(object) {
  msg <- NULL
  if (length(formals(object)) > 1) {
    msg <- c(msg, "TransFunc must only have one argument.")
  }
  res1 <- object(1:5)
  res2 <- object(1:6)
  if (length(res1) != 5 || length(res2) != 6) {
    msg <- c(msg, "TransFunc output length must equal input length.")
  }
  if (!class(res1) %in% c("numeric", "integer")) {
    msg <- c(msg, "TransFunc output must be numeric for numeric inputs.")
  }
  if (is.null(msg)) return(TRUE)
  msg
}

setValidity2(Class = "TransFunc", method = .TransFunc.validity)

mysqrt <- TransFunc(function(x) sqrt(x))
mysqrt <- TransFunc(sqrt) ## Errors... why??
## Error in initialize(value, ...) : 
##   'initialize' method returned an object of class “function” instead 
##   of the required class “TransFunc”

The benefit to having a class inherit from function directly is the ability to use them as regular functions:

mysqrt(1:5)
## [1] 1.000000 1.414214 1.732051 2.000000 2.236068 
body(mysqrt) <- expression(sqrt(x)^2)
mysqrt(1:10)
##  [1]  1  2  3  4  5  6  7  8  9 10

Why does it error when passing functions outside the global env?


Solution

  • It does not work for sqrt because sqrt is primitive.

    I am not aware of any functions that take only one argument and aren't primitive. Therefore I cut your validity down to demonstrate how your code works with other functions from the preloaded packages:

     #using your class definition and counstructor
     .TransFunc.validity <- function(object) {
       msg <- NULL
       res1 <- object(1:5)
       if (!class(res1) %in% c("numeric", "integer")) {
         msg <- c(msg, "TransFunc output must be numeric for numeric     inputs.")
       }
       if (is.null(msg)) return(TRUE)
       msg
      }  
    
      setValidity2(Class = "TransFunc", method = .TransFunc.validity)
    

    Here are the results for the default version of mean

    mymean <- TransFunc(mean.default)
    mymean(1:5)
    [1] 3
    

    Here is a workaround by modifying initialize for your class to catch primitives and turn them into closures:

    #I modified the class definition to use slots instead of prototype
    setClass("TransFunc", contains = c("function"))
    
    TransFunc <- function(x) {
    if (missing(x)) return(new("TransFunc"))
    new2("TransFunc", x)
    }
     
    # Keeping your validity I changed initilalize to:
    
     setMethod("initialize", "TransFunc",
          function(.Object, .Data = function(x) x , ...) {
              if(typeof(.Data) %in% c("builtin", "special"))
                        .Object <- callNextMethod(.Object, function(x) return(.Data(x)),...)
                  
              else 
                 .Object <- callNextMethod(.Object, .Data, ...)
                                                  
              
              .Object                                    
                                                  
          })     
    

    I got the following results

    mysqrt <- TransFunc(sqrt)
    mysqrt(1:5)
    [1] 1.000000 1.414214 1.732051 2.000000    2.236068
    

    EDIT:
    in the comments @ekoam proposes a more general version of initilaize for your class:

    setMethod("initialize", "TransFunc", function(.Object, ...) 
     {maybe_transfunc <- callNextMethod();
          if (is.primitive(maybe_transfunc)) 
              [email protected] <- maybe_transfunc 
          else .Object <- maybe_transfunc; 
     .Object})  
    

    EDIT 2:

    The approach given by @ekoam doesn't maintain the new class. For example:

    mysqrt <- TransFunc(sqrt)
    mysqrt
    # An object of class "TransFunc"
    # function (x)  .Primitive("sqrt")
    mysqrt
    # function (x)  .Primitive("sqrt")
    

    The first proposed method DOES work and maintains the new class. As discussed in the comments, another approach is to catch primitives during the constructor, rather than creating a custom initialize method:

    library(pryr)
    TransFunc <- function(x) {
      if (missing(x)) return(new("TransFunc"))
      if (is.primitive(x)) {
        f <- function(y) x(y)
        # This line isn't strictly necessary, but the actual call
        # will be obscured and printed as 'x(y)' requiring something
        # like pryr::unenclose() to understand the behavior. 
        f <- make_function(formals(f), substitute_q(body(f), environment(f)))
      } else {
        f <- x
      }
      new2("TransFunc", f)
    }