Search code examples
c#simdintrinsicsavx2

AVX2 computing of byte array


I'm fairly new to SIMD, but I have been experimenting with how image processing can be sped up on the CPU (I appreciate the GPU is better for this, but this is more of a learning exercise)

I wanted to perform a simple multiply add on a 8bit grayscale image, ie. an array of byte[].

I implemented both a scaler version and a SIMD version, and was surprised to see the results. The SIMD version is actually quite a bit slower.

So I suspected this might be an issue with the byte being non native to the AVX instructions, as I believe they work much faster on 32bit values? float, int etc? so I also implemented a version using int. This does show better performance, but not by much.

Also, given that images are generally stored in byte form, e.g 8bit grayscale or 32bit 8888 RGBA, it wouldn't make sense to convert the image to ints as the anything you gain in the speed up would be lost in the conversion to/from byte again. In my particular case, the output needs to be byte.

So is there a way to get better performance with the byte version? And as a second question, how would you efficiently handle the issue of the bytes overflowing? i.e is there a way to efficiently clamp it to 255 instead of rolling over?

| Method       | Mean     | Error     | StdDev    | Median   |
|------------- |---------:|----------:|----------:|---------:|
| Scalar_Bytes | 3.835 ms | 0.0766 ms | 0.1565 ms | 3.830 ms |
| Vector_Bytes | 5.351 ms | 0.0970 ms | 0.1227 ms | 5.324 ms |
| Scalar_Ints  | 3.210 ms | 0.0641 ms | 0.0811 ms | 3.200 ms |
| Vector_Ints  | 1.298 ms | 0.0259 ms | 0.0706 ms | 1.277 ms |
public class Tests
{
    const int Count = 2048 * 2048;
    private byte[] _bytes = new byte[Count];
    private int[] _ints = new int[Count];

    [GlobalSetup]
    public void Setup()
    {
        _bytes = new byte[Count];

        for (int i = 0; i < Count; i++)
        {
            _bytes[i] = (byte)i;
        }

        _ints = new int[Count];

        for (int i = 0; i < Count; i++)
        {
            _ints[i] = i;
        }
    }

    [Benchmark]
    public void Scalar_Bytes()
    {
        for (int i = 0; i < Count; i++)
        {
            _bytes[i] = (byte)((_bytes[i] * 2) + 26);
        }
    }

    [Benchmark]
    public unsafe void Vector_Bytes()
    {
        int offset = Vector256<byte>.Count;
        fixed (byte* ptr = _bytes)
        {
            var add = Vector256.Create<byte>(26);
            for (int i = 0; i < Count; i += offset)
            {
                var v = Vector256.Load<byte>(ptr + i);
                v *= 2;
                v += add;

                Vector256.Store(v, ptr + i);
            }
        }
    }

    [Benchmark]
    public void Scalar_Ints()
    {
        for (int i = 0; i < Count; i++)
        {
            _ints[i] = ((_ints[i] * 2) + 26);
        }
    }

    [Benchmark]
    public unsafe void Vector_Ints()
    {
        int offset = Vector256<int>.Count;
        fixed (int* ptr = _ints)
        {
            var add = Vector256.Create<int>(26);
            for (int i = 0; i < Count; i += offset)
            {
                var v = Vector256.Load<int>(ptr + i);
                v *= 2;
                v += add;

                Vector256.Store(v, ptr + i);
            }
        }
    }
}

internal class Program
{
    static void Main(string[] args)
    {
        BenchmarkRunner.Run<Tests>();
    }
}

Solution

  • is there a way to get better performance with the byte version?

    Yes indeed, by using proper AVX2 instructions instead of the operators defined by that struct Vector256<byte>

    Try the following version:

    /// <summary>AVX2 optimized version of <see cref="Vector_Bytes" /></summary>
    [Benchmark]
    public unsafe void vectorBytesOpt()
    {
        Vector256<byte> add = Vector256.Create( (byte)26 );
        fixed( byte* ptr = _bytes )
        {
            byte* rsiEnd = ptr + Count;
            for( byte* rsi = ptr; rsi < rsiEnd; rsi += 32 )
            {
                // Load 32 bytes
                Vector256<byte> v = Vector256.Load( rsi );
                // Multiplication by 2 is equal to adding with itself
                v = Avx2.Add( v, v );
                // Add that extra number
                v = Avx2.Add( v, add );
                // Store 32 bytes
                v.Store( rsi );
            }
        }
    }
    

    how would you efficiently handle the issue of the bytes overflowing?

    Luckily for you, AVX2 set has another instruction which adds bytes using saturation. Here’s another version which clamps your bytes to 255 instead of rolling over.

    /// <summary>Another version which uses saturation</summary>
    [Benchmark]
    public unsafe void vectorBytesOptSat()
    {
        Vector256<byte> add = Vector256.Create( (byte)26 );
        fixed( byte* ptr = _bytes )
        {
            byte* rsiEnd = ptr + Count;
            for( byte* rsi = ptr; rsi < rsiEnd; rsi += 32 )
            {
                Vector256<byte> v = Vector256.Load( rsi );
                v = Avx2.AddSaturate( v, v );
                v = Avx2.AddSaturate( v, add );
                v.Store( rsi );
            }
        }
    }
    

    Here’s the output from your test on my computer with Ryzen 7 8700G processor, when using .NET 8.0 runtime.

    | Method            | Mean        | Error     | StdDev    |
    |------------------ |------------:|----------:|----------:|
    | Scalar_Bytes      | 1,722.81 us | 32.646 us | 30.537 us |
    | Vector_Bytes      | 3,570.63 us | 26.369 us | 24.666 us |
    | vectorBytesOpt    |    43.72 us |  0.434 us |  0.406 us |
    | vectorBytesOptSat |    44.18 us |  0.238 us |  0.223 us |
    | Scalar_Ints       | 1,728.04 us | 16.708 us | 15.629 us |
    | Vector_Ints       |   256.91 us |  7.591 us | 22.023 us |
    

    Update: proper multiplication is possible, but way more computationally expensive. Here’s another version which does that, takes about 63.5 µs on my computer.

    [Benchmark]
    public unsafe void vectorBytesMulSat()
    {
        // This number needs to be < 128
        sbyte multiplier = 2;
    
        // Create a pair of constant vectors for _mm256_maddubs_epi16
        ushort mulLowScalar = (ushort)( multiplier );
        ushort mulHighScalar = (ushort)( mulLowScalar << 8 );
        Vector256<sbyte> mulLow = Vector256.Create( mulLowScalar ).As<ushort, sbyte>();
        Vector256<sbyte> mulHigh = Vector256.Create( mulHighScalar ).As<ushort, sbyte>();
        // Another constant to implement saturation of these products
        Vector256<ushort> maxProduct = Vector256.Create( (ushort)0xFF );
        // Final addition
        Vector256<byte> add = Vector256.Create( (byte)26 );
    
        fixed( byte* ptr = _bytes )
        {
            byte* rsiEnd = ptr + Count;
            for( byte* rsi = ptr; rsi < rsiEnd; rsi += 32 )
            {
                // Load 32 bytes
                Vector256<byte> v = Vector256.Load( rsi );
                // Multiply byte * sbyte, separately for even / odd bytes
                Vector256<ushort> low = Avx2.MultiplyAddAdjacent( v, mulLow ).As<short, ushort>();
                Vector256<ushort> high = Avx2.MultiplyAddAdjacent( v, mulHigh ).As<short, ushort>();
                // Saturate these products
                low = Avx2.Min( low, maxProduct );
                high = Avx2.Min( high, maxProduct );
                // Combine back into bytes
                high = Avx2.ShiftLeftLogical( high, 8 );
                v = Avx2.Or( low, high ).As<ushort, byte>();
                // Add the final constant using saturation, and store
                v = Avx2.AddSaturate( v, add );
                v.Store( rsi );
            }
        }
    }
    

    Update 2: here’s another version which also does proper multiplication and slightly faster on my computer, about 58.9 µs

    // Permutation table to fix order of bytes after _mm256_packus_epi16( even, odd )
    static ReadOnlySpan<byte> permuteBytes => new byte[ 16 ]
    {
        0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15
    };
    
    [Benchmark]
    public unsafe void vectorBytesMulSat()
    {
        // This number needs to be < 128
        sbyte multiplier = 2;
    
        // Create a pair of constant vectors for _mm256_maddubs_epi16
        ushort mulLowScalar = (ushort)( multiplier );
        ushort mulHighScalar = (ushort)( mulLowScalar << 8 );
        Vector256<sbyte> mulLow = Vector256.Create( mulLowScalar ).As<ushort, sbyte>();
        Vector256<sbyte> mulHigh = Vector256.Create( mulHighScalar ).As<ushort, sbyte>();
        // Create a vector to permute bytes after saturation
        Vector256<byte> perm;
        fixed( byte* ptr = permuteBytes )
            perm = Avx2.BroadcastVector128ToVector256( ptr );
        // Final addition
        Vector256<byte> add = Vector256.Create( (byte)26 );
    
        fixed( byte* ptr = _bytes )
        {
            byte* rsiEnd = ptr + Count;
            for( byte* rsi = ptr; rsi < rsiEnd; rsi += 32 )
            {
                // Load 32 bytes
                Vector256<byte> v = Vector256.Load( rsi );
                // Multiply byte * sbyte, separately for even / odd bytes
                Vector256<short> low = Avx2.MultiplyAddAdjacent( v, mulLow );
                Vector256<short> high = Avx2.MultiplyAddAdjacent( v, mulHigh );
                // Pack and saturate these products
                v = Avx2.PackUnsignedSaturate( low, high );
                // Fix order of bytes after the packing
                v = Avx2.Shuffle( v, perm );
                // Add the final constant using saturation, and store
                v = Avx2.AddSaturate( v, add );
                v.Store( rsi );
            }
        }
    }