MirroredStrategy is the simplest way to get started with distributed training in TensorFlow, but it’s surprisingly easy to set up incorrectly, leading to silent performance degradation or outright failure.
Let’s see it in action. Imagine you have a standard Keras model:
import tensorflow as tf
def build_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),
tf.keras.layers.Dense(10, activation='softmax')
])
return model
model = build_model()
To distribute this model across multiple GPUs on a single machine using MirroredStrategy, you’d do this:
# Detect and use all available GPUs
strategy = tf.distribute.MirroredStrategy()
# Build the model and optimizer within the strategy's scope
with strategy.scope():
mirrored_model = build_model()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
mirrored_model.compile(optimizer=optimizer,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Dummy data for demonstration
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
y_train = y_train.astype('int32')
# Prepare dataset for distributed training
BATCH_SIZE_PER_REPLICA = 64
global_batch_size = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(10000).repeat().batch(global_batch_size)
# Train the model
epochs = 1
steps_per_epoch = len(x_train) // global_batch_size
mirrored_model.fit(train_dataset, epochs=epochs, steps_per_epoch=steps_per_epoch)
The core idea is that MirroredStrategy creates a replica of your model on each available GPU. During training, gradients are computed on each replica, then synchronized across all replicas (typically averaged), and finally applied to update the model weights on each replica. This ensures all model replicas remain identical.
The strategy.scope() is crucial. Any variables created within this context (like your model’s layers and the optimizer) will be automatically placed on the appropriate devices managed by the strategy. This includes creating multiple copies of the model’s variables, one for each GPU.
The global_batch_size is also key. You specify the batch size per replica, and TensorFlow calculates the total batch size needed for all replicas. This ensures that the effective batch size seen by the model during training is consistent regardless of the number of GPUs.
The real magic happens in the compile and fit methods. When you call compile within the strategy.scope(), TensorFlow wraps your optimizer and loss function to handle the gradient synchronization. During fit, the data is distributed to each replica, forward and backward passes occur in parallel, gradients are aggregated, and weights are updated.
A common pitfall is forgetting to use strategy.scope() when creating your model or optimizer. If you create them outside the scope, they’ll be placed on the default device (usually CPU), and the MirroredStrategy won’t have any replicas to manage, effectively running on a single device.
Another subtle issue is with custom training loops. If you’re not using model.fit, you’ll need to manually distribute your dataset and synchronize gradients using tf.distribute.Strategy.run and tf.distribute.get_replica_context. This involves fetching a replica-local batch and then using tf.nn.replica_view to access the globally synchronized gradients.
The tf.distribute.get_replica_context().all_reduce is the fundamental operation for synchronizing gradients. By default, MirroredStrategy uses SUM_OVER_RE все reduction, which averages the gradients across all replicas. Understanding this reduction strategy is vital if you ever need to implement custom gradient aggregation logic.
When you start to see performance gains plateau or even decrease with more GPUs, it’s often a sign that communication overhead is becoming a bottleneck. This can happen if your model is very small, or if the data transfer between GPUs and CPU is slow. In such cases, you might need to investigate more advanced strategies or optimize your data pipeline.
Once you’ve successfully set up MirroredStrategy and your training is running across multiple GPUs, the next logical step is to explore training across multiple machines, which typically involves MultiWorkerMirroredStrategy.