Search code examples
c++multithreadingatomicmetal

Incremental by kernel thread


Let’s say that I want to increment a property var incremental:Int32 every time a kernel thread is executed:

//SWIFT
var incremental:Int32 = 0
var incrementalBuffer:MTLBuffer!
var incrementalPointer: UnsafeMutablePointer<Int32>!

init(metalView: MTKView) {
    ...
    incrementalBuffer = Renderer.device.makeBuffer(bytes: &incremental, length: MemoryLayout<Int32>.stride)
    incrementalPointer = incrementalBuffer.contents().bindMemory(to: Int32.self, capacity: 1)
}
func draw(in view: MTKView) {
    ...
    computeCommandEncoder.setComputePipelineState(computePipelineState)
    let width = computePipelineState.threadExecutionWidth
    let threadsPerGroup = MTLSizeMake(width, 1, 1)
    let threadsPerGrid = MTLSizeMake(10, 1, 1)
    computeCommandEncoder.setBuffer(incrementalBuffer, offset: 0, index: 0)
    computeCommandEncoder.dispatchThreads(threadsPerGrid, threadsPerThreadgroup: threadsPerGroup)
    computeCommandEncoder.endEncoding()
    commandBufferCompute.commit()
    commandBufferCompute.waitUntilCompleted()
    
    print(incrementalPointer.pointee)
}

//METAL
kernel void compute_shader (device int& incremental [[buffer(0)]]){
    incremental++;
}

So I expect outputs:

10
20
30
...

but I get:

1
2
3
...

EDIT: After some work based on the answer of @JustSomeGuy, Caroline from raywenderlich and one Apple Engineer I get:

[[kernel]] void compute_shader (device atomic_int& incremental [[buffer(0)]],
                                ushort lid [[thread_position_in_threadgroup]] ){

    threadgroup atomic_int local_atomic;
    if (lid==0) atomic_store_explicit(&local_atomic, 0, memory_order_relaxed);

    atomic_fetch_add_explicit(&local_atomic, 1, memory_order_relaxed);

    threadgroup_barrier(mem_flags::mem_threadgroup);

    if(lid == 0) {
        int local_non_atomic = atomic_load_explicit(&local_atomic, memory_order_relaxed);
        atomic_fetch_add_explicit(&incremental, local_non_atomic, memory_order_relaxed);
    }
}

and works as expected


Solution

  • The reason you are seeing this problem is because ++ is not atomic. It basically comes down to a code like this

    auto temp = incremental;
    incremental = temp + 1;
    temp;
    

    which means that because the threads are executed in "parallel" (it's not really true cause a number of threads forms a SIMD-group which executes in step-lock, but it's not really important here).

    Since the access is not atomic, the result is basically undefined, because there's no way to tell which thread observed which value.

    A quick fix is to use atomic_fetch_add_explicit(incremental, 1, memory_order_relaxed). This makes all accesses to incremental atomic. memory_order_relaxed here means that guarantees on the order of operations is relaxed, so this will work only if you are just adding or just subtracting from the value. memory_order_relaxed is the only memory_order supported in MSL. You can read more on this in Metal Shading Language Specification, section 6.13.

    But this quick fix is pretty bad because it's going to be slow, because access to incremental will have to be synchronized across all the threads. The other way is to use a common pattern where all threads in threadgroup update a value in threadgroup memory and then one or more of threads atomically update the device memory. So the kernel will looks something like

    kernel void compute_shader (device int& incremental [[buffer(0)]], threadgroup int& local [[threadgroup(0)]], ushort lid [[thread_position_in_threadgroup]] ){
        atomic_fetch_add_explicit(local, 1, memory_order_relaxed);
        threadgroup_barrier(mem_flags::mem_threadgroup);
        if(lid == 0) {
            atomic_fetch_add_explicit(incremental, local, memory_order_relaxed);
        }
    }
    

    Which basically means: every thread in threadgroup should add atomically 1 to local, wait until every thread is done (threadgroup_barrier) and then exactly one thread adds atomically the total local to incremental.

    atomic_fetch_add_explicit on a threadgroup variable will use threadgroup atomics instead of global atomics which should be faster.

    You can read specification I linked above to learn more, these patterns are mentioned in samples there.