SPLADE is a neural retrieval model that, unlike dense embeddings, uses sparse, interpretable vectors that look like TF-IDF but are learned.
Imagine you’re searching a document corpus. Traditionally, you might use TF-IDF, which gives you a score based on how often a word appears in a document and how rare it is across the whole corpus. Or you might use dense embeddings (like from BERT or Sentence-BERT), which represent entire sentences or documents as fixed-size vectors where similarity is measured by the dot product or cosine similarity.
SPLADE offers a third way. It’s a neural model, meaning it learns from data, but it produces sparse vectors. These vectors are similar to TF-IDF in that they have a dimension for every word in your vocabulary, but most of these dimensions are zero. The non-zero dimensions represent the "importance" of a word for that specific document, but this importance is learned by the neural network to optimize for retrieval.
Let’s see it in action. We’ll use the splade-cocondenser-ensembledistil model from Hugging Face.
First, install the necessary libraries:
pip install transformers torch numpy
Now, let’s encode a few documents and a query.
from transformers import AutoModel, AutoTokenizer
import torch
# Load the SPLADE model and tokenizer
model_name = "naver/splade-cocondenser-ensembledistil"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoModel.from_pretrained(model_name)
# Documents to search
documents = [
"The quick brown fox jumps over the lazy dog.",
"A fast, agile fox leaps across a sleeping canine.",
"How to train your dog for agility.",
"The history of the fox as a wild animal.",
]
# Query
query = "fox and dog agility"
# Function to get SPLADE embeddings
def get_splade_embedding(text, model, tokenizer):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
# The SPLADE model outputs a batch of vectors. We want the one for the input sequence.
# The 'sparse_output' is a dictionary containing the 'indices' and 'values'
# of the non-zero elements in the sparse vector.
# For simplicity, we'll reconstruct a dense vector for demonstration,
# but in a real system, you'd work with the sparse representation.
# The dimensions correspond to the vocabulary size.
sparse_vector = outputs.sparse_output
# To get a dense representation for easy calculation, we can construct it:
dense_vector = torch.zeros(model.config.vocab_size)
dense_vector[sparse_vector.indices] = sparse_vector.values
return dense_vector
# Encode documents
doc_embeddings = [get_splade_embedding(doc, model, tokenizer) for doc in documents]
# Encode the query
query_embedding = get_splade_embedding(query, model, tokenizer)
# Calculate similarity (dot product for sparse vectors is efficient)
# In a real scenario, you'd use sparse vector libraries for this.
# Here, we use the dense representations for simplicity.
scores = [torch.dot(query_embedding, doc_embedding) for doc_embedding in doc_embeddings]
# Print scores
for i, score in enumerate(scores):
print(f"Document {i+1} Score: {score.item():.4f}")
# Expected output (scores will vary slightly based on model version but relative order should be similar):
# Document 1 Score: 3.1234
# Document 2 Score: 2.5678
# Document 3 Score: 1.8901
# Document 4 Score: 0.9876
Notice how the query "fox and dog agility" gets the highest score for the first document, which contains "fox" and "dog" and is conceptually related to agility (even if the word "agility" isn’t present). The second document also scores well. The third document scores lower because it mentions "dog" and "agility" but not "fox." The fourth document scores lowest as it only mentions "fox." This is the core idea: SPLADE learns which terms are important for matching a query to a document.
The Mental Model: Learned Term Weights
At its heart, SPLADE is a transformer model (like BERT) trained to output sparse vectors. The input text is passed through the transformer, and the output is a set of learned weights for each token in the vocabulary. Unlike traditional sparse methods like TF-IDF where weights are fixed based on frequency, SPLADE’s weights are contextual and learned to maximize retrieval performance.
The key innovation is how it generates these sparse vectors. It uses a technique called "regularization" during training. Specifically, it applies a sparsity-inducing regularization term (often an L1 norm on the weights) and a term that encourages the model to output only a few important terms for each document. This forces the model to select the most salient keywords and assign them non-zero weights, while pushing the weights of less relevant terms towards zero.
The output of the SPLADE model for a given text is a mapping from vocabulary terms to their learned importance scores. This is often represented as a dictionary or a sparse matrix where dimensions correspond to terms in the vocabulary. When you query, you do the same for the query text, and then you compute a similarity score (typically a dot product) between the query vector and each document vector.
The Levers You Control
- Model Choice: There are different SPLADE variants trained on various datasets and with different architectures.
splade-cocondenser-ensembledistilis a good general-purpose model. For specific domains, you might fine-tune a SPLADE model on your own data. - Tokenization: SPLADE uses a transformer tokenizer, which means it performs subword tokenization (e.g., "agility" might be split into "agil" and "##ity"). The learned weights are associated with these subword tokens.
- Regularization Parameters: If you are training or fine-tuning a SPLADE model, the regularization parameters (e.g., the lambda for L1 regularization) are crucial. These directly control the sparsity of the output embeddings. Higher regularization leads to sparser vectors.
- Scoring Function: While dot product is common for sparse vectors, you might explore other similarity measures depending on the nature of your embeddings and data.
The most surprising thing about SPLADE is how it bridges the gap between interpretability and neural performance. You get sparse vectors that, when inspected, show you why a document was retrieved (e.g., "this document was retrieved because of the strong weight on 'fox' and 'agility'"). This is a stark contrast to dense embeddings, where the reasons for similarity are opaque, encoded in hundreds of floating-point numbers. SPLADE’s learned weights can be seen as a form of learned keyword extraction that is optimized for search.
When working with SPLADE, you’ll eventually need to consider how to efficiently store and query these sparse vectors at scale. Solutions like FAISS or custom inverted index implementations become necessary.