Search code examples
c++armsimdneon

Optimizing a for loop with lookup-table using ARM Neon instructions


This is one of my first code in Neon, I want to know if you have any comments or suggestions to improve the code to run faster, that's why I translate the code to that low level.

I posted the code on Code Review site but I didn't get any reply and my post was the only one with Neon tag so I thought to post it here.

Here is a working code. This code applies a lookup table for an image. My goal is to make the code run faster on ARM cortex A53 CPU.

#include <opencv2/opencv.hpp>
#include <iostream>
#include <vector>
#include <numeric> // For std::iota
#include <array> // Structure to hold cached parameters

#ifdef __ARM_NEON
#include <arm_neon.h>
#else
#error "ARM Compiler required."
#endif

struct Cache {
    std::array<uchar, 256> lut_b;
    std::array<uchar, 256> lut_g;
    std::array<uchar, 256> lut_r;
};

// Function to compute simple example data and lookup tables
void compute_data(const cv::Mat& image, Cache& cache) {
    for (int i = 0; i < 256; i++) {
        cache.lut_b[i] = static_cast<uchar>(i);
        cache.lut_g[i] = static_cast<uchar>(i);
        cache.lut_r[i] = static_cast<uchar>(i);
    }
}

// Function to apply lookup table. vy[i] = vtable[vx[i]]
static inline uint8x16_t lookup_neon(const uint8x16x4_t vtable[4], uint8x16_t vx) {
    const uint8x16_t voffset = vmovq_n_u8(64);
    uint8x16_t vy = vqtbl4q_u8(vtable[0], vx);
    vx = vsubq_u8(vx, voffset);
    vy = vqtbx4q_u8(vy, vtable[1], vx);
    vx = vsubq_u8(vx, voffset);
    vy = vqtbx4q_u8(vy, vtable[2], vx);
    vx = vsubq_u8(vx, voffset);
    vy = vqtbx4q_u8(vy, vtable[3], vx);
    return vy;
}

void hist(cv::Mat& image, Cache& cache, bool use_cache) {
    if (!use_cache) {
        compute_data(image, cache);
    }
    // Load cache in registers. (4x4 128-bit registers)
    const uint8x16x4_t vtable_b[4] = {
        vld1q_u8_x4(cache.lut_b.data() + 16 * 4 * 0),
        vld1q_u8_x4(cache.lut_b.data() + 16 * 4 * 1),
        vld1q_u8_x4(cache.lut_b.data() + 16 * 4 * 2),
        vld1q_u8_x4(cache.lut_b.data() + 16 * 4 * 3)
    };
    const uint8x16x4_t vtable_g[4] = {
        vld1q_u8_x4(cache.lut_g.data() + 16 * 4 * 0),
        vld1q_u8_x4(cache.lut_g.data() + 16 * 4 * 1),
        vld1q_u8_x4(cache.lut_g.data() + 16 * 4 * 2),
        vld1q_u8_x4(cache.lut_g.data() + 16 * 4 * 3)
    };
    const uint8x16x4_t vtable_r[4] = {
        vld1q_u8_x4(cache.lut_r.data() + 16 * 4 * 0),
        vld1q_u8_x4(cache.lut_r.data() + 16 * 4 * 1),
        vld1q_u8_x4(cache.lut_r.data() + 16 * 4 * 2),
        vld1q_u8_x4(cache.lut_r.data() + 16 * 4 * 3)
    };
    for (int i = 0; i < image.rows; ++i) {
        uint8_t* row_ptr = image.ptr(i);
        int j = 0;
        // Apply transformation on elements multiple of 16.
        for (; (j + 16) <= image.cols; j += 16) {
            // Load and deinterleave the elements.
            uint8x16x3_t vec = vld3q_u8(row_ptr);
            vec.val[0] = lookup_neon(vtable_b, vec.val[0]);
            vec.val[1] = lookup_neon(vtable_g, vec.val[1]);
            vec.val[2] = lookup_neon(vtable_r, vec.val[2]);
            vst3q_u8(row_ptr, vec);
            // Interleave and stores the elements.
            row_ptr += 3 * 16;
        }
        // Apply transformation on leftover elements.
        for (; j < image.cols; ++j) {
            row_ptr[0] = cache.lut_b[row_ptr[0]];
            row_ptr[1] = cache.lut_g[row_ptr[1]];
            row_ptr[2] = cache.lut_r[row_ptr[2]];
            row_ptr += 3;
        }
    }
}

int main(int argc, char** argv) {
    // Open the video file
    cv::VideoCapture cap("video.mp4");
    if (!cap.isOpened()) {
        std::cerr << "Error opening video file" << std::endl;
        return -1;
    }

    // Get the frame rate of the video
    double fps = cap.get(cv::CAP_PROP_FPS);
    int delay = static_cast<int>(1000 / fps);

    // Create a window to display the video
    cv::namedWindow("Processed Video", cv::WINDOW_NORMAL);

    cv::Mat frame;
    Cache cache;
    int frame_count = 0;
    int recompute_interval = 5; // Recompute every 5 frames

    while (true) {
        cap >> frame;
        if (frame.empty()) {
            break;
        }

        // Determine whether to use the cache or recompute the data
        bool use_cache = (frame_count % recompute_interval != 0);

        // Process the frame using cached or recomputed parameters
        hist(frame, cache, use_cache);

        // Display the processed frame
        cv::imshow("Processed Video", frame);

        // Break the loop if 'q' is pressed
        if (cv::waitKey(delay) == 'q') {
            break;
        }

        frame_count++;
    }

    cap.release();
    cv::destroyAllWindows();

    return 0;
}

Solution

  • EDIT: TL;DR -- scroll down for the very last section


    As pointed in the comments, 3 distinct lookup tables require 48 registers, which is absolutely too much; the generated code will spill a lot.

    The first thing to consider is to think if the LUT is computing an elementary function, which could be approximated with piecewise linear, quadratic or maybe up to cubic functions, since in many older platforms at least the vtbl and vtbx instructions do not instruction-level parallelise and require about ~12 cycles for a 64-element lookup. Even promoting the uint8_t -> float for square root and back to uint8_t is likely a feasible solution for inverse gamma correction.

    If a LUT is unavoidable, I would probably split the work into 4 passes, loading just one quarter of a table into registers in any given moment.

        uint8x16x4_t table_r = vld1q_u8_x4(table_r_ptr);
        uint8x16x4_t table_g = vld1q_u8_x4(table_g_ptr);
        uint8x16x4_t table_b = vld1q_u8_x4(table_b_ptr);
        table_r_ptr+=64;
        table_g_ptr+=64;
        table_b_ptr+=64;
        do {
           uint8x16x3_t data = vld3q_u8(source); source += 48;
           data.val[0] = vqtbl4q_u8(table_r, data.val[0]);
           data.val[1] = vqtbl4q_u8(table_g, data.val[1]);
           data.val[2] = vqtbl4q_u8(table_b, data.val[2]);
           vst3q_u8(destination, data); destination += 48;
        } while (--full_blocks);
        ...
        // then the 3 other passes
        for (int pass = 64; pass <= 192; pass += 64) {
          // load more tables
          table_r = vld1q_u8_x4(table_r_ptr); table_r_ptr+=64;
          table_g = vld1q_u8_x4(table_g_ptr); table_g_ptr+=64;
          table_b = vld1q_u8_x4(table_b_ptr); table_b_ptr+=64;
          // need to rewind the pointers
          source -= pixels * 3;
          destination -= pixels * 3;
          int full_blocks = pixels / 16;
          do {
             uint8x16x3_t data = vld3q_u8(source); source += 48;
             uint8x16x3_t datad = vld3q_u8(destination);
    
             data.val[0] = vsubq_u8(data.val[0], vdupq_n_u8(pass));
             data.val[1] = vsubq_u8(data.val[1], vdupq_n_u8(pass));
             data.val[2] = vsubq_u8(data.val[2], vdupq_n_u8(pass));
    
             datad.val[0] = vqtbx4q_u8(datad.val[0], table_r, data.val[0]);
             datad.val[1] = vqtbx4q_u8(datad.val[0], table_g, data.val[1]);
             datad.val[2] = vqtbx4q_u8(datad.val[0], table_b, data.val[2]);
             vst3q_u8(destination, data); destination += 48;
           } while (--full_blocks);
        }
    

    Rethinking -- while that approach is straightforward, it does require writing 4 times the amount of the data and reading it 7 times.

    One might instead consider keeping at least some of the tables in register and use some fixed registers to reload the changing portion of the look up tables;

        uint8x16x4_t r0,g0,b0, r1,g1,b1;
        loadFirstHalfOfTables(r0,g0,b0, r1,g1,b1);
        {
           uint8x16x3_t data = vld3q_u8(source += 48);
           uint8x16_t tmp = vqtbl4q_u8(r0, data.val[0]);
           // so far we have used 4 registers and have 4 left
           // but we need 5 registers in fact. Let's reuse one
           // register from the tables
           r0.val[0] = vdupq_n_u8(64);
    
           data.val[0] -= r0.val[0];
           tmp = vqtbx4q_u8(tmp, r1, data.val[0]);
           data.val[0] -= r0.val[0];
           uint8x16x4_t r2 = vld4q_u8_x4(table_r + 128);
           tmp = vqtbx4q_u8(tmp, r2, data.val[0]);
           data.val[0] -= r0.val[0];
           r2 = vld4q_u8_x4(table_r + 192);
           data.val[0] = vqtbx4q_u8(tmp, r2, data.val[0]);
           // do the same for G,B channels
           
    
           r0.val[0] = vld1q_u8(r_table); // restore spilled register
           vst3q_u8(dst += 48, data);
        }
    

    The total cost of this approach is 3 data reads, 3 data writes and 25 table reads (from local cache) or 31 128-bit memory bus accesses.

    The first approach requires 12 data reads, 9 destination reads and 12 destination writes, which is 33 memory-bus operations in total, and where the writes/read after writes can be more costly.

    The second approach has OTOH a cost associated to the tight dependency chain, requiring efficient out-of-order executing architecture to make use of the large latencies (up to 10 cycles) of the memory loads.


    Then it's likely that a third option with three passes would be even better:

    template <int N>
    void pass(uint8_t const *r_or_g_or_b_table, ...) {
       uint8x16x4_t t0,t1,t2,t3; // use 16 registers for R table
       t0 = vld1q_u8_x4(r_or_g_or_b_table);
       t1 = vld1q_u8_x4(r_or_g_or_b_table + 64);
       t2 = vld1q_u8_x4(r_or_g_or_b_table + 128);
       t3 = vld1q_u8_x4(r_or_g_or_b_table + 192);
       auto const k64 = vdupq_n_u8(64);
       auto const k128 = vdupq_n_u8(128);
    
       while (blocks--) {
          auto d = vld3q_u8(source += 48);
          auto &tmp = d.val[N];
          auto d1 = vsubq_u8(tmp, k64);
          auto d2 = vsubq_u8(tmp, k128);
          auto d3 = vsubq_u8(d1, k64);
          tmp = vtbl4q_u8(t0, tmp);
          tmp = vtbx4q_u8(tmp, t1, d1);
          tmp = vtbx4q_u8(tmp, t2, d2);
          tmp = vtbx4q_u8(tmp, t3, d3);
          vst3q_u8(destination += 48, d);
       }
    }
    

    The third approach will first read from src and write to dst converting just the R-channel and copying G,B channels. The second pass will read from dst and write to dst converting just the G-channel, and finally the third pass will process just the B-channel. This will have 9 reads per block and 9 writes per block of 16 pixels only.

    Moreover, with just 24 used registers one can also avoid spilling/restoring the d8..d15 in the function prologue and epilogue.