Search code examples
.netvectorizationsimdsse

Find index of unaligned int or long in byte array using SIMD


I have a byte sequence that I want to scan to find index of an integer (or long) value. It can be at any byte offset, not necessarily a multiple of the size. Specifically I am interested in first occurence but an example for all indexes will also be helpful.

If it's not possible I guess I need to convert long into Vector<byte> of 8 byte length than compare two.

Platform is X86. I can constrain app to run only x64 mode.
I need fastest possible way so a code snippet would be great.
I know its an easy question but couldn't find an example (in C# at least).


Solution

  • After reading this blog post i've found a way to do what i want to achieve. I knew i will already have enough junk at the end so i omited to check rest but it can be modified to cover all possible inputs.

            static void Main(string[] args)
        {
            var input = "12345671asdasdasd1asdasdasd2asdasdasd3asdasdasd_12345678asdasdasd1asdasdasd2asdasdasd3asdasdasd_"u8;
            var needle = BitConverter.ToInt64("12345678"u8);
            var ix = IndexOf(input, needle);
    
        }
    
    public unsafe static int IndexOf(ReadOnlySpan<byte> input, long needle)
        {
            fixed (byte* pInput = input)
            {
                int n = input.Length;
                var vecSearch = Avx2.BroadcastScalarToVector256(&needle);
                for (int i = 0; i <= n - 40; i += 32)
                    for (int j = i; j < i + 8; j++)
                    {
                        var vecInput = Avx2.LoadVector256((long*)(pInput + j));
                        var mask = Avx2.CompareEqual(vecInput, vecSearch); //
                        var imask = Avx2.MoveMask(mask.AsByte());
                        if (imask != 0)
                        {
                            return j + (int)Bmi1.TrailingZeroCount((uint)imask);
                        }
                    }
            }
            return -1;
        }
    

    Here is 32bit version:

    public unsafe static int IndexOf32(ReadOnlySpan<byte> input, long needle)
    {
        int n = input.Length;
        var vecSearch32 = Avx2.BroadcastScalarToVector256((int*)&needle);
        var vecSearch64 = Avx2.BroadcastScalarToVector256(&needle);
        fixed (byte* pInput = input)
        {
    
            for (int i = 0; i <= n - 44; i += 32)
                for (int j = i; j < i + 4; j++)
                {
                    var vecInput = Avx.LoadVector256((int*)(pInput + j));
                    var mask = Avx2.CompareEqual(vecInput, vecSearch32);
                    var imask = Avx2.MoveMask(mask.AsByte());
                    if (imask != 0)
                    {
                        var mask1 = Avx2.CompareEqual(vecInput.AsInt64(), vecSearch64).AsInt32();
                        var mask2 = Avx2.CompareEqual(Avx.LoadVector256((long*)(pInput + j + 4)), vecSearch64).AsInt32();
                        var blend = Avx2.Blend(mask1, mask2, 0xaa);
                        imask = Avx2.MoveMask(blend.AsByte());
                        if (imask != 0)
                            return j + (int)Bmi1.TrailingZeroCount((uint)imask);
                    }
                }
        }
        return -1;
    }
    

    Here is 16bit version which i couldnt complete:

    public int AVXShort()
    {
        int n = input.Length;
        fixed (byte* pInput = input)
        fixed (byte* pSearch16 = search16)
        fixed (byte* pSearch64 = search)
        {
            var vecSearch16 = Avx.LoadVector256((short*)pSearch16);
            var vecSearch64 = Avx.LoadVector256((long*)pSearch64);
            for (int i = 0; i <= n - 46; i += 32)
                for (int j = i; j < i + 2; j++)
                {
                    var vecInput = Avx.LoadVector256((short*)(pInput + j));
                    var mask = Avx2.CompareEqual(vecInput, vecSearch16);
                    var imask = Avx2.MoveMask(mask.AsByte());
                    if (imask != 0)
                    {
                        var mask1 = Avx2.CompareEqual(vecInput.AsInt64(), vecSearch64).AsInt16();
                        var mask2 = Avx2.CompareEqual(Avx.LoadVector256((long*)(pInput + j + 2)), vecSearch64).AsInt16();
                        var mask3 = Avx2.CompareEqual(Avx.LoadVector256((long*)(pInput + j + 4)), vecSearch64).AsInt16();
                        var mask4 = Avx2.CompareEqual(Avx.LoadVector256((long*)(pInput + j + 6)), vecSearch64).AsInt16();
                        var res1 = Avx2.Blend(mask1, mask2, 0xa).AsInt16();
                        
                        var res2 = Avx2.Blend(mask3, mask4, 0xa0).AsInt16();
                        var res4 = Avx2.Blend(res1, res2, 0x0);
                        imask = Avx2.MoveMask(res1.AsByte());
                        if (imask != 0)
                            return j + (int)Bmi1.TrailingZeroCount((uint)imask);
                    }
                }
        }
        return -1;
    }