Search code examples
functionlua

Composing two functions in lua


I just started learning lua, so what I'm asking might be impossible.

Now, I have a method that accepts a function:

function adjust_focused_window(fn)
  local win = window.focusedwindow()
  local winframe = win:frame()
  local screenrect = win:screen():frame()
  local f, s = fn(winframe, screenrect)
  win:setframe(f)
end

I have several functions that accept these frames and rectangles (showing just one):

function full_height(winframe, screenrect)
   print ("called full_height for " .. tostring(winframe))
  local f = {
     x = winframe.x,
     y = screenrect.y,
     w = winframe.w,
     h = screenrect.h,
  }
  return f, screenrect
end

Then, I can do the following:

hotkey.bind(scmdalt, '-', function() adjust_focused_window(full_width) end)

Now, how could I compose several functions to adjust_focused_window, without changing it's definition. Something like:

hotkey.bind(scmdalt, '=', function() adjust_focused_window(compose(full_width, full_height)) end)

where compose2 would return a function that accepts the same parameters as full_width and full_height, and internally does something like:

full_height(full_width(...))

Solution

  • As mentioned in the comments, to chain two functions together you can just do:

    function compose(f1, f2)
      return function(...) return f1(f2(...)) end
    end
    

    But what if you want to connect more than 2 functions together? You might ask, is it possible to 'compose' an arbitrary number of functions together?

    The answer is a definite yes -- below I show 3 different approaches for implementing this plus a quick summary of their consequences.

    Iterative Table approach

    The idea here is to call each function in the list one after the other in turn. While doing so, you save the returned results from the previous call into a table and you unpack that table and pass it into the next call.

    function compose1(...)
        local fnchain = check_functions {...}
        return function(...)
            local args = {...}
            for _, fn in ipairs(fnchain) do
                args = {fn(unpack(args))}
            end
            return unpack(args)
        end
    end
    

    The check_functions helper above just checks that the stuff passed in are indeed functions -- raises an error if not. Implementation omitted for brevity.

    +: Reasonably straight-forward approach. Probably what you'd come up with on a first attempt.

    -: Not very efficient on resources. A lot of garbage tables to store results between calls. You also have to deal with packing and unpacking the results.

    Y-Combinator Pattern

    The key insight here is that even though the functions we're calling isn't recursive, it can be made recursive by piggy-backing it on a recursive function.

    function compose2(...)
      local fnchain = check_functions {...}
      local function recurse(i, ...)
        if i == #fnchain then return fnchain[i](...) end
        return recurse(i + 1, fnchain[i](...))
      end
      return function(...) return recurse(1, ...) end
    end
    

    +: Doesn't create extra temporary tables like above. Carefully written to be tail-recursive -- that means no extra stack space needed for calls to long function chains. There's a certain elegance to it.

    Meta-script generation

    With this last approach, you use a lua function that actually generates the exact lua code that performs the function call chain desired.

    function compose3(...)
        local luacode = 
        [[
            return function(%s)
                return function(...)
                    return %s
                end
            end
        ]]
        local paramtable = {}
        local fcount = select('#', ...)
        for i = 1, fcount do
            table.insert(paramtable, "P" .. i)
        end
        local paramcode = table.concat(paramtable, ",")
        local callcode = table.concat(paramtable, "(") ..
                         "(...)" .. string.rep(')', fcount - 1)
        luacode = luacode:format(paramcode, callcode)
        return loadstring(luacode)()(...)
    end
    

    The loadstring(luacode)()(...) probably needs some explaining. Here I chose to encode the function chain as parameter names (P1, P2, P3 etc.) in the generated script. The extra () parenthesis is there to 'unwrap' the nested functions so the inner most function is what's returned. The P1, P2, P3 ... Pn parameters become captured upvalues for each of the functions in the chain eg.

    function(...)
      return P1(P2(P3(...)))
    end
    

    Note, you could also have done this using setfenv but I chose this route just to avoid the breaking change between lua 5.1 and 5.2 on how function environments are set.

    +: Avoids extra intermediate tables like approach #2. Doesn't abuse the stack.

    -: Needs an extra byte-code compile step.