Search code examples
halide

Can I avoid calculating same elements when using split() in Halide?


I have a question about the behaivior of split() in Halide language.

When I use split(), it computes the elements twice at the edge when the computing region is not a multiple of split factor. For example, when computing region is 10 and split factor is 4, Halide will compute the elements [0,1,2,3], [4,5,6,7] and [6,7,8,9] like the result of trace_stores() below.

Is there any ways to calculate only elements [8,9] at the last step in inner loop of split()?

sample code:

#include "Halide.h"
using namespace Halide;

#define INPUT_SIZE 10
int main(int argc, char** argv) {
    Func f("f");
    Var x("x");
    f(x) = x;

    Var xi("xi");
    f.split(x, x, xi, 4); 

    f.trace_stores();
    Image<int32_t> out = f.realize(INPUT_SIZE);
    return 0;
}

trace_stores() results:

Store f.0(0) = 0
Store f.0(1) = 1
Store f.0(2) = 2
Store f.0(3) = 3
Store f.0(4) = 4
Store f.0(5) = 5
Store f.0(6) = 6
Store f.0(7) = 7
Store f.0(6) = 6
Store f.0(7) = 7
Store f.0(8) = 8
Store f.0(9) = 9

Solution

  • It's possible but ugly. Halide generally assumes that it can re-evaluate points in a Func arbitrarily and that inputs don't alias with outputs, so it's always safe to recompute a few values near the edge.

    The fact that this matters is a bad sign. There might be other ways to achieve what you're trying to do.

    Anyway, the workaround is to use explicit RDoms to tell Halide precisely what to iterate over:

    // No pure definition
    f(x) = undef<int>(); 
    
    // An update stage that does the vectorized part:
    Expr w = (input.width()/4)*4;
    RDom r(0, w);
    f(r) = something;
    f.update(0).vectorize(r, 4);
    
    // An update stage that does the tail end:
    RDom r2(input.width(), input.width() - w);
    f(r2) = something;
    f.update(1); // Don't vectorize the tail end