TensorFlow’s checkpointing and early stopping mechanisms are designed to prevent catastrophic data loss and to avoid overfitting, but they often interact in ways that confuse users, leading to unexpected behavior and missed training opportunities.

Let’s see this in action. Imagine we’re training a simple Keras model to classify MNIST digits. We’ll set up a checkpoint to save the best model based on validation accuracy and then use early stopping to halt training if that accuracy doesn’t improve for a few epochs.

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist
import numpy as np

# Load and preprocess data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

# Build a simple model
def build_model():
    model = keras.Sequential([
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(10, activation="softmax"),
    ])
    return model

# Compile the model
model = build_model()
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

# Define callbacks
checkpoint_filepath = "/tmp/checkpoint/best_model.keras"
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=False,  # Save the entire model
    monitor="val_accuracy",
    mode="max",
    save_best_only=True,
)

early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor="val_accuracy",
    patience=3,  # Stop if no improvement for 3 epochs
    mode="max",
    restore_best_weights=True, # Crucial for this example
)

# Train the model
print("Training model...")
history = model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=20,
    validation_split=0.2,
    callbacks=[model_checkpoint_callback, early_stopping_callback],
)

print("\nTraining finished.")
# Load the best model saved by the checkpoint (if restore_best_weights is False)
# If restore_best_weights is True, the model object itself will be restored.
# For demonstration, let's explicitly load it.
print("Loading best model from checkpoint...")
best_model = keras.models.load_model(checkpoint_filepath)
loss, acc = best_model.evaluate(x_test, y_test, verbose=0)
print(f"Test accuracy of the best model: {acc:.4f}")

The core problem these callbacks solve is managing the trade-off between thorough exploration of the loss landscape and the risk of overfitting. Checkpointing saves snapshots of your model’s progress, typically at regular intervals or when a performance metric improves. Early stopping, on the other hand, monitors a metric (like validation accuracy or loss) and halts training if it stops improving for a specified number of epochs (patience). This prevents the model from continuing to train on noisy data and memorizing the training set, which would lead to poor generalization on unseen data.

When used together, the ModelCheckpoint callback can be configured to save the model only when the monitored metric (val_accuracy in our example) reaches a new best. Simultaneously, EarlyStopping monitors the same metric. If EarlyStopping triggers, it means the val_accuracy has plateaued or started to decrease. If restore_best_weights=True is set in EarlyStopping, the training process will automatically revert the model’s weights to the state they were in during the epoch that achieved the best monitored metric. This is often the same epoch (or very close to it) that ModelCheckpoint saved.

Here’s the critical interaction: ModelCheckpoint saves the best model during training. EarlyStopping with restore_best_weights=True reverts the model to its best state after training stops. If you were to use restore_best_weights=False (the default for EarlyStopping), your model object would be left in the state it was after the last epoch before stopping. In that scenario, you would always need to explicitly load the model saved by ModelCheckpoint using keras.models.load_model(checkpoint_filepath) to get the best performing version. However, with restore_best_weights=True, the model object itself is updated to hold the best weights, making an explicit load unnecessary if you are continuing from the current model object.

The most surprising thing is that restore_best_weights=True in EarlyStopping effectively makes ModelCheckpoint redundant if your sole goal is to have the best model available after training. The EarlyStopping callback, when configured to restore weights, will internally manage saving and restoring the best weights it encounters. It doesn’t write a physical file to disk by default like ModelCheckpoint does, but it keeps track of the best weights in memory. This means you can often achieve the same outcome with just EarlyStopping(restore_best_weights=True) and omit ModelCheckpoint entirely if you don’t need to save intermediate versions or the final best model to a file.

However, explicitly using ModelCheckpoint alongside EarlyStopping with restore_best_weights=True provides a robust safety net. ModelCheckpoint ensures a persistent, saved file of the best model, which can be crucial if your training environment is unstable or if you want to analyze the best model independently. It also allows you to save models based on different criteria than early stopping might monitor.

The next problem you’ll run into is managing the filepath for ModelCheckpoint across different runs or when you want to save multiple checkpoints (e.g., one per epoch, or just the best).

Want structured learning?

Take the full Tensorflow course →