TensorFlow Callbacks are the system’s way of letting you inject custom logic during a model’s training loop without you having to rewrite the entire model.fit() process.
Let’s see this in action. Imagine we’re training a simple Keras model.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import os
# 1. Define a simple model
def build_model():
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(784,)),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
model = build_model()
# 2. Prepare some dummy data
(x_train, y_train), _ = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
y_train = y_train.astype('int32')
# 3. Define our callbacks
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath="models/best_model.keras",
save_best_only=True,
monitor='val_loss'
)
# 4. Train the model with callbacks
history = model.fit(x_train[:10000], y_train[:10000],
epochs=10,
validation_split=0.2,
callbacks=[tensorboard_callback, early_stopping_callback, model_checkpoint_callback])
print("Training finished!")
When you run this, a logs/fit/ directory will be created, timestamped. Inside, TensorFlow will write training metrics (loss, accuracy, etc.) and, because histogram_freq=1, it will also log the distributions of weights and biases for each layer at every epoch. This data is specifically formatted for TensorBoard, a visualization tool. You’d then run tensorboard --logdir logs/fit in your terminal from the same directory where your Python script is, and navigate to the provided URL (usually http://localhost:6006/) in your browser. You’ll see live graphs of your metrics, histograms of your model’s parameters, and eventually, if you enable it, a graph of your model’s architecture.
The problem callbacks solve is the "black box" nature of model.fit(). You want to know what’s happening during training – are the weights exploding? Is the loss plateauing? Is it overfitting? Without callbacks, you’d have to manually log metrics at the end of each epoch or write complex custom training loops. Callbacks let you hook into specific points of the training lifecycle: on_epoch_begin, on_epoch_end, on_batch_begin, on_batch_end, on_train_begin, on_train_end, and their validation counterparts.
TensorBoard is the most common companion to callbacks because it visualizes the data generated. TensorBoard callback itself is a wrapper that writes Keras’s internal state (metrics, epoch counts, etc.) into a format TensorBoard can read. histogram_freq=1 tells it to do this for weight/bias histograms every epoch. write_graph=True (which is the default if histogram_freq is set) tells it to also log the computation graph of your model.
EarlyStopping is another crucial callback. It monitors a specified metric (here, val_loss) and stops training if that metric doesn’t improve for a set number of epochs (patience=3). This prevents wasting compute on runs that are clearly not yielding better results and avoids overfitting. monitor='val_loss' is key; it means we’re looking at the validation loss, not the training loss, to detect overfitting.
ModelCheckpoint saves the model’s weights (or the entire model) at regular intervals, or whenever a monitored metric improves. save_best_only=True combined with monitor='val_loss' means it will only save the model when the validation loss is lower than any previously recorded validation loss. This ensures you always have access to the best performing version of your model from the training run, rather than just the weights from the very last epoch, which might be worse. The filepath="models/best_model.keras" specifies where to save it, using the modern .keras format.
The history object returned by model.fit() is also a form of callback, albeit a built-in one. It’s a dictionary-like object that stores the loss and metric values for each epoch for both training and validation sets. You can plot this directly:
import matplotlib.pyplot as plt
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Val'], loc='upper left')
plt.show()
This lets you see the training progress without even needing TensorBoard, though TensorBoard offers far richer visualizations and interactive features.
A subtle but powerful aspect of callbacks is their ability to interact with each other. For instance, you could write a custom callback that inspects the logs dictionary passed to on_epoch_end and modifies training parameters (like learning rate) based on the observed metrics, potentially achieving adaptive learning rates that are more sophisticated than standard schedulers. The logs dictionary contains keys like 'loss', 'accuracy', 'val_loss', 'val_accuracy', and any custom metrics you’ve defined.
One thing most people don’t realize is that callbacks are executed in the order they appear in the callbacks list passed to model.fit(). If you have a callback that modifies model weights or training state, and another callback that reads that state, the order matters. For example, if you had a custom learning rate scheduler callback that ran after the optimizer had already applied its updates for the batch, it wouldn’t affect the current batch’s updates, only subsequent ones.
The next logical step after tracking training runs is to understand how to use these saved models for inference or to fine-tune them on new datasets.