Search code examples
rustintrinsicsavx512half-precision-float

How to call _mm256_mul_ph from rust?


_mm256_mul_ps is the Intel intrinsic for "Multiply packed single-precision (32-bit) floating-point elements". _mm256_mul_ph is the intrinsic for "Multiply packed half-precision (16-bit) floating-point elements ".

I can call _mm256_mul_ps from using using use std::arch::x86_64::*;, e.g.

#[inline]
fn mul(v: __m256, w: __m256) -> __m256 {
     unsafe { _mm256_mul_ps(v, w) }
}

However, it appears to be difficult to call _mm256_mul_ph. Can one call _mm256_mul_ph in Rust?


Solution

  • What you are looking for requires the AVX-512 FP16 and AVX-512 VL instruction sets - the former of which does not appear to have any support within Rust at the moment.

    You can potentially create your own intrinsics as-needed by using the asm! macro. The assembly for _mm256_mul_ph looks like this:

    vmulph ymm, ymm, ymm
    

    So the equivalent in Rust would be written like so:

    #[cfg(target_feature = "avx2")]
    unsafe fn _mm256_mul_ph(a: __m256i, b: __m256i) -> __m256i {
        let dst: __m256i;
    
        asm!(
            "vmulph {0}, {1}, {2}",
            out(ymm_reg) dst,
            in(ymm_reg) a,
            in(ymm_reg) b,
            options(pure, nomem, nostack),
        );
    
        dst
    }
    

    To make your own intrinsics for other instructions, please ensure you follow the guidelines for Rust's inline assembly as well as careful understanding of what the instructions do. Inline assembly is unsafe and can cause very weird behavior if specified improperly.

    Caveats:

    • This only works since the underlying machinery (LLVM 17 as of Rust 1.76) does support this instruction set. If you attempt to try this method on a brand new instruction set (or with an older toolchain) it may not work and fail to compile due to an "invalid instruction".

    • __m256i is used in lieu of a __m256h type since the latter does not exist. Currently __m256i is used as a "bag of bits" (documentation's words) so you must keep track yourself that it is holding f16x16 values.

    • The target_feature = "avx2" condition is woefully inadequate for properly limiting it to targets that can actually run this function. There is no avx512fp16 target feature flag available within Rust (_mm256_mul_ph in particular needs the avx512vl target feature flag as well but that also is not supported).

      You will either need to be careful that you only compile and run this code for architectures that support it - currently Intel Sapphire Rapids CPUs as far as I can tell. Or it might be better to introduce your own compilation flag which, while not perfect, would hopefully limit the scope of things going wrong.

      If compiled and executed on an unsupported architecture, you'll receive an "illegal instruction" error (in the best case).