Search code examples
c++clangvectorizationsimdauto-vectorization

Why can't clang vectorise this loop over a std::span, writing results to a std::array?


Why won't clang 17.0.1 vectorise the loop in the following function:

void adapt(std::span<const F, N + 1> signal)
{
    F true_val = signal.back();
    F y = dot_prod<F, N>(&signal[0], w.begin());
    F e = true_val - y;
    F dot = dot_prod<F, N>(&signal[0], &signal[0]);
    F nu = mu / (eps + dot);
    for (size_t i = 0; i < N; i++)
    {
        w[i] += nu * e * signal[i];
    }
}

There is no carried dependency or issues with floating point associativity and GCC 13.2 has no trouble vectorising it.

Here's a link to the full code on compiler explorer.

The context of this is that I am trying to optimise my code to use a vectorised dot product. Usually std::inner_product will emit a scalar implementation due to the floating point associativity problem unless you use -ffast-math. However, I only want the -ffast-math to apply to a single function, and I was looking for a portable way to do this for both clang and GCC. While looking at the output, I noticed that clang doesn't vectorise this other loop.

Here is the full C++:

#include <cstddef>
#include <span>
#include <numeric>
#include <cassert>

#pragma float_control(precise, off, push)
template <typename F, size_t N> 
__attribute__((optimize("-ffast-math"))) 
constexpr inline F dot_prod(const F *a, const F *b)
{
    F acc = 0.0f;
    for(size_t i = 0; i < N; i++)
        acc += a[i] * b[i];
    return acc;
}
#pragma float_control(pop)

template <typename F, size_t N> class NlmsFilter
{
    static constexpr F mu = 0.5;
    static constexpr F eps = 1.0;
    std::array<F, N> w = {};

  public:
    F predict(std::span<const F> signal)
    {
        assert(signal.size() <= N + 1);
        if (signal.size() == N + 1)
        {
            auto signal_static = signal.template subspan<0, N + 1>();
            adapt(signal_static);
            return dot_prod<F, N>(&signal_static[1], w.begin());
        }
        else if (signal.size() == 0)
        {
            return 0.0f;
        }
        else
        {
            return signal.back();
        }
    }
    void adapt(std::span<const F, N + 1> signal)
    {
        F true_val = signal.back();
        F y = dot_prod<F, N>(&signal[0], w.begin());
        F e = true_val - y;
        F dot = dot_prod<F, N>(&signal[0], &signal[0]);
        F nu = mu / (eps + dot);
        for (size_t i = 0; i < N; i++)
        {
            w[i] += nu * e * signal[i];
        }
    }
};

template class NlmsFilter<float, 32>;

Here is the assembly output from clang:

NlmsFilter<float, 32ul>::predict(std::span<float const, 18446744073709551615ul>): # @NlmsFilter<float, 32ul>::predict(std::span<float const, 18446744073709551615ul>)
        push    r14
        push    rbx
        push    rax
        cmp     rdx, 34
        jae     .LBB0_6
        test    rdx, rdx
        je      .LBB0_2
        mov     rbx, rsi
        cmp     rdx, 33
        jne     .LBB0_4
        mov     r14, rdi
        mov     rsi, rbx
        call    NlmsFilter<float, 32ul>::adapt(std::span<float const, 33ul>)
        add     rbx, 4
        mov     rdi, rbx
        mov     rsi, r14
        add     rsp, 8
        pop     rbx
        pop     r14
        jmp     float dot_prod<float, 32ul>(float const*, float const*)   # TAILCALL
.LBB0_2:
        vxorps  xmm0, xmm0, xmm0
        add     rsp, 8
        pop     rbx
        pop     r14
        ret
.LBB0_4:
        vmovss  xmm0, dword ptr [rbx + 4*rdx - 4] # xmm0 = mem[0],zero,zero,zero
        add     rsp, 8
        pop     rbx
        pop     r14
        ret
.LBB0_6:
        mov     edi, offset .L.str
        mov     esi, offset .L.str.1
        mov     ecx, offset .L__PRETTY_FUNCTION__.NlmsFilter<float, 32ul>::predict(std::span<float const, 18446744073709551615ul>)
        mov     edx, 27
        call    __assert_fail
.LCPI1_0:
        .long   0x3f800000                      # float 1
.LCPI1_1:
        .long   0x3f000000                      # float 0.5
NlmsFilter<float, 32ul>::adapt(std::span<float const, 33ul>): # @NlmsFilter<float, 32ul>::adapt(std::span<float const, 33ul>)
        push    r14
        push    rbx
        push    rax
        mov     r14, rsi
        mov     rbx, rdi
        vmovss  xmm0, dword ptr [rsi + 128]     # xmm0 = mem[0],zero,zero,zero
        vmovss  dword ptr [rsp + 4], xmm0       # 4-byte Spill
        mov     rdi, rsi
        mov     rsi, rbx
        call    float dot_prod<float, 32ul>(float const*, float const*)
        vmovss  xmm1, dword ptr [rsp + 4]       # 4-byte Reload
        vsubss  xmm0, xmm1, xmm0
        vmovss  dword ptr [rsp + 4], xmm0       # 4-byte Spill
        mov     rdi, r14
        mov     rsi, r14
        call    float dot_prod<float, 32ul>(float const*, float const*)
        vaddss  xmm0, xmm0, dword ptr [rip + .LCPI1_0]
        vmovss  xmm1, dword ptr [rip + .LCPI1_1] # xmm1 = mem[0],zero,zero,zero
        vdivss  xmm0, xmm1, xmm0
        vmulss  xmm0, xmm0, dword ptr [rsp + 4] # 4-byte Folded Reload
        vmulss  xmm1, xmm0, dword ptr [r14]
        vaddss  xmm1, xmm1, dword ptr [rbx]
        vmovss  dword ptr [rbx], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 4]
        vaddss  xmm1, xmm1, dword ptr [rbx + 4]
        vmovss  dword ptr [rbx + 4], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 8]
        vaddss  xmm1, xmm1, dword ptr [rbx + 8]
        vmovss  dword ptr [rbx + 8], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 12]
        vaddss  xmm1, xmm1, dword ptr [rbx + 12]
        vmovss  dword ptr [rbx + 12], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 16]
        vaddss  xmm1, xmm1, dword ptr [rbx + 16]
        vmovss  dword ptr [rbx + 16], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 20]
        vaddss  xmm1, xmm1, dword ptr [rbx + 20]
        vmovss  dword ptr [rbx + 20], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 24]
        vaddss  xmm1, xmm1, dword ptr [rbx + 24]
        vmovss  dword ptr [rbx + 24], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 28]
        vaddss  xmm1, xmm1, dword ptr [rbx + 28]
        vmovss  dword ptr [rbx + 28], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 32]
        vaddss  xmm1, xmm1, dword ptr [rbx + 32]
        vmovss  dword ptr [rbx + 32], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 36]
        vaddss  xmm1, xmm1, dword ptr [rbx + 36]
        vmovss  dword ptr [rbx + 36], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 40]
        vaddss  xmm1, xmm1, dword ptr [rbx + 40]
        vmovss  dword ptr [rbx + 40], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 44]
        vaddss  xmm1, xmm1, dword ptr [rbx + 44]
        vmovss  dword ptr [rbx + 44], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 48]
        vaddss  xmm1, xmm1, dword ptr [rbx + 48]
        vmovss  dword ptr [rbx + 48], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 52]
        vaddss  xmm1, xmm1, dword ptr [rbx + 52]
        vmovss  dword ptr [rbx + 52], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 56]
        vaddss  xmm1, xmm1, dword ptr [rbx + 56]
        vmovss  dword ptr [rbx + 56], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 60]
        vaddss  xmm1, xmm1, dword ptr [rbx + 60]
        vmovss  dword ptr [rbx + 60], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 64]
        vaddss  xmm1, xmm1, dword ptr [rbx + 64]
        vmovss  dword ptr [rbx + 64], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 68]
        vaddss  xmm1, xmm1, dword ptr [rbx + 68]
        vmovss  dword ptr [rbx + 68], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 72]
        vaddss  xmm1, xmm1, dword ptr [rbx + 72]
        vmovss  dword ptr [rbx + 72], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 76]
        vaddss  xmm1, xmm1, dword ptr [rbx + 76]
        vmovss  dword ptr [rbx + 76], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 80]
        vaddss  xmm1, xmm1, dword ptr [rbx + 80]
        vmovss  dword ptr [rbx + 80], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 84]
        vaddss  xmm1, xmm1, dword ptr [rbx + 84]
        vmovss  dword ptr [rbx + 84], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 88]
        vaddss  xmm1, xmm1, dword ptr [rbx + 88]
        vmovss  dword ptr [rbx + 88], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 92]
        vaddss  xmm1, xmm1, dword ptr [rbx + 92]
        vmovss  dword ptr [rbx + 92], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 96]
        vaddss  xmm1, xmm1, dword ptr [rbx + 96]
        vmovss  dword ptr [rbx + 96], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 100]
        vaddss  xmm1, xmm1, dword ptr [rbx + 100]
        vmovss  dword ptr [rbx + 100], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 104]
        vaddss  xmm1, xmm1, dword ptr [rbx + 104]
        vmovss  dword ptr [rbx + 104], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 108]
        vaddss  xmm1, xmm1, dword ptr [rbx + 108]
        vmovss  dword ptr [rbx + 108], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 112]
        vaddss  xmm1, xmm1, dword ptr [rbx + 112]
        vmovss  dword ptr [rbx + 112], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 116]
        vaddss  xmm1, xmm1, dword ptr [rbx + 116]
        vmovss  dword ptr [rbx + 116], xmm1
        vmulss  xmm1, xmm0, dword ptr [r14 + 120]
        vaddss  xmm1, xmm1, dword ptr [rbx + 120]
        vmovss  dword ptr [rbx + 120], xmm1
        vmulss  xmm0, xmm0, dword ptr [r14 + 124]
        vaddss  xmm0, xmm0, dword ptr [rbx + 124]
        vmovss  dword ptr [rbx + 124], xmm0
        add     rsp, 8
        pop     rbx
        pop     r14
        ret
float dot_prod<float, 32ul>(float const*, float const*):          # @float dot_prod<float, 32ul>(float const*, float const*)
        vmovups ymm0, ymmword ptr [rsi]
        vmovups ymm1, ymmword ptr [rsi + 32]
        vmovups ymm2, ymmword ptr [rsi + 64]
        vmulps  ymm2, ymm2, ymmword ptr [rdi + 64]
        vmovups ymm3, ymmword ptr [rsi + 96]
        vmulps  ymm0, ymm0, ymmword ptr [rdi]
        vaddps  ymm0, ymm0, ymm2
        vmulps  ymm2, ymm3, ymmword ptr [rdi + 96]
        vmulps  ymm1, ymm1, ymmword ptr [rdi + 32]
        vaddps  ymm1, ymm1, ymm2
        vaddps  ymm0, ymm0, ymm1
        vextractf128    xmm1, ymm0, 1
        vaddps  xmm0, xmm0, xmm1
        vpermilpd       xmm1, xmm0, 1           # xmm1 = xmm0[1,0]
        vaddps  xmm0, xmm0, xmm1
        vmovshdup       xmm1, xmm0              # xmm1 = xmm0[1,1,3,3]
        vaddss  xmm0, xmm0, xmm1
        vzeroupper
        ret
NlmsFilter<float, 32ul>::mu:
        .long   0x3f000000                      # float 0.5

NlmsFilter<float, 32ul>::eps:
        .long   0x3f800000                      # float 1

.L.str:
        .asciz  "signal.size() <= N + 1"

.L.str.1:
        .asciz  "/app/example.cpp"

.L__PRETTY_FUNCTION__.NlmsFilter<float, 32ul>::predict(std::span<float const, 18446744073709551615ul>):
        .asciz  "F NlmsFilter<float, 32>::predict(std::span<const F>) [F = float, N = 32]"

Solution

  • The problem is that clang can't be sure whether w (partially) aliases with signal (or the data signal points to).

    GCC in this case compares the w.data() and signal.data() pointers. If they differ enough, it vectorizes the loop, otherwise it falls back to a scalar loop. Clang probably thinks that the cost of the additional test is not justified by the potential performance gain (it does at some point, if you increase the array size).

    You can tell GCC/Clang that pointed-to memory will not overlap using the __restrict keyword (this is non-standard C++. C does provide a restrict keyword). However, this can only be applied to pointers not directly to a std::span (as far as I know). You can work-around this by providing a helper function which takes its output address as a __restrict-ed plain pointer, e.g.:

    template<typename F, size_t N>
    inline void scale(F * __restrict out, F const* in, F scale)
    {
        for(size_t i=0; i < N; ++i)
        {
            out[i] = scale * in[i];
        }
    }
    

    And instead of your loop, call it like:

    scale<F, N>(w.data(), signal.data(), nu*e);
    

    Modified godbolt-Link: https://godbolt.org/z/WWbdz489e