Search code examples
performancerustx86

How to get efficient floating point maximum in Rust


I was testing how to get the maximum for an array of floating points:

pub fn max(n: [f64;8]) -> f64 {
    IntoIterator::into_iter(n).reduce(|a,b| a.max(b)).unwrap()
}

which gives me (nightly Rust)

        vmovsd  xmm0, qword ptr [rdi + 56]
        vmovsd  xmm1, qword ptr [rdi + 48]
        vmovsd  xmm2, qword ptr [rdi + 40]
        vmovsd  xmm3, qword ptr [rdi + 32]
        vmovsd  xmm4, qword ptr [rdi + 24]
        vmovsd  xmm5, qword ptr [rdi + 16]
        vmovsd  xmm6, qword ptr [rdi]
        vmovsd  xmm7, qword ptr [rdi + 8]
        vcmpunordsd     xmm8, xmm6, xmm6
        vmaxsd  xmm6, xmm7, xmm6
        vblendvpd       xmm6, xmm6, xmm7, xmm8
        vcmpunordsd     xmm7, xmm6, xmm6
        vmaxsd  xmm6, xmm5, xmm6
        vblendvpd       xmm5, xmm6, xmm5, xmm7
        vcmpunordsd     xmm6, xmm5, xmm5
        vmaxsd  xmm5, xmm4, xmm5
        vblendvpd       xmm4, xmm5, xmm4, xmm6
        vcmpunordsd     xmm5, xmm4, xmm4
        vmaxsd  xmm4, xmm3, xmm4
        vblendvpd       xmm3, xmm4, xmm3, xmm5
        vcmpunordsd     xmm4, xmm3, xmm3
        vmaxsd  xmm3, xmm2, xmm3
        vblendvpd       xmm2, xmm3, xmm2, xmm4
        vcmpunordsd     xmm3, xmm2, xmm2
        vmaxsd  xmm2, xmm1, xmm2
        vblendvpd       xmm1, xmm2, xmm1, xmm3
        vcmpunordsd     xmm2, xmm1, xmm1
        vmaxsd  xmm1, xmm0, xmm1
        vblendvpd       xmm0, xmm1, xmm0, xmm2
        ret

So I spend a lot of time with NaN handling. I am pretty sure that vmaxsd does the same as f64::max in Rust, but I am not sure if I overlook something.

So I turned to C++ and got

double max(double *num) {
    double sum = num[0];
    for (int i = 1; i < 8; i++) {
        sum = std::max(sum, num[i]);
    }
    return sum;
}

which compiles to (on gcc 14.1)

        vmovsd  xmm2, QWORD PTR [rdi]
        vmovsd  xmm1, QWORD PTR [rdi+8]
        vmaxsd  xmm0, xmm1, xmm2
        vmovsd  xmm1, QWORD PTR [rdi+16]
        vmovsd  xmm2, QWORD PTR [rdi+24]
        vmaxsd  xmm1, xmm1, xmm0
        vmaxsd  xmm0, xmm2, xmm1
        vmovsd  xmm2, QWORD PTR [rdi+32]
        vmaxsd  xmm1, xmm2, xmm0
        vmovsd  xmm2, QWORD PTR [rdi+40]
        vmaxsd  xmm0, xmm2, xmm1
        vmovsd  xmm2, QWORD PTR [rdi+48]
        vmaxsd  xmm1, xmm2, xmm0
        vmovsd  xmm0, QWORD PTR [rdi+56]
        vmaxsd  xmm0, xmm0, xmm1
        ret

(no fast-math option, just -O3)

which leads me to believe the assembly from Rust is sub-optimal or the semantics of C++ max and Rust max are different.

Can someone shed some light on this issue? And how could I emit the same code as C++ with Rust here?


Solution

  • The documentation of f64::max tells us:

    If one of the arguments is NaN, then the other argument is returned

    So it only produces NaN when both arguments are NaN.

    But std::max uses < for comparision which can produce NaN if only one of the operands is NaN. Similarly MAXSD always returns the second operand when either is NaN and thus also can return NaN with only one (the second) operand being NaN:

    MAX(SRC1, SRC2)
    {
        IF ((SRC1 = 0.0) and (SRC2 = 0.0)) THEN DEST := SRC2;
            ELSE IF (SRC1 = NaN) THEN DEST := SRC2; FI;
            ELSE IF (SRC2 = NaN) THEN DEST := SRC2; FI;
            ELSE IF (SRC1 > SRC2) THEN DEST := SRC1;
            ELSE DEST := SRC2;
        FI;
    }
    

    So while MAXSD and C++s std::max have compatible semantics, Rusts f64::max is not compatible:

    std::cout << std::max(nan, 1.0) << " " << std::max(1.0, nan); // → nan 1
    
    println!("{} {}", f64::max(nan, 1.0), f64::max(1.0, nan));    // → 1 1
    

    Using the same semantics in Rust produces equivalent assembly:

    pub fn max(n: [f64;8]) -> f64 {
        n.into_iter().reduce(|a,b| if a < b { b } else { a }).unwrap()
    }
    
    example::max::h17b765fea01ee3b1:
            movsd   xmm0, qword ptr [rdi + 56]
            movsd   xmm1, qword ptr [rdi + 48]
            movsd   xmm2, qword ptr [rdi + 40]
            movsd   xmm3, qword ptr [rdi + 32]
            movsd   xmm4, qword ptr [rdi + 24]
            movsd   xmm5, qword ptr [rdi + 8]
            maxsd   xmm5, qword ptr [rdi]
            movsd   xmm6, qword ptr [rdi + 16]
            maxsd   xmm6, xmm5
            maxsd   xmm4, xmm6
            maxsd   xmm3, xmm4
            maxsd   xmm2, xmm3
            maxsd   xmm1, xmm2
            maxsd   xmm0, xmm1
            ret