I'm trying to build a stitched metal kernel. My shader code is
[[stitchable]] float add(float a, float b) {
return a + b;
}
[[stitchable]] float load(constant float *a, uint32_t index) {
return a[index];
}
[[stitchable]] void store(device float *a, float value, uint32_t index) {
a[index] = value;
}
[[visible]] void two_inputs(constant float *a, constant float *b, device *c, uint32_t tid);
The driver code is
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
id<MTLCommandQueue> queue = [device newCommandQueue];
id<MTLLibrary> library = [device newDefaultLibrary];
NSArray *functions = @[
[library newFunctionWithName:@"add"],
[library newFunctionWithName:@"load"],
[library newFunctionWithName:@"store"]
];
MTLFunctionStitchingInputNode *srcA = [[MTLFunctionStitchingInputNode alloc] initWithArgumentIndex:0];
MTLFunctionStitchingInputNode *srcB = [[MTLFunctionStitchingInputNode alloc] initWithArgumentIndex:1];
MTLFunctionStitchingInputNode *srcC = [[MTLFunctionStitchingInputNode alloc] initWithArgumentIndex:2];
MTLFunctionStitchingInputNode *srcI = [[MTLFunctionStitchingInputNode alloc] initWithArgumentIndex:3];
MTLFunctionStitchingFunctionNode *loadA = [[MTLFunctionStitchingFunctionNode alloc] initWithName:@"load" arguments:@[srcA, srcI] controlDependencies:@[]];
MTLFunctionStitchingFunctionNode *loadB = [[MTLFunctionStitchingFunctionNode alloc] initWithName:@"load" arguments:@[srcA, srcI] controlDependencies:@[]];
MTLFunctionStitchingFunctionNode *add = [[MTLFunctionStitchingFunctionNode alloc] initWithName:@"load" arguments:@[loadA, loadB] controlDependencies:@[]];
MTLFunctionStitchingFunctionNode *storeC = [[MTLFunctionStitchingFunctionNode alloc] initWithName:@"load" arguments:@[srcC, add, srcI] controlDependencies:@[]];
MTLFunctionStitchingGraph *graph = [[MTLFunctionStitchingGraph alloc] initWithFunctionName:@"two_inputs" nodes:@[loadA, loadB, add] outputNode:storeC attributes:@[]];
MTLStitchedLibraryDescriptor *graphDescriptor = [MTLStitchedLibraryDescriptor new];
graphDescriptor.functions = functions;
graphDescriptor.functionGraphs = @[graph];
NSError *error = NULL;
id<MTLLibrary> graphLibrary = [device newLibraryWithStitchedDescriptor: graphDescriptor.functions error:&error];
NSLog(@"%@", error);
This is causing the metal compiler to fail with the error.
Compiler failed with XPC_ERROR_CONNECTION_INTERRUPTED
MTLCompiler: Compilation failed with XPC_ERROR_CONNECTION_INTERRUPTED on 1 try
...
Error Domain=MTLLibraryErrorDomain Code=3 "Compiler encountered an internal error" UserInfo={NSLocalizedDescription=Compiler encountered an internal error}
I'm trying to run this on an M1 mac.
It turns out the reason the Metal compiler was crashing was this line
MTLFunctionStitchingGraph *graph = [[MTLFunctionStitchingGraph alloc] initWithFunctionName:@"two_inputs" nodes:@[loadA, loadB, add] outputNode:storeC attributes:@[]];
should be
MTLFunctionStitchingGraph *graph = [[MTLFunctionStitchingGraph alloc] initWithFunctionName:@"two_inputs" nodes:@[loadA, loadB, add, storeC] outputNode:NULL attributes:@[]];
instead.