Mixed precision training lets you use 16-bit floating-point numbers (fp16) instead of the usual 32-bit (fp32) for some operations during neural network training on GPUs.

Here’s how it looks in action with a simple Keras model:

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras.models import Sequential

# Define a simple model
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    Flatten(),
    Dense(10, activation='softmax')
])

# Enable mixed precision
tf.keras.mixed_precision.set_global_policy('mixed_float16')

# Compile the model
# Loss scaling is handled automatically by the mixed precision policy
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Dummy data for demonstration
import numpy as np
x_train = np.random.rand(100, 28, 28, 1).astype(np.float32)
y_train = np.random.randint(0, 10, 100)

# Train the model
model.fit(x_train, y_train, epochs=1)

# You can check the dtype of the weights
print(model.layers[0].weights[0].dtype)
print(model.layers[2].weights[0].dtype)

Output will show float16 for some weights and float32 for others.

The core problem mixed precision solves is the computational and memory bottleneck of fp32 on GPUs. GPUs, especially modern ones like NVIDIA’s Tensor Cores, are significantly faster at fp16 matrix multiplications. By using fp16 for the bulk of the training, you can achieve:

  • Faster Training: fp16 operations are up to 2x faster on Tensor Cores.
  • Reduced Memory Usage: fp16 takes half the memory of fp32, allowing larger batch sizes or bigger models.

Here’s the mental model:

  1. Automatic Casting: TensorFlow automatically casts operations to fp16 where it’s safe and beneficial (e.g., convolutions, matrix multiplications).
  2. Loss Scaling: To prevent gradients from becoming too small and underflowing to zero in fp16, a loss scaling factor is applied before backpropagation. The gradients are then scaled back down after they are computed. This is handled automatically by the mixed_float16 policy.
  3. FP32 Master Weights: For numerical stability, the master weights are often kept in fp32. Updates are applied to these fp32 weights, and then a copy is cast to fp16 for the forward and backward passes. This is also managed by the policy.
  4. FP16 for Intermediate Activations: Intermediate activations (outputs of layers) are also cast to fp16 to save memory.

The tf.keras.mixed_precision.set_global_policy('mixed_float16') call is the magic. It tells TensorFlow to:

  • Use float16 for weights where possible.
  • Cast inputs to layers to float16.
  • Perform computations (like Conv2D, matmul) in float16 if the GPU supports it.
  • Cast outputs of layers to float16.
  • Maintain master weights in float32 for updates.
  • Handle loss scaling automatically.

The optimizer (like Adam) will operate on the fp32 master weights, ensuring that gradients computed in fp16 are correctly used for updates without losing precision.

The key insight is that not all operations need fp32. Operations like summing gradients or specific reductions might still require fp32 for precision. The mixed_float16 policy intelligently selects which operations can be safely performed in fp16 and which need to retain fp32.

You can inspect which layers are using which dtypes by looking at the dtype attribute of their weights.

The next concept you’ll likely explore is how to manually control dtype for specific layers or how to implement custom loss scaling if you’re not using the high-level Keras API.

Want structured learning?

Take the full Tensorflow course →