Search code examples
cassemblybit-manipulationbitzig

How can I generate a 256 bit mask


I have an array of uint64_t[4], and I need to generate a mask, such that the array, if it were a 256-bit integer, equals (1 << w) - 1, where w goes from 1 to 256.

The best thing I have come up with is branchless, but it takes MANY instructions. It is in Zig because Clang doesn't seem to expose llvm's saturating subtraction. http://localhost:10240/z/g8h1rV

Is there a better way to do this?

var mask: [4]u64 = undefined;
for (mask) |_, i|
    mask[i] = 0xffffffffffffffff;
mask[3] ^= ((u64(1) << @intCast(u6, (inner % 64) + 1)) - 1) << @intCast(u6, 64 - (inner % 64));
mask[2] ^= ((u64(1) << @intCast(u6, (@satSub(u32, inner, 64) % 64) + 1)) - 1) << @intCast(u6, 64 - (inner % 64));
mask[1] ^= ((u64(1) << @intCast(u6, (@satSub(u32, inner, 128) % 64) + 1)) - 1) << @intCast(u6, 64 - (inner % 64));
mask[0] ^= ((u64(1) << @intCast(u6, (@satSub(u32, inner, 192) % 64) + 1)) - 1) << @intCast(u6, 64 - (inner % 64));

Solution

  • Are you targeting x86-64 with AVX2 for 256-bit vectors? I thought that was an interesting case to answer for.

    If so, you can do this in a few instructions using saturating subtraction and a variable count shift.

    x86 SIMD shifts like vpsrlvq saturate the shift count, shifting all the bits out when the count is >= element width. Unlike integer shifts the shift count is masked (and thus wraps around).

    For the lowest u64 element, starting with all-ones we need to leave it unmodified for bitpos >= 64. Or for smaller bit positions, right-shift it by 64-bitpos. Unsigned saturating subtraction looks like the way to go here, as you observed, to create a shift count of 0 for larger bitpos. But x86 only has SIMD saturating subtraction, and only for byte or word elements. But if we don't care about bitpos > 256, that's fine we can use 16-bit elements at the bottom of each u64, and let a 0-0 happen in the rest of the u64.

    Your code looks pretty overcomplicated, creating (1<<n) - 1 and XORing. I think it's a lot easier to just use a variable-count shift on the 0xFFFF...FF elements directly.

    I don't know Zig, so do whatever you have to to get it to emit asm like this. Hopefully this is useful because you tagged this ; should be easy to translate to intrinsics for C, or Zig if it has them.

    default rel
    section .rodata
    shift_offsets:  dw  64, 128, 192, 256        ; 16-bit elements, to be loaded with zero-extension to 64
    
    section .text
    pos_to_mask256:
        vpmovzxwq   ymm2, [shift_offsets]      ; _mm256_set1_epi64x(256, 192, 128, 64)
        vpcmpeqd    ymm1, ymm1,ymm1            ; ymm1 = all-ones
                                      ; set up vector constants, can be hoisted
    
        vmovd         xmm0, edi
        vpbroadcastq  ymm0, xmm0           ; ymm0 = _mm256_set1_epi64(bitpos)
    
        vpsubusw      ymm0, ymm2, ymm0     ; ymm0 = {256,192,128,64}-bitpos with unsigned saturation
        vpsrlvq       ymm0, ymm1, ymm0     ; mask[i] >>= count, where counts >= 64 create 0s.
    
        ret
    

    If the input integer starts in memory, you can of course efficiently broadcast-load it into a ymm register directly.

    The shift-offsets vector can of course be hoisted out of a loop, as can the all-ones.


    With input = 77, the high 2 elements are zeroed by shifts of 256-77=179, and 192-77=115 bits. Tested with NASM + GDB for EDI=77, and the result is

    (gdb) p /x $ymm0.v4_int64
    {0xffffffffffffffff, 0x1fff, 0x0, 0x0}
    

    GDB prints low element first, opposite of Intel notation / diagrams. This vector is actually 0, 0, 0x1fff, 0xffffffffffffffff, i.e. 64+13 = 77 one bits, and the rest all zeros. Other test cases

    • edi=0: mask = all-zero
    • edi=1: mask = 1
    • ... : mask = edi one bits at the bottom, then zeros
    • edi=255: mask = all ones except for the top bit of the top element
    • edi=256: mask = all ones
    • edi>256: mask = all ones. (unsigned subtraction saturates to 0 everywhere.)

    You need AVX2 for the variable-count shifts. psubusb/w is SSE2, so you could consider doing that part with SIMD and then go back to scalar integer for the shifts, or maybe just use SSE2 shifts for one element at a time. Like psrlq xmm1, xmm0 which takes the low 64 bits of xmm0 as the shift count for all elements of xmm1.

    Most ISAs don't have saturating scalar subtraction. Some ARM CPUs do for scalar integer, I think, but x86 doesn't. IDK what you're using.

    On x86 (and many other ISAs) you have 2 problems:

    • keep all-ones for low elements (either modify the shift result, or saturate shift count to 0)
    • produce 0 for high elements above the one containing the top bit of the mask. x86 scalar shifts can't do this at all, so you might feed the shift an input of 0 for that case. Maybe using cmov to create it based on flags set by sub for 192-w or something.
        count = 192-w;
        shift_input = count<0 ? 0 : ~0ULL;
        shift_input >>= count & 63;      // mask to avoid UB in C.  Optimizes away on x86 where shr does this anyway.
    

    Hmm, this doesn't handle saturating the subtraction to 0 to keep the all-ones, though.

    If tuning for ISAs other than x86, maybe look at some other options. Or maybe there's something better on x86 as well. Creating the all-ones or all-zeros with sar reg,63 is an interesting option (broadcast the sign bit), but we actually need all-ones when 192-count has sign bit = 0.