Search code examples
optimizationvectorizationquaternionsavx

Squared Quaternion using AVX


Does any one know how to vectorize this function using AVX

void cuadradoYSumaNormal(quaternion* a, quaternion* b, quaternion* c) {
          c->w = a->w*a->w - a->x*a->x - a->y*a->y - a->z*a->z + b->w;
          c->x = 2.*a->w*a->x + b->x;
          c->y = 2.*a->w*a->y + b->y;
          c->z = 2.*a->w*a->z + b->z;
    }

I can assume unit-length for a, b and c

quaternion is the following struct:

struct quaternion{
  double w;
  double x;
  double y;
  double z;
};

What the function must do is to square the quaternion *a (using quaterninon multiplication rules) then add the quaternion *b and store the result in *c.


Solution

  • This solution works in case a has unit length, i.e., aw^2+ax^2+ay^2+az^2 == 1

    In that case, the calculation of c->w is equivalent to calculating 2*a->w*a->w - 1.0 + b->w, making this far easier to vectorize. Multiplication with 2 can be achieved by adding a (or a->w) to itself. To reduce the latency chain, the -1.0 should be added to b->w. Possible implementation:

    inline __m256d unit(double value = 1.0)
    {
        return _mm256_set_pd(0,0,0,value);
    }
    
    void cuadradoYSumaNormal_avx(quaternion* a, quaternion* b, quaternion* c) {
    
        __m256d aw = _mm256_broadcast_sd(&a->w);
        __m256d a_ = _mm256_loadu_pd(&a->w);
        __m256d b_ = _mm256_loadu_pd(&b->w);
    
        __m256d a_squared_plus_one = _mm256_mul_pd(aw, _mm256_add_pd(a_,a_));
        __m256d c_ = _mm256_add_pd(a_squared_plus_one, _mm256_add_pd(b_, unit(-1.0)));
    
        _mm256_storeu_pd(&c->w, c_);
    }
    

    If besides AVX you have FMA available, you can join some additions and multiplications to

    (aw * a + [-0.5,0,0,0]) * 2.0 + b
    

    Resulting in just two FMAs (and one broadcast and some loads). Possible implementation:

    void cuadradoYSumaNormal_fma(quaternion* a, quaternion* b, quaternion* c) {
    
        __m256d aw = _mm256_broadcast_sd(&a->w);
        __m256d a_ = _mm256_loadu_pd(&a->w);
        __m256d b_ = _mm256_loadu_pd(&b->w);
    
        __m256d a_squared_half = _mm256_fmadd_pd(aw, a_, unit(-0.5));
        __m256d c_ = _mm256_fmadd_pd(a_squared_half, _mm256_set1_pd(2.0), b_);
    
        _mm256_storeu_pd(&c->w, c_);
    }