The most surprising thing about serving embedding tables in Triton is that the most common performance bottleneck isn’t the model inference itself, but the way you load and access the embedding data.

Let’s see this in action. Imagine you have a recommendation model that uses user and item embeddings. These embeddings are huge – potentially gigabytes of data. You’re serving this model with Triton Inference Server.

Here’s a simplified Python client making a request:

import tritonclient.http as httpclient
import numpy as np

# Assume Triton is running on localhost:8000
client = httpclient.InferenceServerClient(url="localhost:8000")

# Example input: user_id and item_id
user_id = np.array([[12345]], dtype=np.int64)
item_id = np.array([[67890]], dtype=np.int64)

inputs = [
    httpclient.InferInput("user_id", user_id.shape, "INT64"),
    httpclient.InferInput("item_id", item_id.shape, "INT64"),
]

inputs[0].set_data_from_numpy(user_id)
inputs[1].set_data_from_numpy(item_id)

# Example output: combined embedding
outputs = [
    httpclient.InferOutput("user_embedding", [1, 128], "FP32"),
    httpclient.InferOutput("item_embedding", [1, 128], "FP32"),
]

# Perform inference
results = client.infer(
    model_name="my_recommender_model",
    inputs=inputs,
    outputs=outputs,
)

# Process results (e.g., concatenate embeddings, pass to downstream model)
user_emb = results.as_numpy("user_embedding")
item_emb = results.as_numpy("item_embedding")

print("User Embedding:", user_emb)
print("Item Embedding:", item_emb)

In this scenario, when my_recommender_model is invoked, Triton needs to fetch the corresponding embedding vectors for user_id=12345 and item_id=67890. This is where the magic (and potential pain) happens. Triton doesn’t inherently know how to "look up" these IDs in a massive embedding table. It relies on a mechanism to provide this data.

The core problem Triton solves here is bridging the gap between your model’s computational graph (which expects fixed-size tensors for inference) and your potentially massive, sparse embedding data. You can’t just load gigabytes of embeddings directly into the model’s input tensors because they’d be too large and most of it would be unused. Instead, you need a system that can efficiently retrieve specific embedding vectors based on input IDs.

Triton handles this through its Model Repository and specifically, by allowing custom Ensembles or Python Backends. For embedding tables, the most common and performant approach is using a Python Backend.

Here’s how a Python Backend typically works for embeddings:

  1. Model Definition: You define a Triton model (e.g., my_recommender_model) in your model repository. This model’s configuration points to a Python script.
  2. Python Script: This script contains the logic for handling requests.
    • It defines the model’s inputs (user_id, item_id) and outputs (user_embedding, item_embedding).
    • Crucially, it loads the embedding table data once when the model is initialized. This data could be in various formats: NumPy arrays, memory-mapped files, or even loaded into an in-memory database or key-value store.
    • When a request arrives, the script receives the input IDs. It then uses these IDs to perform a lookup in the loaded embedding table.
    • The retrieved embedding vectors are then formatted as NumPy arrays and returned as Triton model outputs.

Let’s look at a snippet of what that Python Backend script might look like:

# my_recommender_model/model.py
import numpy as np
import triton_python_backend.common as pb_common

class TritonPythonModel:
    def initialize(self, config):
        # Load embedding table ONCE during initialization
        # This is the critical performance step.
        # Example: Loading from a large NumPy file
        self.embedding_table = np.load("embeddings.npy") # Shape: (num_unique_ids, embedding_dim)
        self.embedding_dim = self.embedding_table.shape[1]

        # Other potential loading mechanisms:
        # - memory mapping: np.memmap("embeddings.mmap", dtype=np.float32, mode='r', shape=(N, D))
        # - database/KV store: redis_client.get("embedding_key")

    def execute(self, requests):
        responses = []
        for request in requests:
            user_ids = pb_common.get_input_tensor_by_name(request, "user_id").as_numpy()
            item_ids = pb_common.get_input_tensor_by_name(request, "item_id").as_numpy()

            # Perform lookups
            # user_ids and item_ids are likely 2D arrays, e.g., [[12345], [67890]]
            # We need to flatten them for indexing if they come in batches
            user_embeddings = self.embedding_table[user_ids.flatten()]
            item_embeddings = self.embedding_table[item_ids.flatten()]

            # Format outputs
            user_embedding_output = pb_common.Tensor(
                name="user_embedding",
                dims=user_embeddings.shape,
                datatype=pb_common.np_to_triton_dtype(user_embeddings.dtype)
            )
            user_embedding_output.set_user_data(user_embeddings.tobytes())

            item_embedding_output = pb_common.Tensor(
                name="item_embedding",
                dims=item_embeddings.shape,
                datatype=pb_common.np_to_triton_dtype(item_embeddings.dtype)
            )
            item_embedding_output.set_user_data(item_embeddings.tobytes())

            responses.append(pb_common.InferResponse([user_embedding_output, item_embedding_output]))
        return responses

    def finalize(self):
        pass # Clean up resources if necessary

The key here is that np.load("embeddings.npy") or any other loading mechanism happens once in initialize. Subsequent execute calls directly access the already-loaded data. This avoids the massive overhead of reading from disk or network for every single request.

The most counterintuitive part of optimizing this is realizing that the np.load or initial data loading can itself be slow if the embedding table is truly massive and unoptimized. If you’re loading a 100GB .npy file, initialize might take minutes, which is unacceptable for dynamic model loading. This is why techniques like memory mapping (np.memmap) or using specialized key-value stores (like RocksDB or Redis) that can serve individual keys extremely quickly are often preferred over monolithic .npy files. The Python backend then acts as a thin layer to fetch from these optimized stores.

If you’re serving a model that requires embeddings, the next challenge you’ll face is managing multiple, potentially overlapping, embedding tables for different models or different types of features within the same model.

Want structured learning?

Take the full Triton course →