Search code examples
neural-networkcntk

CNTK Neural network with not one-hot-vector output (multi-class classifier)


Thank you for the CNTK Tool, the examples are running pretty fast. Since some days, I try to set up a simple network, but I dont get it. I need a network with 2 input and 3 output, for example:

|features 0.3 0.5 |labels 0.2 0.7 0.9

The output is not a one-hot-vector, the network has to learn the label-values 0.2 0.7 0.9. But most examples have a one-hot-vector as output, so it is not clear to me how to solve this. I have tried to change the tutorial with 3 classification, but it does not work, the network does not learn the output correctly. The network I have tried is:

BrainScriptNetworkBuilder = {

    SDim = 2     # feature dimension
    H1Dim = 50   # hidden dimension
    H2Dim = 50   # hidden dimension
    LDim = 3     # number of classes (labels)

    model (features) = {
        W0 = ParameterTensor {(H1Dim:SDim)}  ; b0 = ParameterTensor {H1Dim}
        W1 = ParameterTensor {(H2Dim:H1Dim)} ; b1 = ParameterTensor {H2Dim}
        W2 = ParameterTensor {(LDim:H2Dim)}  ; b2 = ParameterTensor {LDim}

        r1 = ReLU(W0 * features + b0) # hidden layer 1
        r2 = ReLU(W1 * r1       + b1) # hidden layer 2
        z =  ReLU(W2 * r2       + b2) 
    }.z

    # define inputs
    features = Input {SDim, sparse = false}
    labels   = Input {LDim, sparse = false} 

    # apply model to features
    z = model (features)

    # define criteria and output(s)
    ce  = SquareError(labels, z)  # criterion (loss)
    err = SquareError(labels, z)  # additional metric

    # connect to the system. These five variables must be named exactly like this.
    featureNodes    = (features)
    inputNodes      = (labels)
    criterionNodes  = (ce)
    evaluationNodes = (err)
    outputNodes     = (z)
}

So my question is: How to set up a network in CNTK, so that the output is not a one hot vector?

Thank you for help.


Solution

  • When your label is not a one-hot vector, squareError is a good loss function to minimize. If some examples have a one-hot label you can still user squareError. So I think you are doing everything right, you might have to just tune the learning rate to get it to work well.