Search code examples
pythontensorflowmachine-learningkerasdeep-learning

Deep learning models yielding high training accuracy but poor performance on testing data in binary text classification


I've been encountering a perplexing issue while working on a binary text classification task. Despite experimenting with numerous deep learning models, including various architectures and hyperparameters, I consistently observe high training accuracy, typically ranging from 97% to 99%. However, when I evaluate these models on unseen testing data, their performance significantly deteriorates.

In an attempt to address this issue, I decided to explore machine learning models as an alternative approach. Surprisingly, models such as Random Forest yielded comparable or even superior performance to deep learning models, achieving around 97% accuracy on both training and testing data. Subsequently, I experimented with several other machine learning algorithms, and Logistic Regression emerged as the most suitable option for my specific use case.

Despite these findings, I remain puzzled as to why the deep learning models, despite exhibiting impressive training accuracy, fail to generalize well to unseen data. Could someone shed light on potential reasons behind this discrepancy? Are there common pitfalls or considerations specific to deep learning that I might be overlooking? Any insights or suggestions would be greatly appreciated.


Solution

  • The problem you are facing is probably overfitting. overfitting occurs when an algorithm fits too closely or even exactly to its training data, resulting in a model that can’t make accurate predictions or conclusions from any unseen data, you can learn more about it here.

    The reason why other models are working better may be due to the fact that deep learning models can be prone to overfitting, especially when the model is too complex or the amount of training data is insufficient. Random Forest, with its ensemble approach, inherently has a regularization effect which can help prevent overfitting.

    Probable solutions to your problem could be things like:

    1. collecting more data: this could help with generalization
    2. simplifying your model: you can try by less number of layers or less complex architecture
    3. regularization: there are plenty of methods like weight or feature regularization

    there are other ways to avoid overfitting you can search and see which one is best for your project