Example of topic classification in text documents

This example shows how to balance the text data before to train a classifier.

Note that for this example, the data are slightly imbalanced but it can happen that for some data sets, the imbalanced ratio is more significant.

# Authors: Guillaume Lemaitre <[email protected]>
# License: MIT

from collections import Counter

from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import make_pipeline

from imblearn.under_sampling import RandomUnderSampler
from imblearn.pipeline import make_pipeline as make_pipeline_imb
from imblearn.metrics import classification_report_imbalanced

print(__doc__)

Setting the data set

We use a part of the 20 newsgroups data set by loading 4 topics. Using the scikit-learn loader, the data are split into a training and a testing set.

Note the class #3 is the minority class and has almost twice less samples than the majority class.

categories = ['alt.atheism', 'talk.religion.misc',
              'comp.graphics', 'sci.space']
newsgroups_train = fetch_20newsgroups(subset='train',
                                      categories=categories)
newsgroups_test = fetch_20newsgroups(subset='test',
                                     categories=categories)

X_train = newsgroups_train.data
X_test = newsgroups_test.data

y_train = newsgroups_train.target
y_test = newsgroups_test.target

print('Training class distributions summary: {}'.format(Counter(y_train)))
print('Test class distributions summary: {}'.format(Counter(y_test)))

Out:

Training class distributions summary: Counter({2: 593, 1: 584, 0: 480, 3: 377})
Test class distributions summary: Counter({2: 394, 1: 389, 0: 319, 3: 251})

The usual scikit-learn pipeline

You might usually use scikit-learn pipeline by combining the TF-IDF vectorizer to feed a multinomial naive bayes classifier. A classification report summarized the results on the testing set.

As expected, the recall of the class #3 is low mainly due to the class imbalanced.

pipe = make_pipeline(TfidfVectorizer(), MultinomialNB())
pipe.fit(X_train, y_train)
y_pred = pipe.predict(X_test)

print(classification_report_imbalanced(y_test, y_pred))

Out:

pre       rec       spe        f1       geo       iba       sup

          0       0.67      0.94      0.86      0.79      0.90      0.82       319
          1       0.96      0.92      0.99      0.94      0.95      0.90       389
          2       0.87      0.98      0.94      0.92      0.96      0.92       394
          3       0.97      0.36      1.00      0.52      0.60      0.33       251

avg / total       0.87      0.84      0.94      0.82      0.88      0.78      1353

Balancing the class before classification

To improve the prediction of the class #3, it could be interesting to apply a balancing before to train the naive bayes classifier. Therefore, we will use a RandomUnderSampler to equalize the number of samples in all the classes before the training.

It is also important to note that we are using the make_pipeline function implemented in imbalanced-learn to properly handle the samplers.

pipe = make_pipeline_imb(TfidfVectorizer(),
                         RandomUnderSampler(),
                         MultinomialNB())

pipe.fit(X_train, y_train)
y_pred = pipe.predict(X_test)

Although the results are almost identical, it can be seen that the resampling allowed to correct the poor recall of the class #3 at the cost of reducing the other metrics for the other classes. However, the overall results are slightly better.

print(classification_report_imbalanced(y_test, y_pred))

Out:

pre       rec       spe        f1       geo       iba       sup

          0       0.68      0.90      0.87      0.78      0.89      0.79       319
          1       0.98      0.84      0.99      0.90      0.91      0.82       389
          2       0.94      0.89      0.98      0.91      0.93      0.86       394
          3       0.80      0.73      0.96      0.76      0.83      0.68       251

avg / total       0.87      0.85      0.95      0.85      0.90      0.80      1353

Total running time of the script: ( 0 minutes 29.402 seconds)

Gallery generated by Sphinx-Gallery