Triton’s PyTorch TorchScript backend is a surprisingly efficient way to serve PyTorch models, but its real power lies in how it lets you bypass Python entirely at inference time.

Let’s see it in action. Imagine you have a simple PyTorch model trained to classify MNIST digits.

import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(9216, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return torch.log_softmax(x, dim=1)

model = SimpleCNN()
# Assume model is already trained

Now, to deploy this with Triton, we need to convert it to TorchScript. This "freezes" the model’s computation graph, making it portable and independent of the original Python code.

# Save the model as TorchScript
scripted_model = torch.jit.script(model)
scripted_model.save("simple_cnn.pt")

This simple_cnn.pt file is what Triton will load. To set up Triton to serve this, you create a config.pbtxt file in your model repository.

name: "simple_cnn"
platform: "pytorch_libtorch"
max_batch_size: 8
input [
  {
    name: "input__0"
    data_type: TYPE_FP32
    dims: [ 1, 28, 28 ]
  }
]
output [
  {
    name: "output__0"
    data_type: TYPE_FP32
    dims: [ 10 ]
  }
]

The platform: "pytorch_libtorch" tells Triton to use its TorchScript backend. The input and output definitions specify the expected tensor names, data types, and shapes. Triton will automatically infer batching based on the max_batch_size and the input/output dimensions.

When a client sends a request, Triton doesn’t execute any Python code. It directly loads the simple_cnn.pt file using the LibTorch C++ runtime, feeds the input tensor, and returns the output tensor. This eliminates Python interpreter overhead, GIL contention, and dependencies, leading to significantly lower latency and higher throughput, especially for models with complex control flow or large numbers of small operations.

The torch.jit.script function traces the model’s execution and generates an intermediate representation (IR) that LibTorch can understand. This IR is what gets serialized into the .pt file. For more complex models or those using dynamic Python features, torch.jit.trace might be an alternative, though script generally produces more robust and optimizable TorchScript. Triton leverages LibTorch’s just-in-time (JIT) compiler to further optimize this IR before inference.

The "magic" happens in how Triton orchestrates the LibTorch C++ library. When you specify pytorch_libtorch as the platform, Triton initializes a LibTorch runtime. It then loads your *.pt file into this runtime. The input and output tensor names specified in config.pbtxt are crucial because they map directly to the tensor names expected by your TorchScript module’s forward method. Triton handles the serialization and deserialization of data between its internal tensor format and the format expected by LibTorch, allowing seamless integration without any Python glue code.

One common point of confusion is the naming of input and output tensors. If your TorchScript model’s forward method expects arguments named x and returns a tensor named y, and you don’t explicitly rename them during TorchScript conversion, Triton will typically expect default names like input__0 and output__0. You can verify the exact names by inspecting your TorchScript module in Python using print(scripted_model.graph). The names appearing in the graph’s prim::GetAttr or aten:: operations are what Triton will map to.

Understanding how Triton handles different PyTorch model types and the nuances of TorchScript compilation is key to unlocking its full potential.

Want structured learning?

Take the full Triton course →