Search code examples
iosswiftneural-networkconv-neural-networkmetal

MPSCNN Weight Ordering


The Metal Performance Shader framework provides support for building your own Convolutional Neural Nets. When creating for instance an MSPCNNConvolution it requires a 4D weight tensor as init parameter that is represented as a 1D float pointer.

init(device: MTLDevice,
  convolutionDescriptor: MPSCNNConvolutionDescriptor,
  kernelWeights: UnsafePointer<Float>,
  biasTerms: UnsafePointer<Float>?,
  flags: MPSCNNConvolutionFlags)

The documentation has this to say about the 4D tensor

The layout of the filter weight is arranged so that it can be reinterpreted as a 4D tensor (array) weight[outputChannels][kernelHeight][kernelWidth][inputChannels/groups]

Unfortunately that information doesn't really tell me how to arrange a 4D array into a one dimensional Float pointer.

I tried ordering the weights like the BNNS counterpart requires it, but without luck.

How do I properly represent the 4D tensor (array) as a 1D Float pointer (array)?

PS: I tried arranging it like a C array and getting the pointer to the flat array, but it didn't work.

UPDATE

@RhythmicFistman: That's how I stored it in a plain array, which I can convert to a UsafePointer<Float> (but doesn't work):

var output = Array<Float>(repeating: 0, count: weights.count)

for o in 0..<outputChannels {
    for ky in 0..<kernelHeight {
        for kx in 0..<kernelWidth {
            for i in 0..<inputChannels {
                let offset = ((o * kernelHeight + ky) * kernelWidth + kx) * inputChannels + i
                output[offset] = ...
            }
        }
    }
}

Solution

  • Ok so I figured it out. Here are the 2 python functions I use to reform my convolutions and fully connected matrices

    # shape required for MPSCNN [oC kH kW iC]
    # tensorflow order is [kH kW iC oC]
    def convshape(a):
        a = np.swapaxes(a, 2, 3)
        a = np.swapaxes(a, 1, 2)
        a = np.swapaxes(a, 0, 1)
        return a
    
    # fully connected only requires a x/y swap
    def fullshape(a):
        a = np.swapaxes(a, 0, 1)
        return a