User queries are often imperfect - they may be vague, use different terminology than the documents, or lack context. Query rewriting and expansion techniques transform user queries to improve retrieval quality.
The "semantic gap" between how users express queries and how information is stored in documents is a fundamental challenge in information retrieval:
Add synonyms and related terms to the original query:
from openai import OpenAI
def expand_query(query: str, llm_client: OpenAI) -> str:
"""Expand query with synonyms and related terms."""
prompt = f"""Expand the following search query by adding synonyms and
related terms that would help find relevant documents. Return the expanded
query as a single line.
Original query: {query}
Expanded query:"""
response = llm_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
temperature=0.3
)
return response.choices[0].message.content
# Example:
# Input: "car insurance rates"
# Output: "car automobile vehicle insurance rates premiums coverage cost pricing"Transform the query to be clearer or more specific:
def rewrite_query(query: str, context: str = None) -> str:
"""Rewrite query to be more specific and search-friendly."""
prompt = f"""Rewrite the following search query to be more specific and
effective for document retrieval. Make it clearer and add relevant context
if helpful.
{f"Context: {context}" if context else ""}
Original query: {query}
Rewritten query:"""
response = llm.invoke(prompt)
return response.content
# Example:
# Input: "how to fix it"
# Context: "User was discussing Python import errors"
# Output: "How to fix Python ImportError and ModuleNotFoundError issues"Generate multiple query variations and retrieve for each:
def generate_multi_queries(query: str, n: int = 3) -> list[str]:
"""Generate multiple query variations."""
prompt = f"""Generate {n} different versions of the following search query.
Each version should approach the topic from a different angle while maintaining
the original intent. Return one query per line.
Original query: {query}
Query variations:"""
response = llm.invoke(prompt)
queries = response.content.strip().split("\n")
return [q.strip() for q in queries if q.strip()]
# Example:
# Input: "What is RAG?"
# Output:
# 1. "What is Retrieval Augmented Generation and how does it work?"
# 2. "Explain the RAG architecture for AI systems"
# 3. "How do RAG systems combine retrieval with language models?"
class MultiQueryRetriever:
def __init__(self, base_retriever, llm):
self.retriever = base_retriever
self.llm = llm
def retrieve(self, query: str, k: int = 5) -> list:
# Generate multiple queries
queries = generate_multi_queries(query, n=3)
queries.append(query) # Include original
# Retrieve for each query
all_docs = {}
for q in queries:
results = self.retriever.search(q, k=k)
for doc in results:
doc_id = doc["id"]
if doc_id not in all_docs or doc["score"] > all_docs[doc_id]["score"]:
all_docs[doc_id] = doc
# Return top k by best score
ranked = sorted(all_docs.values(), key=lambda x: x["score"], reverse=True)
return ranked[:k]Generate a hypothetical answer, then use it for retrieval instead of the query:
def hyde_retrieval(query: str, retriever, llm) -> list:
"""
Hypothetical Document Embeddings (HyDE).
Instead of embedding the question, embed a hypothetical answer.
This aligns better with how documents are written.
"""
# Generate hypothetical answer
prompt = f"""Write a detailed paragraph that would answer the following
question. Write as if you are creating a document that contains this
information.
Question: {query}
Document paragraph:"""
hypothetical_doc = llm.invoke(prompt).content
# Retrieve using the hypothetical document instead of query
results = retriever.search(hypothetical_doc)
return results
# Example:
# Query: "What causes the northern lights?"
#
# HyDE generates: "The northern lights, or aurora borealis, are caused by
# charged particles from the sun interacting with gases in Earth's atmosphere.
# When solar wind particles collide with oxygen and nitrogen atoms..."
#
# This hypothetical doc is then used for embedding and retrievalGenerate a more abstract version of the query to find broader context:
def step_back_query(query: str) -> str:
"""Generate a more general version of the query."""
prompt = f"""Given a specific question, generate a more general "step-back"
question that would help gather broader context to answer the original question.
Specific question: {query}
Step-back question:"""
return llm.invoke(prompt).content
# Example:
# Original: "Why did my Python 3.9 script fail with asyncio on Windows?"
# Step-back: "How does asyncio work differently across operating systems?"
class StepBackRetriever:
def __init__(self, base_retriever, llm):
self.retriever = base_retriever
self.llm = llm
def retrieve(self, query: str, k: int = 5):
# Get step-back question
step_back = step_back_query(query)
# Retrieve for both
original_results = self.retriever.search(query, k=k)
stepback_results = self.retriever.search(step_back, k=k)
# Merge results (prefer original, augment with step-back)
seen = set()
results = []
for doc in original_results:
results.append(doc)
seen.add(doc["id"])
for doc in stepback_results:
if doc["id"] not in seen and len(results) < k * 2:
results.append(doc)
return results[:k*2] # Return more for broader contextBreak complex queries into simpler sub-queries:
def decompose_query(query: str) -> list[str]:
"""Break a complex query into simpler sub-queries."""
prompt = f"""Break down the following complex question into simpler
sub-questions that can be answered independently. Return one question per line.
Complex question: {query}
Sub-questions:"""
response = llm.invoke(prompt)
sub_queries = response.content.strip().split("\n")
return [q.strip() for q in sub_queries if q.strip()]
# Example:
# Input: "Compare the performance of RAG vs fine-tuning for customer support bots"
#
# Sub-queries:
# 1. "What is the performance of RAG for customer support applications?"
# 2. "What is the performance of fine-tuning for customer support applications?"
# 3. "What are the key metrics for comparing RAG and fine-tuning?"
# 4. "What are the trade-offs between RAG and fine-tuning?"Before retrieval, classify the query to determine the best strategy:
from enum import Enum
from pydantic import BaseModel
class QueryType(str, Enum):
FACTUAL = "factual" # Simple fact lookup
COMPARATIVE = "comparative" # Compare multiple things
PROCEDURAL = "procedural" # How-to questions
EXPLORATORY = "exploratory" # Open-ended exploration
CONVERSATIONAL = "conversational" # Follow-up in conversation
class QueryAnalysis(BaseModel):
query_type: QueryType
entities: list[str]
keywords: list[str]
requires_context: bool
suggested_filters: dict
def analyze_query(query: str, history: list = None) -> QueryAnalysis:
"""Analyze query to determine retrieval strategy."""
prompt = f"""Analyze this search query and provide structured output.
Query: {query}
{f"Conversation history: {history}" if history else ""}
Provide:
1. Query type: factual, comparative, procedural, exploratory, or conversational
2. Key entities mentioned
3. Important keywords
4. Whether it needs conversation context
5. Suggested metadata filters (if any)
Output as JSON:"""
response = llm.invoke(prompt)
# Parse JSON response
return QueryAnalysis.model_validate_json(response.content)
# Route based on analysis
def route_query(query: str, analysis: QueryAnalysis):
if analysis.query_type == QueryType.COMPARATIVE:
# Use decomposition for comparative queries
sub_queries = decompose_query(query)
return multi_retrieve(sub_queries)
elif analysis.query_type == QueryType.CONVERSATIONAL:
# Reformulate with context
reformulated = reformulate_with_context(query, history)
return retriever.search(reformulated)
elif analysis.query_type == QueryType.EXPLORATORY:
# Use HyDE for exploratory queries
return hyde_retrieval(query)
else:
# Standard retrieval for factual/procedural
return retriever.search(query)Multi-turn conversations present unique challenges for RAG retrieval. The latest user message often lacks context needed for effective retrieval. Here are the main strategies:
The simplest approach - embed only the current message:
def embed_latest_turn(messages: list[dict]) -> str:
"""Use only the latest user message for retrieval."""
latest = messages[-1]["content"]
return retriever.search(latest)
# Example:
# User: "Tell me about RAG"
# Assistant: "RAG combines retrieval with generation..."
# User: "What about the chunking part?" <- This is embedded
#
# Problem: "the chunking part" lacks context - retrieval may failConcatenate the last N turns to provide context:
def embed_concatenated_history(
messages: list[dict],
max_turns: int = 3,
max_tokens: int = 512
) -> list:
"""Concatenate recent turns for richer context."""
# Get recent messages
recent = messages[-max_turns * 2:] # *2 for user+assistant pairs
# Concatenate with role prefixes
concat_text = " ".join([
f"[{m['role'].upper()}]: {m['content']}"
for m in recent
])
# Truncate if too long
concat_text = truncate_to_tokens(concat_text, max_tokens)
return retriever.search(concat_text)
# Example:
# Embedded: "[USER]: Tell me about RAG [ASSISTANT]: RAG combines... [USER]: What about the chunking part?"
# Now "chunking" has context about RAGUse an LLM to create a standalone query from the conversation:
def condense_conversation_to_query(
messages: list[dict],
max_history: int = 5
) -> str:
"""Convert conversation history into a standalone query."""
recent_messages = messages[-max_history*2:]
conversation = "\n".join([
f"{m['role'].title()}: {m['content']}"
for m in recent_messages
])
prompt = f"""Given the following conversation, create a standalone search
query that captures what the user is currently looking for. The query should
be self-contained and not require the conversation context.
Conversation:
{conversation}
Standalone search query:"""
response = llm.invoke(prompt)
return response.content.strip()
# Example conversation:
# User: "Tell me about RAG"
# Assistant: "RAG is Retrieval Augmented Generation..."
# User: "What databases work best with it?"
#
# Condensed: "What vector databases work best with RAG systems?"class ConversationalRetriever:
"""Production-ready conversational RAG retriever."""
def __init__(self, llm, retriever, embedder):
self.llm = llm
self.retriever = retriever
self.embedder = embedder
self.query_cache = {} # Cache condensed queries
def condense_with_cache(self, messages: list[dict]) -> str:
"""Cache condensed queries to avoid redundant LLM calls."""
cache_key = self._messages_hash(messages)
if cache_key in self.query_cache:
return self.query_cache[cache_key]
condensed = self._condense(messages)
self.query_cache[cache_key] = condensed
return condensed
def _condense(self, messages: list[dict]) -> str:
# Only condense if there's conversation history
if len(messages) <= 1:
return messages[-1]["content"]
prompt = f"""Rewrite the user's last question as a standalone query.
Include all necessary context from the conversation.
Return ONLY the rewritten query, nothing else.
Conversation:
{self._format_messages(messages)}
Standalone query:"""
return self.llm.invoke(prompt).content.strip()Track entities and topics mentioned, embed structured state:
from dataclasses import dataclass, field
from typing import Set
@dataclass
class DialogueState:
"""Track conversation state for retrieval."""
current_topic: str = ""
mentioned_entities: Set[str] = field(default_factory=set)
user_intent: str = ""
context_keywords: Set[str] = field(default_factory=set)
def to_query(self, current_message: str) -> str:
"""Convert state to enhanced query."""
parts = [current_message]
if self.current_topic:
parts.append(f"topic: {self.current_topic}")
if self.mentioned_entities:
parts.append(f"entities: {', '.join(self.mentioned_entities)}")
return " | ".join(parts)
class StatefulRetriever:
def __init__(self, llm, retriever):
self.llm = llm
self.retriever = retriever
self.state = DialogueState()
def update_state(self, message: str):
"""Extract and update dialogue state from message."""
# Use LLM to extract entities and topic
prompt = f"""Extract from this message:
1. Main topic (1-3 words)
2. Named entities (people, products, concepts)
3. Keywords important for search
Message: {message}
Return as JSON: {{"topic": "...", "entities": [...], "keywords": [...]}}}"""
response = self.llm.invoke(prompt)
data = json.loads(response.content)
self.state.current_topic = data.get("topic", self.state.current_topic)
self.state.mentioned_entities.update(data.get("entities", []))
self.state.context_keywords.update(data.get("keywords", []))
def retrieve(self, current_message: str) -> list:
self.update_state(current_message)
enhanced_query = self.state.to_query(current_message)
return self.retriever.search(enhanced_query)Different conversation types need different reformulation strategies:
class TaskAwareReformulator:
"""Reformulate queries based on conversation task type."""
def __init__(self, llm):
self.llm = llm
self.task_prompts = {
"qa": self._qa_prompt,
"troubleshooting": self._troubleshooting_prompt,
"research": self._research_prompt,
"comparison": self._comparison_prompt
}
def detect_task(self, messages: list[dict]) -> str:
"""Detect the type of conversation task."""
prompt = f"""Classify this conversation into one category:
- qa: Simple question answering
- troubleshooting: Debugging or problem solving
- research: Exploring a topic in depth
- comparison: Comparing options or alternatives
Conversation: {messages[-3:]}
Category:"""
return self.llm.invoke(prompt).content.strip().lower()
def reformulate(self, messages: list[dict]) -> str:
task = self.detect_task(messages)
prompt_fn = self.task_prompts.get(task, self._qa_prompt)
return prompt_fn(messages)
def _troubleshooting_prompt(self, messages: list[dict]) -> str:
"""For troubleshooting, include error details and steps tried."""
return f"""The user is troubleshooting an issue. Create a search query that includes:
1. The specific error or problem
2. The technology/system involved
3. What they've already tried
Conversation: {self._format(messages)}
Search query for finding the solution:"""
def _research_prompt(self, messages: list[dict]) -> str:
"""For research, broaden the query to find related concepts."""
return f"""The user is researching a topic. Create a query that:
1. Captures the main concept
2. Includes related subtopics mentioned
3. Would find comprehensive overview content
Conversation: {self._format(messages)}
Research query:"""| Strategy | Latency | Quality | Best For |
|---|---|---|---|
| Latest Turn Only | Fastest | Poor for follow-ups | First message, complete questions |
| Concatenated History | Fast | Good | Short conversations, quick responses |
| LLM Condensation | +200-500ms | Excellent | Most production use cases |
| Dialogue State | Medium | Excellent | Complex multi-turn, entity tracking |
| Task-Optimized | Slower | Best | Specialized assistants |
class QueryEnhancementPipeline:
"""Complete query enhancement pipeline."""
def __init__(self, llm, retriever):
self.llm = llm
self.retriever = retriever
def enhance(
self,
query: str,
conversation_history: list = None,
enable_hyde: bool = True,
enable_multi_query: bool = True,
enable_step_back: bool = False
) -> dict:
"""
Apply query enhancement techniques and retrieve.
Returns dict with:
- enhanced_queries: list of generated queries
- results: retrieved documents
"""
enhanced_queries = [query]
# Handle conversation context
if conversation_history:
condensed = condense_conversation_to_query(
conversation_history + [{"role": "user", "content": query}]
)
enhanced_queries.append(condensed)
# Generate multi-queries
if enable_multi_query:
multi = generate_multi_queries(query, n=2)
enhanced_queries.extend(multi)
# Add step-back query
if enable_step_back:
step_back = step_back_query(query)
enhanced_queries.append(step_back)
# Collect all results
all_results = {}
for q in enhanced_queries:
results = self.retriever.search(q, k=5)
for doc in results:
if doc["id"] not in all_results:
all_results[doc["id"]] = doc
# If HyDE enabled, also search with hypothetical doc
if enable_hyde:
hyde_results = hyde_retrieval(query, self.retriever, self.llm)
for doc in hyde_results:
if doc["id"] not in all_results:
all_results[doc["id"]] = doc
# Rank all results
ranked = sorted(
all_results.values(),
key=lambda x: x.get("score", 0),
reverse=True
)
return {
"original_query": query,
"enhanced_queries": enhanced_queries,
"results": ranked[:10]
}def evaluate_enhancement(
test_queries: list[dict], # [{query, relevant_doc_ids}]
baseline_retriever,
enhanced_retriever,
k: int = 5
):
"""Compare retrieval with and without enhancement."""
baseline_metrics = {"recall": [], "mrr": []}
enhanced_metrics = {"recall": [], "mrr": []}
for item in test_queries:
query = item["query"]
relevant = set(item["relevant_doc_ids"])
# Baseline retrieval
baseline_results = baseline_retriever.search(query, k=k)
baseline_ids = [r["id"] for r in baseline_results]
# Enhanced retrieval
enhanced_results = enhanced_retriever.search(query, k=k)
enhanced_ids = [r["id"] for r in enhanced_results]
# Compute metrics
for results_ids, metrics in [
(baseline_ids, baseline_metrics),
(enhanced_ids, enhanced_metrics)
]:
# Recall@k
found = len(set(results_ids) & relevant)
metrics["recall"].append(found / len(relevant))
# MRR
for rank, doc_id in enumerate(results_ids, 1):
if doc_id in relevant:
metrics["mrr"].append(1 / rank)
break
else:
metrics["mrr"].append(0)
return {
"baseline": {
"recall@k": sum(baseline_metrics["recall"]) / len(test_queries),
"mrr": sum(baseline_metrics["mrr"]) / len(test_queries)
},
"enhanced": {
"recall@k": sum(enhanced_metrics["recall"]) / len(test_queries),
"mrr": sum(enhanced_metrics["mrr"]) / len(test_queries)
}
}In the next lesson, we'll explore re-ranking and relevance scoring to improve precision in RAG systems.