RAG Systems

0 of 13 lessons completed

Performance Optimization and Caching

Production RAG systems must balance quality with latency and cost. This lesson covers optimization strategies across the entire RAG pipeline, from embedding generation to response caching.

RAG Performance Bottlenecks

Typical RAG Latency Breakdown:
┌─────────────────────────────────────────────────────────────┐
│ Component              │ Latency     │ Cost Impact          │
├─────────────────────────────────────────────────────────────┤
│ Query Embedding        │ 10-50ms     │ Low (small input)    │
│ Vector Search          │ 10-100ms   │ Fixed (per query)    │
│ Re-ranking (optional)  │ 50-200ms   │ Medium               │
│ LLM Generation         │ 500-3000ms │ High (input+output)  │
└─────────────────────────────────────────────────────────────┘

Total typical latency: 1-4 seconds
Main cost driver: LLM tokens (context + generation)

Caching Strategies

1. Embedding Cache

import hashlib
import redis
import json
import numpy as np

class EmbeddingCache:
    """Cache embeddings to avoid recomputation."""
    
    def __init__(self, redis_url: str = "redis://localhost:6379"):
        self.redis = redis.from_url(redis_url)
        self.prefix = "emb:"
        self.ttl = 86400 * 7  # 7 days
    
    def _key(self, text: str, model: str) -> str:
        content = f"{model}:{text}"
        hash_val = hashlib.sha256(content.encode()).hexdigest()[:16]
        return f"{self.prefix}{hash_val}"
    
    def get(self, text: str, model: str) -> np.ndarray | None:
        key = self._key(text, model)
        cached = self.redis.get(key)
        
        if cached:
            return np.frombuffer(cached, dtype=np.float32)
        return None
    
    def set(self, text: str, model: str, embedding: np.ndarray):
        key = self._key(text, model)
        self.redis.setex(key, self.ttl, embedding.astype(np.float32).tobytes())
    
    def get_or_compute(self, text: str, model: str, compute_fn) -> np.ndarray:
        cached = self.get(text, model)
        if cached is not None:
            return cached
        
        embedding = compute_fn(text)
        self.set(text, model, embedding)
        return embedding

2. Retrieval Cache

from functools import lru_cache
from datetime import datetime, timedelta

class RetrievalCache:
    """Cache retrieval results for repeated queries."""
    
    def __init__(self, max_size: int = 10000, ttl_seconds: int = 3600):
        self.cache = {}
        self.max_size = max_size
        self.ttl = ttl_seconds
    
    def _key(self, query: str, k: int, filters: dict = None) -> str:
        filter_str = json.dumps(filters, sort_keys=True) if filters else ""
        return hashlib.md5(f"{query}:{k}:{filter_str}".encode()).hexdigest()
    
    def get(self, query: str, k: int, filters: dict = None) -> list | None:
        key = self._key(query, k, filters)
        
        if key in self.cache:
            entry = self.cache[key]
            if datetime.now() < entry["expires"]:
                return entry["results"]
            else:
                del self.cache[key]
        
        return None
    
    def set(self, query: str, k: int, results: list, filters: dict = None):
        # LRU eviction
        if len(self.cache) >= self.max_size:
            oldest = min(self.cache.items(), key=lambda x: x[1]["created"])
            del self.cache[oldest[0]]
        
        key = self._key(query, k, filters)
        self.cache[key] = {
            "results": results,
            "created": datetime.now(),
            "expires": datetime.now() + timedelta(seconds=self.ttl)
        }

3. Semantic Query Cache

class SemanticCache:
    """
    Cache based on query similarity, not exact match.
    Similar queries can reuse cached results.
    """
    
    def __init__(self, embedding_model, similarity_threshold: float = 0.95):
        self.embedding_model = embedding_model
        self.threshold = similarity_threshold
        self.cache_entries = []  # [(embedding, query, results, timestamp)]
    
    def get(self, query: str) -> list | None:
        query_emb = self.embedding_model.encode(query)
        
        for emb, cached_query, results, timestamp in self.cache_entries:
            similarity = np.dot(query_emb, emb) / (
                np.linalg.norm(query_emb) * np.linalg.norm(emb)
            )
            
            if similarity >= self.threshold:
                return results
        
        return None
    
    def set(self, query: str, results: list):
        query_emb = self.embedding_model.encode(query)
        self.cache_entries.append((
            query_emb, query, results, datetime.now()
        ))
        
        # Limit cache size
        if len(self.cache_entries) > 10000:
            self.cache_entries = self.cache_entries[-5000:]

4. Response Cache

class ResponseCache:
    """
    Cache full RAG responses for exact query matches.
    Useful for common/repeated queries.
    """
    
    def __init__(self, redis_client, ttl: int = 3600):
        self.redis = redis_client
        self.ttl = ttl
        self.prefix = "rag_response:"
    
    def _key(self, query: str) -> str:
        # Normalize query
        normalized = query.lower().strip()
        return f"{self.prefix}{hashlib.sha256(normalized.encode()).hexdigest()}"
    
    def get(self, query: str) -> dict | None:
        cached = self.redis.get(self._key(query))
        if cached:
            return json.loads(cached)
        return None
    
    def set(self, query: str, response: dict):
        self.redis.setex(
            self._key(query),
            self.ttl,
            json.dumps(response)
        )

Batching and Parallelization

Batch Embedding Generation

import asyncio
from concurrent.futures import ThreadPoolExecutor

class BatchEmbedder:
    def __init__(self, model, batch_size: int = 64):
        self.model = model
        self.batch_size = batch_size
        self.executor = ThreadPoolExecutor(max_workers=4)
    
    def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
        """Embed texts in batches for efficiency."""
        all_embeddings = []
        
        for i in range(0, len(texts), self.batch_size):
            batch = texts[i:i + self.batch_size]
            embeddings = self.model.encode(
                batch,
                batch_size=self.batch_size,
                show_progress_bar=False,
                convert_to_numpy=True
            )
            all_embeddings.extend(embeddings)
        
        return all_embeddings
    
    async def embed_async(self, texts: list[str]) -> list[np.ndarray]:
        """Async wrapper for non-blocking embedding."""
        loop = asyncio.get_event_loop()
        return await loop.run_in_executor(
            self.executor,
            self.embed_batch,
            texts
        )

Parallel Multi-Query Retrieval

import asyncio
from typing import Callable

class ParallelRetriever:
    def __init__(self, retrievers: list):
        self.retrievers = retrievers
    
    async def retrieve_parallel(
        self,
        query: str,
        k: int = 10
    ) -> list[list[dict]]:
        """Run multiple retrievers in parallel."""
        
        async def run_retriever(retriever, query, k):
            loop = asyncio.get_event_loop()
            return await loop.run_in_executor(
                None,
                lambda: retriever.search(query, k)
            )
        
        tasks = [
            run_retriever(r, query, k)
            for r in self.retrievers
        ]
        
        results = await asyncio.gather(*tasks)
        return results

# Usage
async def hybrid_search(query: str):
    parallel_retriever = ParallelRetriever([
        bm25_retriever,
        vector_retriever,
        reranker_retriever
    ])
    
    results = await parallel_retriever.retrieve_parallel(query, k=10)
    return merge_results(results)

Context Optimization

Context Window Optimization

class ContextOptimizer:
    """Optimize context to fit LLM limits while maximizing relevance."""
    
    def __init__(
        self,
        max_tokens: int = 4096,
        tokenizer = None
    ):
        self.max_tokens = max_tokens
        self.tokenizer = tokenizer
    
    def count_tokens(self, text: str) -> int:
        if self.tokenizer:
            return len(self.tokenizer.encode(text))
        return len(text.split()) * 1.3  # Rough estimate
    
    def optimize_context(
        self,
        query: str,
        documents: list[dict],
        system_prompt: str = "",
        reserved_for_output: int = 500
    ) -> list[dict]:
        """
        Select and order documents to fit context limit.
        Prioritizes higher-scored documents.
        """
        # Calculate available tokens
        query_tokens = self.count_tokens(query)
        system_tokens = self.count_tokens(system_prompt)
        available = self.max_tokens - query_tokens - system_tokens - reserved_for_output
        
        # Sort by relevance score
        sorted_docs = sorted(
            documents,
            key=lambda x: x.get("score", 0),
            reverse=True
        )
        
        selected = []
        used_tokens = 0
        
        for doc in sorted_docs:
            doc_tokens = self.count_tokens(doc["text"])
            
            if used_tokens + doc_tokens <= available:
                selected.append(doc)
                used_tokens += doc_tokens
            else:
                # Try truncating the document
                remaining = available - used_tokens
                if remaining > 100:  # Worth including partial
                    truncated = self.truncate_to_tokens(doc["text"], remaining)
                    doc_copy = {**doc, "text": truncated, "truncated": True}
                    selected.append(doc_copy)
                break
        
        return selected
    
    def truncate_to_tokens(self, text: str, max_tokens: int) -> str:
        """Truncate text to fit token limit, preserving sentence boundaries."""
        sentences = text.split(". ")
        result = []
        tokens = 0
        
        for sentence in sentences:
            sentence_tokens = self.count_tokens(sentence)
            if tokens + sentence_tokens <= max_tokens:
                result.append(sentence)
                tokens += sentence_tokens
            else:
                break
        
        return ". ".join(result) + "." if result else text[:max_tokens*4]

Context Compression

class ContextCompressor:
    """Compress context by extracting only relevant information."""
    
    def __init__(self, llm):
        self.llm = llm
    
    def compress(self, query: str, document: str, max_length: int = 500) -> str:
        """Extract only the parts of document relevant to the query."""
        
        prompt = f"""Extract ONLY the sentences from this document that are 
relevant to answering the question. Remove irrelevant information.
Keep the extracted text under {max_length} characters.

Question: {query}

Document:
{document}

Relevant excerpt:"""
        
        response = self.llm.invoke(prompt)
        return response.content[:max_length]
    
    def compress_batch(
        self,
        query: str,
        documents: list[dict],
        max_length_each: int = 300
    ) -> list[dict]:
        """Compress multiple documents."""
        compressed = []
        
        for doc in documents:
            compressed_text = self.compress(query, doc["text"], max_length_each)
            compressed.append({
                **doc,
                "text": compressed_text,
                "original_length": len(doc["text"]),
                "compressed_length": len(compressed_text)
            })
        
        return compressed

Reducing LLM Costs

Model Selection Based on Query Complexity

class AdaptiveLLMSelector:
    """Route queries to appropriate models based on complexity."""
    
    def __init__(self, models: dict):
        self.models = models  # {"fast": gpt-3.5, "quality": gpt-4}
    
    def classify_complexity(self, query: str, context_length: int) -> str:
        """Classify query complexity."""
        
        # Simple heuristics
        is_simple = (
            len(query.split()) < 10 and
            context_length < 1000 and
            not any(kw in query.lower() for kw in [
                "compare", "analyze", "explain why", "synthesize"
            ])
        )
        
        return "fast" if is_simple else "quality"
    
    def get_model(self, query: str, context_length: int):
        complexity = self.classify_complexity(query, context_length)
        return self.models[complexity]

Streaming Responses

async def stream_rag_response(query: str, retriever, llm):
    """Stream RAG response for better perceived latency."""
    
    # Retrieve (not streamable)
    documents = retriever.search(query, k=5)
    context = "\n---\n".join([d["text"] for d in documents])
    
    # Stream generation
    prompt = f"""Answer based on context.
    
Context:
{context}

Question: {query}

Answer:"""
    
    # Yield tokens as they're generated
    async for chunk in llm.astream(prompt):
        yield chunk.content

# FastAPI endpoint
from fastapi import FastAPI
from fastapi.responses import StreamingResponse

app = FastAPI()

@app.get("/query")
async def query_endpoint(q: str):
    return StreamingResponse(
        stream_rag_response(q, retriever, llm),
        media_type="text/plain"
    )

Index Optimization

# HNSW Index Tuning for Production
class VectorIndexOptimizer:
    """Optimize vector index parameters."""
    
    @staticmethod
    def get_hnsw_params(
        dataset_size: int,
        query_latency_target_ms: int = 50
    ) -> dict:
        """
        Recommend HNSW parameters based on requirements.
        
        M: Number of connections per node (higher = better quality, more memory)
        ef_construction: Build-time search breadth (higher = better quality, slower build)
        ef_search: Query-time search breadth (higher = better quality, slower query)
        """
        
        if dataset_size < 100_000:
            return {
                "M": 16,
                "ef_construction": 200,
                "ef_search": 100
            }
        elif dataset_size < 1_000_000:
            return {
                "M": 32,
                "ef_construction": 200,
                "ef_search": 150 if query_latency_target_ms > 30 else 100
            }
        else:
            return {
                "M": 48,
                "ef_construction": 400,
                "ef_search": 200 if query_latency_target_ms > 50 else 100
            }
    
    @staticmethod
    def estimate_memory_gb(
        num_vectors: int,
        dimensions: int,
        M: int = 16
    ) -> float:
        """Estimate memory usage for HNSW index."""
        
        # Vector storage
        vector_bytes = num_vectors * dimensions * 4  # float32
        
        # HNSW graph (approximate)
        graph_bytes = num_vectors * M * 2 * 8  # 2 layers avg, 8 bytes per link
        
        total_bytes = vector_bytes + graph_bytes
        return total_bytes / (1024 ** 3)

Monitoring and Observability

import time
from dataclasses import dataclass, field
from typing import Optional
import logging

@dataclass
class RAGMetrics:
    query: str
    total_latency_ms: float
    embedding_latency_ms: float
    retrieval_latency_ms: float
    rerank_latency_ms: Optional[float]
    generation_latency_ms: float
    num_retrieved: int
    num_tokens_context: int
    num_tokens_output: int
    cache_hit: bool = False

class InstrumentedRAG:
    """RAG system with comprehensive metrics."""
    
    def __init__(self, components: dict):
        self.embedder = components["embedder"]
        self.retriever = components["retriever"]
        self.reranker = components.get("reranker")
        self.llm = components["llm"]
        self.logger = logging.getLogger(__name__)
    
    def query(self, user_query: str) -> tuple[str, RAGMetrics]:
        metrics = {}
        start_total = time.time()
        
        # Embedding
        start = time.time()
        query_embedding = self.embedder.embed(user_query)
        metrics["embedding_latency_ms"] = (time.time() - start) * 1000
        
        # Retrieval
        start = time.time()
        documents = self.retriever.search(query_embedding, k=20)
        metrics["retrieval_latency_ms"] = (time.time() - start) * 1000
        metrics["num_retrieved"] = len(documents)
        
        # Reranking (optional)
        if self.reranker:
            start = time.time()
            documents = self.reranker.rerank(user_query, documents, top_k=5)
            metrics["rerank_latency_ms"] = (time.time() - start) * 1000
        else:
            documents = documents[:5]
            metrics["rerank_latency_ms"] = None
        
        # Build context
        context = "\n---\n".join([d["text"] for d in documents])
        metrics["num_tokens_context"] = len(context.split()) * 1.3
        
        # Generation
        start = time.time()
        response = self.llm.generate(user_query, context)
        metrics["generation_latency_ms"] = (time.time() - start) * 1000
        metrics["num_tokens_output"] = len(response.split()) * 1.3
        
        # Total
        metrics["total_latency_ms"] = (time.time() - start_total) * 1000
        metrics["query"] = user_query
        
        rag_metrics = RAGMetrics(**metrics)
        self.logger.info("rag_query", extra=vars(rag_metrics))
        
        return response, rag_metrics

Production Optimization Checklist

Performance Optimization Checklist:

□ Embedding Layer
  ├── Cache embeddings for repeated texts
  ├── Use batch embedding for ingestion
  ├── Consider smaller embedding models for speed
  └── Use GPU acceleration if available

□ Retrieval Layer
  ├── Tune HNSW parameters (M, ef_construction, ef_search)
  ├── Cache frequent query results
  ├── Use semantic caching for similar queries
  └── Implement parallel multi-index search

□ Re-ranking Layer
  ├── Limit candidates for re-ranking (50-100 max)
  ├── Use lighter re-ranker for low-stakes queries
  ├── Cache re-ranking results
  └── Consider skipping for simple queries

□ Generation Layer
  ├── Optimize context window usage
  ├── Use context compression for long documents
  ├── Route simple queries to faster models
  ├── Stream responses for better UX
  └── Cache responses for common queries

□ Infrastructure
  ├── Use connection pooling for databases
  ├── Deploy embedding models on GPU
  ├── Use Redis for caching layer
  ├── Implement proper monitoring
  └── Set up alerting for latency spikes

Key Takeaways

  • LLM generation is the main latency bottleneck - optimize context size
  • Multi-layer caching - embeddings, retrieval results, full responses
  • Semantic caching catches similar (not just identical) queries
  • Batch and parallelize embedding and retrieval operations
  • Compress context to reduce tokens and cost
  • Route queries to appropriate models based on complexity
  • Stream responses for better perceived latency
  • Monitor everything - latency, tokens, cache hit rates

In the next lesson, we'll cover deploying RAG systems to production environments.