Search code examples
metal

How to pass a float array to metal kernel function?


I don't want to use the texture1d_array. Can I simply pass an array of float? I will write into it inside my kernel function.


Solution

  • In order to write to an array of float inside a kernel function, you need to supply a buffer argument to the kernel. The parameter should have the type device float * and be attributed with a buffer attribute, specifying which argument table slot it will occupy:

    kernel void my_kernel(device float *data [[buffer(0)]],
                          uint threadIndex [[thread_position_in_grid]])
    {
        data[threadIndex] = /* calculate value for this element */;
    }
    

    To create such a buffer in your app code, request it to be allocated by your Metal device:

    let buffer = device.makeBuffer(length: MemoryLayout<Float>.stride * dataCount,
                                   options: [])!
    

    On the Mac, you might want to create the buffer with the .storageModeManaged option, which will not automatically synchronize the values you write from your kernel back to CPU-readable memory. You can use a blit encoder and the synchronize(resource:) method to copy back from GPU memory. On iOS, managed buffers don't exist and no synchronization is required apart from the usual (ensuring that you're never reading from the same location someone else is writing).

    When you're ready to dispatch your compute work, bind the buffer as an argument of your compute command encoder:

    computeCommandEncoder.setBuffer(buffer, offset: 0, index: 0)
    

    Dispatch whatever size of grid makes sense to get the job done. Encode any other work you might need to do (including synchronization commands) in your command buffer, commit the command buffer, and ensure it's completed before attempting to read the contents of the buffer.

    To read the contents of the buffer, cast the contents of the buffer to an UnsafeMutableBufferPointer of the appropriate type, which allows you to treat the buffer just like any other Sequence:

    let data = UnsafeMutableBufferPointer<Float>(start: buffer.contents().assumingMemoryBound(to: Float.self),
                                                 count:dataCount)
    // iterate over elements of data or whatever...