Search code examples
c++simdintrinsicsavxavx2

How to interleave 3 float vectors into an array with AVX intrinsics C++


I have 3 __m256 vectors x, y, z filled with 8 elements of data each (single precision floats), and I'd like to store them interleaved into memory [x0, y0, z0, x1, y1, z1, ...].

What are the relevant and useful operations to use to store them into a (possibly unaligned) array or std::vector?

The brute force way is obviously terrible unless the compiler turns this into some vector shuffles:

#include "immintrin.h"
#include <vector>

// actually the results of computation, so typically already live in regs
__m256 x = _mm256_set_ps(7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f);
__m256 y = _mm256_set_ps(7.1f, 6.1f, 5.1f, 4.1f, 3.1f, 2.1f, 1.1f, 0.1f);
__m256 z = _mm256_set_ps(7.2f, 6.2f, 5.2f, 4.2f, 3.2f, 2.2f, 1.2f, 0.2f);

std::vector<float> result;
result.resize(24);
for (int i = 0; i < 8; i++)
{
    result[i * 3] = x[i];
    result[i * 3 + 1] = y[i];
    result[i * 3 + 2] = z[i];
} // result = {0.0f, 0.1f, 0.2f, 1.0f, 1.1f, 1.2f, etc..}

Solution

  • typedef __m256 f256;
    typedef __m256i i256;
    
    #define set8i _mm256_setr_epi32
    
    inline f256 permute8f(const f256 a, const i256 choice) {
       return _mm256_permutevar8x32_ps(a, choice); 
    }
    
    template<bool c0, bool c1, bool c2, bool c3, bool c4, bool c5, bool c6, bool c7>
    inline f256 select8f(const f256 tr, const f256 fr)
      { return _mm256_blend_ps(fr, tr, (c7 << 7) | (c6 << 6) | (c5 << 5) | (c4 << 4) | (c3 << 3) | (c2 << 2) | (c1 << 1) | c0); }
    
    void vec3_soa_to_aos(f256& A, f256& B, f256& C,
             const f256 x, const f256 y, const f256 z)
    {
      // indices so we can permute into something sane. 
      const i256 PX = set8i(0, 3, 6, 1, 4, 7, 2, 5);
      const i256 PY = set8i(5, 0, 3, 6, 1, 4, 7, 2);
      const i256 PZ = set8i(2, 5, 0, 3, 6, 1, 4, 7);
    
      // re-arrange so we can select correct elements.
      const f256 X = permute8f(x, PX);  // 0.0f  3.0f  6.0f  1.0f  4.0f  7.0f  2.0f  5.0f
      const f256 Y = permute8f(y, PY);  // 5.1f  0.1f  3.1f  6.1f  1.1f  4.1f  7.1f  2.1f
      const f256 Z = permute8f(z, PZ);  // 2.2f  5.2f  0.2f  3.2f  6.2f  1.2f  4.2f  7.2f 
      
      // perform our two stage selection
      const f256 A0 = select8f<1, 0, 0, 1, 0, 0, 1, 0>(X, Y);  // 0.0f  0.1f  whatever  1.0f  1.1f  whatever  2.0f  2.1f
      const f256 B0 = select8f<1, 0, 0, 1, 0, 0, 1, 0>(Z, X);  // 2.2f  3.0f  whatever  3.2f  4.0f  whatever  4.2f  5.0f
      const f256 C0 = select8f<1, 0, 0, 1, 0, 0, 1, 0>(Y, Z);  // 5.1f  5.2f  whatever  6.1f  6.2f  whatever  7.1f  7.2f
      A = select8f<0, 0, 1, 0, 0, 1, 0, 0>(Z, A0);  // 0.0f  0.1f  0.2f  1.0f  1.1f  1.2f  2.0f  2.1f
      B = select8f<0, 0, 1, 0, 0, 1, 0, 0>(Y, B0);  // 2.2f  3.0f  3.1f  3.2f  4.0f  4.1f  4.2f  5.0f
      C = select8f<0, 0, 1, 0, 0, 1, 0, 0>(X, C0);  // 5.1f  5.2f  6.0f  6.1f  6.2f  7.0f  7.1f  7.2f 
    }
    
    // for completeness.... 
    inline void vec3_aos_to_soa(
      const f256 A, const f256 B, const f256 C, 
      f256& x, f256& y, f256& z)
    {
      const f256 X0 = select8f<1, 0, 0, 1, 0, 0, 1, 0>(A, B);
      const f256 Y0 = select8f<1, 0, 0, 1, 0, 0, 1, 0>(C, A);
      const f256 Z0 = select8f<1, 0, 0, 1, 0, 0, 1, 0>(B, C);
    
      const f256 X = select8f<0, 0, 1, 0, 0, 1, 0, 0>(C, X0);
      const f256 Y = select8f<0, 0, 1, 0, 0, 1, 0, 0>(B, Y0);
      const f256 Z = select8f<0, 0, 1, 0, 0, 1, 0, 0>(A, Z0);
    
      const i256 PX = set8i(0, 3, 6, 1, 4, 7, 2, 5);
      const i256 PY = set8i(1, 4, 7, 2, 5, 0, 3, 6);
      const i256 PZ = set8i(2, 5, 0, 3, 6, 1, 4, 7);
    
      // rearrange and output
      x = permute8f(X, PX);
      y = permute8f(Y, PY);
      z = permute8f(Z, PZ);
    }
    
    __m256 x = _mm256_set_ps(7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f);
    __m256 y = _mm256_set_ps(7.1f, 6.1f, 5.1f, 4.1f, 3.1f, 2.1f, 1.1f, 0.1f);
    __m256 z = _mm256_set_ps(7.2f, 6.2f, 5.2f, 4.2f, 3.2f, 2.2f, 1.2f, 0.2f);
    
    vec3_soa_to_aos(x, y, z, x, y, z);
    
    
    std::vector<float> result;
    result.resize(24);
    _mm256_storeu_ps(result.data(), x);
    _mm256_storeu_ps(result.data() + 8, y);
    _mm256_storeu_ps(result.data() + 16, z);