Search code examples
machine-learningtensorflowclassificationmulticlass-classification

Machine learning multi-classification: Why use 'one-hot' encoding instead of a number


I'm currently working on a classification problem with tensorflow, and i'm new to the world of machine learning, but I don't get something.

I have successfully tried to train models that output the y tensor like this:

y = [0,0,1,0]

But I can't understand the principal behind it...

Why not just train the same model to output classes such as y = 3 or y = 4

This seems much more flexible, because I can imagine having a multi-classification problem with 2 million possible classes, and it would be much more efficient to output a number between 0-2,000,000 than to output a tensor of 2,000,000 items for every result.

What am I missing?


Solution

  • Ideally, you could train you model to classify input instances and producing a single output. Something like

    y=1 means input=dog, y=2 means input=airplane. An approach like that, however, brings a lot of problems:

    1. How do I interpret the output y=1.5?
    2. Why I'm trying the regress a number like I'm working with continuous data while I'm, in reality, working with discrete data?

    In fact, what are you doing is treating a multi-class classification problem like a regression problem. This is locally wrong (unless you're doing binary classification, in that case, a positive and a negative output are everything you need).

    To avoid these (and other) issues, we use a final layer of neurons and we associate an high-activation to the right class.

    The one-hot encoding represents the fact that you want to force your network to have a single high-activation output when a certain input is present.

    This, every input=dog will have 1, 0, 0 as output and so on.

    In this way, you're correctly treating a discrete classification problem, producing a discrete output and well interpretable (in fact you'll always extract the output neuron with the highest activation using tf.argmax, even though your network hasn't learned to produce the perfect one-hot encoding you'll be able to extract without doubt the most likely correct output )