Weights & Biases (W&B) callbacks for Keras and PyTorch Lightning don’t just log metrics; they act as intelligent agents, dynamically influencing your training based on observed performance.

Let’s see this in action with a quick PyTorch Lightning example. Imagine you’re training a model and want to automatically adjust the learning rate when the validation loss plateaus.

import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import wandb
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from wandb.integration.lightning import WandbLogger

# --- Setup ---
# Load a pre-trained model and tokenizer
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# Create dummy data
texts = ["This is a positive review.", "This is a negative review."] * 100
labels = [1, 0] * 100
encoded_inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")

dataset = TensorDataset(encoded_inputs["input_ids"], encoded_inputs["attention_mask"], torch.tensor(labels))
train_loader = DataLoader(dataset, batch_size=4)
val_loader = DataLoader(dataset, batch_size=4) # Using same for demo

# --- PyTorch Lightning Module ---
class TextClassifier(pl.LightningModule):
    def __init__(self, model, learning_rate=1e-5):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate
        self.save_hyperparameters()

    def forward(self, input_ids, attention_mask):
        return self.model(input_ids=input_ids, attention_mask=attention_mask)

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch
        outputs = self(input_ids, attention_mask)
        loss = torch.nn.functional.cross_entropy(outputs.logits, labels)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch
        outputs = self(input_ids, attention_mask)
        loss = torch.nn.functional.cross_entropy(outputs.logits, labels)
        self.log("val_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        # This is where the callback magic happens: ReduceLROnPlateau
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.1,
            patience=2,
            min_lr=1e-6
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val_loss" # Crucial: monitor validation loss
        }

# --- W&B Integration ---
# Initialize W&B run
run = wandb.init(project="pl-keras-callback-demo", job_type="training")

# Instantiate W&B logger
wandb_logger = WandbLogger(project="pl-keras-callback-demo", log_model="all")

# Instantiate other callbacks
lr_monitor = LearningRateMonitor(logging_interval='step') # Logs LR every step
checkpoint_callback = ModelCheckpoint(monitor="val_loss", mode="min", filename="{epoch}-{val_loss:.2f}")

# --- Training ---
model_pl = TextClassifier(model)
trainer = pl.Trainer(
    max_epochs=5,
    accelerator="auto", # Use GPU if available
    devices=1,
    logger=wandb_logger,
    callbacks=[lr_monitor, checkpoint_callback, wandb_logger.model_checkpoint] # Use W&B's checkpointing
)

trainer.fit(model_pl, train_loader, val_loader)

# Finish W&B run
run.finish()

In this PyTorch Lightning example, the configure_optimizers method is key. We’re returning a dictionary that includes an optimizer and an lr_scheduler. The scheduler torch.optim.lr_scheduler.ReduceLROnPlateau is configured to monitor "val_loss". When trainer.fit runs, PyTorch Lightning automatically passes the val_loss logged by the validation_step to this scheduler. If the val_loss stops decreasing for patience=2 epochs, the scheduler reduces the learning rate by factor=0.1. The WandbLogger automatically captures these learning rate changes and logs them as part of the training run, alongside the loss, accuracy, and any other metrics you log.

The core problem W&B callbacks solve is bridging the gap between your raw training loop and the sophisticated monitoring/control mechanisms provided by W&B. Instead of manually writing code to ping W&B after each epoch to check a condition, you declare your intent (e.g., "reduce LR if val_loss plateaus") within your existing framework (Keras or PyTorch Lightning), and the W&B integration handles the communication. This means your training script remains clean and focused on the model’s logic, while the callbacks manage the operational aspects.

For Keras, the integration is similarly seamless. You pass WandbCallback as an argument to model.fit():

import tensorflow as tf
from tensorflow import keras
import wandb
from wandb.keras import WandbCallback

# Assuming you have a Keras model `model`, data `x_train`, `y_train`, etc.

# Initialize W&B run
run = wandb.init(project="keras-callback-demo", job_type="training")

# Define Keras callbacks
# WandbCallback automatically logs metrics, hyperparameters, and model architecture
# It also integrates with Keras's EarlyStopping and ReduceLROnPlateau
keras_callbacks = [
    WandbCallback(
        log_weights=True, # Log model weights periodically
        log_gradients=True, # Log gradient norms
        save_model=True, # Save the best model based on validation loss
        validation_steps=50 # Number of validation batches to log
    ),
    keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=2, min_lr=1e-6)
]

# Compile the model (if not already compiled)
# model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Train the model
# model.fit(x_train, y_train, epochs=20, validation_data=(x_val, y_val), callbacks=keras_callbacks)

# Finish W&B run
run.finish()

Here, WandbCallback wraps Keras’s built-in callbacks and adds W&B-specific logging. log_weights=True and log_gradients=True send detailed model state information to W&B. Crucially, WandbCallback also integrates with Keras’s EarlyStopping and ReduceLROnPlateau by ensuring their actions (like stopping training or reducing LR) are also logged to W&B. This means you can inspect why training stopped or when the learning rate was adjusted directly in the W&B UI.

The mental model to hold is that these callbacks are hooks into the training lifecycle. They intercept events (like on_epoch_end, on_batch_end, on_validation_end) and, based on their configuration, perform actions and send data to W&B. They are not just passive loggers; they can actively modify training behavior through their integration with framework-specific callbacks.

The most surprising part is how much you can offload to these callbacks. You can configure W&B to automatically log system resource utilization (CPU, GPU, RAM) during training, log model predictions on validation batches, and even create interactive W&B Artifacts for datasets or model checkpoints, all with minimal code changes within your existing Keras or PyTorch Lightning structure. This allows for a much richer understanding of your training process beyond just metrics.

The next step is to explore custom W&B callbacks, allowing you to define entirely new training behaviors and logging strategies tailored to your specific research needs.

Want structured learning?

Take the full Wandb course →