Search code examples
c#simdintrinsicsavx

AVX2 consuming bytes whilst producing uints?


Slowly learning SIMD, but there are still some aspects that I cannot wrap my head around when trying to come up with SIMD solutions to a problem. One of those being when the input is smaller than the output.

As an example, lets say I have an 8 bit grayscale image. i.e each pixel is a byte in the range 0-255. And I now want to convert that to a pre-multiplied alpha image with a specified colour. So the input is an 8bit array (8 bit per pixel), but the output is a 32bit array (32 bit per pixel RGBA_8888).

So the arrays are not one to one. One byte in the grayscale array will be converted to 4 bytes in the colour array .

In scalar form, this would look like this :

 public class Test
 {
     const int ImageSize = 2048;
     const int ImageLength = ImageSize * ImageSize;
     private byte[] _bytesGray = new byte[ImageLength];
     private uint[] _pixelsRGBA = new uint[ImageLength];

     private const byte _colorR = 0xFF;
     private const byte _colorG = 0x01;
     private const byte _colorB = 0x02;
     private const byte _colorA = 0xFF;

     [GlobalSetup]
     public void Setup()
     {
         for (int i = 0; i < ImageLength; i++)
         {
             _bytesGray[i] = (byte)(i + 1);
             _pixelsRGBA[i] = 0;
         }
     }

     [Benchmark]
     public unsafe void GrayscaleToColor_Scalar()
     {
         fixed (byte* bytePtr = _bytesGray)
         fixed (uint* pixelPtr = _pixelsRGBA)
         {
             for (int i = 0; i < ImageLength; ++i)
             {
                 byte value = bytePtr[i];
                 byte r = (byte)((value * _colorR) >> 8);
                 byte g = (byte)((value * _colorG) >> 8);
                 byte b = (byte)((value * _colorB) >> 8);
                 byte a = (byte)((value * _colorA) >> 8);

                 pixelPtr[i] = (uint)(r << 24 | g << 16 | b << 8 | a);
             }
         }
     }
}

To process this in SIMD form, my thinking is that I want to process the _bytesGray array so I can take full advantage of the 256 vector register.

fixed (byte* valueBytes = _bytesGray)
{
    for (int i = 0; i < ImageLength; i += 32)
    {
        ...
    }
}

But as the input pixel is a byte and the output pixel is a uint, I think I would then need 4x Vectors. Where each vector contains 8 of the 32 bytes. But each byte would then be duplicated 4 times in each uint. At which point I could then do my multiply

fixed (byte* valueBytes = _bytesGray)
{
    for (int i = 0; i < ImageLength; i += 32)
    {
        Vector256<byte> bytes = Avx2.LoadVector256(valueBytes  + i);

        Vector256<uint> _0_8_grayBytes = ... // get 0-8 bytes and splat each byte to fill uint
        Vector256<uint> _8_16_grayBytes =  ... // get 8-16 bytes and splat each byte to fill uint
        Vector256<uint> _16_24_grayBytes =  ... // get 16-24 bytes and splat each byte to fill uint 
        Vector256<uint> _24_32_grayBytes =  ... // get 24-32 bytes and splat each byte to fill uint
    }
}

But I can't seem to figure out how to express what I want in instructions to do or even if its the right approach.

How would you go about doing this?


Solution

  • I would do it like that.

    The main idea is loading 8 input bytes at a time, using the broadcast load instruction. Then selectively move and zero out bytes to expand these 8 bytes into 16 ushort numbers with the numbers scaled by 0x100 and duplicated. Then use 16-bit integer multiplications in two vectors, finally combine into a 32-bytes vector and store.

    /// <summary>Create a pair of multipliers for _mm256_mulhi_epu16 instruction</summary>
    static Vector256<ushort> makeMultipliers( byte low, byte high )
    {
        // Compute multipliers with uint division
        uint a = low;
        uint b = high;
        a = ( a * 0x8000u + 254u ) / 255u;
        b = ( b * 0x8000u + 254u ) / 255u;
        // Create the pattern as uint scalar
        uint s = a | ( b << 16 );
        // Broadcast the uint, and bit cast the vector to short lanes
        return Vector256.Create( s ).AsUInt16();
    }
    
    // Permutation table which transform 4 bytes into 0A0A0B0B0C0C0D0D sequence of 16 bytes
    static ReadOnlySpan<byte> permBytes => new byte[ 32 ]
    {
        // First 16 byte slice uses bytes [ 0 .. 3 ]
        0xFF, 0, 0xFF, 0,  0xFF, 1, 0xFF, 1,  0xFF, 2, 0xFF, 2,  0xFF, 3, 0xFF, 3,
        // The second slice uses bytes [ 4 .. 7 ]
        0xFF, 4, 0xFF, 4,  0xFF, 5, 0xFF, 5,  0xFF, 6, 0xFF, 6,  0xFF, 7, 0xFF, 7,
    };
    
    [Benchmark]
    public unsafe void GrayscaleToColor_Simd()
    {
        Vector256<ushort> mul0 = makeMultipliers( _colorA, _colorG );
        Vector256<ushort> mul1 = makeMultipliers( _colorB, _colorR );
    
        Vector256<byte> perm;
        fixed( byte* ptr = permBytes )
            perm = Avx2.LoadVector256( ptr );
        Vector256<byte> blendMask = Vector256.Create( (ushort)0xFF00 ).AsByte();
    
        fixed( byte* bytePtr = _bytesGray )
        fixed( uint* pixelPtr = _pixelsRGBA )
        {
            byte* bytePtrEnd = bytePtr + ImageLength;
            byte* rsi = bytePtr;
            uint* rdi = pixelPtr;
            while( rsi < bytePtrEnd )
            {
                // Broadcast 8 bytes from memory
                Vector256<byte> grey32 = Avx2.BroadcastScalarToVector256( (ulong*)rsi ).AsByte();
                rsi += 8;
                // Duplicate bytes, scaling ushort numbers by the factor of 0x100
                Vector256<ushort> grey16 = Avx2.Shuffle( grey32, perm ).AsUInt16();
    
                // Multiply 16-bit numbers, keeping the higher 16 bits of the products
                Vector256<ushort> low = Avx2.MultiplyHigh( grey16, mul0 );
                Vector256<ushort> high = Avx2.MultiplyHigh( grey16, mul1 );
                // Shift bits into correct location,
                // which is low byte for even lanes, high byte for odd ones
                low = Avx2.ShiftRightLogical( low, 7 );
                high = Avx2.ShiftLeftLogical( high, 1 );
    
                // Combine into a single vector, and store
                Avx2.BlendVariable( low.AsByte(), high.AsByte(), blendMask )
                    .AsUInt32().Store( rdi );
                rdi += 8;
            }
        }
    }
    

    By the way, you have an interesting numerical issue in your formula. When you do (byte)((value * _colorR) >> 8); you will never get 255 on output. Even when both input bytes are 255, you will only get 254 in the result.

    That’s why I have used different math there. The rounding is less than ideal in my version because the optimal one would require twice as many multiplications in the loop: first to apply the input scaling, another one to divide by 255. Still, at least 0xFF input scaling seems to result in the original color.

    Update: The idea behind the algorithm is following. The complete expression we need to compute for these numbers is grey * c / 255 which can be viewed as grey * mul where mul is a fractional number ≤ 1.0. Instead of doing floating-point math, my implementation computes that expression with a few cheap integer math instructions.

    The _mm256_mulhi_epu16 instruction computes the following expression for each pair of ushort numbers: ( a * b ) >> 16

    Note that instruction can’t scale input numbers by 100% because doing so would require b to be 0x10000 however that number doesn’t fit in ushort, it requires at least 17 bits. The next best thing is scaling by up to 50%, because a 50% scaling would require b = 0x8000 which fits in the ushort lanes just fine, and then using bits [ 7 .. 14 ] of the output ushort lanes.

    The makeMultipliers function computes two of these scaling coefficients and makes a vector of [ c0, c1 ] ushort numbers replicated over the complete vector. We call it twice because we need 4 of them, one of each channel of the output image.

    Overall, my version does the following math to these bytes:

    static byte computeColor( byte grey, byte component )
    {
        // Computed by makeMultipliers outside of the loop
        uint c32 = component;
        c32 = ( c32 * 0x8000u + 254u ) / 255u;
        ushort c16 = (ushort)c32;
    
        // Computed by moving bytes by 1 with Avx2.Shuffle, which inserts zero bytes
        ushort grey16 = (ushort)( (ushort)grey << 8 );
    
        // Computed by Avx2.MultiplyHigh
        ushort res = (ushort)( ( (uint)c16 * grey16 ) >> 16 );
    
        // Computed by bitwise shifts and byte extracts
        return (byte)( res >> 7 );
    }