The most surprising thing about W&B’s confusion matrix and PR curve logging is that they aren’t just static images; they’re live, interactive components that update with every new metric logged, revealing the evolving performance of your model in real-time.

Let’s see this in action. Imagine you’re training a multi-class image classifier. You’d typically log your predictions and ground truth labels for each batch or epoch.

import wandb
import torch
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix, precision_recall_curve, auc
import numpy as np

# Assume you have your model, dataloader, etc.
# model = YourModel()
# test_loader = YourDataLoader()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)

# Initialize W&B
wandb.init(project="pr-curve-confusion-matrix-demo")

# Dummy data and model for demonstration
num_classes = 5
num_samples = 100
y_true = np.random.randint(0, num_classes, num_samples)
y_scores = np.random.rand(num_samples, num_classes) # Raw logits or probabilities

# In a real scenario, you'd be iterating through your data loader
# and getting these y_true and y_scores from your model's predictions.
# For this example, we'll just use dummy data.

# Log the confusion matrix
cm = confusion_matrix(y_true, np.argmax(y_scores, axis=1))
wandb.log({"confusion_matrix": wandb.plot.confusion_matrix(
    probs=None, # Set to None if you only have class labels
    preds=np.argmax(y_scores, axis=1),
    y_true=y_true,
    class_names=[f"Class {i}" for i in range(num_classes)]
)})

# Log the PR curve (for a binary or multi-class scenario)
# For multi-class, we often plot one curve per class or macro/micro averages.
# Let's do micro-average PR curve for simplicity here.

# Convert to one-hot encoding for multi-class PR curve calculation
y_true_one_hot = np.eye(num_classes)[y_true]

# Calculate micro-average precision and recall
# We need to flatten the true labels and scores for micro-averaging
y_true_flat = y_true_one_hot.ravel()
y_scores_flat = y_scores.ravel()

# Handle cases where a class might not have any positive predictions or true positives
# Precision-recall curve requires at least one positive sample.
# For demonstration, we'll ensure there are some positive samples.

# Calculate precision, recall, and thresholds for micro-average
precision, recall, thresholds = precision_recall_curve(y_true_flat, y_scores_flat)
pr_auc = auc(recall, precision) # Note: auc(recall, precision) is the standard order

# Log the PR curve
wandb.log({"pr_curve": wandb.plot.pr_curve(
    y_true=y_true_flat,
    probas=y_scores_flat,
    # You can also specify a specific class for a per-class PR curve if needed
    # For micro-average, we pass the flattened arrays.
    # If you wanted a specific class, say class 0:
    # y_true_class0 = y_true_one_hot[:, 0]
    # y_scores_class0 = y_scores[:, 0]
    # wandb.plot.pr_curve(y_true=y_true_class0, probas=y_scores_class0, ...)
    classes=[f"Class {i}" for i in range(num_classes)], # For multi-class context
    title=f"Micro-Average PR Curve (AUC = {pr_auc:.3f})"
)})

wandb.finish()

The wandb.plot.confusion_matrix and wandb.plot.pr_curve functions are your primary tools. They take raw prediction data and labels (or probabilities) and render interactive plots directly in your W&B dashboard. The confusion matrix visually highlights where your model is making mistakes – which classes are being confused for others. The PR curve, especially the micro-averaged one, gives you a single metric (AUC) that summarizes performance across all classes, showing the trade-off between precision and recall at various probability thresholds.

The magic happens because W&B doesn’t just store a static image. When you log these plots, W&B’s frontend processes the data and renders an interactive visualization. This means you can hover over cells in the confusion matrix to see the exact counts, or hover along the PR curve to see the corresponding precision and recall values at specific thresholds. This interactivity is crucial for deep dives into model performance.

Internally, W&B uses libraries like Plotly.js to render these plots. The wandb.plot module acts as a convenient wrapper, taking your NumPy arrays or PyTorch tensors, performing necessary transformations (like calculating the confusion matrix if you provide raw predictions and labels, or formatting data for PR curves), and then passing the structured data to the appropriate plotting library. For the confusion matrix, it expects y_true and preds (or probs if you want to see the probability distribution within the matrix cells). For the PR curve, it needs y_true and probas (probabilities or scores).

The key levers you control are the data you feed into these logging functions and the class_names argument. Providing accurate class_names makes the plots immediately interpretable. For multi-class PR curves, you can log a micro-average (as shown), macro-average (by averaging per-class PR curves), or individual per-class PR curves. The choice depends on what aspect of performance you want to emphasize. Micro-averaging is good when class imbalance is a concern and you want to give equal weight to each instance. Macro-averaging gives equal weight to each class, which is useful if you care about performance on all classes equally, regardless of their frequency.

A detail that trips many people up is how wandb.plot.pr_curve handles multi-class scenarios when you only provide flattened y_true and probas. It defaults to calculating a micro-average. If you intend to plot per-class curves or a macro-average, you need to explicitly calculate those metrics yourself using libraries like scikit-learn and then log the resulting precision, recall, and AUC values, potentially as separate plots or as part of a wandb.Table.

The next concept you’ll likely encounter is logging ROC curves and understanding their relationship with PR curves, especially in the context of imbalanced datasets.

Want structured learning?

Take the full Wandb course →