Search code examples
c++sseavxavx512

What is the "correct" way to go from avx/sse masks to avx512 masks?


I have some existing avx/sse masks that I got the old way:

auto mask_sse = _mm_cmplt_ps(a, b);
auto mask_avx = _mm_cmp_ps(a, b, 17);

In some circumstances when mixing old avx code with new avx512 code, I want to convert these old style masks into the new avx512 __mmask4 or __mmask8 types.

I tried this:

auto mask_avx512 = _mm_cmp_ps_mask(sse_mask, _mm_setzero_ps(), 25/*nge unordered quiet*/);

and it seems to work for plain old outputs of comparisons, but I don't think it would capture positive NANs correctly that could have been used with an sse4.1 _mm_blendv_ps.

There also is good old _mm_movemask_ps but that looks like it puts the mask all the way out in a general purpose register, and I would need to chain it with a _cvtu32_mask8 to pull it back into one of the dedicated mask registers.

Is there a cleaner way to just directly pull the sign bit out of an old style mask into one of the k registers?

Example Code:

Here's an example program doing the sort of mask conversion the first way I mentioned above

#include "x86intrin.h"
#include <cassert>
#include <cstdio>

int main()
{
    auto a = _mm_set_ps(-1, 0, 1, 2);
    auto c = _mm_set_ps(3, 4, 5, 6);

    auto sse_mask    = _mm_cmplt_ps(a, _mm_setzero_ps());
    auto avx512_mask = _mm_cmp_ps_mask(sse_mask, _mm_setzero_ps(), 25);

    alignas(16) float v1[4];
    alignas(16) float v2[4];
    _mm_store_ps(v1, _mm_blendv_ps(a, c, sse_mask));
    _mm_store_ps(v2, _mm_mask_blend_ps(avx512_mask, a, c));

    assert(v1[0] == v2[0]);
    assert(v1[1] == v2[1]);
    assert(v1[2] == v2[2]);
    assert(v1[3] == v2[3]);
    return 0;
}

Solution

  • Use an AVX-512 compare intrinsic to get an AVX-512 mask in the first place (like _mm_cmp_ps_mask); that's going to be significantly more efficient than comparing into a vector and then converting it, unless the compiler optimizes away this inefficiency for you. (Consider using a wrapper library like Agner Fog's VCL to try to abstract away the difference. The VCL licence changed recently from GPL to Apache.)


    But if you really need this (e.g. as a stop-gap before you finish optimizing), you don't need an FP compare. _mm_cmp_ps in C produces a __m128 result, but it's not really a vector of floats1. It's all-one-bits / all-zero-bits. You just want the bits, so you're looking for the AVX-512 equivalent of vmovmskps, but into a k register instead of GP integer. i.e. VPMOVD2M k, x/y/zmm for 32-bit source elements.

       __m128 cmpvec = _mm_cmplt_ps(v, _mm_setzero_ps() );
       __mmask8 cmpmask = _mm_movepi32_mask( _mm_castps_si128(cmpvec) );   // <----
    
    // equivalent to comparing into a mask in the first place:
       __mmask8 cmpmask = _mm_cmplt_ps_mask(v, _mm_setzero_ps(), _CMP_LT_OQ);
    
    // equivalent to (if I got this right)
       __mmask8 cmpmask = _mm_fpclass_ps_mask(v, 0x40 | 0x10);  // negative | negative_inf
    

    https://uops.info/ is down right now, otherwise I'd check latency and execution ports of VPMOVD2M vs. VCMPPS into mask (for an UNORD predicate) vs. VFPCLASSPS.


    Footnote 1: You could use AVX-512 vfpclassps into a mask, or even compare against itself with a vcmpps predicate like UNORD to detect NAN or not. But those are I think slower.


    I would need to chain it with a _cvtu32_mask8 to pull it back into one of the dedicated mask registers.

    The way compilers currently do things, __mmask8 is just a typedef for unsigned char, and __mmask16 is unsigned short. They're freely convertible without intrinsics, for good or ill. But in asm, it takes a kmovb k1, eax instruction to get the data from a GP reg to a k mask reg, and that instruction can only run on port 5 in current CPUs.