Search code examples
c++halide

How to make Halide use the sliding window optimization?


I'm learning Halide and I'm struggling with the scheduling part. I'm trying to make Halide do the same thing as a hand coded implementation of the algorithm. I don't want to parallelize it but to vectorize it, but first I wanted to understand how to make Halide do a simple sliding window. I tried different things:

  • splitting the sum in x and y and then different scheduling
  • defining the sum in recursive form (snippet at the end)

But I can't get it to produce anything similar. It boils down to a simple variation of a mean filter. So, how do I schedule Halide code to actually do a sliding window like the original code?

This is the code:

void lmrCPU(const cv::Mat& image, std::vector<uint16_t>& vertSum, int xMin, int xMax, int yMin, int yMax, int lmrSize, cv::Mat& lmr ) {
    int lmrWidth = 2*lmrSize + 1;
    int area = lmrWidth*lmrWidth;

    vertSum.resize(image.cols);

    // Zero contents of vertSum
    memset(vertSum.data(), 0, vertSum.size()*sizeof(uint16_t));

    for (int yy = yMin; yy < yMin + lmrWidth; ++yy)
    {
        for (int x=0; x<image.cols; x++) vertSum[x] += unsigned(image.ptr<uint8_t>(yy)[x]);
    }

    for (int y = yMin; y <= yMax; y++)
    {
        if (y > yMin + lmrSize && y <= yMax - lmrSize)
        {
            for (int x = 0; x < image.cols; x++)
                vertSum[x] += unsigned(image.at<uint8_t>(y + lmrSize, x)) - unsigned(image.at<uint8_t>(y - lmrSize - 1, x));
        }

        unsigned sum = 0;

        int xx, x, xxBack;
        for (xx = xMin; xx < xMin + lmrWidth; xx++) sum += vertSum[xx];

        for (x = xMin; x < xMin + lmrSize; x++)
            lmr.at<int>(y, x) = int(image.at<uint8_t>(y, x)) * area - sum;

        sum -= vertSum[xMin + 2*lmrSize]; // take off ready for next loop

        for (x = xMin + lmrSize, xxBack = x - lmrSize, xx = x + lmrSize; x < xMax - lmrSize; x++, xx++, xxBack++)
        {
            sum += vertSum[xx];
            lmr.at<int>(y, x) = int(image.at<uint8_t>(y,x))*area - sum;
            sum -= vertSum[xxBack];
        }

        sum += vertSum[xx];

        for ( ; x <= xMax; x++)
            lmr.at<int>(y, x) = int(image.at<uint8_t>(y,x))*area - sum;
    }
}
#include "Halide.h"

namespace {
    class LMR : public Halide::Generator<LMR> {
    public:
        ImageParam input{UInt(8), 2, "input"};

        Param<int32_t> xMin{"xMin"}, xMax{"xMax"};
        Param<int32_t> yMin{"yMin"}, yMax{"yMax"};
        Param<int32_t> lmrSize{"lmrSize"};

        Var x{"x"}, y{"y"};

        Func build() {
            auto lmrWidth = 2*lmrSize + 1;
            auto area = lmrWidth*lmrWidth;

            Halide::Func input_int32 ("input_int32");
            input_int32(x, y) = Halide::cast<int32_t>(input(x, y));

            Halide::Func input_uint16 ("input_uint16");
            input_uint16(x, y) = Halide::cast<uint16_t>(input(x, y));

            Halide::Expr clamped_x = Halide::clamp(x, xMin, xMax);
            Halide::Expr clamped_y = Halide::clamp(y, yMin, yMax);

            Halide::Expr lmr_x = Halide::clamp(x, xMin+lmrSize, xMax-lmrSize);
            Halide::Expr lmr_y = Halide::clamp(y, yMin+lmrSize, yMax-lmrSize);

            Halide::RDom box (-lmrSize, lmrWidth, "box");

            Halide::Func vertSum ("vertSum");
            vertSum(x, y) = Halide::undef<uint16_t>();
            {
                Halide::RDom ry (yMin+lmrSize+1, yMax-yMin-2*lmrSize, "ry");
                vertSum(x, yMin+lmrSize) = Halide::cast<uint16_t>(0);//Halide::sum(input_uint16(x, yMin+lmrSize+box), "sum_y");
                vertSum(x, yMin+lmrSize) += input(x, yMin+lmrSize+box);
                vertSum(x, ry) = vertSum(x, ry-1) + input_uint16(x, ry+lmrSize) - input_uint16(x, ry-1-lmrSize);
            }

            Halide::Func sumLmr ("sumLmr");
            sumLmr(x, y) = Halide::undef<uint16_t>();
            {
                Halide::RDom rx (xMin+lmrSize+1, xMax-xMin-2*lmrSize, "rx");
                sumLmr(xMin+lmrSize, y) = Halide::cast<uint16_t>(0);//Halide::sum(vertSum(xMin+lmrSize+box, y), "sum_x");
                sumLmr(xMin+lmrSize, y) += vertSum(xMin+lmrSize+box, y);
                sumLmr(rx, y) = sumLmr(rx-1, y) + vertSum(rx+lmrSize, y) - vertSum(rx-1-lmrSize, y);
            }
            Halide::Func lmr ("lmr");
            lmr(x, y) = input_int32(clamped_x, clamped_y)*area - Halide::cast<int32_t>(sumLmr(lmr_x, lmr_y));

            vertSum
                .fold_storage(y, 1)
                .store_root()
                .compute_at(lmr, y);

            sumLmr
                .fold_storage(x, 1)
                .store_at(lmr, y)
                .compute_at(lmr, x);

            return lmr;
        }
    };

    HALIDE_REGISTER_GENERATOR(LMR, "lmr")
}

This is the output of lmr.print_loop_nest();

store vertSum:
  produce lmr:
    for y:
      produce vertSum:
        for y:
          for x:
            vertSum(...) = ...
        for x:
          vertSum(...) = ...
        for x:
          for box:
            vertSum(...) = ...
        for x:
          for ry:
            vertSum(...) = ...
      consume vertSum:
        store sumLmr:
          for x:
            produce sumLmr:
              for y:
                for x:
                  sumLmr(...) = ...
              for y:
                sumLmr(...) = ...
              for y:
                for box:
                  sumLmr(...) = ...
              for y:
                for rx:
                  sumLmr(...) = ...
            consume sumLmr:
              lmr(...) = ...

Solution

  • Figured out what was the problem. Halide doesn't know how to optimize the code with a variable window size. By hardcoding the lmrWidth it generates the correct code. The final code looks like this:

    #include "Halide.h"
    
    namespace {
        class LMR : public Halide::Generator<LMR> {
        public:
            ImageParam input{UInt(8), 2, "input"};
    
            Param<int32_t> xMin{"xMin"}, xMax{"xMax"};
            Param<int32_t> yMin{"yMin"}, yMax{"yMax"};
            Param<int32_t> lmrSize{"lmrSize"};
    
            Var x{"x"}, y{"y"};
    
            Func build() {
                auto lmrWidth = 11;//2*lmrSize + 1;
                auto area = lmrWidth*lmrWidth;
    
                Halide::Expr clamped_x = Halide::clamp(x, xMin, xMax);
                Halide::Expr clamped_y = Halide::clamp(y, yMin, yMax);
    
                Halide::Expr lmr_x = Halide::clamp(x, xMin+lmrSize, xMax-lmrSize);
                Halide::Expr lmr_y = Halide::clamp(y, yMin+lmrSize, yMax-lmrSize);
    
                Halide::Func sum_y ("sum_y");
                {
                    Halide::Expr sum = input(x, y-lmrSize);
                    for (int i=1;i<lmrWidth;i++) {
                        sum = sum + Halide::cast<uint16_t>(input(x, y-lmrSize+i));
                    }
                    sum_y(x, y) = sum;
                }
    
                Halide::Func sum_x ("sum_x");
                {
                    Halide::Expr sum = sum_y(x-lmrSize, y);
                    for (int i=1;i<lmrWidth;i++) {
                        sum = sum + sum_y(x-lmrSize+i, y);
                    }
                    sum_x(x, y) = sum;
                }
    
                Halide::Func sumLmr("sumLmr");
                sumLmr(x, y) = sum_x(x, y);
    
                Halide::Func output("output");
                output(x, y) = Halide::cast<int32_t>(input(clamped_x, clamped_y))*area - Halide::cast<int32_t>(sumLmr(lmr_x, lmr_y));
    
                sum_x.compute_root();
                sum_y.compute_at(sum_x, y);
                output.vectorize(x, 16);
    
                return output;
            }
        };
    
        HALIDE_REGISTER_GENERATOR(LMR, "lmr")
    }