Search code examples
c++cbit-manipulationsimdswar

Subtracting packed 8-bit integers in an 64-bit integer by 1 in parallel, SWAR without hardware SIMD


If I have a 64-bit integer that I'm interpreting as an array of packed 8-bit integers with 8 elements. I need to subtract the constant 1 from each packed integer while handling overflow without the result of one element affecting the result of another element.

I have this code at the moment and it works but I need a solution that does the subtraction of each packed 8-bit integer in parallel and doesn't make memory accesses. On x86 I could use SIMD instructions like psubb that subtracts packed 8-bit integers in parallel but the platform I'm coding for doesn't support SIMD instructions. (RISC-V in this case).

So I'm trying to do SWAR (SIMD within a register) to manually cancel out carry propagation between bytes of a uint64_t, doing something equivalent to this:

uint64_t sub(uint64_t arg) {
    uint8_t* packed = (uint8_t*) &arg;

    for (size_t i = 0; i < sizeof(uint64_t); ++i) {
        packed[i] -= 1;
    }

    return arg;
}

I think you could do this with bitwise operators but I'm not sure. I'm looking for a solution that doesn't use SIMD instructions. I'm looking for a solution in C or C++ that's quite portable or just the theory behind it so I can implement my own solution.


Solution

  • If you have a CPU with efficient SIMD instructions, SSE/MMX paddb (_mm_add_epi8) is also viable. Peter Cordes' answer also describes GNU C (gcc/clang) vector syntax, and safety for strict-aliasing UB. I strongly encourage reviewing that answer as well.

    Doing it yourself with uint64_t is fully portable, but still requires care to avoid alignment problems and strict-aliasing UB when accessing a uint8_t array with a uint64_t*. You left that part out of the question by starting with your data in a uint64_t already, but for GNU C a may_alias typedef solves the problem (see Peter's answer for that or memcpy).

    Otherwise you could allocate / declare your data as uint64_t and access it via uint8_t* when you want individual bytes. unsigned char* is allowed to alias anything so that sidesteps the problem for the specific case of 8-bit elements. (If uint8_t exists at all, it's probably safe to assume it's an unsigned char.)


    Note that this is a change from a prior incorrect algorithm (see revision history).

    This is possible without looping for arbitrary subtraction, and gets more efficient for a known constant like 1 in each byte. The main trick is to prevent carry-out from each byte by setting the high bit, then correct the subtraction result.

    We are going to slightly optimize the subtraction technique given here. They define:

    SWAR sub z = x - y
        z = ((x | H) - (y &~H)) ^ ((x ^~y) & H)
    

    with H defined as 0x8080808080808080U (i.e. the MSBs of each packed integer). For a decrement, y is 0x0101010101010101U.

    We know that y has all of its MSBs clear, so we can skip one of the mask steps (i.e. y & ~H is the same as y in our case). The calculation proceeds as follows:

    1. We set the MSBs of each component of x to 1, so that a borrow cannot propagate past the MSB to the next component. Call this the adjusted input.
    2. We subtract 1 from each component, by subtracting 0x01010101010101 from the corrected input. This does not cause inter-component borrows thanks to step 1. Call this the adjusted output.
    3. We need to now correct the MSB of the result. We xor the adjusted output with the inverted MSBs of the original input to finish fixing up the result.

    The operation can be written as:

    #define U64MASK 0x0101010101010101U
    #define MSBON 0x8080808080808080U
    uint64_t decEach(uint64_t i){
          return ((i | MSBON) - U64MASK) ^ ((i ^ MSBON) & MSBON);
    }
    

    Preferably, this is inlined by the compiler (use compiler directives to force this), or the expression is written inline as part of another function.

    Testcases:

    in:  0000000000000000
    out: ffffffffffffffff
    
    in:  f200000015000013
    out: f1ffffff14ffff12
    
    in:  0000000000000100
    out: ffffffffffff00ff
    
    in:  808080807f7f7f7f
    out: 7f7f7f7f7e7e7e7e
    
    in:  0101010101010101
    out: 0000000000000000
    

    Performance details

    Here's the x86_64 assembly for a single invocation of the function. For better performance it should be inlined with the hope that the constants can live in a register as long as possible. In a tight loop where the constants live in a register, the actual decrement takes five instructions: or+not+and+add+xor after optimization. I don't see alternatives that would beat the compiler's optimization.

    uint64t[rax] decEach(rcx):
        movabs  rcx, -9187201950435737472
        mov     rdx, rdi
        or      rdx, rcx
        movabs  rax, -72340172838076673
        add     rax, rdx
        and     rdi, rcx
        xor     rdi, rcx
        xor     rax, rdi
        ret
    

    With some IACA testing of the following snippet:

    // Repeat the SWAR dec in a loop as a microbenchmark
    uint64_t perftest(uint64_t dummyArg){
        uint64_t dummyCounter = 0;
        uint64_t i = 0x74656a6d27080100U; // another dummy value.
        while(i ^ dummyArg) {
            IACA_START
            uint64_t naive = i - U64MASK;
            i = naive + ((i ^ naive ^ U64MASK) & U64MASK);
            dummyCounter++;
        }
        IACA_END
        return dummyCounter;
    }
    
    
    

    we can show that on a Skylake machine, performing the decrement, xor, and compare+jump can be performed at just under 5 cycles per iteration:

    Throughput Analysis Report
    --------------------------
    Block Throughput: 4.96 Cycles       Throughput Bottleneck: Backend
    Loop Count:  26
    Port Binding In Cycles Per Iteration:
    --------------------------------------------------------------------------------------------------
    |  Port  |   0   -  DV   |   1   |   2   -  D    |   3   -  D    |   4   |   5   |   6   |   7   |
    --------------------------------------------------------------------------------------------------
    | Cycles |  1.5     0.0  |  1.5  |  0.0     0.0  |  0.0     0.0  |  0.0  |  1.5  |  1.5  |  0.0  |
    --------------------------------------------------------------------------------------------------
    

    (Of course, on x86-64 you'd just load or movq into an XMM reg for paddb, so it might be more interesting to look at how it compiles for an ISA like RISC-V.)