Search code examples
c#unity-game-enginedirectxhlslcompute-shader

Unity: Compute Shader to calculate closest point to each vertex


I have a mesh, and an array of points. I want to calculate, for each vertex, the index of the closest point in the array. I have a routine that works:

        for (int i=0;i<vertexPositions.Length;i++)
    {
        float minDist = 100000.0f;
        int index=0;
        float dist;
        for (int a=0;a<pointPositions.Length;a++)
        {
            dist = (vertexPositions[i] - pointPositions[a]).sqrMagnitude;
            if (dist<minDist)
            {
                minDist = dist;
                index = a;
            }
        }
        vertexParameter[i] = index;
    }

The vertexParameter array contains the desired result. This routine is very slow if there are many vertices, so I wanted to make a Compute Shader that does the exact same thing. But I'm a beginner at Compute Shaders…

This is my Compute Shader code:

#pragma kernel ClosestPoint

struct vertexData
{
    float3 position;
    int parameter;
};
struct pointData
{
    float3 position;
    float parameter;
};

RWStructuredBuffer<vertexData> vertex;
StructuredBuffer<pointData> point;


[numthreads(32, 1, 1)]
void ClosestPoint(uint3 id : SV_DispatchThreadID)
{
    int index;
    float dist;
    float minDist = 1000.0f;
    for (uint i = 0; i < point.Length; i++)
    {
        dist = distance(point[i].position, vertex[id.x].position);
        if (dist < minDist)
        {
            minDist = dist;
            index =  i;
        }

    }
    vertex[id.x].parameter =  index;
}

I don't know why, but this code gives erroneous results. The results change if I modify the ThreadGroups in the Dispatch call, so I suppose it might be due to some synchronization issues…?

In case it's needed, this is the script code that calls the shader:

        vertex = new ComputeBuffer(vertices.Length, System.Runtime.InteropServices.Marshal.SizeOf(typeof(vertexData)));
    vertex.SetData(vertices);


    point= new ComputeBuffer(points.Length, System.Runtime.InteropServices.Marshal.SizeOf(typeof(pointData)));
    point.SetData(points);

    shader.SetBuffer(kernelHandle, "vertex", vertex);
    shader.SetBuffer(kernelHandle, "point", point);
    shader.Dispatch(kernelHandle, 1, 1, 1);
    vertex.GetData(vertices);
    for (int i = 0; i < vertexParameter.Length; i++)
    {
        vertexParameter[i] = vertices[i].parameter;
    }
    vertex.Release();
    point.Release();

Solution

  • I believe that you're mistaken in the relation between threadGroups in your Dispatch() call, and [numthreads()] in your kernel specification.

    The result of shader.Dispatch(kernelHandle, vertices.Length, 1, 1); combined with [numthreads(32,1,1)] is not "many thread groups all with a single thread", it is vertices.Length thread groups, all with 32 threads.

    Your kernel will thus be invoked 32*vertices.Length times, with id.x growing commensurately... you get the correct result with the code from your comment because whatever happens when you try to read and write vertex[id.x] after id.x has gone out of bounds, it doesn't alter the fact that you have by then already computed all of the correct results and stored them in the appropriate place.

    What you need to do then in order not to waste time, is set threadGroupsX in your Dispatch() to ceil(vertices.Length/32) (pseudocode).

    You could also add something like

    if (id.x >= vertexLength) return;
    

    in your kernel (because unless you happen to have a multiple of 32 vertices, some threads will be out of bounds)... but actually, that probably wouldn't help with anything in terms of performance or safety; the reads and writes outside of vertices.Length will be essentially no-ops, while the extra branching in the kernel could incur a cost. I imagine it probably is insignificant either way in this case, and perhaps having a statement like that can make the logic clearer for human readers... but it does mean the extra boilerplate of passing in an extra uniform.

    Incidentally, you may also want to use an ASyncGPUReadbackRequest to avoid stalling your code on vertex.GetData(vertices); if that makes sense in your application. You might have written it that way in the question for brevity (which as you may notice is not always my strong point).