vLLM can load your custom model architecture, but it’s not as simple as just pointing it at a PyTorch checkpoint. The core challenge is that vLLM’s optimized inference engine expects a specific structure for attention computations, namely PagedAttention. You need to ensure your custom model’s forward pass is compatible with this.
Here’s how you integrate your own architecture:
First, understand that vLLM doesn’t magically infer your model’s layers. You’ll need to subclass vllm.model.base_model.BaseModel and implement the necessary methods to tell vLLM how to process your model’s outputs. The critical method here is _forward_model, where you’ll perform the actual forward pass of your custom architecture.
Let’s say you have a custom transformer model. You’d start by creating a class like this:
from vllm.model.base_model import BaseModel
from vllm.model.weight_manager import WeightManager
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.attention.backends.xformers import XFormersAttention # Or other attention backend
logger = init_logger(__name__)
class MyCustomModel(BaseModel):
def __init__(self, model_config: ModelConfig, weight_manager: WeightManager):
super().__init__(model_config, weight_manager)
# Initialize your custom model architecture here.
# This would typically involve loading your PyTorch model weights.
# For demonstration, we'll assume `self.model` is your PyTorch nn.Module
# that has already been loaded with weights by the WeightManager.
# For example:
# self.model = YourCustomModelClass(...)
# self.weight_manager.load_weights(self.model)
# You MUST initialize the attention mechanism.
# vLLM's PagedAttention is critical for performance.
# If your model is a standard transformer, you'd use something like:
self.attention = XFormersAttention(self.model_config)
# If your model has a fundamentally different attention mechanism,
# you might need to implement a custom attention class.
# For most cases, adapting to PagedAttention is the way to go.
The WeightManager is crucial. It handles loading your model’s weights. You’ll pass your instantiated PyTorch model to weight_manager.load_weights(). The ModelConfig object will contain details like num_hidden_layers, hidden_size, num_attention_heads, etc., which vLLM uses for its internal optimizations.
The heart of the integration is the _forward_model method. This is where you take the input tensors and pass them through your model.
def _forward_model(
self,
input_ids: torch.Tensor,
*,
kv_cache: Optional[SeqData],
max_seq_len: int,
lora_key_name: Optional[str] = None,
lora_metadata: Optional[LoraMetadata] = None,
# Other arguments like attention_mask, position_ids, etc. if your model needs them.
) -> ForwardOutput:
# This is where you call your actual PyTorch model's forward pass.
# The exact inputs/outputs depend on your model architecture.
# For a standard transformer, it might look like this:
# 1. Prepare inputs for your model. This often involves embedding lookup
# and potentially adding positional encodings if not handled internally.
hidden_states = self.model.embed_tokens(input_ids) # Example for embedding
# 2. Iterate through your model's layers, passing hidden_states through.
# This is where PagedAttention integration is key.
# vLLM's attention mechanism will handle the KV cache.
# You need to feed the layer outputs to the attention module.
# Example for a transformer block:
for i in range(self.model_config.num_hidden_layers):
layer = self.model.layers[i] # Assuming your model has .layers attribute
# You need to pass the KV cache to the attention mechanism.
# The `attn_outputs` will contain the attention result and updated KV cache.
attn_outputs = self.attention(
hidden_states,
kv_cache,
layer_idx=i, # Crucial for PagedAttention to manage KV cache per layer
# Other arguments like attention_mask, position_ids might be needed
)
hidden_states = attn_outputs.hidden_states
kv_cache = attn_outputs.kv_cache # Update KV cache for next step
# Pass through feed-forward network (FFN)
hidden_states = layer(hidden_states) # Example, your FFN might be separate
# 3. Final layer normalization and output projection (e.g., LM head)
# This part is highly model-specific.
hidden_states = self.model.final_layernorm(hidden_states) # Example
logits = self.model.lm_head(hidden_states) # Example
# The return type `ForwardOutput` expects `logits` and potentially `kv_cache`.
return ForwardOutput(logits=logits, kv_cache=kv_cache)
The kv_cache object is vLLM’s mechanism for managing the Key-Value cache efficiently. When you call self.attention(...), it receives the current hidden_states and the kv_cache from the previous step. It computes the attention output and returns the updated kv_cache which you then pass to the next layer or the next token generation step. The layer_idx is vital for PagedAttention to correctly manage the KV cache blocks.
A key detail is how your custom model’s attention layers are exposed and how you interact with vLLM’s Attention backend. For standard transformer architectures, you’ll likely want to replace your model’s internal attention with vLLM’s PagedAttention implementation. This means your _forward_model will call self.attention instead of self.model.layers[i].attention. The self.attention object handles the PagedAttention logic. You’ll need to ensure your MyCustomModel class is initialized with a compatible Attention backend (e.g., XFormersAttention or PagedAttention if you’re not using xformers).
If your model has a non-standard architecture (e.g., MoE layers, different attention types), you might need to subclass vllm.attention.base_attention.Attention itself and implement the PagedAttention logic for your specific needs. However, for most common transformer variations, adapting the _forward_model to use vLLM’s existing Attention backends is the most straightforward path.
The WeightManager also needs to know how to map your model’s weight names to its internal tensor names. If your model’s weights aren’t named conventionally (e.g., transformer.h.0.attn.q_proj.weight), you might need to provide a custom mapping to the WeightManager or rename weights during loading.
Finally, when you instantiate your LLM object, you’ll need to specify your custom model class:
from vllm import LLM, SamplingParams
# Assuming MyCustomModel is defined and imported
model_path = "/path/to/your/custom/model"
llm = LLM(model=MyCustomModel, model_path=model_path)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
prompts = ["Write a poem about the ocean."]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
The model_path would point to your model’s directory containing the weights and configuration. vLLM’s ModelConfig will attempt to load configuration from standard locations (like config.json), and the WeightManager will then use this config to load weights. If your configuration is non-standard, you might need to manually construct ModelConfig or ensure your model directory has the necessary files.
The most surprising aspect of vLLM custom model integration is how tightly coupled the _forward_model method is to the Attention backend. You’re not just passing inputs to your model; you’re orchestrating the flow of tensors through vLLM’s optimized attention mechanism, ensuring the KV cache is correctly managed at each layer.
The next challenge you’ll likely encounter is optimizing the _forward_model for custom layer types or non-transformer architectures, which might require a deeper dive into vLLM’s Attention backend implementations.