2

Lets consider a simple reduction, such as a dot product:

pub fn add(a:&[f32], b:&[f32]) -> f32 {
    a.iter().zip(b.iter()).fold(0.0, |c,(x,y)| c+x*y))
}

Using rustc 1.68 with -C opt-level=3 -C target-feature=+avx2,+fma I get

.LBB0_5:
        vmovss  xmm1, dword ptr [rdi + 4*rsi]
        vmulss  xmm1, xmm1, dword ptr [rdx + 4*rsi]
        vmovss  xmm2, dword ptr [rdi + 4*rsi + 4]
        vaddss  xmm0, xmm0, xmm1
        vmulss  xmm1, xmm2, dword ptr [rdx + 4*rsi + 4]
        vaddss  xmm0, xmm0, xmm1
        vmovss  xmm1, dword ptr [rdi + 4*rsi + 8]
        vmulss  xmm1, xmm1, dword ptr [rdx + 4*rsi + 8]
        vaddss  xmm0, xmm0, xmm1
        vmovss  xmm1, dword ptr [rdi + 4*rsi + 12]
        vmulss  xmm1, xmm1, dword ptr [rdx + 4*rsi + 12]
        lea     rax, [rsi + 4]
        vaddss  xmm0, xmm0, xmm1
        mov     rsi, rax
        cmp     rcx, rax
        jne     .LBB0_5

which is a scalar implementation with loop unrolling, not even contracting the mul+add into FMAs. From this code to simd code should be easy, why does rustc not optimize this?

If I replace f32 with i32 I get the desired auto-vectorization:

.LBB0_5:
        vmovdqu ymm4, ymmword ptr [rdx + 4*rax]
        vmovdqu ymm5, ymmword ptr [rdx + 4*rax + 32]
        vmovdqu ymm6, ymmword ptr [rdx + 4*rax + 64]
        vmovdqu ymm7, ymmword ptr [rdx + 4*rax + 96]
        vpmulld ymm4, ymm4, ymmword ptr [rdi + 4*rax]
        vpaddd  ymm0, ymm4, ymm0
        vpmulld ymm4, ymm5, ymmword ptr [rdi + 4*rax + 32]
        vpaddd  ymm1, ymm4, ymm1
        vpmulld ymm4, ymm6, ymmword ptr [rdi + 4*rax + 64]
        vpmulld ymm5, ymm7, ymmword ptr [rdi + 4*rax + 96]
        vpaddd  ymm2, ymm4, ymm2
        vpaddd  ymm3, ymm5, ymm3
        add     rax, 32
        cmp     r8, rax
        jne     .LBB0_5
Peter Cordes
  • 328,167
  • 45
  • 605
  • 847
Unlikus
  • 1,419
  • 10
  • 24

1 Answers1

5

This is because floating points are not associative, meaning in general a+(b+c) != (a+b)+c. So summing up floating points becomes are serial task, because the compiler will not reorder ((a+b)+c)+d into (a+b)+(c+d). The last can be vectorized, the first cannot.

In most cases the programmer does not care about the differences in summing order.

gcc and clang provide the -fassociative-math flag which will allow the compiler to reorder floating point operations for performance.

rustc does not provide this and for all I know llvm also does not accept flags which will change this behavior.

In nightly Rust you can use #![feature(core_intrinsics)] to get the optimization:

#![feature(core_intrinsics)]
pub fn add(a:&[f32], b:&[f32]) -> f32 {
    unsafe {
        a.iter().zip(b.iter()).fold(0.0, |c,(x,y)| std::intrinsics::fadd_fast(c,x*y))
    }
}

This does not use fma. So for fma you have to use:

#![feature(core_intrinsics)]
pub fn add(a:&[f32], b:&[f32]) -> f32 {
    unsafe {
        a.iter().zip(b.iter()).fold(0.0, |c,(&x,&y)| std::intrinsics::fadd_fast(c,std::intrinsics::fmul_fast(x,y)))
    }
}

I am not aware of a stable Rust solution, which does not involve explicit simd intrinsics.

Unlikus
  • 1,419
  • 10
  • 24
  • I was surprised Rust doesn't seem to have any `-ffast-math` equivalent to allow such optimizations. https://docs.rs/fast_fp/latest/fast_fp/ translates to C and builds part of your code with `clang` (presumably `-O3 -ffast-math -march=whatever`). In Rust itself, people suggest that "unsafe" math optimizations could potentially violate memory-safety. (Perhaps if you used FP math to compute an array index? Or indirectly via comparing a result.) So people have suggested that's why Rust in general doesn't have a fast-math option. – Peter Cordes Apr 19 '23 at 14:38
  • 1
    Oh I see, `std::intrinsics::fadd_fast` (https://doc.rust-lang.org/std/intrinsics/fn.fadd_fast.html) does an FP add which the compiler is allowed to pretend is associative. That's cool. Is there no option to enable contraction of `a*b + c` into `fma(a,b,c)`, like `clang -ffp-contract=fast` (or just `on` for within a single expression)? That's an arguably weird feature of C that's treated as being strictly IEEE-compliant, separate from fast-math stuff. – Peter Cordes Apr 19 '23 at 14:42