Search code examples
kerasdeep-learningnlplstmtext-classification

Why my LSTM for Multi-Label Text Classification underperforms?


I'm using Windows 10 machine. Libraries: Keras with Tensorflow 2.0 Embeddings:Glove(100 dimensions)

I am trying to implement an LSTM architecture for multi-label text classification.

My problem is that no matter how much fine-tuning I do, the results are really bad.

I am not experienced in DL practical implementations that's why I ask for your advice.

Below I will state basic information about my dataset and my model so far.

I can't embed images since I am a new member so they appear as links.

Dataset form+Embedings form+train-test-split form

Dataset's labels distribution

My Implementation of LSTM

Model's Summary

Model's Accuracy plot

Model's Loss plot

As you can see my dataset is really small (~6.000 examples) and maybe that's one reason why I cannot achieve better results. Still, I chose it because it's unbiased.

  1. I'd like to know if there is any fundamental mistake in my code regarding the dimensions, shape, activation functions, and loss functions for multi-label text classification?

  2. What would you recommend to achieve better results on my model? Also any general advice regarding optimizing, methods,# of nodes, layers, dropouts, etc is very welcome.

Model's best val accuracy that I achieved so far is ~0.54 and even if I tried to raise it, it seems stuck there.


Solution

  • There are many ways to get this wrong but the most common mistake is to get your model overfit the training data. I suspect that 0.54 accuracy means that your model selects the most common label (offensive) for almost all cases.

    So, consider one of these simple solutions:

    • Create balanced training data: like 400 samples from each class.
    • or sample balanced batches for training (exactly the same number of labels on each training batch)

    In addition to tracking accuracy and loss, look at precision-recall-f1 or even better try plotting area under curve, maybe different classes need different thresholds of activation. (If you are using Sigmoid on last layer maybe one class could perform better with 0.2 activations and another class with 0.7)