I am learning about C avx intrinsics and I am wondering how this works.
I am familiar that I can do something like this:
__m256 evens = _mm256_set_ps(2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0);
Here I am storing 8 32 bit float numbers. So this is 256 bits.
But suppose I am writing a linear algebra library. Then how I do work with arbitrary number of vectors; for example how do I fit 10 32 bit floats in a avx vector?
If you can provide some examples, I will tremendously appreciated
Well, this is a broad question with a simple answer: Use loops and deal with element counts that are not multiples of your vector width as required.
Since an overly broad question invites an overly broad answer and I couldn't find a good reference on SO, maybe it is worthwhile to write up how to do vectorized loops.
As an example, consider a simple out[i] = in[i] * scale
loop. Of course a compiler is perfectly capable of vectorizing this on its own but it serves us well as a demonstration. The basic implementation looks like this:
#include <immintrin.h>
#include <stddef.h>
/* using ptrdiff_t */
void vector_scale(float* out, const float* in, ptrdiff_t n, float scale)
{
/* broadcast factor to all vector elements */
const __m256 vscale = _mm256_set1_ps(scale);
ptrdiff_t i;
for(i = 0; i + 8 <= n; i += 8) {
/* process 8 elements at a time */
__m256 cur = _mm256_loadu_ps(in + i);
cur = _mm256_mul_ps(cur, vscale);
_mm256_storeu_ps(out + i, cur);
}
/* deal with last 0-7 elements */
for(; i < n; ++i)
out[i] = in[i] * scale;
}
Vectorizing this loop had one consequence compared to the plain C version: Input and output may not overlap in such a way that storing the vector in one iteration overwrites the inputs in one of the following iterations. They may be direct aliases of one another (in == out
) but you cannot have in = out + 1
or similar. The auto-vectorized code would contain checks and fallbacks for that but if you manually vectorize, you normally don't bother and simply define this as a requirement.
If you want to prevent the compiler from making these checks and fallbacks plus integrate the requirement into the function declaration, use a restrict
pointer.
void vector_scale(
float* restrict out, const float* in, ptrdiff_t n, float scale);
Alternatively, you can do the same checks in your manually vectorized code.
inline _Bool pointers_interfere(
const void* out, const void* in, size_t vectorsize)
{
ptrdiff_t distance = (const char*) out - (const char*) in;
return distance > 0 && (size_t) distance < vectorsize;
}
void vector_scale(float* out, const float* in, ptrdiff_t n, float scale)
{
const __m256 vscale = _mm256_set1_ps(scale);
ptrdiff_t i = 0;
if(pointers_interfere(out, in, sizeof(__m256))) {
for(; i < n; ++i) // unvectorized fallback
out[i] = in[i] * scale;
return;
}
...
}
Some things to look out with this check:
out == in
out == in + 8 && n > 8
, the result would still be nonsense; it's just the same nonsense you would get from the plain C version.out - in
is not well defined by the C standard if both pointers point to different arrays. However, we don't program for an abstract machine but specifically for x86-64 where we can assume to have a flat memory model. But keep this in mind when reusing that code on other platforms where this assumption may not hold, e.g. OpenCL kernels.I chose a loop of the form for(i = 0; i + 8 <= n; i += 8)
since it is easy to read, hard to mess up (especially with unsigned types) and reasonably efficient. But it are not the most efficient way because the condition forces the compiler to compute i+8
before the loop body when the body itself only needs i
. A naive loop construction would waste a register on holding i+8
. To understand how the compiler optimizes this and how we can help it, we have to differentiate between the loop condition being a signed type such as ptrdiff_t
and an unsigned type such as size_t
.
For signed types, the compiler is free to assume that i+8
never over- or underflows and it can also work with signed comparisons. However, the compiler will not rewrite the condition into i < n - 7
. This would fail if n = LLONG_MIN
(or close to it) where n-7
would wrap-around. Instead, GCC transforms the condition into something like i < ((n - 8) & -8) + 8
. Things get more messy if the loop doesn't start at 0, as will be the case in some optimized loops further down. Therefore I don't recommend doing this optimization manually.
For unsigned types, fewer optimizations are possible and the wrap-around hazard occurs if n < 8
. The compiler now turns the i
into i+8
inside the loop and makes every memory access with a fixed negative offset, e.g. [rsi-32+rax*4]
instead of [rsi+rax*4]
; or it does the naive code outlined above that wastes extra instructions and registers.
In both signed and unsigned cases, we can help the compiler along with some explicit checks. Something like this should do the trick:
void vector_scale(float* out, const float* in, ptrdiff_t n, float scale)
{
const __m256 vscale = _mm256_set1_ps(scale);
ptrdiff_t i = 0;
if(n > 7) {
for(; i < n - 7; i += 8) {
__m256 cur = _mm256_loadu_ps(in + i);
cur = _mm256_mul_ps(cur, vscale);
_mm256_storeu_ps(out + i, cur);
}
}
for(; i < n; ++i)
out[i] = in[i] * scale;
}
The slightly awkward if(n > 7) for(…)
construct can be replaced with a more compact but harder to read if() do {} while()
as shown below. I will not use it for the remainder of the answer in order to make the logic easier to follow.
if(n > 7) do {
__m256 cur = _mm256_loadu_ps(in + i);
cur = _mm256_mul_ps(cur, vscale);
_mm256_storeu_ps(out + i, cur);
} while((i += 8) < n - 7);
for(; i < n; ++i)
out[i] = in[i] * scale;
If you use signed loop counters like I do here, you probably get better code if you simply write if(n <= 0) return;
early in the function and then you can do a simple i < n - 7
condition without fear of wrap-around. However, for the benefit of the casual copy-and-paster, I make sure that all code below remains valid even if the loop counters are changed to unsigned types.
With SSE it used to be very important that memory accesses were aligned to 16 byte boundaries. AVX generally handles unaligned access gracefully but you may still find performance improvements with aligned accesses. For AVX512, alignment is again more important. Whether aligning the loop is worth the extra code and runtime needs to be tested. Here I show how to do it. Since we have two arrays, we can only really align to one. When in doubt, it is usually better to align the output but this may also depend on the specific case and hardware.
Computing the aligned position is a prime candidate for micro-optimizations. The routine I show below is reasonably efficient and written to be readable and reusable. C++ templates or direct inline code can probably save some instructions. The idea is always the same: You round the memory address to the next higher multiple of the vector alignment. Then you figure out how many scalar entries that are and you check that the next alignment isn't beyond the end of the array.
#include <stdint.h>
/* using uintptr_t */
inline void* next_aligned(const void* ptr, size_t alignment, const void* end)
{
const uintptr_t pos = (uintptr_t) ptr;
const uintptr_t alignedpos = (pos + alignment - 1) & -alignment;
void* rtrn = (void*) alignedpos;
return rtrn > end ? (void*) end : rtrn;
}
void vector_scale(float* out, const float* in, ptrdiff_t n, float scale)
{
const __m256 vscale = _mm256_set1_ps(scale);
const ptrdiff_t aligned =
(float*) next_aligned(out, sizeof(__m256), out + n) - out;
ptrdiff_t i;
for(i = 0; i < aligned; ++i)
out[i] = in[i] * scale;
if(n > 7) {
for(; i < n - 7; i += 8) {
__m256 cur = _mm256_loadu_ps(in + i);
cur = _mm256_mul_ps(cur, vscale);
_mm256_store_ps(out + i, cur);
}
}
for(; i < n; ++i)
out[i] = in[i] * scale;
}
Instead of scalar loops, we can use smaller vector sizes to deal with the misaligned front and the remaining tail elements. This has the added benefit of stopping the compiler from auto-vectorizing these short loops. Whether this is worth it needs to be tested. For example Oracle Java's compiler seems to think it is not worth it.
void vector_scale(float* out, const float* in, ptrdiff_t n, float scale)
{
const __m256 vscale = _mm256_set1_ps(scale);
const ptrdiff_t aligned =
(float*) next_aligned(out, sizeof(__m256), out + n) - out;
ptrdiff_t i, tail;
if(n <= 0) /* otherwise bit-tests fail if negative */
return;
if((i = aligned & 1) != 0)
out[0] = in[0] * scale;
if(aligned & 2) {
/* process 2 elements */
__m128 cur = _mm_loadl_pi(_mm_undefined_ps(), (const __m64*) (in + i));
cur = _mm_mul_ps(cur, _mm256_castps256_ps128(vscale));
_mm_storel_pi((__m64*) (out + i), cur);
i += 2;
}
if(aligned & 4) {
/* process 4 elements */
__m128 cur = _mm_loadu_ps(in + i);
cur = _mm_mul_ps(cur, _mm256_castps256_ps128(vscale));
_mm_store_ps(out + i, cur);
i += 4;
}
if(n > 7) {
for(; i < n - 7; i += 8) {
/* process 8 elements at a time */
__m256 cur = _mm256_loadu_ps(in + i);
cur = _mm256_mul_ps(cur, vscale);
_mm256_store_ps(out + i, cur);
}
}
tail = n - i;
if(tail & 4) {
__m128 cur = _mm_load_ps(in + i);
cur = _mm_mul_ps(cur, _mm256_castps256_ps128(vscale));
_mm_store_ps(out + i, cur);
i += 4;
}
if(tail & 2) {
__m128 cur = _mm_loadl_pi(_mm_undefined_ps(), (const __m64*) (in + i));
cur = _mm_mul_ps(cur, _mm256_castps256_ps128(vscale));
_mm_storel_pi((__m64*) (out + i), cur);
i += 2;
}
if(tail & 1) {
out[i] = in[i] * scale;
++i;
}
}
Instead of using smaller vectors, we can apply a simple trick: We do one misaligned iteration for the first 8 entries. Then, instead of skipping forward 8 elements, we skip forward to the aligned position. This means the first aligned iteration will partially overlap with the unaligned iteration but this is much cheaper than doing the partial processing we did above. We can then do the same for the tail with an unaligned iteration starting at n - 8
. This style of processing has three more requirements compared to the loops above.
in[i]
the position within the vector register must not matter (this can occur with shuffle instructions, for example)void vector_scale(
float* restrict out, const float* in, ptrdiff_t n, float scale)
{
const __m256 vscale = _mm256_set1_ps(scale);
const ptrdiff_t aligned =
(float*) next_aligned(out + 1, sizeof(__m256), out + n) - out;
ptrdiff_t i, tail = n - 8;
if(n < 8) {
/* special case for small inputs */
if(n <= 0)
return;
if((i = n & 1) != 0)
out[0] = in[0] * scale;
if(n & 2) {
__m128 cur = _mm_loadl_pi(_mm_undefined_ps(), (const __m64*) (in + i));
cur = _mm_mul_ps(cur, _mm256_castps256_ps128(vscale));
_mm_storel_pi((__m64*) (out + i), cur);
i += 2;
}
if(n & 4) {
__m128 cur = _mm_loadu_ps(in + i);
cur = _mm_mul_ps(cur, _mm256_castps256_ps128(vscale));
_mm_storeu_ps(out + i, cur);
i += 4;
}
return;
}
{ /* first iteration, potentially misaligned */
__m256 cur = _mm256_loadu_ps(in);
cur = _mm256_mul_ps(cur, vscale);
_mm256_storeu_ps(out, cur);
}
for(i = aligned; i < tail; i += 8) {
__m256 cur = _mm256_loadu_ps(in + i);
cur = _mm256_mul_ps(cur, vscale);
_mm256_store_ps(out + i, cur);
}
{ /* last iteration */
__m256 cur = _mm256_loadu_ps(in + tail);
cur = _mm256_mul_ps(cur, vscale);
_mm256_storeu_ps(out + tail, cur);
}
}
Please note the slight change from next_aligned(out, …)
to next_aligned(out + 1, …)
. Otherwise a properly aligned output array would result in one fully redundant iteration.
The sample above covers simple transformations but what about reductions? Let's explore a simple vector dot product: sum(left[i] * right[i])
. As above, we start simple and then optimize the code iteratively. I assume we have AVX and FMA instructions available for this one.
float vector_dot(const float* left, const float* right, ptrdiff_t n)
{
__m256 vsum = _mm256_setzero_ps();
__m128 high_half, low_half;
float headsum, tailsum = 0.f;
ptrdiff_t i = 0;
if(n > 7) {
for(; i < n - 7; i += 8) {
const __m256 lefti = _mm256_loadu_ps(left + i);
const __m256 righti = _mm256_loadu_ps(right + i);
vsum = _mm256_fmadd_ps(lefti, righti, vsum);
}
}
/* Reduce 8 elements to 4 */
high_half = _mm256_extractf128_ps(vsum, 1);
low_half = _mm256_castps256_ps128(vsum);
low_half = _mm_add_ps(low_half, high_half);
/* Reduce 4 to 2 */
high_half = _mm_movehl_ps(low_half, low_half);
low_half = _mm_add_ps(low_half, high_half);
/* Reduce 2 to 1 */
high_half = _mm_movehdup_ps(low_half);
low_half = _mm_add_ss(low_half, high_half);
headsum = _mm_cvtss_f32(low_half);
/* Process last 0-7 elements */
for(; i < n; ++i)
tailsum += left[i] * right[i];
return headsum + tailsum;
}
The main loop of our reduction has one critical issue: The FMA instruction depends on the value of the previous iteration. On most systems, FMA has a latency of 4 cycles (though older CPUs such as AMD Zen2 or Intel Broadwell use 5 cycles). This means, at best we get one iteration every 4 cycles instead of one per cycle. To avoid this, we have to interleave at least 4 FMA dependency chains; basically keeping four partial sums. What makes this extra painful is that you cannot rely on inner loops for this unless you know your compiler unrolls those (GCC doesn't).
float vector_dot(const float* left, const float* right, ptrdiff_t n)
{
__m256 vsum1 = _mm256_setzero_ps();
__m256 vsum2 = vsum1, vsum3 = vsum1, vsum4 = vsum1;
__m128 high_half, low_half;
float headsum, tailsum = 0.f;
ptrdiff_t i = 0;
if(n > 31) {
for(; i < n - 31; i += 32) {
__m256 lefti = _mm256_loadu_ps(left + i);
__m256 righti = _mm256_loadu_ps(right + i);
vsum1 = _mm256_fmadd_ps(lefti, righti, vsum1);
lefti = _mm256_loadu_ps(left + i + 8);
righti = _mm256_loadu_ps(right + i + 8);
vsum2 = _mm256_fmadd_ps(lefti, righti, vsum2);
lefti = _mm256_loadu_ps(left + i + 16);
righti = _mm256_loadu_ps(right + i + 16);
vsum3 = _mm256_fmadd_ps(lefti, righti, vsum3);
lefti = _mm256_loadu_ps(left + i + 24);
righti = _mm256_loadu_ps(right + i + 24);
vsum4 = _mm256_fmadd_ps(lefti, righti, vsum4);
}
}
if(n - i >= 16) {
__m256 lefti = _mm256_loadu_ps(left + i);
__m256 righti = _mm256_loadu_ps(right + i);
vsum1 = _mm256_fmadd_ps(lefti, righti, vsum1);
lefti = _mm256_loadu_ps(left + i + 8);
righti = _mm256_loadu_ps(right + i + 8);
vsum2 = _mm256_fmadd_ps(lefti, righti, vsum2);
i += 16;
}
if(n - i >= 8) {
__m256 lefti = _mm256_loadu_ps(left + i);
__m256 righti = _mm256_loadu_ps(right + i);
vsum1 = _mm256_fmadd_ps(lefti, righti, vsum1);
i += 8;
}
/* Reduce 32 to 16 */
vsum1 = _mm256_add_ps(vsum1, vsum3);
vsum2 = _mm256_add_ps(vsum2, vsum4);
/* Reduce 16 to 8 */
vsum1 = _mm256_add_ps(vsum1, vsum2);
/* Reduce 8 elements to 4 */
high_half = _mm256_extractf128_ps(vsum1, 1);
low_half = _mm256_castps256_ps128(vsum1);
low_half = _mm_add_ps(low_half, high_half);
/* Reduce 4 to 2 */
high_half = _mm_movehl_ps(low_half, low_half);
low_half = _mm_add_ps(low_half, high_half);
/* Reduce 2 to 1 */
high_half = _mm_movehdup_ps(low_half);
low_half = _mm_add_ss(low_half, high_half);
headsum = _mm_cvtss_f32(low_half);
/* Process last 0-7 elements */
for(; i < n; ++i)
tailsum += left[i] * right[i];
return headsum + tailsum;
}
The following optimizations are possible but will not be covered in detail:
ptrdiff_t tail;
if(n <= 0) /* check so bit tests don't fail with negative n */
return 0.f;
... /* main loop */
tail = n - i;
/* Reduce 8 elements to 4 */
high_half = _mm256_extractf128_ps(vsum1, 1);
low_half = _mm256_castps256_ps128(vsum1);
low_half = _mm_add_ps(low_half, high_half);
if(tail & 4) {
__m128 lefti = _mm_loadu_ps(left + i);
__m128 righti = _mm_loadu_ps(right + i);
low_half = _mm_fmadd_ps(lefti, righti, low_half);
i += 4;
}
if(tail & 2) {
__m128 lefti = _mm_loadl_pi(_mm_setzero_ps(), (const __m64*) (left + i));
__m128 righti = _mm_loadl_pi(_mm_setzero_ps(), (const __m64*) (right + i));
low_half = _mm_fmadd_ps(lefti, righti, low_half);
i += 2;
}
/* Reduce 2 to 1 */
high_half = _mm_movehdup_ps(low_half);
low_half = _mm_add_ss(low_half, high_half);
float sum = _mm_cvtss_f32(low_half);
if(tail & 1)
sum += left[i] * right[i];
return sum;
movaps
instructions or folding an aligned load into an arithmetic instruction.With AVX-512's mask registers, things get a lot easier. The alignment and tail loops can simply be implemented via masked loads and stores.
void vector_scale(float* out, const float* in, ptrdiff_t n, float scale)
{
const __m512 vscale = _mm512_set1_ps(scale);
const ptrdiff_t aligned =
(float*) next_aligned(out + 1, sizeof(__m512), out + n) - out;
__mmask16 mask = _cvtu32_mask16((1 << (unsigned) aligned) - 1);
ptrdiff_t i = aligned, tail;
__m512 cur;
if(n <= 0)
return;
cur = _mm512_maskz_loadu_ps(mask, in);
cur = _mm512_mul_ps(cur, vscale);
_mm512_mask_storeu_ps(out, mask, cur);
if(n > 15) {
for(; i < n - 15; i += 16) {
cur = _mm512_load_ps(in + i);
cur = _mm512_mul_ps(cur, vscale);
_mm512_store_ps(out + i, cur);
}
}
tail = n - i;
mask = _cvtu32_mask16((1 << (unsigned) tail) - 1);
cur = _mm512_maskz_load_ps(mask, in + i);
cur = _mm512_mul_ps(cur, vscale);
_mm512_mask_store_ps(out + i, mask, cur);
}
One thing to look out for when adapting this code is the mask generation: If you need a __mmask32
, the intermediate 1 << ones
needs to be a 64 bit type to avoid overflowing if ones == 32
and for __mmask64
you would need __int128
. Especially in the second case a better option is __mask64 mask = _cvtu64_mask64((uint64_t) -1 >> (64 - ones))
.
As usual with AVX-512, choosing 512 bit or 256 bit vectors requires testing and knowledge of the specific hardware since large vectors may or may not reduce clock rate, potentially affecting other parts of the code, and it may have other limiting factors.