Search code examples
iostensorflowimage-segmentationtensorflow-litesemantic-segmentation

How to use segmentation model output tensor?


I'm trying to run the segmentation model on iOS and I have several questions about how I should properly use the output tensor.

Here the link on the model I'm using: https://www.tensorflow.org/lite/models/segmentation/overview

When I run this model I'm getting the output tensor with dimension: 1 x 257 x 257 x 21. Why I get 21 as the last dimension? It looks like for each pixel we are getting the class scores. Do we need to find argmax here to get the correct class value?

But why only 21 classes? I was thinking it should contain more. And where I can find the info which value corresponds to a certain class. In ImageClassification example we have a label.txt with 1001 classes.

Based on ImageClassification example I did an attempt to parse the tensor: firstly transform it to Float array of size 1 387 029 (21 x 257 x 257) and then using the following code I'm creating an image pixel by pixel:

    // size = 257
    // depth = 21
    // array - float array of size 1 387 029
    for i in 0..<size {
        for j in 0..<size {
            var scores: [Float] = []
            for k in 0..<depth {
                let index = i * size * depth + j * depth + k
                let score = array[index]
                scores.append(score)
            }
            if let maxScore = scores.max(),
                let maxClass = scores.firstIndex(of: maxScore) {
                let index = i * size + j

                if maxClass == 0 {
                    pixelBuffer[index] = .blue
                } else if maxClass == 12 {
                    pixelBuffer[index] = .black
                } else {
                    pixelBuffer[index] = .green
                }
            }
        }
    }

Here the result I get:

enter image description here

You can see that quality is not really good. What have I missed?

The segmentation model for CoreML(https://developer.apple.com/machine-learning/models/) works much better on the same example:

enter image description here


Solution

  • It seems like your model was trained on PASCAL VOC data that has 21 classes for segmentation.
    You can find a list of the classes here:

    background
    aeroplane
    bicycle
    bird
    boat
    bottle
    bus
    car
    cat
    chair
    cow
    diningtable
    dog
    horse
    motorbike
    person
    pottedplant
    sheep
    sofa
    train
    tvmonitor