Search code examples
macrosjuliametaprogrammingauto-generate

Calling macro from within generated function in Julia


I have been messing around with generated functions in Julia, and have come to a weird problem I do not understand fully: My final goal would involve calling a macro (more specifically @tullio) from within a generated function (to perform some tensor contractions that depend on the input tensors). But I have been having problems, which I narrowed down to calling the macro from within the generated function.

To illustrate the problem, let's consider a very simple example that also fails:

macro my_add(a,b) 
    return :($a + $b)
end

function add_one_expr(x::T) where T
    y = one(T)
    return :( @my_add($x,$y) )
end

@generated function add_one_gen(x::T) where T
    y = one(T)
    return :( @my_add($x,$y) )
end

With these declarations, I find that eval(add_one_expr(2.0)) works just as expected and returns and expression

:(@my_add 2.0 1.0)

which correctly evaluates to 3.0.

However evaluating add_one_gen(2.0) returns the following error:

MethodError: no method matching +(::Type{Float64}, ::Float64)

Doing some research, I have found that @generated actually produces two codes, and in one only the types of the variables can be used. I think this is what is happening here, but I do not understand what is happening at all. It must be some weird interaction between macros and generated functions.

Can someone explain and/or propose a solution? Thank you!


Solution

  • I find it helpful to think of generated functions as having two components: the body and any generated code (the stuff inside a quote..end). The body is evaluated at compile time, and doesn't "know" the values, only the types. So for a generated function taking x::T as an argument, any references to x in the body will actually point to the type T. This can be very confusing. To make things clearer, I recommend the body only refer to types, never to values.

    Here's a little example:

    julia> @generated function show_val_and_type(x::T) where {T}
               quote
                   println("x is ", x)
                   println("\$x is ", $x)
                   println("T is ", T)
                   println("\$T is ", $T)
               end
           end
    show_val_and_type
    
    julia> show_val_and_type(3)
    x is 3
    $x is Int64
    T is Int64
    $T is Int64
    

    The interpolated $x means "take the x from the body (which refers to T) and splice it in.

    If you follow the approach of never referring to values in the body, you can test generated functions by removing the @generated, like this:

    julia> function add_one_gen(x::T) where T
               y = one(T)
               quote
                   @my_add(x,$y)
               end
           end
    add_one_gen
    
    julia> add_one_gen(3)
    quote
        #= REPL[42]:4 =#
        #= REPL[42]:4 =# @my_add x 1
    end
    

    That looks reasonable, but when we test it we get

    julia> add_one_gen(3)
    ERROR: UndefVarError: x not defined
    Stacktrace:
     [1] macro expansion
       @ ./REPL[48]:4 [inlined]
     [2] add_one_gen(x::Int64)
       @ Main ./REPL[48]:1
     [3] top-level scope
       @ REPL[49]:1
    

    So let's see what the macro gives us

    julia> @macroexpand @my_add x 1
    :(Main.x + 1)
    

    It's pointing to Main.x, which doesn't exist. The macro is being too eager, and we need to delay its evaluation. The standard way to do this is with esc. So finally, this works:

    julia> macro my_add(a,b) 
               return :($(esc(a)) + $(esc(b)))
           end
    @my_add
    
    julia> @generated function add_one_gen(x::T) where T
               y = one(T)
               quote
                   @my_add(x,$y)
               end
           end
    add_one_gen
    
    julia> add_one_gen(3)
    4