I have trained a simple Pytorch neural network on some data, and now wish to test and evaluate it using metrics like accuracy, recall, f1 and precision. I searched the Pytorch documentation thoroughly and could not find any classes or functions for these metrics. I then tried converting the predicted labels and the actual labels to numpy arrays and using scikit-learn's metrics, but the predicted labels don't seem to be either 0 or 1 (my labels), but instead continuous values. Because of this scikit-learn metrics don't work. Fast.ai documentation didn't make much sense either, I could not understand which class to inherit for precision etc (although I was able to calculate accuracy). Any help would be much desperately appreciated.
Usually, in a binary classification setting, your neural network will output the probability that the event occurs (e.g., if you are using sigmoid activation and a single neuron at the output layer), which is a continuous value between 0 and 1. To evaluate precision and recall of your model (e.g., with scikit-learn's precision_score
and recall_score
), it is required that you convert the probability of your model into binary value. This is achieved by specifying a threshold value for your model's probability. (For a overview about threshold, please take a look at this reference: https://developers.google.com/machine-learning/crash-course/classification/thresholding)
Scikit-learn's precision_recall_curve
(https://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html) is commonly used to understand how precision and recall metrics behave for different probability thresholds. By analysing the precision and recall values per threshold, you will be able to specify the best threshold for your problem (you may want higher precision, so you will aim for higher thresholds, e.g., 90%; or you may want to have a balanced precision and recall, and you will need to check the threshold that returns the best f1 score for your problem). A good overview on the topic may be found in the following reference: https://machinelearningmastery.com/threshold-moving-for-imbalanced-classification/
I hope this may be of help.