Triton’s CUDA Graph optimization is fundamentally about making your GPU kernels run without the overhead of repeatedly launching them, but it’s not just about speed; it’s about predictable latency.
Imagine you have a tiny, lightning-fast kernel that does a single element-wise addition. On a modern GPU, the overhead of telling the GPU to run that kernel can be larger than the kernel itself. CUDA Graphs capture the entire sequence of kernel launches and memory operations into a single, static graph that the GPU can execute much more efficiently, often with near-zero launch overhead.
Here’s a simple example. Let’s say we have a Triton kernel for element-wise addition:
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, y_ptr, output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
def encode(x, y, output, n_elements, block_size=1024):
grid = (triton.cdiv(n_elements, block_size),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=block_size)
# In your Python code:
import torch
N = 1000000
x = torch.randn(N, device='cuda')
y = torch.randn(N, device='cuda')
output = torch.empty(N, device='cuda')
# Without CUDA Graphs:
# Many launches would look like this in a loop:
# encode(x, y, output, N)
# output = output + 1 # Imagine another kernel
# encode(x, y, output, N)
# With CUDA Graphs:
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# Record the graph
graph = torch.cuda.CUDAGraph()
graph.capture_begin()
encode(x, y, output, N)
# Add more operations here that will be part of the graph
output.mul_(2) # Example of another operation
graph.capture_end()
# Now, replay the graph
# The first replay will compile the graph
graph.replay()
# Subsequent replays will be very fast
for _ in range(100):
graph.replay()
# The output tensor will contain the result of (x + y) * 2
The core idea is that instead of the Python interpreter and CUDA driver setting up a new launch command for add_kernel every single time, the entire sequence of operations (kernel launches, memory copies, etc.) is captured once. This captured sequence, the CUDA Graph, is then executed by the GPU directly, bypassing much of the host-side overhead.
How Triton leverages this:
Triton kernels, when compiled, produce PTX code that can be launched by the CUDA runtime. The triton.compile function returns a Python callable that, when invoked with appropriate arguments and grid configuration, issues the kernel launch. When you use Triton kernels within a captured CUDA Graph, it’s these launch commands that get recorded.
The torch.cuda.CUDAGraph API is the primary way to achieve this. You capture_begin() on a specific CUDA stream, execute your Triton-encoded operations, capture_end(), and then replay() the graph. The replay() operation is what benefits from the graph optimization.
The mental model:
Think of a CUDA Graph as a pre-recorded movie of your GPU operations. Instead of telling the actor (GPU) what to do step-by-step every time (launch kernel A, launch kernel B), you just hit "play" on the movie. The GPU then executes the recorded sequence of actions without needing further instructions from the CPU for each individual step.
The key levers you control are:
- What gets captured: Any valid CUDA operations (including Triton kernel launches,
torch.Tensoroperations that dispatch to CUDA, memory copies) executed betweengraph.capture_begin()andgraph.capture_end()on the same stream are recorded. - When it’s replayed:
graph.replay()executes the captured operations. The first replay might involve some compilation/initialization, but subsequent replays are extremely fast. - Data dependency: If the inputs to the graph change (e.g., different tensors, different data values), the graph will still use the same sequence of operations but on the new data. This is where the efficiency comes from – the computation structure is static, but the data can be dynamic.
- Stream isolation: Graphs are tied to the stream on which they were captured. If you need to run operations on multiple streams concurrently or independently, you’ll need separate graphs or careful synchronization.
The one thing most people don’t know:
CUDA Graphs are incredibly effective for operations that have a fixed computational structure but varying input data. However, they do not re-record operations if the structure of the computation changes (e.g., conditional branches that take different paths based on input values, dynamic kernel launch configurations). If your Triton kernel has conditional logic that depends on input data, that conditional path will be recorded once during capture, and the graph will always execute that specific path during replay, regardless of subsequent input data. This can lead to incorrect results if not carefully managed. For truly dynamic computation graphs, you’d need to re-capture the graph or avoid graphs altogether.
The next frontier you’ll likely encounter is managing graph state across multiple replays and understanding how to efficiently update the data that the graph operates on without re-capturing.