The most surprising thing about tackling class imbalance in TensorFlow is that often, the "fix" isn’t about making your dataset look more balanced, but about teaching your model to care more about the rare examples.
Let’s see this in action. Imagine you have a dataset for detecting rare fraudulent transactions. Most transactions are legitimate (class 0), and only a tiny fraction are fraudulent (class 1).
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split
import numpy as np
# Simulate imbalanced data
num_samples = 10000
num_fraud = 100
num_legit = num_samples - num_fraud
# Features (dummy data)
X = np.random.rand(num_samples, 10)
# Labels: 0 for legitimate, 1 for fraud
y = np.zeros(num_samples, dtype=int)
fraud_indices = np.random.choice(num_samples, num_fraud, replace=False)
y[fraud_indices] = 1
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
# Define a simple model
inputs = Input(shape=(10,))
x = Dense(32, activation='relu')(inputs)
outputs = Dense(1, activation='sigmoid')(x) # Sigmoid for binary classification
model = Model(inputs=inputs, outputs=outputs)
# --- Method 1: Class Weights ---
# Calculate weights
total_samples = len(y_train)
neg_count = np.sum(y_train == 0)
pos_count = np.sum(y_train == 1)
weight_for_0 = (1 / neg_count) * (total_samples / 2.0)
weight_for_1 = (1 / pos_count) * (total_samples / 2.0)
class_weight = {0: weight_for_0, 1: weight_for_1}
print(f"Class weights: {class_weight}")
# Compile with class weights
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
# Train with class weights
history_weighted = model.fit(X_train, y_train,
epochs=10,
batch_size=32,
validation_split=0.2,
class_weight=class_weight,
verbose=0) # Suppress verbose output for brevity
print("Model trained with class weights.")
# --- Method 2: Oversampling (using imbalanced-learn for simplicity in demo) ---
# In a real scenario, you'd use libraries like imbalanced-learn
# from imblearn.over_sampling import SMOTE
# smote = SMOTE(random_state=42)
# X_train_smote, y_train_smote = smote.fit_resample(X_train, y_train)
# print(f"Original training data shape: {X_train.shape}, {y_train.shape}")
# print(f"SMOTE oversampled data shape: {X_train_smote.shape}, {y_train_smote.shape}")
# Note: For this example, we'll skip the actual re-training to keep it concise.
# The principle is to upsample the minority class to match the majority class size.
# --- Method 3: Focal Loss ---
def focal_loss(gamma=2., alpha=.25):
def focal_loss_fixed(y_true, y_pred):
pt = tf.where(tf.equal(y_true, 1), y_pred, 1 - y_pred)
at = tf.where(tf.equal(y_true, 1), alpha, 1 - alpha)
# Use epsilon to prevent log(0)
epsilon = tf.keras.backend.epsilon()
pt = tf.clip_by_value(pt, epsilon, 1. - epsilon)
loss = -at * (1 - pt)**gamma * tf.math.log(pt)
return tf.reduce_mean(loss)
return focal_loss_fixed
# Compile with Focal Loss
model_focal = Model(inputs=inputs, outputs=outputs) # Re-initialize model for clarity
model_focal.compile(optimizer='adam',
loss=focal_loss(gamma=2., alpha=0.25), # Example hyperparameters
metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
# Train with Focal Loss
history_focal = model_focal.fit(X_train, y_train,
epochs=10,
batch_size=32,
validation_split=0.2,
verbose=0)
print("Model trained with Focal Loss.")
# Evaluate (simplified for demo)
loss_w, acc_w, prec_w, recall_w = model.evaluate(X_test, y_test, verbose=0)
print(f"\nWeighted Model - Test Accuracy: {acc_w:.4f}, Precision: {prec_w:.4f}, Recall: {recall_w:.4f}")
loss_f, acc_f, prec_f, recall_f = model_focal.evaluate(X_test, y_test, verbose=0)
print(f"Focal Loss Model - Test Accuracy: {acc_f:.4f}, Precision: {prec_f:.4f}, Recall: {recall_f:.4f}")
This code snippet demonstrates three common strategies: Class Weights, Oversampling (conceptually shown), and Focal Loss.
The core problem class imbalance presents is that a model trained on imbalanced data will naturally favor the majority class. During training, the loss function is dominated by the errors on the frequent examples. The gradients will push the model to get those right, effectively ignoring the minority class. Accuracy becomes a misleading metric; a model predicting "not fraud" 99% of the time achieves 99% accuracy but is useless for fraud detection.
Class Weights
What it is: You tell the model how much to penalize errors for each class. You assign a higher weight to the minority class.
How it works: When calculating the loss, each sample’s contribution is multiplied by its class weight. For a minority class sample that’s misclassified, the resulting loss is higher, giving the model a stronger signal to correct that error.
Diagnosis/Implementation:
- Calculate Class Counts:
neg_count = np.sum(y_train == 0) pos_count = np.sum(y_train == 1) total_samples = len(y_train) - Compute Weights: A common formula is
weight_for_class_X = (total_samples / (num_classes * count_for_class_X)). For binary classification:weight_for_0 = (1 / neg_count) * (total_samples / 2.0) weight_for_1 = (1 / pos_count) * (total_samples / 2.0) class_weight = {0: weight_for_0, 1: weight_for_1} - Apply during
model.fit():
This directly influences the loss calculation within the training loop.model.fit(X_train, y_train, class_weight=class_weight, ...)
Oversampling and Undersampling
What it is: Modifying the training dataset itself to create a more balanced distribution. Oversampling duplicates minority class samples, while undersampling removes majority class samples.
How it works: By presenting a more balanced dataset to the model, the loss function is no longer skewed by the sheer number of majority class examples. Each class gets roughly equal influence on the gradient updates. Techniques like SMOTE (Synthetic Minority Over-sampling Technique) generate synthetic samples rather than just duplicating existing ones, which can be more effective.
Diagnosis/Implementation:
- Use a library like
imbalanced-learn:from imblearn.over_sampling import SMOTE smote = SMOTE(random_state=42) X_train_resampled, y_train_resampled = smote.fit_resample(X_train, y_train) - Train on the resampled data:
The model sees an equal number of positive and negative examples during training.model.fit(X_train_resampled, y_train_resampled, ...)
Focal Loss
What it is: A modification of the standard cross-entropy loss function that down-weights easy, well-classified examples and focuses training on hard-to-classify ones.
How it works: Focal Loss introduces two parameters: alpha (similar to class weights, balancing class importance) and gamma. The gamma parameter is the key. It reshapes the standard cross-entropy loss. When gamma is greater than 0, the loss for well-classified examples (where the predicted probability pt is close to 1 for the correct class) is significantly reduced by the (1 - pt)**gamma term. This means the model’s gradients are primarily driven by misclassified examples, especially those the model is confidently wrong about, which are often the minority class samples.
Diagnosis/Implementation:
- Define the Focal Loss function:
def focal_loss(gamma=2., alpha=.25): def focal_loss_fixed(y_true, y_pred): # ... (implementation as shown in the code block above) return tf.reduce_mean(loss) return focal_loss_fixed - Compile the model with the custom loss:
The loss calculation duringmodel.compile(optimizer='adam', loss=focal_loss(gamma=2., alpha=0.25), # Tune gamma and alpha metrics=['accuracy'])model.fitnow uses this specialized function.
The most common pitfall is forgetting to re-evaluate your model using metrics that are robust to imbalance, such as Precision, Recall, F1-score, or AUC, rather than just accuracy.