Search code examples
iosxcodeshadermetalmetalkit

How define a metal shader with dynamic buffer declaration?


I have this Metal shader :

struct InVertex {
  packed_float3 pos;
  packed_uchar4 color;
};

vertex ProjectedVertex vertexShader(const device InVertex *vertexArray [[buffer(0)]],
                                    const unsigned int vertexId [[vertex_id]]){

  InVertex in = vertexArray[vertexId];
  ....

}

However I would like to make the buffer declaration "dynamic". IE I would like that my shader is able to handle buffer type like for exemple

struct InVertex1 {
  packed_float3 pos;
  packed_uchar4 color;
};

struct InVertex2 {
  float4 pos;
  flat4 color;
};

struct InVertex3 {
  float4 pos;
  float4 tangent;
  float4 color;
};

etc..

so I would like something like :

vertex ProjectedVertex vertexShader(const device ???? *vertexArray [[buffer(0)]],
                                    const unsigned int vertexId [[vertex_id]]
                                    const device int vertexType [[buffer(1)]] ){

  if vertexType = InVertex1Type {
    ... handle InVertex1 type ...
  }
  else if vertexType = InVertex2Type {  
    ... handle InVertex2 type ...
  }
  else if vertexType = InVertex3Type {  
    ... handle InVertex3 type ...
  }

}

Solution

  • The Metal programming language is a C++14-based Specification with extensions and restrictions. Taking this into account you can do the following.

    First create a header file called ShaderTypes.h:

    //  Header containing types and enum constants shared between Metal shaders and Swift/ObjC source
    #ifndef ShaderTypes_h
    #define ShaderTypes_h
    
    #ifndef __METAL_VERSION__
    /// 96-bit 3 component float vector type
    typedef struct __attribute__ ((packed)) packed_float3 {
        float x;
        float y;
        float z;
    } packed_float3;
    #endif
    
    typedef struct
    {
        packed_float3 pos;
        packed_uchar4 color;
    } InVertex1;
    
    typedef struct
    {
        vector_float4 pos;
        vector_float4 color;
    } InVertex2;
    
    typedef struct
    {
        vector_float4 pos;
        vector_float4 tangent;
        vector_float4 color;
    
    } InVertex3;
    
    enum VertexType {InVertex1Type = 0, InVertex2Type = 1, InVertex3Type = 2};
    
    typedef struct
    {
        InVertex1 InVertex1;
        InVertex2 InVertex2;
        InVertex3 InVertex3;
        VertexType vertexType;
    } dynamicStruct;
    
    #endif /* ShaderTypes_h */
    

    In your render class add the following:

        // Include header shared between C code here, which executes Metal API commands, and .metal files
        #import "ShaderTypes.h"
    
        id <MTLBuffer> _dynamicBuffer;
        // Create your dynamic buffer.
        void InitBuffer(id<MTLDevice> device)
        {
            _dynamicBuffer = [device newBufferWithLength:sizeof(dynamicStruct) options:MTLResourceStorageModeShared];
        }
        // Update your dynamic buffer.
        void UpdateBuffer()
        {
            dynamicStruct* ds = (dynamicStruct*)_dynamicBuffer.contents;
    
            ds->InVertex1.color = {0, 0, 0, 0};
            ds->InVertex2.pos = {0, 1, 1, 1};
            ds->InVertex3.tangent = {1, 1, 1, 1};
            // Select specific struct
            ds->vertexType = VertexType::InVertex2Type;
        }
    
    - (void)drawInMTKView:(nonnull MTKView *)view
    {
    
        ...
    
         // Pass your dynamic buffer to the shader.
        [renderEncoder setVertexBuffer:_dynamicBuffer offset:0 atIndex:0];
    
        ...
    }
    

    And finally in your shader file (.metal):

    // Including header shared between this Metal shader code and Swift/C code executing Metal API commands
    #import "ShaderTypes.h"
    
    vertex ProjectedVertex vertexShader(constant dynamicStruct & dynamicStruct[[ buffer(0) ]],
                                        const unsigned int vertexId [[vertex_id]])
    {
    
        InVertex1 v1;
        InVertex2 v2;
        InVertex3 v3;
    
        if(dynamicStruct.vertexType == VertexType::InVertex1Type)
        {
            v1 = dynamicStruct.InVertex1;
        }
        else if(dynamicStruct.vertexType == VertexType::InVertex2Type)
        {
            v2 = dynamicStruct.InVertex2;
        }
        else if(dynamicStruct.vertexType == VertexType::InVertex3Type)
        {
            v3 = dynamicStruct.InVertex3;
        }
      ....
    
    }