RAG Systems

0 of 13 lessons completed

Reranking and Relevance Scoring

Re-ranking is a crucial technique in multistage retrieval pipelines that significantly improves precision. In this lesson, we'll explore re-ranking models, learning-to-rank paradigms, and production implementation patterns.

Why Re-ranking?

Initial retrieval (first stage) optimizes for recall - finding all potentially relevant documents quickly. Re-ranking (second stage) optimizes for precision - putting the best documents at the top.

The Two-Stage Retrieval Pipeline

Query
  │
  ▼
┌─────────────────────────────┐
│   Stage 1: Initial Retrieval │
│   (Bi-encoder, BM25)         │
│   Fast, high recall          │
│   Returns 50-500 candidates  │
└─────────────────────────────┘
  │
  ▼
┌─────────────────────────────┐
│   Stage 2: Re-ranking        │
│   (Cross-encoder, LLM)       │
│   Slower, high precision     │
│   Returns top 5-10           │
└─────────────────────────────┘
  │
  ▼
Final Results

Bi-encoder vs Cross-encoder

AspectBi-encoderCross-encoder
ArchitectureEncode query and doc separatelyEncode query-doc pair together
SpeedFast (documents pre-encoded)Slow (must encode pairs at query time)
QualityGoodExcellent (captures fine-grained interactions)
Use CaseInitial retrieval (millions of docs)Re-ranking (50-500 candidates)

Cross-Encoder Re-rankers

Cross-encoders process query and document together, enabling rich attention between tokens from both. This captures nuances that bi-encoders miss.

How Cross-Encoders Work

Input: [CLS] query tokens [SEP] document tokens [SEP]
         │
         ▼
    ┌──────────────┐
    │  Transformer  │
    │  (BERT-like)  │
    └──────────────┘
         │
         ▼
    [CLS] embedding
         │
         ▼
    Linear layer → Relevance score

Implementation

from sentence_transformers import CrossEncoder

class CrossEncoderReranker:
    def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
        self.model = CrossEncoder(model_name)
    
    def rerank(
        self,
        query: str,
        documents: list[dict],
        top_k: int = 5
    ) -> list[dict]:
        """
        Re-rank documents by relevance to query.
        
        Args:
            query: The search query
            documents: List of dicts with 'text' and other fields
            top_k: Number of top results to return
        
        Returns:
            Reranked documents with added 'rerank_score'
        """
        if not documents:
            return []
        
        # Create query-document pairs
        pairs = [[query, doc["text"]] for doc in documents]
        
        # Get relevance scores
        scores = self.model.predict(pairs)
        
        # Add scores to documents
        for doc, score in zip(documents, scores):
            doc["rerank_score"] = float(score)
        
        # Sort by rerank score
        ranked = sorted(documents, key=lambda x: x["rerank_score"], reverse=True)
        
        return ranked[:top_k]

# Usage
reranker = CrossEncoderReranker()

# First stage: retrieve candidates
candidates = retriever.search("What is machine learning?", k=50)

# Second stage: rerank
final_results = reranker.rerank("What is machine learning?", candidates, top_k=5)

Popular Re-ranking Models

Cross-Encoder Models

ModelSizeNotes
cross-encoder/ms-marco-MiniLM-L-6-v222MFast, good baseline
cross-encoder/ms-marco-MiniLM-L-12-v233MBetter quality, still fast
BAAI/bge-reranker-large335MHigh quality, multilingual
BAAI/bge-reranker-v2-m3568MLatest, multilingual, long context

API-Based Re-rankers

# Cohere Rerank
import cohere

co = cohere.Client(api_key="your-api-key")

def cohere_rerank(query: str, documents: list[str], top_k: int = 5):
    response = co.rerank(
        model="rerank-english-v3.0",
        query=query,
        documents=documents,
        top_n=top_k,
        return_documents=True
    )
    
    return [
        {
            "text": result.document.text,
            "score": result.relevance_score,
            "index": result.index
        }
        for result in response.results
    ]

# Jina Reranker
import requests

def jina_rerank(query: str, documents: list[str], top_k: int = 5):
    response = requests.post(
        "https://api.jina.ai/v1/rerank",
        headers={"Authorization": "Bearer your-api-key"},
        json={
            "model": "jina-reranker-v2-base-multilingual",
            "query": query,
            "documents": documents,
            "top_n": top_k
        }
    )
    return response.json()["results"]

ColBERT: Late Interaction Re-ranking

ColBERT uses a "late interaction" mechanism that's more efficient than full cross-encoders while maintaining high quality.

MaxSim Scoring

import torch

def maxsim_score(query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor) -> float:
    """
    ColBERT MaxSim scoring.
    
    For each query token, find its maximum similarity to any document token.
    Sum these maximum similarities.
    
    Args:
        query_embeddings: [num_query_tokens, dim]
        doc_embeddings: [num_doc_tokens, dim]
    
    Returns:
        Relevance score
    """
    # Compute all pairwise similarities
    # [query_tokens, doc_tokens]
    similarities = torch.matmul(query_embeddings, doc_embeddings.T)
    
    # For each query token, take max over document tokens
    max_sims, _ = similarities.max(dim=1)
    
    # Sum of max similarities
    return max_sims.sum().item()

# Example with ColBERT
from colbert import Indexer, Searcher
from colbert.infra import ColBERTConfig

config = ColBERTConfig(
    nbits=2,
    doc_maxlen=300,
    query_maxlen=32
)

# Index documents
indexer = Indexer(checkpoint="colbert-ir/colbertv2.0", config=config)
indexer.index(name="my_index", collection=documents)

# Search with ColBERT
searcher = Searcher(index="my_index")
results = searcher.search("What is machine learning?", k=10)

Learning-to-Rank Paradigms

1. Pointwise Approach

Train a model to predict absolute relevance score for each document.

# Binary classification: relevant or not
loss = binary_cross_entropy(predicted_score, actual_relevance)

# Regression: predict graded relevance (0-4)
loss = mse(predicted_score, actual_relevance_grade)

2. Pairwise Approach

Train model to correctly order pairs of documents.

# Given (query, doc_positive, doc_negative)
# Ensure score(doc_positive) > score(doc_negative)

def pairwise_loss(score_pos, score_neg, margin=1.0):
    """Margin ranking loss."""
    return torch.relu(margin - score_pos + score_neg)

# Or using cross-entropy on softmax
def pairwise_softmax_loss(score_pos, score_neg):
    scores = torch.stack([score_pos, score_neg])
    labels = torch.tensor([0])  # Index of positive
    return cross_entropy(scores.unsqueeze(0), labels)

3. Listwise Approach

Optimize the entire ranking list at once.

def listwise_loss(predicted_scores, true_relevances):
    """
    ListMLE: Maximize likelihood of correct ranking.
    """
    # Sort by true relevance
    sorted_indices = torch.argsort(true_relevances, descending=True)
    sorted_scores = predicted_scores[sorted_indices]
    
    # Compute log probability of this ordering
    log_probs = []
    for i in range(len(sorted_scores)):
        remaining = sorted_scores[i:]
        log_prob = sorted_scores[i] - torch.logsumexp(remaining, dim=0)
        log_probs.append(log_prob)
    
    return -sum(log_probs)

LLM-Based Re-ranking

Use LLMs for re-ranking via relevance scoring or listwise ranking:

Pointwise LLM Scoring

def llm_score_document(query: str, document: str, llm) -> float:
    """Use LLM to score relevance."""
    
    prompt = f"""On a scale of 1-10, rate how relevant this document is to the query.
Respond with ONLY a number.

Query: {query}

Document: {document}

Relevance score:"""
    
    response = llm.invoke(prompt)
    try:
        return float(response.content.strip())
    except:
        return 0.0

def llm_rerank_pointwise(query: str, documents: list[dict], llm, top_k: int = 5):
    """Re-rank using LLM pointwise scoring."""
    
    for doc in documents:
        doc["llm_score"] = llm_score_document(query, doc["text"], llm)
    
    ranked = sorted(documents, key=lambda x: x["llm_score"], reverse=True)
    return ranked[:top_k]

Listwise LLM Ranking

def llm_rerank_listwise(query: str, documents: list[dict], llm, top_k: int = 5):
    """Use LLM to rank documents directly."""
    
    # Format documents with indices
    doc_list = "\n".join([
        f"[{i}] {doc['text'][:500]}..."
        for i, doc in enumerate(documents)
    ])
    
    prompt = f"""Given the query and documents below, rank the documents from 
most relevant to least relevant. Return ONLY the document indices in order, 
separated by commas.

Query: {query}

Documents:
{doc_list}

Ranking (most relevant first):"""
    
    response = llm.invoke(prompt)
    
    # Parse ranking
    try:
        indices = [int(x.strip()) for x in response.content.split(",")]
        ranked = [documents[i] for i in indices if i < len(documents)]
        return ranked[:top_k]
    except:
        return documents[:top_k]

# Pairwise tournament for more reliable LLM ranking
def llm_tournament_rerank(query: str, documents: list[dict], llm, top_k: int = 5):
    """Use pairwise comparisons in tournament style."""
    
    def compare(doc_a: dict, doc_b: dict) -> int:
        prompt = f"""Which document is more relevant to the query?
Answer with ONLY 'A' or 'B'.

Query: {query}

Document A: {doc_a['text'][:300]}

Document B: {doc_b['text'][:300]}

More relevant:"""
        
        response = llm.invoke(prompt).content.strip().upper()
        return -1 if response == "A" else 1
    
    # Sort using LLM comparison
    from functools import cmp_to_key
    ranked = sorted(documents, key=cmp_to_key(compare))
    return ranked[:top_k]

Instruction-Following Re-ranking

Modern re-rankers can follow custom instructions for domain-specific ranking:

from FlagEmbedding import FlagReranker

# BGE Reranker with instructions
reranker = FlagReranker(
    "BAAI/bge-reranker-v2-gemma",
    use_fp16=True
)

# Custom instruction for legal domain
instruction = "Prioritize legal precedents from higher courts. Prefer recent cases over older ones when relevance is similar."

# Rerank with instruction
pairs = [[f"{instruction}\n{query}", doc["text"]] for doc in documents]
scores = reranker.compute_score(pairs)

# Sort by score
ranked = sorted(zip(scores, documents), reverse=True, key=lambda x: x[0])
final_docs = [doc for _, doc in ranked[:top_k]]

Metadata-Aware Re-ranking

from datetime import datetime

class MetadataReranker:
    """Combine semantic relevance with metadata signals."""
    
    def __init__(self, semantic_weight: float = 0.7):
        self.semantic_weight = semantic_weight
        self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
    
    def compute_recency_boost(self, doc_date: str, max_boost: float = 0.2) -> float:
        """More recent documents get higher scores."""
        try:
            doc_datetime = datetime.fromisoformat(doc_date)
            days_old = (datetime.now() - doc_datetime).days
            # Linear decay over 365 days
            return max_boost * max(0, 1 - days_old / 365)
        except:
            return 0.0
    
    def compute_authority_boost(self, doc_metadata: dict, max_boost: float = 0.1) -> float:
        """Boost authoritative sources."""
        authority_sources = {"official", "verified", "primary"}
        if doc_metadata.get("source_type") in authority_sources:
            return max_boost
        return 0.0
    
    def rerank(
        self,
        query: str,
        documents: list[dict],
        top_k: int = 5
    ) -> list[dict]:
        # Get semantic scores
        pairs = [[query, doc["text"]] for doc in documents]
        semantic_scores = self.cross_encoder.predict(pairs)
        
        # Normalize semantic scores to 0-1
        max_score = max(semantic_scores)
        min_score = min(semantic_scores)
        range_score = max_score - min_score or 1
        normalized_semantic = [(s - min_score) / range_score for s in semantic_scores]
        
        # Compute final scores with metadata boosts
        for i, doc in enumerate(documents):
            metadata = doc.get("metadata", {})
            
            recency = self.compute_recency_boost(metadata.get("date", ""))
            authority = self.compute_authority_boost(metadata)
            
            doc["semantic_score"] = normalized_semantic[i]
            doc["recency_boost"] = recency
            doc["authority_boost"] = authority
            
            doc["final_score"] = (
                self.semantic_weight * normalized_semantic[i] +
                recency + authority
            )
        
        ranked = sorted(documents, key=lambda x: x["final_score"], reverse=True)
        return ranked[:top_k]

Production Re-ranking Pipeline

class ProductionReranker:
    """Production-ready re-ranking with fallbacks and caching."""
    
    def __init__(
        self,
        primary_model: str = "BAAI/bge-reranker-large",
        fallback_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
        max_input_length: int = 512,
        batch_size: int = 32
    ):
        self.primary = CrossEncoder(primary_model, max_length=max_input_length)
        self.fallback = CrossEncoder(fallback_model, max_length=max_input_length)
        self.batch_size = batch_size
        self.cache = {}
    
    def _cache_key(self, query: str, doc_text: str) -> str:
        import hashlib
        content = f"{query}|||{doc_text}"
        return hashlib.md5(content.encode()).hexdigest()
    
    def _batch_predict(self, model, pairs: list[list[str]]) -> list[float]:
        """Batch prediction with error handling."""
        all_scores = []
        
        for i in range(0, len(pairs), self.batch_size):
            batch = pairs[i:i + self.batch_size]
            scores = model.predict(batch)
            all_scores.extend(scores)
        
        return all_scores
    
    def rerank(
        self,
        query: str,
        documents: list[dict],
        top_k: int = 5,
        use_cache: bool = True
    ) -> list[dict]:
        if not documents:
            return []
        
        # Check cache
        cached_scores = {}
        uncached_pairs = []
        uncached_indices = []
        
        for i, doc in enumerate(documents):
            key = self._cache_key(query, doc["text"])
            if use_cache and key in self.cache:
                cached_scores[i] = self.cache[key]
            else:
                uncached_pairs.append([query, doc["text"]])
                uncached_indices.append(i)
        
        # Score uncached pairs
        if uncached_pairs:
            try:
                scores = self._batch_predict(self.primary, uncached_pairs)
            except Exception as e:
                print(f"Primary model failed: {e}, using fallback")
                scores = self._batch_predict(self.fallback, uncached_pairs)
            
            # Update cache and scores
            for idx, score in zip(uncached_indices, scores):
                key = self._cache_key(query, documents[idx]["text"])
                self.cache[key] = score
                cached_scores[idx] = score
        
        # Add scores to documents
        for i, doc in enumerate(documents):
            doc["rerank_score"] = float(cached_scores[i])
        
        # Sort and return
        ranked = sorted(documents, key=lambda x: x["rerank_score"], reverse=True)
        return ranked[:top_k]

Evaluating Re-rankers

def evaluate_reranker(
    test_data: list[dict],  # [{query, candidates, relevant_ids}]
    reranker,
    k_values: list[int] = [1, 3, 5, 10]
) -> dict:
    """Evaluate reranker on test data."""
    
    metrics = {f"ndcg@{k}": [] for k in k_values}
    metrics.update({f"mrr@{k}": [] for k in k_values})
    
    for item in test_data:
        query = item["query"]
        candidates = item["candidates"]
        relevant_ids = set(item["relevant_ids"])
        
        # Rerank
        reranked = reranker.rerank(query, candidates, top_k=max(k_values))
        
        for k in k_values:
            top_k_ids = [doc["id"] for doc in reranked[:k]]
            
            # MRR@k
            mrr = 0
            for rank, doc_id in enumerate(top_k_ids, 1):
                if doc_id in relevant_ids:
                    mrr = 1 / rank
                    break
            metrics[f"mrr@{k}"].append(mrr)
            
            # NDCG@k
            dcg = sum(
                1 / np.log2(rank + 2)
                for rank, doc_id in enumerate(top_k_ids)
                if doc_id in relevant_ids
            )
            ideal_dcg = sum(1 / np.log2(i + 2) for i in range(min(k, len(relevant_ids))))
            ndcg = dcg / ideal_dcg if ideal_dcg > 0 else 0
            metrics[f"ndcg@{k}"].append(ndcg)
    
    # Average metrics
    return {k: np.mean(v) for k, v in metrics.items()}

Key Takeaways

  • Two-stage retrieval is the production standard: fast recall → precise re-ranking
  • Cross-encoders are more accurate than bi-encoders but slower
  • ColBERT offers a good speed/quality trade-off with late interaction
  • API re-rankers (Cohere, Jina) are convenient for production
  • LLM re-ranking is powerful but expensive; use for high-stakes queries
  • Metadata signals (recency, authority) can boost relevance
  • Cache re-ranking scores for frequently queried documents
  • Evaluate with NDCG and MRR to measure ranking quality

In the next lesson, we'll explore hybrid RAG architectures that combine multiple retrieval and generation strategies.