The most surprising thing about W&B gradient logging is that it’s not just for visualizing gradients; it’s your first line of defense against training instability, often revealing issues before your loss tanks or your metrics flatline.

Let’s see it in action. Imagine you’re training a ResNet-50 on ImageNet and you’re seeing erratic loss spikes. You’ve got W&B integrated, so you’ve already set up wandb.init() and are logging your model’s gradients. Your training script might look something like this:

import wandb
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10 # Using CIFAR10 for a quicker demo
from torchvision import transforms

# Initialize W&B
wandb.init(project="gradient-debugging-demo", config={
    "learning_rate": 0.01,
    "epochs": 10,
    "batch_size": 64
})

# Load a pre-trained ResNet-50
model = resnet50(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10) # CIFAR-10 has 10 classes

# Data loading and preprocessing
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(Subset(train_dataset, range(1000)), batch_size=wandb.config.batch_size, shuffle=True)

# Optimizer and Loss
optimizer = optim.SGD(model.parameters(), lr=wandb.config.learning_rate, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(wandb.config.epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()

        # Log gradients to W&B
        wandb.log({"loss": loss.item()})

        # This is where we log gradients.
        # We'll log the norm of gradients for each parameter.
        # You can also log the gradients themselves for more detailed analysis.
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad_norm = torch.norm(param.grad).item()
                wandb.log({f"grad_norm/{name}": grad_norm})

        optimizer.step()

        running_loss += loss.item()
        if i % 10 == 9:    # print every 10 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 10:.3f}')
            running_loss = 0.0

    print(f"Epoch {epoch+1} finished.")

wandb.finish()

When you run this, your W&B dashboard will show a "Gradients" tab. This tab lets you visualize the distribution and norms of your gradients across different layers. It’s not just a static plot; you can see how these distributions evolve over training steps.

The core problem gradient logging helps solve is training instability, which manifests as exploding or vanishing gradients. Exploding gradients are typically characterized by extremely large gradient values, often inf or NaN, that can cause large, erratic updates to model weights, leading to diverging loss. Vanishing gradients, on the other hand, are extremely small values that result in minimal weight updates, preventing the model from learning effectively, especially in deeper layers.

Here’s how it works internally: During the backward pass (loss.backward()), PyTorch (or your framework) computes the gradients of the loss with respect to each parameter in the model. W&B intercepts these gradients after loss.backward() and before optimizer.step(). You can then access param.grad for each parameter and log its properties. The wandb.log() call sends this information to your W&B project, where it’s aggregated and visualized. You can log the raw gradient tensors, their norms (like in the example), or even histograms of gradient values.

The key levers you control are:

  1. What to log: You decide which parameters’ gradients to track and what metric (norm, mean, max, histogram) to log. Logging norms is a good starting point for detecting magnitude issues.
  2. Logging frequency: You can log gradients every step, every epoch, or at custom intervals. More frequent logging gives finer-grained insights but increases W&B’s data transfer.
  3. Gradient clipping: If you detect exploding gradients, you can use torch.nn.utils.clip_grad_norm_() or torch.nn.utils.clip_grad_value_() before optimizer.step() to cap their magnitude.

Now, let’s dive into the common causes of instability and how to spot them with W&B gradient logging.

Common Causes of Training Instability

1. Learning Rate Too High:

  • Diagnosis: You’ll see gradient norms suddenly spike to very large values (e.g., 1e5, 1e10) or even inf/NaN. The loss plot will likely show sharp upward spikes or complete divergence.
  • Command/Check: In your W&B dashboard, go to the "Gradients" tab. Look for any parameter’s gradient norm that suddenly goes from a reasonable value (e.g., 1e-2 to 1e-4) to extremely large numbers or inf.
  • Fix: Reduce the learning rate. For example, if you’re using 0.01, try 0.001 or 0.0001.
  • Why it works: A high learning rate can cause the optimizer to overshoot the minimum of the loss function. With large gradients, each step taken is too big, leading to oscillation or divergence. Reducing the learning rate makes these steps smaller and more controlled.

2. Poor Initialization:

  • Diagnosis: Vanishing gradients are prevalent from the very first few steps, particularly in deeper layers or activation functions like sigmoid/tanh. Gradient norms for many parameters will be close to zero, and they might not increase significantly even after many training steps.
  • Command/Check: In the "Gradients" tab, observe the norms of parameters in earlier layers (e.g., layer1.0.conv1.weight) or weights of final linear layers. If they are consistently near zero and don’t change much, it’s a sign.
  • Fix: Use better weight initialization schemes like Kaiming (He) initialization for ReLU activations or Xavier (Glorot) initialization for sigmoid/tanh. For example, in PyTorch, nn.Linear(..., weight_init=nn.init.kaiming_normal_) or nn.Conv2d(..., weight_init=nn.init.xavier_uniform_).
  • Why it works: Proper initialization helps ensure that the variance of activations and gradients remains roughly constant across layers, preventing them from shrinking or growing exponentially as they propagate through the network.

3. Exploding Gradients due to Network Architecture (e.g., RNNs, deep CNNs):

  • Diagnosis: Similar to a high learning rate, but more inherent to the network structure. Gradients can spontaneously become very large, even with a moderate learning rate. This is common in Recurrent Neural Networks (RNNs) due to repeated matrix multiplications over time steps.
  • Command/Check: W&B gradient norms will show sudden, massive spikes. Histograms might show a long tail of very large values.
  • Fix: Implement gradient clipping. For example, before optimizer.step(), add:
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    (The max_norm value is a hyperparameter; 1.0 is a common starting point).
  • Why it works: Gradient clipping limits the maximum magnitude of the gradients, preventing them from becoming excessively large and destabilizing the training process. It doesn’t prevent learning; it just caps the size of the update step.

4. Activation Functions:

  • Diagnosis: Vanishing gradients, especially in deep networks, can be exacerbated by activation functions like sigmoid or tanh, which have saturated regions where their derivatives are close to zero. If you see gradient norms consistently decreasing through layers or staying very small, this could be a culprit.
  • Command/Check: Examine gradient norms layer by layer. If norms are significantly smaller in layers that follow activations like nn.Sigmoid() or nn.Tanh(), that’s a clue.
  • Fix: Switch to activation functions with non-saturating gradients in their typical operating range, such as ReLU (nn.ReLU()) and its variants (Leaky ReLU, PReLU).
  • Why it works: ReLU and its variants have a derivative of 1 for positive inputs, which helps gradients propagate more effectively through deep networks without diminishing.

5. Batch Normalization Issues:

  • Diagnosis: While batch norm usually helps stabilize training, misconfigurations or using it without sufficient batch size can sometimes lead to noisy or unstable gradients. Gradients might fluctuate wildly between batches.
  • Command/Check: Look for extreme variance in gradient norms from one batch to the next, especially for parameters associated with Batch Norm layers.
  • Fix: Ensure you have a sufficiently large batch size (e.g., 32 or 64) when using Batch Norm. If issues persist, consider alternatives or carefully tune Batch Norm’s epsilon and momentum parameters.
  • Why it works: Batch Norm normalizes activations by their batch statistics. If the batch size is too small, these statistics can be noisy, leading to unstable gradient estimates.

6. Numerical Precision Issues (e.g., FP16 training):

  • Diagnosis: When using mixed precision (e.g., torch.cuda.amp), gradients can underflow to zero if the loss is very small or if intermediate calculations result in values below the representable range of FP16. You’ll see gradients suddenly becoming NaN or zero for many parameters.
  • Command/Check: In W&B, check for NaN values in gradient norms or histograms showing a large spike at zero. This is particularly common if you’re not using gradient scaling.
  • Fix: Ensure you are using gradient scaling when training with FP16. Libraries like torch.cuda.amp.GradScaler handle this automatically by scaling up loss before backward and scaling down gradients after.
  • Why it works: Gradient scaling artificially increases the magnitude of gradients during the backward pass, allowing them to be represented accurately in FP16. The final gradients are then scaled back down before the optimizer step.

When you address these issues, your next immediate problem will likely be fine-tuning hyperparameters like learning rate schedules or regularization strength to achieve optimal convergence.

Want structured learning?

Take the full Wandb course →