TensorFlow’s XLA compiler can make your TPU training run astonishingly faster, but it’s not a magic bullet; it’s a sophisticated tool that requires understanding.

Let’s see XLA in action. Imagine we’re training a simple BERT model, and we want to measure the speed difference with and without XLA.

import tensorflow as tf
import time

# Dummy data and model for demonstration
input_ids = tf.random.uniform((32, 128), maxval=30522, dtype=tf.int32)
attention_mask = tf.ones((32, 128), dtype=tf.int32)
token_type_ids = tf.zeros((32, 128), dtype=tf.int32)

# Load a pre-trained BERT model (or a simplified version)
# For simplicity, we'll use a small custom model that mimics BERT's structure
def build_bert_like_model(vocab_size=30522, hidden_size=768, num_hidden_layers=2, num_attention_heads=12, intermediate_size=3072, max_position_embeddings=512, type_vocab_size=2):
    inputs = dict(
        input_ids=tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name="input_ids"),
        attention_mask=tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name="attention_mask"),
        token_type_ids=tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name="token_type_ids"),
    )

    embedding_layer = tf.keras.layers.Embedding(vocab_size, hidden_size, name="embeddings/word_embeddings")
    position_embedding_layer = tf.keras.layers.Embedding(max_position_embeddings, hidden_size, name="embeddings/position_embeddings")
    token_type_embedding_layer = tf.keras.layers.Embedding(type_vocab_size, hidden_size, name="embeddings/token_type_embeddings")

    word_embeddings = embedding_layer(inputs["input_ids"])
    position_embeddings = position_embedding_layer(tf.range(tf.shape(inputs["input_ids"])[1], dtype=tf.int32))
    token_type_embeddings = token_type_embedding_layer(inputs["token_type_ids"])

    embeddings = tf.keras.layers.Add()(
        [word_embeddings, position_embeddings, token_type_embeddings]
    )
    embeddings = tf.keras.layers.LayerNormalization(epsilon=1e-12)(embeddings)
    embeddings = tf.keras.layers.Dropout(0.1)(embeddings)

    x = embeddings
    for _ in range(num_hidden_layers):
        # Simplified Transformer Block
        attention_output = tf.keras.layers.MultiHeadAttention(
            num_heads=num_attention_heads, key_dim=hidden_size // num_attention_heads
        )(x, x, attention_mask=inputs["attention_mask"])
        attention_output = tf.keras.layers.Dropout(0.1)(attention_output)
        attention_output = tf.keras.layers.LayerNormalization(epsilon=1e-12)(x + attention_output)

        intermediate_output = tf.keras.layers.Dense(intermediate_size, activation="gelu", name="intermediate")(attention_output)
        layer_output = tf.keras.layers.Dense(hidden_size, name="output")(intermediate_output)
        layer_output = tf.keras.layers.Dropout(0.1)(layer_output)
        layer_output = tf.keras.layers.LayerNormalization(epsilon=1e-12)(attention_output + layer_output)
        x = layer_output

    outputs = tf.keras.layers.Dense(vocab_size, activation='softmax', name='predictions')(x[:, 0, :]) # Predict next token for CLS token
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

# Instantiate the model
model = build_bert_like_model()
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()

# Compile the model
model.compile(optimizer=optimizer, loss=loss_fn)

# --- Without XLA ---
print("--- Training without XLA ---")
start_time = time.time()
for _ in range(5): # Run a few steps
    with tf.GradientTape() as tape:
        predictions = model({
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
        })
        # Dummy labels for loss calculation
        labels = tf.random.uniform((32,), maxval=30522, dtype=tf.int32)
        loss = loss_fn(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
end_time = time.time()
print(f"Time taken without XLA: {end_time - start_time:.4f} seconds\n")

# --- With XLA ---
print("--- Training with XLA ---")
# Enable XLA globally
tf.config.optimizer.set_jit(True)

start_time = time.time()
for _ in range(5): # Run a few steps
    with tf.GradientTape() as tape:
        predictions = model({
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
        })
        labels = tf.random.uniform((32,), maxval=30522, dtype=tf.int32)
        loss = loss_fn(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
end_time = time.time()
print(f"Time taken with XLA: {end_time - start_time:.4f} seconds")

# Disable XLA again if needed for subsequent operations
tf.config.optimizer.set_jit(False)

The most surprising truth about TensorFlow’s XLA (Accelerated Linear Algebra) compiler is that it doesn’t just "run faster"; it fundamentally rewrites your TensorFlow graph to execute operations more efficiently, often fusing them into single, highly optimized kernels.

When you enable XLA, TensorFlow doesn’t just execute your tf.function as is. Instead, it takes the computation graph defined within that function, analyzes it, and then compiles it into a highly optimized executable specific to your hardware (like TPUs). This compilation process involves several key optimizations:

  1. Operator Fusion: XLA can fuse multiple operations (like element-wise additions, multiplications, and activations) into a single, larger operation. Instead of loading data into memory, performing an operation, storing it, loading it again for the next operation, and so on, XLA can perform all these steps within a single kernel on the accelerator, drastically reducing memory bandwidth bottlenecks and kernel launch overhead.
  2. Constant Folding: Computations that involve only constants are performed at compile time, not runtime.
  3. Memory Layout Optimization: XLA can rearrange data in memory to be more efficient for the specific operations being performed.
  4. Parallelism Strategy: It determines the best way to parallelize operations across the available cores of your TPU.

This means that the first time your tf.function runs with XLA enabled, there’s a compilation overhead. Subsequent runs, however, should be significantly faster if XLA could find substantial optimizations.

Here’s how you typically enable it for TPUs:

import tensorflow as tf

# For TPU execution, you usually need to wrap your model or training loop
# with a TPUStrategy. XLA is often enabled by default for TPUs,
# but you can explicitly control it.

# Explicitly enable JIT compilation (which XLA implements)
tf.config.optimizer.set_jit(True)

# Or, more granularly, for specific functions:
@tf.function(jit_compile=True)
def train_step_xla(inputs, labels, model, optimizer):
    with tf.GradientTape() as tape:
        predictions = model(inputs)
        loss = loss_fn(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

# When using TPUs, you'd typically have something like:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

with strategy.scope():
    model = build_bert_like_model(...) # Your model
    optimizer = tf.keras.optimizers.Adam(...)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(...)
    # Compile the model if using Keras API
    model.compile(optimizer=optimizer, loss=loss_fn)

# Then, when you call model.fit() or a custom training loop,
# XLA will be applied by default on TPUs if enabled globally or via jit_compile=True.

The core idea is that XLA transforms a series of small, independent TensorFlow operations into a single, large, fused computation. For example, a sequence like:

y = tf.nn.relu(tf.matmul(x, w) + b)

might be compiled by XLA into a single kernel that performs the matrix multiplication, adds the bias, and applies the ReLU activation all in one go, without materializing intermediate results to memory. This is especially powerful on hardware like TPUs that have specialized matrix multiplication units and high memory bandwidth.

The compilation process itself can be a bottleneck. If your model is very small or your dataset is tiny, the overhead of XLA compilation might outweigh the runtime benefits, leading to slower execution. XLA shines when you have computationally intensive graphs that can be significantly optimized through fusion and parallelization.

One aspect that often surprises people is that not all TensorFlow operations are perfectly supported or optimized by XLA. While XLA’s coverage is extensive, there are still some ops that might cause XLA to "break" the computation graph, forcing it to fall back to standard TensorFlow execution for those specific parts. This can happen with certain custom operations, dynamic control flow that XLA can’t statically analyze, or specific TF features. When this happens, you’ll often see a warning or message indicating that a subgraph was not compiled by XLA. You can use tf.config.optimizer.get_experimental_options() to inspect XLA’s status and tf.debugging.print_v2 within XLA-compiled functions can sometimes be tricky.

Ultimately, XLA is about transforming your model’s computation into a more hardware-friendly, fused representation. The next step after optimizing your training speed with XLA is often delving into how to profile XLA’s performance to pinpoint exactly which parts of your graph are being fused and which are not.

Want structured learning?

Take the full Tensorflow course →