Practical ML: Addressing Class Imbalance

Practical ML: Addressing Class Imbalance

Ever wondered why your machine learning model is failing to detect rare but critical events?

Picture this: You're developing a model to diagnose a rare disease, but despite your best efforts, it keeps missing those crucial cases.

This isn't just frustrating—it's costly, time-consuming, and can have serious real-world consequences.

This issue stems from class imbalance, where your training data is skewed, heavily favoring some classes over others.

Class imbalance can severely impact the model's performance, leaving you with inaccurate predictions and wasted resources.

In this article, we’ll delve into the nature of class imbalance, its impact on machine learning models, and the most effective strategies to tackle it.

By the end, you'll be equipped with practical solutions to ensure your models perform reliably, even with imbalanced data.

Keep reading to learn how to turn this challenge into an opportunity for creating robust, accurate models.

Challenges of Class Imbalance

Class imbalance typically arises in classification tasks.

In such tasks, one class might have a large number of samples, while another class has very few.

Let's examine the key reasons behind this:

  • Insufficient signal for minority classes: When the number of samples in the minority classes is significantly lower than the majority classes, the model may not have enough information to learn the distinguishing features of the minority classes. This lack of signal can result in poor generalization and reduced accuracy for those classes. If the training set lacks sufficient instances of the rare classes, the model might assume these classes do not exist.

  • Exploiting simple heuristics: Class imbalance can make it easier for the model to get stuck in a suboptimal solution by relying on simple heuristics instead of learning the underlying patterns in the data.

  • Asymmetric costs of errors: Class imbalance introduces an asymmetry in the cost of misclassification. Misclassifying a sample from the minority class may have a higher cost than misclassifying a sample from the majority class. If the loss function does not account for this asymmetry, the model will treat all samples equally, leading to suboptimal performance on the minority classes.

Approaches to Handling Class Imbalance

To address the challenges posed by class imbalance, we can employ three main approaches: using appropriate evaluation metrics, data-level methods, and algorithm-level methods.

Using the Right Evaluation Metrics

Choosing the appropriate evaluation metrics is crucial when dealing with class imbalance.

Traditional metrics like overall accuracy and error rate can be misleading in imbalanced scenarios, as they are dominated by the performance on the majority class.

Instead, we should focus on metrics that provide a more comprehensive view of the model's performance, such as:

  • Precision: Measures the proportion of true positive predictions among all positive predictions.

  • Recall (Sensitivity or True Positive Rate): Measures the proportion of true positive predictions among all actual positive samples.

  • F1 Score: The harmonic mean of precision and recall, providing a balanced measure of the model's performance.

Additionally, we can use the Receiver Operating Characteristic (ROC) curve and the Precision-Recall curve to visualize the model's performance at different classification thresholds.

These curves help us understand the trade-off between true positive rate and false positive rate (ROC curve) or precision and recall (Precision-Recall curve), allowing us to select an appropriate threshold based on our specific requirements.

The Precision-Recall curve is particularly useful in evaluating models for imbalanced datasets.

Data-Level Methods

Data-level methods aim to modify the distribution of the training data to reduce the level of imbalance, making it easier for the model to learn.

Resampling

The most common family of techniques in this category is resampling, which includes oversampling (adding more instances from the minority classes) and undersampling (removing instances from the majority classes):

  • Random Undersampling: Randomly remove instances from the majority class to balance the class distribution.

  • Random Oversampling: Randomly duplicate instances from the minority classes to increase their representation.

It's important to note that when resampling the training data, we should never evaluate the model on the resampled data, as it can lead to overfitting to the resampled distribution.

Two-Phase Learning

Another data-level technique is two-phase learning, where we first train the model on resampled data and then fine-tune it on the original data.

This approach allows the model to learn from a more balanced distribution initially and then adapt to the real-world distribution.

Algorithm-Level Methods

Algorithm-level methods keep the training data distribution intact but modify the learning algorithm to make it more robust to class imbalance.

These methods often involve adjusting the loss function to assign higher weights to the instances we care about, guiding the model to prioritize learning from those instances.

The good thing about these methods is that they do not change the distribution of the training data.

Cost-Sensitive Learning:

The loss function guides the learning process. Adjusting the loss function can help the model focus more on minority classes.

You can assign different misclassification costs to different classes based on a predefined cost matrix.

The model will prioritize minimizing the overall cost rather than just the classification error.

Class-Balanced Loss

Assign higher weights to the minority classes in the loss function, making the model focus more on learning from those classes.

The weights can be inversely proportional to the class frequencies or based on the effective number of samples, which takes into account the overlap among existing samples. So, classes with fewer samples are given higher weights.

This encourages the model to focus more on learning minority classes.

Let's see an example using traditional ML algorithms:

# Import necessary libraries
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix

# Load a sample imbalanced dataset
X, y = make_classification(n_samples=1000, n_features=20, n_informative=2,
                           n_redundant=10, n_clusters_per_class=1,
                           weights=[0.9, 0.1], flip_y=0, random_state=42)

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)

# Calculate class weights
class_weights = dict(zip(np.unique(y_train), 1. / np.bincount(y_train)))

# Train the logistic regression model with class weights
model = LogisticRegression(class_weight=class_weights)
model.fit(X_train, y_train)

# Predict on the test set
y_pred = model.predict(X_test)

# Evaluate the model
print("Confusion Matrix:")
print(confusion_matrix(y_test, y_pred))

Focal Loss

Dynamically adjust the weights of the loss function based on the model's confidence in its predictions.

This approach encourages the model to focus on learning from the samples it still has difficulty classifying, rather than the easy ones.

Ensemble Methods

Ensemble methods, which combine multiple models, can also help address class imbalance.

By aggregating predictions from multiple models, ensemble methods can provide more robust performance on imbalanced datasets.

Data Augmentation

Data augmentation involves creating new training samples by applying transformations to existing data. This can help increase the number of minority class samples.

Deep Learning Models and Imbalanced Datasets

Deep learning models, such as convolutional neural networks (CNNs) and recurrent neural networks (RNNs), have the capacity to learn complex patterns from large amounts of data.

However, when faced with class imbalance, these models can still struggle to perform optimally.

One approach to address class imbalance in deep learning models is to incorporate the aforementioned techniques into the training process.

Let's see how to implement class weights using Keras.

import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.utils import to_categorical

# Load a sample imbalanced dataset
X, y = make_classification(n_samples=1000, n_features=20, n_informative=2,
                           n_redundant=10, n_clusters_per_class=1,
                           weights=[0.9, 0.1], flip_y=0, random_state=42)

# Split the dataset into training and testing sets using stratification
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)

# Convert labels to categorical (one-hot encoding)
y_train_cat = to_categorical(y_train)
y_test_cat = to_categorical(y_test)

# Calculate the class weights inversely proportional to class frequencies in the training set
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
class_weights_dict = dict(enumerate(class_weights))

# Build a simple neural network model
model = Sequential()
model.add(Dense(64, input_dim=X_train.shape[1], activation='relu'))
model.add(Dense(32, activation='relu'))
model.add(Dense(y_train_cat.shape[1], activation='softmax'))

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model with class weights
model.fit(X_train, y_train_cat, epochs=50, batch_size=32, class_weight=class_weights_dict, validation_split=0.2)

# Evaluate the model on the test set
loss, accuracy = model.evaluate(X_test, y_test_cat)
print(f'Test Loss: {loss}')
print(f'Test Accuracy: {accuracy}')

For example, we can apply resampling techniques to the training data before feeding it into the model or modify the loss function to assign higher weights to the minority classes.

Another approach is to design the architecture of the deep learning model to be more robust to class imbalance.

This can involve using techniques like transfer learning, where a pre-trained model is fine-tuned on the imbalanced dataset, leveraging the learned features to improve performance on the minority classes.

Additionally, deep learning models can benefit from data augmentation techniques, such as generating synthetic samples of the minority classes using generative adversarial networks (GANs) or applying random transformations to existing samples.

These techniques can help increase the diversity and quantity of the minority class samples, providing the model with more information to learn from.

Conclusion

Class imbalance is a prevalent challenge in machine learning, where the performance of models can be significantly impacted by the distribution of samples across classes.

By understanding the challenges posed by class imbalance and employing appropriate techniques, such as using the right evaluation metrics, data-level methods, and algorithm-level methods, we can mitigate its effects and improve the performance of our models.

Deep learning models, with their ability to learn complex patterns, can benefit from these techniques to better handle imbalanced datasets.

By incorporating resampling, modifying loss functions, and leveraging techniques like transfer learning and data augmentation, deep learning models can be made more robust to class imbalance.

Ultimately, addressing class imbalance requires a combination of careful evaluation, data preprocessing, and algorithmic adjustments.

By applying these techniques and continuously monitoring and refining our models, we can build deep learning systems that perform well even in the presence of class imbalance, enabling more accurate and reliable predictions in real-world scenarios.

PS: If you like this article, share it with others ♻️

Would help a lot ❤️

And feel free to follow me for articles more like this.