Search code examples
juliafunction-binding

Bind function arguments in Julia


Does Julia provide something similar to std::bind in C++? I wish to do something along the lines of:

function add(x, y)
  return x + y
end


add45 = bind(add, 4, 5)
add2 = bind(add, _1, 2)
add3 = bind(add, 3, _2)

And if this is possible does it incur any performance overhead?


Solution

  • As answered here you can obtain this behavior using higher order functions in Julia.

    Regarding the performance. There should be no overhead. Actually the compiler should inline everything in such a situation and even perform constant propagation (so that the code could actually be faster). The use of const in the other answer here is needed only because we are working in global scope. If all this would be used within a function then const is not required (as the function that takes this argument will be properly compiled), so in the example below I do not use const.

    Let me give an example with Base.Fix1 and your add function:

    julia> using BenchmarkTools
    
    julia> function add(x, y)
             return x + y
           end
    add (generic function with 1 method)
    
    julia> add2 = Base.Fix1(add, 10)
    (::Base.Fix1{typeof(add), Int64}) (generic function with 1 method)
    
    julia> y = 1:10^6;
    
    julia> @btime add.(10, $y);
      1.187 ms (2 allocations: 7.63 MiB)
    
    julia> @btime $add2.($y);
      1.189 ms (2 allocations: 7.63 MiB)
    

    Note that I did not define add2 as const and since we are in global scope I need to prefix it with $ to interpolate its value into the benchmarking suite.

    If I did not do it you would get:

    julia> @btime add2.($y);
      1.187 ms (6 allocations: 7.63 MiB)
    

    Which is essentially the same timing and memory use, but does 6 not 2 allocations since in this case add2 is a type-unstable global variable.

    I work on DataFrames.jl, and there using the patterns which we discuss here is very useful. Let me give just one example:

    julia> using DataFrames
    
    julia> df = DataFrame(x = 1:5)
    5×1 DataFrame
     Row │ x
         │ Int64
    ─────┼───────
       1 │     1
       2 │     2
       3 │     3
       4 │     4
       5 │     5
    
    julia> filter(:x => <(2.5), df)
    2×1 DataFrame
     Row │ x
         │ Int64
    ─────┼───────
       1 │     1
       2 │     2
    

    What the operation does is picking rows where values from column :x that are less than 2.5. The key thing to understand here is what <(2.5) does. It is:

    julia> <(2.5)
    (::Base.Fix2{typeof(<), Float64}) (generic function with 1 method)
    

    so as you can see it is similar to what we would have obtained if we defined the x -> x < 2.5 function (essentially fixing the second argument of < function, as in Julia < is just a two argument function). Such shortcuts like <(2.5) above are defined in Julia by default for several common comparison operators.