ssearch/archived/query_rewrite_hyde.py
Eric e9fc99ddc6 Initial commit: RAG pipeline for semantic search over personal journal archive
Vector search with cross-encoder re-ranking, hybrid BM25+vector retrieval,
incremental index updates, and multiple LLM backends (Ollama local, OpenAI API).
2026-02-20 06:02:28 -05:00

126 lines
4.3 KiB
Python

# query_rewrite_hyde.py
# Run a querry on a vector store
#
# Latest experiment to include query rewriting using HyDE (Hypothetial Document Embeddings)
# The goal is to reduce the semantic gap between the query and the indexed documents
# This verison implements a prompt and uses the build_exp.py vector store
# Based on query_exp.py
#
# E.M.F. July 2025
from llama_index.core import (
StorageContext,
load_index_from_storage,
ServiceContext,
Settings,
)
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.llms.ollama import Ollama
from llama_index.core.prompts import PromptTemplate
from llama_index.core.indices.query.query_transform import HyDEQueryTransform
from llama_index.core.query_engine.transform_query_engine import TransformQueryEngine
import os
# Globals
# Embedding model used in vector store (this should match the one in build_exp.py or equivalent)
# embed_model = HuggingFaceEmbedding(model_name="all-mpnet-base-v2")
embed_model = HuggingFaceEmbedding(cache_folder="./models",model_name="BAAI/bge-large-en-v1.5")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# LLM model to use in query transform and generation
llm="llama3.1:8B"
# Other models tried:
# llm="deepseek-r1:8B"
# llm="gemma3:1b"
# Custom prompt for the query engine
PROMPT = PromptTemplate(
"""You are an expert research assistant. You are given top-ranked writing excerpts (CONTEXT) and a user's QUERY.
Instructions:
- Base your response *only* on the CONTEXT.
- The snippets are ordered from most to least relevant—prioritize insights from earlier (higher-ranked) snippets.
- Aim to reference *as many distinct* relevant files as possible (up to 10).
- Do not invent or generalize; refer to specific passages or facts only.
- If a passage only loosely matches, deprioritize it.
Format your answer in two parts:
1. **Summary Theme**
Summarize the dominant theme from the relevant context in a few sentences.
2. **Matching Files**
Make a list of 10 matching files. The format for each should be:
<filename> — <rationale tied to content. Include date or section hints if available.>
CONTEXT:
{context_str}
QUERY:
{query_str}
Now provide the theme and list of matching files."""
)
#
# Main program routine
#
def main():
# Use a local model to generate
Settings.llm = Ollama(
model=llm, # First model tested
request_timeout=360.0,
context_window=8000
)
# Load embedding model (same as used for vector store)
Settings.embed_model = embed_model
# Load persisted vector store + metadata
storage_context = StorageContext.from_defaults(persist_dir="./storage_exp")
index = load_index_from_storage(storage_context)
# Build regular query engine with custom prompt
base_query_engine = index.as_query_engine(
similarity_top_k=15, # pull wide
#response_mode="compact" # concise synthesis
text_qa_template=PROMPT, # custom prompt
# node_postprocessors=[
# SimilarityPostprocessor(similarity_cutoff=0.75) # keep strong hits; makes result count flexible
# ],
)
# HyDE is "Hypothetical Document Embeddings"
# It generates a hypothetical document based on the query
# and uses that to augment the query
# Here we include the original query as well
# I get better similarity values with include_orignal=True
hyde_transform = HyDEQueryTransform(llm=Settings.llm,include_original=True)
# Query
while True:
q = input("\nEnter a search topic or question (or 'exit'): ").strip()
if q.lower() in ("exit", "quit"):
break
print()
# The query uses a HyDE trasformation to rewrite the query
query_engine = TransformQueryEngine(base_query_engine, query_transform=hyde_transform)
# Generate the response by querying the engine
# This performes the similarity search and then applies the prompt
response = query_engine.query(q)
# Return the query response and source documents
print(response.response)
print("\nSource documents:")
for node in response.source_nodes:
meta = getattr(node, "metadata", None) or node.node.metadata
print(meta.get("file_name"), "---", meta.get("file_path"), getattr(node, "score", None))
if __name__ == "__main__":
main()