Vector search with cross-encoder re-ranking, hybrid BM25+vector retrieval, incremental index updates, and multiple LLM backends (Ollama local, OpenAI API).
223 lines
No EOL
6.9 KiB
Python
223 lines
No EOL
6.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
query_topk_prompt_engine.py
|
|
|
|
Query a vector store with a custom prompt for research assistance.
|
|
Uses BAAI/bge-large-en-v1.5 embeddings and Ollama for generation.
|
|
|
|
E.M.F. January 2026
|
|
Using Claude Sonnet 4.5 to suggest changes
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
from llama_index.core import (
|
|
Settings,
|
|
StorageContext,
|
|
load_index_from_storage,
|
|
)
|
|
from llama_index.core.prompts import PromptTemplate
|
|
from llama_index.core.postprocessor import SimilarityPostprocessor
|
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
|
from llama_index.llms.ollama import Ollama
|
|
|
|
|
|
# Suppress tokenizer parallelism warnings
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
# Configuration defaults
|
|
DEFAULT_LLM = "command-r7b"
|
|
DEFAULT_EMBED_MODEL = "BAAI/bge-large-en-v1.5"
|
|
DEFAULT_STORAGE_DIR = "./storage_exp"
|
|
DEFAULT_TOP_K = 15
|
|
DEFAULT_SIMILARITY_CUTOFF = 0.7 # Set to None to disable
|
|
|
|
|
|
def get_prompt_template(max_files: int = 10) -> PromptTemplate:
|
|
"""Return the custom prompt template for the query engine."""
|
|
return PromptTemplate(
|
|
f"""You are an expert research assistant. You are given top-ranked writing excerpts (CONTEXT) and a user's QUERY.
|
|
|
|
Instructions:
|
|
- Base your response *only* on the CONTEXT.
|
|
- The snippets are ordered from most to least relevant—prioritize insights from earlier (higher-ranked) snippets.
|
|
- Aim to reference *as many distinct* relevant files as possible (up to {max_files}).
|
|
- Do not invent or generalize; refer to specific passages or facts only.
|
|
- If a passage only loosely matches, deprioritize it.
|
|
|
|
Format your answer in two parts:
|
|
|
|
1. **Summary Theme**
|
|
Summarize the dominant theme from the relevant context in a few sentences.
|
|
|
|
2. **Matching Files**
|
|
List up to {max_files} matching files. Format each as:
|
|
<filename> - <rationale tied to content. Include date or section hints if available.>
|
|
|
|
CONTEXT:
|
|
{{context_str}}
|
|
|
|
QUERY:
|
|
{{query_str}}
|
|
|
|
Now provide the theme and list of matching files."""
|
|
)
|
|
|
|
|
|
def load_models(
|
|
llm_name: str = DEFAULT_LLM,
|
|
embed_model_name: str = DEFAULT_EMBED_MODEL,
|
|
cache_folder: str = "./models",
|
|
request_timeout: float = 360.0,
|
|
context_window: int = 8000,
|
|
):
|
|
"""Initialize and configure the LLM and embedding models."""
|
|
Settings.llm = Ollama(
|
|
model=llm_name,
|
|
request_timeout=request_timeout,
|
|
context_window=context_window,
|
|
)
|
|
Settings.embed_model = HuggingFaceEmbedding(
|
|
cache_folder=cache_folder,
|
|
model_name=embed_model_name,
|
|
local_files_only=True,
|
|
)
|
|
|
|
|
|
def load_query_engine(
|
|
storage_dir: str = DEFAULT_STORAGE_DIR,
|
|
top_k: int = DEFAULT_TOP_K,
|
|
similarity_cutoff: float | None = DEFAULT_SIMILARITY_CUTOFF,
|
|
max_files: int = 10,
|
|
):
|
|
"""Load the vector store and create a query engine with custom prompt."""
|
|
storage_path = Path(storage_dir)
|
|
if not storage_path.exists():
|
|
raise FileNotFoundError(f"Storage directory not found: {storage_dir}")
|
|
|
|
storage_context = StorageContext.from_defaults(persist_dir=str(storage_path))
|
|
index = load_index_from_storage(storage_context)
|
|
|
|
# Build postprocessors
|
|
postprocessors = []
|
|
if similarity_cutoff is not None:
|
|
postprocessors.append(SimilarityPostprocessor(similarity_cutoff=similarity_cutoff))
|
|
|
|
return index.as_query_engine(
|
|
similarity_top_k=top_k,
|
|
text_qa_template=get_prompt_template(max_files),
|
|
node_postprocessors=postprocessors if postprocessors else None,
|
|
)
|
|
|
|
|
|
def get_node_metadata(node) -> dict:
|
|
"""Safely extract metadata from a source node."""
|
|
# Handle different node structures in llamaindex
|
|
if hasattr(node, "metadata") and node.metadata:
|
|
return node.metadata
|
|
if hasattr(node, "node") and hasattr(node.node, "metadata"):
|
|
return node.node.metadata
|
|
return {}
|
|
|
|
|
|
def print_results(response):
|
|
"""Print the query response and source documents."""
|
|
print("\n" + "=" * 60)
|
|
print("RESPONSE")
|
|
print("=" * 60 + "\n")
|
|
print(response.response)
|
|
|
|
print("\n" + "=" * 60)
|
|
print("SOURCE DOCUMENTS")
|
|
print("=" * 60 + "\n")
|
|
|
|
for i, node in enumerate(response.source_nodes, 1):
|
|
meta = get_node_metadata(node)
|
|
score = getattr(node, "score", None)
|
|
file_name = meta.get("file_name", "Unknown")
|
|
file_path = meta.get("file_path", "Unknown")
|
|
score_str = f"{score:.3f}" if score is not None else "N/A"
|
|
print(f"{i:2}. [{score_str}] {file_name}")
|
|
print(f" Path: {file_path}")
|
|
|
|
|
|
def parse_args():
|
|
"""Parse command line arguments."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Query a vector store with a custom research assistant prompt.",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Examples:
|
|
python query_topk_prompt_engine.py "What themes appear in the documents?"
|
|
python query_topk_prompt_engine.py --top-k 20 --llm llama3.1:8B "Find references to machine learning"
|
|
""",
|
|
)
|
|
parser.add_argument("query", nargs="+", help="The query text")
|
|
parser.add_argument(
|
|
"--llm",
|
|
default=DEFAULT_LLM,
|
|
help=f"Ollama model to use for generation (default: {DEFAULT_LLM})",
|
|
)
|
|
parser.add_argument(
|
|
"--storage-dir",
|
|
default=DEFAULT_STORAGE_DIR,
|
|
help=f"Path to the vector store (default: {DEFAULT_STORAGE_DIR})",
|
|
)
|
|
parser.add_argument(
|
|
"--top-k",
|
|
type=int,
|
|
default=DEFAULT_TOP_K,
|
|
help=f"Number of similar documents to retrieve (default: {DEFAULT_TOP_K})",
|
|
)
|
|
parser.add_argument(
|
|
"--similarity-cutoff",
|
|
type=float,
|
|
default=DEFAULT_SIMILARITY_CUTOFF,
|
|
help=f"Minimum similarity score (default: {DEFAULT_SIMILARITY_CUTOFF}, use 0 to disable)",
|
|
)
|
|
parser.add_argument(
|
|
"--max-files",
|
|
type=int,
|
|
default=10,
|
|
help="Maximum files to list in response (default: 10)",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
# Handle similarity cutoff of 0 as "disabled"
|
|
similarity_cutoff = args.similarity_cutoff if args.similarity_cutoff > 0 else None
|
|
|
|
try:
|
|
print(f"Loading models (LLM: {args.llm})...")
|
|
load_models(llm_name=args.llm)
|
|
|
|
print(f"Loading index from {args.storage_dir}...")
|
|
query_engine = load_query_engine(
|
|
storage_dir=args.storage_dir,
|
|
top_k=args.top_k,
|
|
similarity_cutoff=similarity_cutoff,
|
|
max_files=args.max_files,
|
|
)
|
|
|
|
query_text = " ".join(args.query)
|
|
print(f"Querying: {query_text[:100]}{'...' if len(query_text) > 100 else ''}")
|
|
|
|
response = query_engine.query(query_text)
|
|
print_results(response)
|
|
|
|
except FileNotFoundError as e:
|
|
print(f"Error: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
except Exception as e:
|
|
print(f"Error during query: {e}", file=sys.stderr)
|
|
raise
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |