Search code examples
metal

How to pass scalar parameter to a metal kernel function?


I am new to metal. I want to use metal compute to do some math, so I create a kernel function (shader?), let's say

    kernel void foo(device float *data1,
                    device float *data2,
                    device float *result,
                    int flag,
                    uint index [[thread_position_in_grid]])
    {
      if(flag==SOMETHING)
      {
      }...
    }

Any idea to encode a scalar value to the flag parameter in MTLComputeCommandEncoder?


Solution

  • You are already doing it. There isn't much difference between a void* buffer with "arbitrary" data and an int.

    Juse make the binding a device or constant (since it's a flag I would assume constant is more suitable) address space reference and decorate if with [[ buffer(n) ]] attribute for better readability (and other buffer bindings also), so your new function signature is gonna look like

    kernel void foo(device float *data1 [[buffer(0)]],
                    device float *data2 [[buffer(1)]],
                    device float *result [[buffer(2)]],
                    device int& flag [[buffer(3)]],
                    uint index [[thread_position_in_grid]])
    

    As for the encoder, you can use setBuffer or setBytes on your MTLComputeCommandEncoder but basically, the easiest way to do this would be

    id<MTLComputeCommandEncoder> encoder = ...
    // ...
    int flag = SomeFlag | SomeOtherFlag
    [encoder setBytes:&flag length:sizeof(flag) atIndex:3];