.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/ensemble/plot_comparison_ensemble_classifier.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_ensemble_plot_comparison_ensemble_classifier.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_ensemble_plot_comparison_ensemble_classifier.py:


=============================================
Compare ensemble classifiers using resampling
=============================================

Ensemble classifiers have shown to improve classification performance compare
to single learner. However, they will be affected by class imbalance. This
example shows the benefit of balancing the training set before to learn
learners. We are making the comparison with non-balanced ensemble methods.

We make a comparison using the balanced accuracy and geometric mean which are
metrics widely used in the literature to evaluate models learned on imbalanced
set.

.. GENERATED FROM PYTHON SOURCE LINES 15-19

.. code-block:: Python


    # Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
    # License: MIT








.. GENERATED FROM PYTHON SOURCE LINES 20-22

.. code-block:: Python

    print(__doc__)








.. GENERATED FROM PYTHON SOURCE LINES 23-29

Load an imbalanced dataset
--------------------------

We will load the UCI SatImage dataset which has an imbalanced ratio of 9.3:1
(number of majority sample for a minority sample). The data are then split
into training and testing.

.. GENERATED FROM PYTHON SOURCE LINES 29-32

.. code-block:: Python


    from sklearn.model_selection import train_test_split








.. GENERATED FROM PYTHON SOURCE LINES 33-39

.. code-block:: Python

    from imblearn.datasets import fetch_datasets

    satimage = fetch_datasets()["satimage"]
    X, y = satimage.data, satimage.target
    X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0)








.. GENERATED FROM PYTHON SOURCE LINES 40-49

Classification using a single decision tree
-------------------------------------------

We train a decision tree classifier which will be used as a baseline for the
rest of this example.

The results are reported in terms of balanced accuracy and geometric mean
which are metrics widely used in the literature to validate model trained on
imbalanced set.

.. GENERATED FROM PYTHON SOURCE LINES 51-57

.. code-block:: Python

    from sklearn.tree import DecisionTreeClassifier

    tree = DecisionTreeClassifier()
    tree.fit(X_train, y_train)
    y_pred_tree = tree.predict(X_test)








.. GENERATED FROM PYTHON SOURCE LINES 58-68

.. code-block:: Python

    from sklearn.metrics import balanced_accuracy_score

    from imblearn.metrics import geometric_mean_score

    print("Decision tree classifier performance:")
    print(
        f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_tree):.2f} - "
        f"Geometric mean {geometric_mean_score(y_test, y_pred_tree):.2f}"
    )





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Decision tree classifier performance:
    Balanced accuracy: 0.74 - Geometric mean 0.71




.. GENERATED FROM PYTHON SOURCE LINES 69-77

.. code-block:: Python

    import seaborn as sns
    from sklearn.metrics import ConfusionMatrixDisplay

    sns.set_context("poster")

    disp = ConfusionMatrixDisplay.from_estimator(tree, X_test, y_test, colorbar=False)
    _ = disp.ax_.set_title("Decision tree")




.. image-sg:: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_001.png
   :alt: Decision tree
   :srcset: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_001.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 78-85

Classification using bagging classifier with and without sampling
-----------------------------------------------------------------

Instead of using a single tree, we will check if an ensemble of decision tree
can actually alleviate the issue induced by the class imbalancing. First, we
will use a bagging classifier and its counter part which internally uses a
random under-sampling to balanced each bootstrap sample.

.. GENERATED FROM PYTHON SOURCE LINES 87-100

.. code-block:: Python

    from sklearn.ensemble import BaggingClassifier

    from imblearn.ensemble import BalancedBaggingClassifier

    bagging = BaggingClassifier(n_estimators=50, random_state=0)
    balanced_bagging = BalancedBaggingClassifier(n_estimators=50, random_state=0)

    bagging.fit(X_train, y_train)
    balanced_bagging.fit(X_train, y_train)

    y_pred_bc = bagging.predict(X_test)
    y_pred_bbc = balanced_bagging.predict(X_test)








.. GENERATED FROM PYTHON SOURCE LINES 101-103

Balancing each bootstrap sample allows to increase significantly the balanced
accuracy and the geometric mean.

.. GENERATED FROM PYTHON SOURCE LINES 105-116

.. code-block:: Python

    print("Bagging classifier performance:")
    print(
        f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_bc):.2f} - "
        f"Geometric mean {geometric_mean_score(y_test, y_pred_bc):.2f}"
    )
    print("Balanced Bagging classifier performance:")
    print(
        f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_bbc):.2f} - "
        f"Geometric mean {geometric_mean_score(y_test, y_pred_bbc):.2f}"
    )





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Bagging classifier performance:
    Balanced accuracy: 0.73 - Geometric mean 0.68
    Balanced Bagging classifier performance:
    Balanced accuracy: 0.86 - Geometric mean 0.86




.. GENERATED FROM PYTHON SOURCE LINES 117-132

.. code-block:: Python

    import matplotlib.pyplot as plt

    fig, axs = plt.subplots(ncols=2, figsize=(10, 5))
    ConfusionMatrixDisplay.from_estimator(
        bagging, X_test, y_test, ax=axs[0], colorbar=False
    )
    axs[0].set_title("Bagging")

    ConfusionMatrixDisplay.from_estimator(
        balanced_bagging, X_test, y_test, ax=axs[1], colorbar=False
    )
    axs[1].set_title("Balanced Bagging")

    fig.tight_layout()




.. image-sg:: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_002.png
   :alt: Bagging, Balanced Bagging
   :srcset: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_002.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 133-139

Classification using random forest classifier with and without sampling
-----------------------------------------------------------------------

Random forest is another popular ensemble method and it is usually
outperforming bagging. Here, we used a vanilla random forest and its balanced
counterpart in which each bootstrap sample is balanced.

.. GENERATED FROM PYTHON SOURCE LINES 141-160

.. code-block:: Python

    from sklearn.ensemble import RandomForestClassifier

    from imblearn.ensemble import BalancedRandomForestClassifier

    rf = RandomForestClassifier(n_estimators=50, random_state=0)
    brf = BalancedRandomForestClassifier(
        n_estimators=50,
        sampling_strategy="all",
        replacement=True,
        bootstrap=False,
        random_state=0,
    )

    rf.fit(X_train, y_train)
    brf.fit(X_train, y_train)

    y_pred_rf = rf.predict(X_test)
    y_pred_brf = brf.predict(X_test)








.. GENERATED FROM PYTHON SOURCE LINES 161-164

Similarly to the previous experiment, the balanced classifier outperform the
classifier which learn from imbalanced bootstrap samples. In addition, random
forest outperforms the bagging classifier.

.. GENERATED FROM PYTHON SOURCE LINES 166-177

.. code-block:: Python

    print("Random Forest classifier performance:")
    print(
        f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_rf):.2f} - "
        f"Geometric mean {geometric_mean_score(y_test, y_pred_rf):.2f}"
    )
    print("Balanced Random Forest classifier performance:")
    print(
        f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_brf):.2f} - "
        f"Geometric mean {geometric_mean_score(y_test, y_pred_brf):.2f}"
    )





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Random Forest classifier performance:
    Balanced accuracy: 0.73 - Geometric mean 0.68
    Balanced Random Forest classifier performance:
    Balanced accuracy: 0.87 - Geometric mean 0.87




.. GENERATED FROM PYTHON SOURCE LINES 178-187

.. code-block:: Python

    fig, axs = plt.subplots(ncols=2, figsize=(10, 5))
    ConfusionMatrixDisplay.from_estimator(rf, X_test, y_test, ax=axs[0], colorbar=False)
    axs[0].set_title("Random forest")

    ConfusionMatrixDisplay.from_estimator(brf, X_test, y_test, ax=axs[1], colorbar=False)
    axs[1].set_title("Balanced random forest")

    fig.tight_layout()




.. image-sg:: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_003.png
   :alt: Random forest, Balanced random forest
   :srcset: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_003.png
   :class: sphx-glr-single-img





.. GENERATED FROM PYTHON SOURCE LINES 188-194

Boosting classifier
-------------------

In the same manner, easy ensemble classifier is a bag of balanced AdaBoost
classifier. However, it will be slower to train than random forest and will
achieve worse performance.

.. GENERATED FROM PYTHON SOURCE LINES 196-209

.. code-block:: Python

    from sklearn.ensemble import AdaBoostClassifier

    from imblearn.ensemble import EasyEnsembleClassifier, RUSBoostClassifier

    estimator = AdaBoostClassifier(n_estimators=10, algorithm="SAMME")
    eec = EasyEnsembleClassifier(n_estimators=10, estimator=estimator)
    eec.fit(X_train, y_train)
    y_pred_eec = eec.predict(X_test)

    rusboost = RUSBoostClassifier(n_estimators=10, estimator=estimator)
    rusboost.fit(X_train, y_train)
    y_pred_rusboost = rusboost.predict(X_test)





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    /home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
      warnings.warn(
    /home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
      warnings.warn(
    /home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
      warnings.warn(
    /home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
      warnings.warn(
    /home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
      warnings.warn(
    /home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
      warnings.warn(
    /home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
      warnings.warn(
    /home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
      warnings.warn(
    /home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
      warnings.warn(
    /home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
      warnings.warn(
    /home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
      warnings.warn(
    /home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
      warnings.warn(
    /home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
      warnings.warn(
    /home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
      warnings.warn(




.. GENERATED FROM PYTHON SOURCE LINES 210-221

.. code-block:: Python

    print("Easy ensemble classifier performance:")
    print(
        f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_eec):.2f} - "
        f"Geometric mean {geometric_mean_score(y_test, y_pred_eec):.2f}"
    )
    print("RUSBoost classifier performance:")
    print(
        f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_rusboost):.2f} - "
        f"Geometric mean {geometric_mean_score(y_test, y_pred_rusboost):.2f}"
    )





.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Easy ensemble classifier performance:
    Balanced accuracy: 0.83 - Geometric mean 0.83
    RUSBoost classifier performance:
    Balanced accuracy: 0.83 - Geometric mean 0.83




.. GENERATED FROM PYTHON SOURCE LINES 222-233

.. code-block:: Python

    fig, axs = plt.subplots(ncols=2, figsize=(10, 5))

    ConfusionMatrixDisplay.from_estimator(eec, X_test, y_test, ax=axs[0], colorbar=False)
    axs[0].set_title("Easy Ensemble")
    ConfusionMatrixDisplay.from_estimator(
        rusboost, X_test, y_test, ax=axs[1], colorbar=False
    )
    axs[1].set_title("RUSBoost classifier")

    fig.tight_layout()
    plt.show()



.. image-sg:: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_004.png
   :alt: Easy Ensemble, RUSBoost classifier
   :srcset: /auto_examples/ensemble/images/sphx_glr_plot_comparison_ensemble_classifier_004.png
   :class: sphx-glr-single-img






.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 12.394 seconds)

**Estimated memory usage:**  407 MB


.. _sphx_glr_download_auto_examples_ensemble_plot_comparison_ensemble_classifier.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_comparison_ensemble_classifier.ipynb <plot_comparison_ensemble_classifier.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_comparison_ensemble_classifier.py <plot_comparison_ensemble_classifier.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_comparison_ensemble_classifier.zip <plot_comparison_ensemble_classifier.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_