.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/ensemble/plot_comparison_ensemble_classifier.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_ensemble_plot_comparison_ensemble_classifier.py: ============================================= Compare ensemble classifiers using resampling ============================================= Ensemble classifiers have shown to improve classification performance compare to single learner. However, they will be affected by class imbalance. This example shows the benefit of balancing the training set before to learn learners. We are making the comparison with non-balanced ensemble methods. We make a comparison using the balanced accuracy and geometric mean which are metrics widely used in the literature to evaluate models learned on imbalanced set. .. GENERATED FROM PYTHON SOURCE LINES 15-19 .. code-block:: Python # Authors: Guillaume Lemaitre # License: MIT .. GENERATED FROM PYTHON SOURCE LINES 20-22 .. code-block:: Python print(__doc__) .. GENERATED FROM PYTHON SOURCE LINES 23-29 Load an imbalanced dataset -------------------------- We will load the UCI SatImage dataset which has an imbalanced ratio of 9.3:1 (number of majority sample for a minority sample). The data are then split into training and testing. .. GENERATED FROM PYTHON SOURCE LINES 29-32 .. code-block:: Python from sklearn.model_selection import train_test_split .. GENERATED FROM PYTHON SOURCE LINES 33-39 .. code-block:: Python from imblearn.datasets import fetch_datasets satimage = fetch_datasets()["satimage"] X, y = satimage.data, satimage.target X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0) .. GENERATED FROM PYTHON SOURCE LINES 40-49 Classification using a single decision tree ------------------------------------------- We train a decision tree classifier which will be used as a baseline for the rest of this example. The results are reported in terms of balanced accuracy and geometric mean which are metrics widely used in the literature to validate model trained on imbalanced set. .. GENERATED FROM PYTHON SOURCE LINES 51-57 .. code-block:: Python from sklearn.tree import DecisionTreeClassifier tree = DecisionTreeClassifier() tree.fit(X_train, y_train) y_pred_tree = tree.predict(X_test) .. GENERATED FROM PYTHON SOURCE LINES 58-68 .. code-block:: Python from sklearn.metrics import balanced_accuracy_score from imblearn.metrics import geometric_mean_score print("Decision tree classifier performance:") print( f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_tree):.2f} - " f"Geometric mean {geometric_mean_score(y_test, y_pred_tree):.2f}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none Decision tree classifier performance: Balanced accuracy: 0.76 - Geometric mean 0.74 .. GENERATED FROM PYTHON SOURCE LINES 69-77 .. code-block:: Python import seaborn as sns from sklearn.metrics import ConfusionMatrixDisplay sns.set_context("poster") disp = ConfusionMatrixDisplay.from_estimator(tree, X_test, y_test, colorbar=False) _ = disp.ax_.set_title("Decision tree") .. image-sg:: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_001.png :alt: Decision tree :srcset: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 78-85 Classification using bagging classifier with and without sampling ----------------------------------------------------------------- Instead of using a single tree, we will check if an ensemble of decision tree can actually alleviate the issue induced by the class imbalancing. First, we will use a bagging classifier and its counter part which internally uses a random under-sampling to balanced each bootstrap sample. .. GENERATED FROM PYTHON SOURCE LINES 87-100 .. code-block:: Python from sklearn.ensemble import BaggingClassifier from imblearn.ensemble import BalancedBaggingClassifier bagging = BaggingClassifier(n_estimators=50, random_state=0) balanced_bagging = BalancedBaggingClassifier(n_estimators=50, random_state=0) bagging.fit(X_train, y_train) balanced_bagging.fit(X_train, y_train) y_pred_bc = bagging.predict(X_test) y_pred_bbc = balanced_bagging.predict(X_test) .. GENERATED FROM PYTHON SOURCE LINES 101-103 Balancing each bootstrap sample allows to increase significantly the balanced accuracy and the geometric mean. .. GENERATED FROM PYTHON SOURCE LINES 105-116 .. code-block:: Python print("Bagging classifier performance:") print( f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_bc):.2f} - " f"Geometric mean {geometric_mean_score(y_test, y_pred_bc):.2f}" ) print("Balanced Bagging classifier performance:") print( f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_bbc):.2f} - " f"Geometric mean {geometric_mean_score(y_test, y_pred_bbc):.2f}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none Bagging classifier performance: Balanced accuracy: 0.73 - Geometric mean 0.68 Balanced Bagging classifier performance: Balanced accuracy: 0.86 - Geometric mean 0.86 .. GENERATED FROM PYTHON SOURCE LINES 117-132 .. code-block:: Python import matplotlib.pyplot as plt fig, axs = plt.subplots(ncols=2, figsize=(10, 5)) ConfusionMatrixDisplay.from_estimator( bagging, X_test, y_test, ax=axs[0], colorbar=False ) axs[0].set_title("Bagging") ConfusionMatrixDisplay.from_estimator( balanced_bagging, X_test, y_test, ax=axs[1], colorbar=False ) axs[1].set_title("Balanced Bagging") fig.tight_layout() .. image-sg:: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_002.png :alt: Bagging, Balanced Bagging :srcset: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 133-139 Classification using random forest classifier with and without sampling ----------------------------------------------------------------------- Random forest is another popular ensemble method and it is usually outperforming bagging. Here, we used a vanilla random forest and its balanced counterpart in which each bootstrap sample is balanced. .. GENERATED FROM PYTHON SOURCE LINES 141-160 .. code-block:: Python from sklearn.ensemble import RandomForestClassifier from imblearn.ensemble import BalancedRandomForestClassifier rf = RandomForestClassifier(n_estimators=50, random_state=0) brf = BalancedRandomForestClassifier( n_estimators=50, sampling_strategy="all", replacement=True, bootstrap=False, random_state=0, ) rf.fit(X_train, y_train) brf.fit(X_train, y_train) y_pred_rf = rf.predict(X_test) y_pred_brf = brf.predict(X_test) .. GENERATED FROM PYTHON SOURCE LINES 161-164 Similarly to the previous experiment, the balanced classifier outperform the classifier which learn from imbalanced bootstrap samples. In addition, random forest outperforms the bagging classifier. .. GENERATED FROM PYTHON SOURCE LINES 166-177 .. code-block:: Python print("Random Forest classifier performance:") print( f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_rf):.2f} - " f"Geometric mean {geometric_mean_score(y_test, y_pred_rf):.2f}" ) print("Balanced Random Forest classifier performance:") print( f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_brf):.2f} - " f"Geometric mean {geometric_mean_score(y_test, y_pred_brf):.2f}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none Random Forest classifier performance: Balanced accuracy: 0.73 - Geometric mean 0.68 Balanced Random Forest classifier performance: Balanced accuracy: 0.87 - Geometric mean 0.87 .. GENERATED FROM PYTHON SOURCE LINES 178-187 .. code-block:: Python fig, axs = plt.subplots(ncols=2, figsize=(10, 5)) ConfusionMatrixDisplay.from_estimator(rf, X_test, y_test, ax=axs[0], colorbar=False) axs[0].set_title("Random forest") ConfusionMatrixDisplay.from_estimator(brf, X_test, y_test, ax=axs[1], colorbar=False) axs[1].set_title("Balanced random forest") fig.tight_layout() .. image-sg:: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_003.png :alt: Random forest, Balanced random forest :srcset: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 188-194 Boosting classifier ------------------- In the same manner, easy ensemble classifier is a bag of balanced AdaBoost classifier. However, it will be slower to train than random forest and will achieve worse performance. .. GENERATED FROM PYTHON SOURCE LINES 196-209 .. code-block:: Python from sklearn.ensemble import AdaBoostClassifier from imblearn.ensemble import EasyEnsembleClassifier, RUSBoostClassifier estimator = AdaBoostClassifier(n_estimators=10) eec = EasyEnsembleClassifier(n_estimators=10, estimator=estimator) eec.fit(X_train, y_train) y_pred_eec = eec.predict(X_test) rusboost = RUSBoostClassifier(n_estimators=10, estimator=estimator) rusboost.fit(X_train, y_train) y_pred_rusboost = rusboost.predict(X_test) .. rst-class:: sphx-glr-script-out .. code-block:: none /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( /home/circleci/mambaforge/envs/testenv/lib/python3.11/site-packages/sklearn/ensemble/_weight_boosting.py:519: FutureWarning: The SAMME.R algorithm (the default) is deprecated and will be removed in 1.6. Use the SAMME algorithm to circumvent this warning. warnings.warn( .. GENERATED FROM PYTHON SOURCE LINES 210-221 .. code-block:: Python print("Easy ensemble classifier performance:") print( f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_eec):.2f} - " f"Geometric mean {geometric_mean_score(y_test, y_pred_eec):.2f}" ) print("RUSBoost classifier performance:") print( f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_rusboost):.2f} - " f"Geometric mean {geometric_mean_score(y_test, y_pred_rusboost):.2f}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none Easy ensemble classifier performance: Balanced accuracy: 0.85 - Geometric mean 0.85 RUSBoost classifier performance: Balanced accuracy: 0.85 - Geometric mean 0.85 .. GENERATED FROM PYTHON SOURCE LINES 222-233 .. code-block:: Python fig, axs = plt.subplots(ncols=2, figsize=(10, 5)) ConfusionMatrixDisplay.from_estimator(eec, X_test, y_test, ax=axs[0], colorbar=False) axs[0].set_title("Easy Ensemble") ConfusionMatrixDisplay.from_estimator( rusboost, X_test, y_test, ax=axs[1], colorbar=False ) axs[1].set_title("RUSBoost classifier") fig.tight_layout() plt.show() .. image-sg:: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_004.png :alt: Easy Ensemble, RUSBoost classifier :srcset: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 38.182 seconds) **Estimated memory usage:** 238 MB .. _sphx_glr_download_auto_examples_ensemble_plot_comparison_ensemble_classifier.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_comparison_ensemble_classifier.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_comparison_ensemble_classifier.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_