Search code examples
machine-learningclassificationdata-analysislogistic-regression

How many learning curves should I plot for a multi-class logistic regression classifier?


If we have K classes, do I have to plot K learning curves? Because it seems impossible to me to calculate the train/validation error against all K theta vectors at once.

To clarify, the learning curve is a plot of the training & cross validation/test set error/cost vs training set size. This plot should allow you to see if increasing the training set size improves performance. More generally, the learning curve allows you to identify whether your algorithm suffers from a bias (under fitting) or variance (over fitting) problem.


Solution

  • It depends. Learning curves do not concern themselves with the number of classes. Like you said, it is a plot of training set and test set error, where that error is a numerical value. This is all learning curves are.

    That error can be anything you want: accuracy, precision, recall, F1 score etc. (even MAE, MSE and others for regression).

    However, the error you choose to use is the one that does or does not apply to your specific problem, which in turn indirectly affects how you should use learning curves.

    Accuracy is well defined for any number of classes, so if you use this, a single plot should suffice.

    Precision and recall, however, are defined only for binary problems. You can somewhat generalize them (see here for example) by considering the binary problem with classes x and not x for each class x. In that case, you will probably want to plot learning curves for each class. This will also help you identify problems relating to certain classes better.

    If you want to read more about performance metrics, I like this paper a lot.