Production RAG systems must balance quality with latency and cost. This lesson covers optimization strategies across the entire RAG pipeline, from embedding generation to response caching.
Typical RAG Latency Breakdown:
┌─────────────────────────────────────────────────────────────┐
│ Component │ Latency │ Cost Impact │
├─────────────────────────────────────────────────────────────┤
│ Query Embedding │ 10-50ms │ Low (small input) │
│ Vector Search │ 10-100ms │ Fixed (per query) │
│ Re-ranking (optional) │ 50-200ms │ Medium │
│ LLM Generation │ 500-3000ms │ High (input+output) │
└─────────────────────────────────────────────────────────────┘
Total typical latency: 1-4 seconds
Main cost driver: LLM tokens (context + generation)import hashlib
import redis
import json
import numpy as np
class EmbeddingCache:
"""Cache embeddings to avoid recomputation."""
def __init__(self, redis_url: str = "redis://localhost:6379"):
self.redis = redis.from_url(redis_url)
self.prefix = "emb:"
self.ttl = 86400 * 7 # 7 days
def _key(self, text: str, model: str) -> str:
content = f"{model}:{text}"
hash_val = hashlib.sha256(content.encode()).hexdigest()[:16]
return f"{self.prefix}{hash_val}"
def get(self, text: str, model: str) -> np.ndarray | None:
key = self._key(text, model)
cached = self.redis.get(key)
if cached:
return np.frombuffer(cached, dtype=np.float32)
return None
def set(self, text: str, model: str, embedding: np.ndarray):
key = self._key(text, model)
self.redis.setex(key, self.ttl, embedding.astype(np.float32).tobytes())
def get_or_compute(self, text: str, model: str, compute_fn) -> np.ndarray:
cached = self.get(text, model)
if cached is not None:
return cached
embedding = compute_fn(text)
self.set(text, model, embedding)
return embeddingfrom functools import lru_cache
from datetime import datetime, timedelta
class RetrievalCache:
"""Cache retrieval results for repeated queries."""
def __init__(self, max_size: int = 10000, ttl_seconds: int = 3600):
self.cache = {}
self.max_size = max_size
self.ttl = ttl_seconds
def _key(self, query: str, k: int, filters: dict = None) -> str:
filter_str = json.dumps(filters, sort_keys=True) if filters else ""
return hashlib.md5(f"{query}:{k}:{filter_str}".encode()).hexdigest()
def get(self, query: str, k: int, filters: dict = None) -> list | None:
key = self._key(query, k, filters)
if key in self.cache:
entry = self.cache[key]
if datetime.now() < entry["expires"]:
return entry["results"]
else:
del self.cache[key]
return None
def set(self, query: str, k: int, results: list, filters: dict = None):
# LRU eviction
if len(self.cache) >= self.max_size:
oldest = min(self.cache.items(), key=lambda x: x[1]["created"])
del self.cache[oldest[0]]
key = self._key(query, k, filters)
self.cache[key] = {
"results": results,
"created": datetime.now(),
"expires": datetime.now() + timedelta(seconds=self.ttl)
}class SemanticCache:
"""
Cache based on query similarity, not exact match.
Similar queries can reuse cached results.
"""
def __init__(self, embedding_model, similarity_threshold: float = 0.95):
self.embedding_model = embedding_model
self.threshold = similarity_threshold
self.cache_entries = [] # [(embedding, query, results, timestamp)]
def get(self, query: str) -> list | None:
query_emb = self.embedding_model.encode(query)
for emb, cached_query, results, timestamp in self.cache_entries:
similarity = np.dot(query_emb, emb) / (
np.linalg.norm(query_emb) * np.linalg.norm(emb)
)
if similarity >= self.threshold:
return results
return None
def set(self, query: str, results: list):
query_emb = self.embedding_model.encode(query)
self.cache_entries.append((
query_emb, query, results, datetime.now()
))
# Limit cache size
if len(self.cache_entries) > 10000:
self.cache_entries = self.cache_entries[-5000:]class ResponseCache:
"""
Cache full RAG responses for exact query matches.
Useful for common/repeated queries.
"""
def __init__(self, redis_client, ttl: int = 3600):
self.redis = redis_client
self.ttl = ttl
self.prefix = "rag_response:"
def _key(self, query: str) -> str:
# Normalize query
normalized = query.lower().strip()
return f"{self.prefix}{hashlib.sha256(normalized.encode()).hexdigest()}"
def get(self, query: str) -> dict | None:
cached = self.redis.get(self._key(query))
if cached:
return json.loads(cached)
return None
def set(self, query: str, response: dict):
self.redis.setex(
self._key(query),
self.ttl,
json.dumps(response)
)import asyncio
from concurrent.futures import ThreadPoolExecutor
class BatchEmbedder:
def __init__(self, model, batch_size: int = 64):
self.model = model
self.batch_size = batch_size
self.executor = ThreadPoolExecutor(max_workers=4)
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
"""Embed texts in batches for efficiency."""
all_embeddings = []
for i in range(0, len(texts), self.batch_size):
batch = texts[i:i + self.batch_size]
embeddings = self.model.encode(
batch,
batch_size=self.batch_size,
show_progress_bar=False,
convert_to_numpy=True
)
all_embeddings.extend(embeddings)
return all_embeddings
async def embed_async(self, texts: list[str]) -> list[np.ndarray]:
"""Async wrapper for non-blocking embedding."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
self.embed_batch,
texts
)import asyncio
from typing import Callable
class ParallelRetriever:
def __init__(self, retrievers: list):
self.retrievers = retrievers
async def retrieve_parallel(
self,
query: str,
k: int = 10
) -> list[list[dict]]:
"""Run multiple retrievers in parallel."""
async def run_retriever(retriever, query, k):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
lambda: retriever.search(query, k)
)
tasks = [
run_retriever(r, query, k)
for r in self.retrievers
]
results = await asyncio.gather(*tasks)
return results
# Usage
async def hybrid_search(query: str):
parallel_retriever = ParallelRetriever([
bm25_retriever,
vector_retriever,
reranker_retriever
])
results = await parallel_retriever.retrieve_parallel(query, k=10)
return merge_results(results)class ContextOptimizer:
"""Optimize context to fit LLM limits while maximizing relevance."""
def __init__(
self,
max_tokens: int = 4096,
tokenizer = None
):
self.max_tokens = max_tokens
self.tokenizer = tokenizer
def count_tokens(self, text: str) -> int:
if self.tokenizer:
return len(self.tokenizer.encode(text))
return len(text.split()) * 1.3 # Rough estimate
def optimize_context(
self,
query: str,
documents: list[dict],
system_prompt: str = "",
reserved_for_output: int = 500
) -> list[dict]:
"""
Select and order documents to fit context limit.
Prioritizes higher-scored documents.
"""
# Calculate available tokens
query_tokens = self.count_tokens(query)
system_tokens = self.count_tokens(system_prompt)
available = self.max_tokens - query_tokens - system_tokens - reserved_for_output
# Sort by relevance score
sorted_docs = sorted(
documents,
key=lambda x: x.get("score", 0),
reverse=True
)
selected = []
used_tokens = 0
for doc in sorted_docs:
doc_tokens = self.count_tokens(doc["text"])
if used_tokens + doc_tokens <= available:
selected.append(doc)
used_tokens += doc_tokens
else:
# Try truncating the document
remaining = available - used_tokens
if remaining > 100: # Worth including partial
truncated = self.truncate_to_tokens(doc["text"], remaining)
doc_copy = {**doc, "text": truncated, "truncated": True}
selected.append(doc_copy)
break
return selected
def truncate_to_tokens(self, text: str, max_tokens: int) -> str:
"""Truncate text to fit token limit, preserving sentence boundaries."""
sentences = text.split(". ")
result = []
tokens = 0
for sentence in sentences:
sentence_tokens = self.count_tokens(sentence)
if tokens + sentence_tokens <= max_tokens:
result.append(sentence)
tokens += sentence_tokens
else:
break
return ". ".join(result) + "." if result else text[:max_tokens*4]class ContextCompressor:
"""Compress context by extracting only relevant information."""
def __init__(self, llm):
self.llm = llm
def compress(self, query: str, document: str, max_length: int = 500) -> str:
"""Extract only the parts of document relevant to the query."""
prompt = f"""Extract ONLY the sentences from this document that are
relevant to answering the question. Remove irrelevant information.
Keep the extracted text under {max_length} characters.
Question: {query}
Document:
{document}
Relevant excerpt:"""
response = self.llm.invoke(prompt)
return response.content[:max_length]
def compress_batch(
self,
query: str,
documents: list[dict],
max_length_each: int = 300
) -> list[dict]:
"""Compress multiple documents."""
compressed = []
for doc in documents:
compressed_text = self.compress(query, doc["text"], max_length_each)
compressed.append({
**doc,
"text": compressed_text,
"original_length": len(doc["text"]),
"compressed_length": len(compressed_text)
})
return compressedclass AdaptiveLLMSelector:
"""Route queries to appropriate models based on complexity."""
def __init__(self, models: dict):
self.models = models # {"fast": gpt-3.5, "quality": gpt-4}
def classify_complexity(self, query: str, context_length: int) -> str:
"""Classify query complexity."""
# Simple heuristics
is_simple = (
len(query.split()) < 10 and
context_length < 1000 and
not any(kw in query.lower() for kw in [
"compare", "analyze", "explain why", "synthesize"
])
)
return "fast" if is_simple else "quality"
def get_model(self, query: str, context_length: int):
complexity = self.classify_complexity(query, context_length)
return self.models[complexity]async def stream_rag_response(query: str, retriever, llm):
"""Stream RAG response for better perceived latency."""
# Retrieve (not streamable)
documents = retriever.search(query, k=5)
context = "\n---\n".join([d["text"] for d in documents])
# Stream generation
prompt = f"""Answer based on context.
Context:
{context}
Question: {query}
Answer:"""
# Yield tokens as they're generated
async for chunk in llm.astream(prompt):
yield chunk.content
# FastAPI endpoint
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
app = FastAPI()
@app.get("/query")
async def query_endpoint(q: str):
return StreamingResponse(
stream_rag_response(q, retriever, llm),
media_type="text/plain"
)# HNSW Index Tuning for Production
class VectorIndexOptimizer:
"""Optimize vector index parameters."""
@staticmethod
def get_hnsw_params(
dataset_size: int,
query_latency_target_ms: int = 50
) -> dict:
"""
Recommend HNSW parameters based on requirements.
M: Number of connections per node (higher = better quality, more memory)
ef_construction: Build-time search breadth (higher = better quality, slower build)
ef_search: Query-time search breadth (higher = better quality, slower query)
"""
if dataset_size < 100_000:
return {
"M": 16,
"ef_construction": 200,
"ef_search": 100
}
elif dataset_size < 1_000_000:
return {
"M": 32,
"ef_construction": 200,
"ef_search": 150 if query_latency_target_ms > 30 else 100
}
else:
return {
"M": 48,
"ef_construction": 400,
"ef_search": 200 if query_latency_target_ms > 50 else 100
}
@staticmethod
def estimate_memory_gb(
num_vectors: int,
dimensions: int,
M: int = 16
) -> float:
"""Estimate memory usage for HNSW index."""
# Vector storage
vector_bytes = num_vectors * dimensions * 4 # float32
# HNSW graph (approximate)
graph_bytes = num_vectors * M * 2 * 8 # 2 layers avg, 8 bytes per link
total_bytes = vector_bytes + graph_bytes
return total_bytes / (1024 ** 3)import time
from dataclasses import dataclass, field
from typing import Optional
import logging
@dataclass
class RAGMetrics:
query: str
total_latency_ms: float
embedding_latency_ms: float
retrieval_latency_ms: float
rerank_latency_ms: Optional[float]
generation_latency_ms: float
num_retrieved: int
num_tokens_context: int
num_tokens_output: int
cache_hit: bool = False
class InstrumentedRAG:
"""RAG system with comprehensive metrics."""
def __init__(self, components: dict):
self.embedder = components["embedder"]
self.retriever = components["retriever"]
self.reranker = components.get("reranker")
self.llm = components["llm"]
self.logger = logging.getLogger(__name__)
def query(self, user_query: str) -> tuple[str, RAGMetrics]:
metrics = {}
start_total = time.time()
# Embedding
start = time.time()
query_embedding = self.embedder.embed(user_query)
metrics["embedding_latency_ms"] = (time.time() - start) * 1000
# Retrieval
start = time.time()
documents = self.retriever.search(query_embedding, k=20)
metrics["retrieval_latency_ms"] = (time.time() - start) * 1000
metrics["num_retrieved"] = len(documents)
# Reranking (optional)
if self.reranker:
start = time.time()
documents = self.reranker.rerank(user_query, documents, top_k=5)
metrics["rerank_latency_ms"] = (time.time() - start) * 1000
else:
documents = documents[:5]
metrics["rerank_latency_ms"] = None
# Build context
context = "\n---\n".join([d["text"] for d in documents])
metrics["num_tokens_context"] = len(context.split()) * 1.3
# Generation
start = time.time()
response = self.llm.generate(user_query, context)
metrics["generation_latency_ms"] = (time.time() - start) * 1000
metrics["num_tokens_output"] = len(response.split()) * 1.3
# Total
metrics["total_latency_ms"] = (time.time() - start_total) * 1000
metrics["query"] = user_query
rag_metrics = RAGMetrics(**metrics)
self.logger.info("rag_query", extra=vars(rag_metrics))
return response, rag_metricsPerformance Optimization Checklist:
□ Embedding Layer
├── Cache embeddings for repeated texts
├── Use batch embedding for ingestion
├── Consider smaller embedding models for speed
└── Use GPU acceleration if available
□ Retrieval Layer
├── Tune HNSW parameters (M, ef_construction, ef_search)
├── Cache frequent query results
├── Use semantic caching for similar queries
└── Implement parallel multi-index search
□ Re-ranking Layer
├── Limit candidates for re-ranking (50-100 max)
├── Use lighter re-ranker for low-stakes queries
├── Cache re-ranking results
└── Consider skipping for simple queries
□ Generation Layer
├── Optimize context window usage
├── Use context compression for long documents
├── Route simple queries to faster models
├── Stream responses for better UX
└── Cache responses for common queries
□ Infrastructure
├── Use connection pooling for databases
├── Deploy embedding models on GPU
├── Use Redis for caching layer
├── Implement proper monitoring
└── Set up alerting for latency spikesIn the next lesson, we'll cover deploying RAG systems to production environments.