Search code examples
cgccassemblyx86loop-unrolling

Why does gcc's code-gen for my unrolled loop epilogue look over-complicated?


Thanks for all the comments so far. I am sorry that I have used a bad example in my original question, that almost everyone would say: "Oh, you should use memcopy!" But that is not what my question is about.

My question is more generic about how manual loop unrolling should be done. Consider this example this time, by summing all elements in an array:

#include <stdlib.h>

double sum (size_t n, double *x) {
  size_t nr = n & 1;
  double *end = x + (n - nr);
  double sum_x = 0.0;
  for (; x < end; x++) sum_x += *x;
  if (nr) sum_x += *x;
  return sum_x;
  }

The compiler generated assembly admits a similar behaviour (to what is shown by the array-copying example in my original question)

sum:
  movq %rdi, %rcx
  andl $1, %ecx
  subq %rcx, %rdi
  leaq (%rsi,%rdi,8), %rdx
  cmpq %rdx, %rsi
  jnb .L5
  movq %rsi, %rax
  pxor %xmm0, %xmm0
.L3:
  addsd (%rax), %xmm0
  addq $8, %rax
  cmpq %rax, %rdx
  ja .L3
  movq %rsi, %rax
  notq %rax
  addq %rax, %rdx
  shrq $3, %rdx
  leaq 8(%rsi,%rdx,8), %rsi
.L2:
  testq %rcx, %rcx
  je .L1
  addsd (%rsi), %xmm0
.L1:
  ret
.L5:
  pxor %xmm0, %xmm0
  jmp .L2

However, if I now schedule the "fractional" part ahead of the main loop (as I later dig out in an answer I posted), the compiler does much better job.

#include <stdlib.h>

double sum (size_t n, double *x) {
  size_t nr = n & 1;
  double *end = x + n;
  double sum_x = 0.0;
  if (nr) sum_x += *x;
  for (x += nr; x < end; x++) sum_x += *x;
  return sum_x;
  }

sum:
  leaq (%rsi,%rdi,8), %rdx
  pxor %xmm0, %xmm0
  andl $1, %edi
  je .L2
  addsd (%rsi), %xmm0
.L2:
  leaq (%rsi,%rdi,8), %rax
  cmpq %rax, %rdx
  jbe .L1
.L4:
  addsd (%rax), %xmm0
  addq $8, %rax
  cmpq %rax, %rdx
  ja .L4
.L1:
  ret

I have only used a compiler flag -O2. So as Peter said, the compiler generated assembly should be close to C source code. Then the question is, why does a compiler do better in the latter case?

This is not really a performance-related question. It is just something I unconsciously found (and can't explain) when checking compiler's assembly output for C code from a C project I have been writing. Thanks again. Thank Peter for proposing a better title for the question.


Original question:

The following small C function copies a, a vector of n entries to b. A manual loop unrolling of depth 2 is applied.

#include <stddef.h>

void foo (ptrdiff_t n, double *a, double *b) {
  ptrdiff_t i = 0;
  ptrdiff_t nr = n & 1;
  n -= nr;                  // `n` is an even integer
  while (i < n) {
    b[i] = a[i];
    b[i + 1] = a[i + 1];
    i += 2;
    }                       // `i = n` when the loop ends
  if (nr) b[i] = a[i];
  }

It gives the x64 assembly under gcc -O2 (any gcc version 5.4+). However, I find the part of the output as commented weird. Why does the compiler ever generate them?

foo:
  movq %rdi, %rcx
  xorl %eax, %eax
  andl $1, %ecx
  subq %rcx, %rdi
  testq %rdi, %rdi
  jle .L11
.L12:
  movsd (%rsi,%rax,8), %xmm0
  movsd %xmm0, (%rdx,%rax,8)
  movsd 8(%rsi,%rax,8), %xmm0
  movsd %xmm0, 8(%rdx,%rax,8)
  addq $2, %rax
  cmpq %rax, %rdi           // `i` in %rax, `n` in %rdi
  jg .L12                   // the loop ends, with `i = n`, BELOW IS WEIRD
  subq $1, %rdi             // n = n - 1;
  shrq %rdi                 // n = n / 2;
  leaq 2(%rdi,%rdi), %rax   // i = 2 * n + 2;  (this is just `i = n`, isn't it?)
.L11:
  testq %rcx, %rcx
  je .L10
  movsd (%rsi,%rax,8), %xmm0
  movsd %xmm0, (%rdx,%rax,8)
.L10:
  ret

A similar version using size_t instead of ptrdiff_t gives something similar:

#include <stdlib.h>

void bar (size_t n, double *a, double *b) {
  size_t i = 0;
  size_t nr = n & 1;
  n -= nr;                  // `n` is an even integer
  while (i < n) {
    b[i] = a[i];
    b[i + 1] = a[i + 1];
    i += 2;
    }                       // `i = n` when the loop ends
  if (nr) b[i] = a[i];
  }

bar:
  movq %rdi, %rcx
  andl $1, %ecx
  subq %rcx, %rdi
  je .L20
  xorl %eax, %eax
.L21:
  movsd (%rsi,%rax,8), %xmm0
  movsd %xmm0, (%rdx,%rax,8)
  movsd 8(%rsi,%rax,8), %xmm0
  movsd %xmm0, 8(%rdx,%rax,8)
  addq $2, %rax
  cmpq %rax, %rdi           // `i` in %rax, `n` in %rdi
  ja .L21                   // the loop ends, with `i = n`, BUT BELOW IS WEIRD
  subq $1, %rdi             // n = n - 1;
  andq $-2, %rdi            // n = n & (-2);
  addq $2, %rdi             // n = n + 2;  (this is just `i = n`, isn't it?)
.L20:
  testq %rcx, %rcx
  je .L19
  movsd (%rsi,%rdi,8), %xmm0
  movsd %xmm0, (%rdx,%rdi,8)
.L19:
  ret

And here is another equivalence,

#include <stdlib.h>

void baz (size_t n, double *a, double *b) {
  size_t nr = n & 1;
  n -= nr;
  double *b_end = b + n;
  while (b < b_end) {
    b[0] = a[0];
    b[1] = a[1];
    a += 2;
    b += 2;
    }                       // `b = b_end` when the loop ends
  if (nr) b[0] = a[0];
  }

but the following assembly looks more odd (though produced under -O2). Now n, a and b are all copied, and when the loop ends, we take 5 lines of code just to end up with b_copy = 0?!

baz:                        // initially, `n` in %rdi, `a` in %rsi, `b` in %rdx
  movq %rdi, %r8            // n_copy = n;
  andl $1, %r8d             // nr = n_copy & 1;
  subq %r8, %rdi            // n_copy -= nr;
  leaq (%rdx,%rdi,8), %rdi  // b_end = b + n;
  cmpq %rdi, %rdx           // if (b >= b_end) jump to .L31
  jnb .L31
  movq %rdx, %rax           // b_copy = b;
  movq %rsi, %rcx           // a_copy = a;
.L32:
  movsd (%rcx), %xmm0
  addq $16, %rax
  addq $16, %rcx
  movsd %xmm0, -16(%rax)
  movsd -8(%rcx), %xmm0
  movsd %xmm0, -8(%rax)
  cmpq %rax, %rdi           // `b_copy` in %rax, `b_end` in %rdi
  ja .L32                   // the loop ends, with `b_copy = b_end`
  movq %rdx, %rax           // b_copy = b;
  notq %rax                 // b_copy = ~b_copy;
  addq %rax, %rdi           // b_end = b_end + b_copy;
  andq $-16, %rdi           // b_end = b_end & (-16);
  leaq 16(%rdi), %rax       // b_copy = b_end + 16;
  addq %rax, %rsi           // a += b_copy;   (isn't `b_copy` just 0?)
  addq %rax, %rdx           // b += b_copy;
.L31:
  testq %r8, %r8            // if (nr == 0) jump to .L30
  je .L30
  movsd (%rsi), %xmm0       // xmm0 = a[0];
  movsd %xmm0, (%rdx)       // b[0] = xmm0;
.L30:
  ret

Can anyone explain what the compiler has in mind in all three cases?


Solution

  • Looks like if I unroll the loop in the following manner, a compiler can generate neater code.

    #include <stdlib.h>
    #include <stddef.h>
    
    void foo (ptrdiff_t n, double *a, double *b) {
      ptrdiff_t i = n & 1;
      if (i) b[0] = a[0];
      while (i < n) {
        b[i] = a[i];
        b[i + 1] = a[i + 1];
        i += 2;
        }
      }
    
    void bar (size_t n, double *a, double *b) {
      size_t i = n & 1;
      if (i) b[0] = a[0];
      while (i < n) {
        b[i] = a[i];
        b[i + 1] = a[i + 1];
        i += 2;
        }
      }
    
    void baz (size_t n, double *a, double *b) {
      size_t nr = n & 1;
      double *b_end = b + n;
      if (nr) b[0] = a[0];
      b += nr;
      while (b < b_end) {
        b[0] = a[0];
        b[1] = a[1];
        a += 2;
        b += 2;
        }
      }
    

    foo:
      movq %rdi, %rax
      andl $1, %eax
      je .L9
      movsd (%rsi), %xmm0
      movsd %xmm0, (%rdx)
      cmpq %rax, %rdi
      jle .L11
    .L4:
      movsd (%rsi,%rax,8), %xmm0
      movsd %xmm0, (%rdx,%rax,8)
      movsd 8(%rsi,%rax,8), %xmm0
      movsd %xmm0, 8(%rdx,%rax,8)
      addq $2, %rax
    .L9:
      cmpq %rax, %rdi
      jg .L4
    .L11:
      ret
    

    bar:
      movq %rdi, %rax
      andl $1, %eax
      je .L20
      movsd (%rsi), %xmm0
      movsd %xmm0, (%rdx)
      cmpq %rax, %rdi
      jbe .L21
    .L15:
      movsd (%rsi,%rax,8), %xmm0
      movsd %xmm0, (%rdx,%rax,8)
      movsd 8(%rsi,%rax,8), %xmm0
      movsd %xmm0, 8(%rdx,%rax,8)
      addq $2, %rax
    .L20:
      cmpq %rax, %rdi
      ja .L15
    .L21:
      ret
    

    baz:
      leaq (%rdx,%rdi,8), %rcx
      andl $1, %edi
      je .L23
      movsd (%rsi), %xmm0
      movsd %xmm0, (%rdx)
    .L23:
      leaq (%rdx,%rdi,8), %rax
      cmpq %rax, %rcx
      jbe .L22
    .L25:
      movsd (%rsi), %xmm0
      addq $16, %rax
      addq $16, %rsi
      movsd %xmm0, -16(%rax)
      movsd -8(%rsi), %xmm0
      movsd %xmm0, -8(%rax)
      cmpq %rax, %rcx
      ja .L25
    .L22:
      ret