Logging metrics from multiple GPUs in a distributed training setup can feel like trying to get a single, coherent story from a room full of people shouting at once.

Here’s what that looks like in practice, using PyTorch’s DistributedDataParallel (DDP) and Weights & Biases.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import wandb
import os

# --- Basic Setup ---
# This part assumes you've already set up your distributed environment
# (e.g., using torch.distributed.launch or torchrun)
# RANK = int(os.environ['RANK'])
# WORLD_SIZE = int(os.environ['WORLD_SIZE'])
# MASTER_ADDR = os.environ['MASTER_ADDR']
# MASTER_PORT = os.environ['MASTER_PORT']

# For demonstration, let's simulate a distributed setup if not run with torchrun
if not dist.is_initialized():
    # These would normally be set by your distributed launcher
    os.environ['RANK'] = '0'
    os.environ['WORLD_SIZE'] = '2'
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group("nccl", rank=int(os.environ['RANK']), world_size=int(os.environ['WORLD_SIZE']))

RANK = dist.get_rank()
WORLD_SIZE = dist.get_world_size()

# --- Model and Data ---
model = nn.Linear(10, 2).to(f'cuda:{RANK}')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Wrap model with DDP
model = nn.parallel.DistributedDataParallel(model, device_ids=[RANK])

# Dummy data - in a real scenario, this would be a DistributedSampler
dummy_data = torch.randn(16, 10).to(f'cuda:{RANK}')
dummy_labels = torch.randint(0, 2, (16,)).to(f'cuda:{RANK}')

# --- Weights & Biases Initialization ---
# It's crucial to initialize wandb on rank 0, or on all ranks if you want
# separate runs per GPU (less common for metrics).
# For synchronized metrics, initialize on rank 0.
if RANK == 0:
    wandb.init(project="wandb-distributed-demo", sync_tensorboard=True,  # sync_tensorboard can help with some integrations
               config={"learning_rate": 0.01, "epochs": 5, "batch_size": 16})
    print("Wandb initialized on rank 0.")
else:
    # On other ranks, you might want to initialize wandb in offline mode
    # or not at all if you only want rank 0 to log.
    # For this example, we'll initialize it on all ranks but only log from rank 0.
    # This requires a bit more care to avoid duplicate logging.
    # A common pattern is to use a shared directory for offline logs or
    # only call wandb.log on rank 0.
    # For simplicity here, we'll just print a message.
    print(f"Rank {RANK} is running. Wandb logging will be handled by rank 0.")

# --- Training Loop ---
for epoch in range(wandb.config.epochs):
    optimizer.zero_grad()
    outputs = model(dummy_data)
    loss = criterion(outputs, dummy_labels)

    # In DDP, loss.backward() is automatically synchronized.
    loss.backward()
    optimizer.step()

    # --- Logging Metrics ---
    # Only log from rank 0 to avoid duplicate logs and ensure synchronization
    if RANK == 0:
        # We need to gather the loss from all processes if we want the *average* loss
        # across the batch on *this node*. For global average, we'd need all-reduce.
        # Let's demonstrate gathering the loss from all GPUs.
        # First, create a tensor to hold the losses from all ranks.
        loss_tensor = loss.detach().clone()
        # Ensure the tensor is on CPU for gathering if it's on GPU
        if loss_tensor.is_cuda:
            loss_tensor = loss_tensor.cpu()

        # Create a list to store losses from all ranks
        all_losses = [torch.zeros_like(loss_tensor) for _ in range(WORLD_SIZE)]

        # Use dist.all_gather to collect the loss from each process
        # This requires all processes to call all_gather.
        dist.all_gather(all_losses, loss_tensor)

        # Convert the list of tensors to a single tensor and calculate the mean
        # Note: all_losses is a list of tensors, one for each rank.
        # Each tensor in the list contains the loss from that rank.
        # We need to convert them to a single tensor on CPU to compute the mean.
        gathered_losses_tensor = torch.stack(all_losses)
        mean_loss = torch.mean(gathered_losses_tensor).item()

        wandb.log({"epoch": epoch, "loss": mean_loss, "gpu_loss": loss.item()})
        print(f"Rank {RANK}: Epoch {epoch}, Loss: {mean_loss:.4f}, GPU Loss: {loss.item():.4f}")
    else:
        # Other ranks still need to participate in the distributed operations
        # like all_gather, but they don't call wandb.log.
        # We still need to call dist.all_gather to make sure the operation
        # completes on all ranks.
        loss_tensor = loss.detach().clone()
        if loss_tensor.is_cuda:
            loss_tensor = loss_tensor.cpu()
        all_losses = [torch.zeros_like(loss_tensor) for _ in range(WORLD_SIZE)]
        dist.all_gather(all_losses, loss_tensor)
        print(f"Rank {RANK}: Epoch {epoch}, Loss calculated (not logged).")


if RANK == 0:
    wandb.finish()
    print("Wandb finished on rank 0.")

The core idea is that wandb.log is a synchronous operation. If every GPU calls wandb.log independently, you’ll get duplicate logs, and the metrics won’t represent the global state of your training. The standard practice is to have one process (usually rank 0) responsible for all wandb.log calls.

To get metrics that represent the entire distributed batch (e.g., the average loss across all GPUs), you need to use distributed communication primitives like dist.all_gather or dist.reduce before logging.

In the example above, we perform dist.all_gather on the loss tensor. Each GPU sends its computed loss to all other GPUs. Then, rank 0 collects these losses, calculates the mean, and logs that mean value. We also log gpu_loss to show the individual loss on rank 0.

This pattern ensures that:

  1. No duplicate logs: Only rank 0 calls wandb.log.
  2. Synchronized metrics: Distributed communication primitives ensure that the logged metrics reflect the state across all participating GPUs.
  3. Correct averaging: You can aggregate per-GPU results to get a global view.

You’ll notice that even though wandb.init is called on all ranks (or can be), the actual wandb.log calls are guarded by if RANK == 0:. The dist.all_gather call, however, must be made by all ranks to ensure the communication completes.

The next thing you’ll likely want to tackle is synchronizing hyperparameters or other configurations across all ranks, ensuring consistency even if they were set differently on initialization.

Want structured learning?

Take the full Wandb course →