Re-ranking is a crucial technique in multistage retrieval pipelines that significantly improves precision. In this lesson, we'll explore re-ranking models, learning-to-rank paradigms, and production implementation patterns.
Initial retrieval (first stage) optimizes for recall - finding all potentially relevant documents quickly. Re-ranking (second stage) optimizes for precision - putting the best documents at the top.
Query
│
▼
┌─────────────────────────────┐
│ Stage 1: Initial Retrieval │
│ (Bi-encoder, BM25) │
│ Fast, high recall │
│ Returns 50-500 candidates │
└─────────────────────────────┘
│
▼
┌─────────────────────────────┐
│ Stage 2: Re-ranking │
│ (Cross-encoder, LLM) │
│ Slower, high precision │
│ Returns top 5-10 │
└─────────────────────────────┘
│
▼
Final Results| Aspect | Bi-encoder | Cross-encoder |
|---|---|---|
| Architecture | Encode query and doc separately | Encode query-doc pair together |
| Speed | Fast (documents pre-encoded) | Slow (must encode pairs at query time) |
| Quality | Good | Excellent (captures fine-grained interactions) |
| Use Case | Initial retrieval (millions of docs) | Re-ranking (50-500 candidates) |
Cross-encoders process query and document together, enabling rich attention between tokens from both. This captures nuances that bi-encoders miss.
Input: [CLS] query tokens [SEP] document tokens [SEP]
│
▼
┌──────────────┐
│ Transformer │
│ (BERT-like) │
└──────────────┘
│
▼
[CLS] embedding
│
▼
Linear layer → Relevance scorefrom sentence_transformers import CrossEncoder
class CrossEncoderReranker:
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.model = CrossEncoder(model_name)
def rerank(
self,
query: str,
documents: list[dict],
top_k: int = 5
) -> list[dict]:
"""
Re-rank documents by relevance to query.
Args:
query: The search query
documents: List of dicts with 'text' and other fields
top_k: Number of top results to return
Returns:
Reranked documents with added 'rerank_score'
"""
if not documents:
return []
# Create query-document pairs
pairs = [[query, doc["text"]] for doc in documents]
# Get relevance scores
scores = self.model.predict(pairs)
# Add scores to documents
for doc, score in zip(documents, scores):
doc["rerank_score"] = float(score)
# Sort by rerank score
ranked = sorted(documents, key=lambda x: x["rerank_score"], reverse=True)
return ranked[:top_k]
# Usage
reranker = CrossEncoderReranker()
# First stage: retrieve candidates
candidates = retriever.search("What is machine learning?", k=50)
# Second stage: rerank
final_results = reranker.rerank("What is machine learning?", candidates, top_k=5)| Model | Size | Notes |
|---|---|---|
| cross-encoder/ms-marco-MiniLM-L-6-v2 | 22M | Fast, good baseline |
| cross-encoder/ms-marco-MiniLM-L-12-v2 | 33M | Better quality, still fast |
| BAAI/bge-reranker-large | 335M | High quality, multilingual |
| BAAI/bge-reranker-v2-m3 | 568M | Latest, multilingual, long context |
# Cohere Rerank
import cohere
co = cohere.Client(api_key="your-api-key")
def cohere_rerank(query: str, documents: list[str], top_k: int = 5):
response = co.rerank(
model="rerank-english-v3.0",
query=query,
documents=documents,
top_n=top_k,
return_documents=True
)
return [
{
"text": result.document.text,
"score": result.relevance_score,
"index": result.index
}
for result in response.results
]
# Jina Reranker
import requests
def jina_rerank(query: str, documents: list[str], top_k: int = 5):
response = requests.post(
"https://api.jina.ai/v1/rerank",
headers={"Authorization": "Bearer your-api-key"},
json={
"model": "jina-reranker-v2-base-multilingual",
"query": query,
"documents": documents,
"top_n": top_k
}
)
return response.json()["results"]ColBERT uses a "late interaction" mechanism that's more efficient than full cross-encoders while maintaining high quality.
import torch
def maxsim_score(query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor) -> float:
"""
ColBERT MaxSim scoring.
For each query token, find its maximum similarity to any document token.
Sum these maximum similarities.
Args:
query_embeddings: [num_query_tokens, dim]
doc_embeddings: [num_doc_tokens, dim]
Returns:
Relevance score
"""
# Compute all pairwise similarities
# [query_tokens, doc_tokens]
similarities = torch.matmul(query_embeddings, doc_embeddings.T)
# For each query token, take max over document tokens
max_sims, _ = similarities.max(dim=1)
# Sum of max similarities
return max_sims.sum().item()
# Example with ColBERT
from colbert import Indexer, Searcher
from colbert.infra import ColBERTConfig
config = ColBERTConfig(
nbits=2,
doc_maxlen=300,
query_maxlen=32
)
# Index documents
indexer = Indexer(checkpoint="colbert-ir/colbertv2.0", config=config)
indexer.index(name="my_index", collection=documents)
# Search with ColBERT
searcher = Searcher(index="my_index")
results = searcher.search("What is machine learning?", k=10)Train a model to predict absolute relevance score for each document.
# Binary classification: relevant or not
loss = binary_cross_entropy(predicted_score, actual_relevance)
# Regression: predict graded relevance (0-4)
loss = mse(predicted_score, actual_relevance_grade)Train model to correctly order pairs of documents.
# Given (query, doc_positive, doc_negative)
# Ensure score(doc_positive) > score(doc_negative)
def pairwise_loss(score_pos, score_neg, margin=1.0):
"""Margin ranking loss."""
return torch.relu(margin - score_pos + score_neg)
# Or using cross-entropy on softmax
def pairwise_softmax_loss(score_pos, score_neg):
scores = torch.stack([score_pos, score_neg])
labels = torch.tensor([0]) # Index of positive
return cross_entropy(scores.unsqueeze(0), labels)Optimize the entire ranking list at once.
def listwise_loss(predicted_scores, true_relevances):
"""
ListMLE: Maximize likelihood of correct ranking.
"""
# Sort by true relevance
sorted_indices = torch.argsort(true_relevances, descending=True)
sorted_scores = predicted_scores[sorted_indices]
# Compute log probability of this ordering
log_probs = []
for i in range(len(sorted_scores)):
remaining = sorted_scores[i:]
log_prob = sorted_scores[i] - torch.logsumexp(remaining, dim=0)
log_probs.append(log_prob)
return -sum(log_probs)Use LLMs for re-ranking via relevance scoring or listwise ranking:
def llm_score_document(query: str, document: str, llm) -> float:
"""Use LLM to score relevance."""
prompt = f"""On a scale of 1-10, rate how relevant this document is to the query.
Respond with ONLY a number.
Query: {query}
Document: {document}
Relevance score:"""
response = llm.invoke(prompt)
try:
return float(response.content.strip())
except:
return 0.0
def llm_rerank_pointwise(query: str, documents: list[dict], llm, top_k: int = 5):
"""Re-rank using LLM pointwise scoring."""
for doc in documents:
doc["llm_score"] = llm_score_document(query, doc["text"], llm)
ranked = sorted(documents, key=lambda x: x["llm_score"], reverse=True)
return ranked[:top_k]def llm_rerank_listwise(query: str, documents: list[dict], llm, top_k: int = 5):
"""Use LLM to rank documents directly."""
# Format documents with indices
doc_list = "\n".join([
f"[{i}] {doc['text'][:500]}..."
for i, doc in enumerate(documents)
])
prompt = f"""Given the query and documents below, rank the documents from
most relevant to least relevant. Return ONLY the document indices in order,
separated by commas.
Query: {query}
Documents:
{doc_list}
Ranking (most relevant first):"""
response = llm.invoke(prompt)
# Parse ranking
try:
indices = [int(x.strip()) for x in response.content.split(",")]
ranked = [documents[i] for i in indices if i < len(documents)]
return ranked[:top_k]
except:
return documents[:top_k]
# Pairwise tournament for more reliable LLM ranking
def llm_tournament_rerank(query: str, documents: list[dict], llm, top_k: int = 5):
"""Use pairwise comparisons in tournament style."""
def compare(doc_a: dict, doc_b: dict) -> int:
prompt = f"""Which document is more relevant to the query?
Answer with ONLY 'A' or 'B'.
Query: {query}
Document A: {doc_a['text'][:300]}
Document B: {doc_b['text'][:300]}
More relevant:"""
response = llm.invoke(prompt).content.strip().upper()
return -1 if response == "A" else 1
# Sort using LLM comparison
from functools import cmp_to_key
ranked = sorted(documents, key=cmp_to_key(compare))
return ranked[:top_k]Modern re-rankers can follow custom instructions for domain-specific ranking:
from FlagEmbedding import FlagReranker
# BGE Reranker with instructions
reranker = FlagReranker(
"BAAI/bge-reranker-v2-gemma",
use_fp16=True
)
# Custom instruction for legal domain
instruction = "Prioritize legal precedents from higher courts. Prefer recent cases over older ones when relevance is similar."
# Rerank with instruction
pairs = [[f"{instruction}\n{query}", doc["text"]] for doc in documents]
scores = reranker.compute_score(pairs)
# Sort by score
ranked = sorted(zip(scores, documents), reverse=True, key=lambda x: x[0])
final_docs = [doc for _, doc in ranked[:top_k]]from datetime import datetime
class MetadataReranker:
"""Combine semantic relevance with metadata signals."""
def __init__(self, semantic_weight: float = 0.7):
self.semantic_weight = semantic_weight
self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
def compute_recency_boost(self, doc_date: str, max_boost: float = 0.2) -> float:
"""More recent documents get higher scores."""
try:
doc_datetime = datetime.fromisoformat(doc_date)
days_old = (datetime.now() - doc_datetime).days
# Linear decay over 365 days
return max_boost * max(0, 1 - days_old / 365)
except:
return 0.0
def compute_authority_boost(self, doc_metadata: dict, max_boost: float = 0.1) -> float:
"""Boost authoritative sources."""
authority_sources = {"official", "verified", "primary"}
if doc_metadata.get("source_type") in authority_sources:
return max_boost
return 0.0
def rerank(
self,
query: str,
documents: list[dict],
top_k: int = 5
) -> list[dict]:
# Get semantic scores
pairs = [[query, doc["text"]] for doc in documents]
semantic_scores = self.cross_encoder.predict(pairs)
# Normalize semantic scores to 0-1
max_score = max(semantic_scores)
min_score = min(semantic_scores)
range_score = max_score - min_score or 1
normalized_semantic = [(s - min_score) / range_score for s in semantic_scores]
# Compute final scores with metadata boosts
for i, doc in enumerate(documents):
metadata = doc.get("metadata", {})
recency = self.compute_recency_boost(metadata.get("date", ""))
authority = self.compute_authority_boost(metadata)
doc["semantic_score"] = normalized_semantic[i]
doc["recency_boost"] = recency
doc["authority_boost"] = authority
doc["final_score"] = (
self.semantic_weight * normalized_semantic[i] +
recency + authority
)
ranked = sorted(documents, key=lambda x: x["final_score"], reverse=True)
return ranked[:top_k]class ProductionReranker:
"""Production-ready re-ranking with fallbacks and caching."""
def __init__(
self,
primary_model: str = "BAAI/bge-reranker-large",
fallback_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
max_input_length: int = 512,
batch_size: int = 32
):
self.primary = CrossEncoder(primary_model, max_length=max_input_length)
self.fallback = CrossEncoder(fallback_model, max_length=max_input_length)
self.batch_size = batch_size
self.cache = {}
def _cache_key(self, query: str, doc_text: str) -> str:
import hashlib
content = f"{query}|||{doc_text}"
return hashlib.md5(content.encode()).hexdigest()
def _batch_predict(self, model, pairs: list[list[str]]) -> list[float]:
"""Batch prediction with error handling."""
all_scores = []
for i in range(0, len(pairs), self.batch_size):
batch = pairs[i:i + self.batch_size]
scores = model.predict(batch)
all_scores.extend(scores)
return all_scores
def rerank(
self,
query: str,
documents: list[dict],
top_k: int = 5,
use_cache: bool = True
) -> list[dict]:
if not documents:
return []
# Check cache
cached_scores = {}
uncached_pairs = []
uncached_indices = []
for i, doc in enumerate(documents):
key = self._cache_key(query, doc["text"])
if use_cache and key in self.cache:
cached_scores[i] = self.cache[key]
else:
uncached_pairs.append([query, doc["text"]])
uncached_indices.append(i)
# Score uncached pairs
if uncached_pairs:
try:
scores = self._batch_predict(self.primary, uncached_pairs)
except Exception as e:
print(f"Primary model failed: {e}, using fallback")
scores = self._batch_predict(self.fallback, uncached_pairs)
# Update cache and scores
for idx, score in zip(uncached_indices, scores):
key = self._cache_key(query, documents[idx]["text"])
self.cache[key] = score
cached_scores[idx] = score
# Add scores to documents
for i, doc in enumerate(documents):
doc["rerank_score"] = float(cached_scores[i])
# Sort and return
ranked = sorted(documents, key=lambda x: x["rerank_score"], reverse=True)
return ranked[:top_k]def evaluate_reranker(
test_data: list[dict], # [{query, candidates, relevant_ids}]
reranker,
k_values: list[int] = [1, 3, 5, 10]
) -> dict:
"""Evaluate reranker on test data."""
metrics = {f"ndcg@{k}": [] for k in k_values}
metrics.update({f"mrr@{k}": [] for k in k_values})
for item in test_data:
query = item["query"]
candidates = item["candidates"]
relevant_ids = set(item["relevant_ids"])
# Rerank
reranked = reranker.rerank(query, candidates, top_k=max(k_values))
for k in k_values:
top_k_ids = [doc["id"] for doc in reranked[:k]]
# MRR@k
mrr = 0
for rank, doc_id in enumerate(top_k_ids, 1):
if doc_id in relevant_ids:
mrr = 1 / rank
break
metrics[f"mrr@{k}"].append(mrr)
# NDCG@k
dcg = sum(
1 / np.log2(rank + 2)
for rank, doc_id in enumerate(top_k_ids)
if doc_id in relevant_ids
)
ideal_dcg = sum(1 / np.log2(i + 2) for i in range(min(k, len(relevant_ids))))
ndcg = dcg / ideal_dcg if ideal_dcg > 0 else 0
metrics[f"ndcg@{k}"].append(ndcg)
# Average metrics
return {k: np.mean(v) for k, v in metrics.items()}In the next lesson, we'll explore hybrid RAG architectures that combine multiple retrieval and generation strategies.