Search code examples
halide

Inlining a rvar to a var stage in Halide


I’m trying to inline a rvar to a regular var stage.

Let’s say I have an input vector of some size, which I would like to multiply each element by 2, as well as sum all of its elements.

something like :

Input<Buffer<int>> vector {“vector”, 1};
Output<Buffer<int>> output {“output”, 1}
Output<Func> sum_elements {“sum_elements”, Int(32), 0};

Var i;
RDom r(0, vector.length());
output(i) = 2 * vector(i);
sum_elements() = 0;
sum_elements() += vector(r.x);

I would like to know if its possible to schedule sum_elements to be computed along output, even if not optimal:

output.compute_root(); // or any other form of schedule

// somehow create a link between r.x and i, and then:

sum_elements.update().compute_with(output, i);

so the equivalent c would be:

for (int i = 0; i < vector.length(); ++i)
{
    output(i) = 2 * vector (i);
    sum_elements() += vector (i);
}

Solution

  • An alternative approach that might be acceptable is to use the RDom to also define the output:

    Input<Buffer<int>> vector {“vector”, 1};
    Output<Buffer<int>> output {“output”, 1}
    Output<Func> sum_elements {“sum_elements”, Int(32), 0};
    
    Var i;
    RDom r(0, vector.length());
    output(i) = 0;
    output(r.x) = 2 * vector(r.x) + 0 * output(r.x - 1);
    sum_elements() = 0;
    sum_elements() += vector(r.x);
    
    output.compute_root();
    sum_elements.compute_root();
    sum_elements.update().compute_with(output.update(), r.x);
    

    Note that without the no-op 0 * output(r.x - 1), the above would fail with the following error message:

    Invalid compute_with: types of dim 0 of output.s1(r5$x is PureRVar)
    and sum_elements.s1(r5$x is ImpureRVar) do not match.
    

    There is probably a more elegant way to fix this, but the workaround is effectively removed during compilation. It seems to turn the use of the RVar into something impure (my knowledge of the difference between PureRVar and ImpureRVar is limited, unfortunately).