The most surprising thing about logging JAX/Flax training with Weights & Biases is how little of your existing JAX/Flax code you actually need to touch.
Let’s see it in action. Imagine you have a simple Flax model and a standard JAX training loop.
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
import wandb
# Define a simple model
class SimpleMLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=64)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
# Initialize model and optimizer
key = jax.random.PRNGKey(0)
model = SimpleMLP()
params = model.init(key, jnp.ones([1, 28 * 28]))['params']
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)
# Dummy data
x_train = jnp.ones([64, 28 * 28])
y_train = jax.random.randint(key, [64], 0, 10)
# Define loss and training step
def loss_fn(params, x, y):
logits = model.apply({'params': params}, x)
one_hot_y = jax.nn.one_hot(y, num_classes=10)
loss = optax.softmax_cross_entropy(logits, one_hot_y).mean()
return loss
@jax.jit
def train_step(params, opt_state, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
# --- W&B Integration ---
run = wandb.init(project="jax-flax-example")
# Training loop
num_epochs = 5
batch_size = 64
for epoch in range(num_epochs):
# In a real scenario, you'd iterate over your dataset
params, opt_state, loss = train_step(params, opt_state, x_train, y_train)
# Log metrics to W&B
wandb.log({"epoch": epoch, "loss": loss.item()})
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")
run.finish()
This code initializes a simple MLP, sets up an Adam optimizer, and defines a basic JAX train_step function. The key W&B integration happens with wandb.init(project="jax-flax-example") before the loop and wandb.log({"epoch": epoch, "loss": loss.item()}) inside the loop. That’s it. W&B automatically handles the synchronization of these logged metrics to your dashboard.
The mental model is straightforward: W&B acts as an external observer and recorder for your training process. You explicitly tell it what data to record using wandb.log(). This data can be anything you can compute or track: scalar metrics like loss and accuracy, hyperparameters, model architecture details (though JAX/Flax often makes this less explicit than, say, PyTorch’s nn.Module definitions), gradients, and even model checkpoints.
The wandb.log() function is the central API. It accepts a dictionary where keys are metric names (strings) and values are the corresponding data points. For scalars, you just pass the number. W&B will automatically plot these over time. You can log multiple metrics in a single call, and they’ll be timestamped together.
You control the logging frequency. In the example, we log once per epoch. For more granular tracking, you could log after each batch or a set number of batches. This is crucial for understanding training dynamics, especially with large datasets.
One thing most people don’t know is that wandb.log() is asynchronous by default. This means the call returns immediately, and W&B handles the network communication in the background. This is generally good for performance, preventing your training loop from being blocked by I/O. However, if you call run.finish() too soon after the last wandb.log(), you might miss some data. For critical logging just before script exit, it’s often a good idea to explicitly wandb.flush() to ensure all buffered data is sent.
Beyond simple metrics, W&B also integrates deeply with hyperparameter sweeps. You can define a sweep configuration in YAML or Python, and W&B will manage the execution of multiple training runs with different hyperparameter combinations, automatically tracking which run corresponds to which settings.
The next concept you’ll likely explore is logging gradients and model parameters to visualize their distributions over time.