Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- src/app/__pycache__/main.cpython-311.pyc +0 -0
- src/app/api/__pycache__/routes.cpython-311.pyc +0 -0
- src/app/api/routes.py +42 -0
- src/app/main.py +14 -0
- src/embeddings/__pycache__/embedder.cpython-311.pyc +0 -0
- src/embeddings/embedder.py +14 -0
- src/eval/__pycache__/hallucination.cpython-311.pyc +0 -0
- src/eval/__pycache__/relevancy.cpython-311.pyc +0 -0
- src/eval/__pycache__/retrieval_metrics.cpython-311.pyc +0 -0
- src/eval/hallucination.py +41 -0
- src/eval/relevancy.py +35 -0
- src/eval/retrieval_metrics.py +30 -0
- src/indexer/__pycache__/bm25_index.cpython-311.pyc +0 -0
- src/indexer/__pycache__/faiss_index.cpython-311.pyc +0 -0
- src/indexer/bm25_index.py +33 -0
- src/indexer/faiss_index.py +24 -0
- src/ingestion/__pycache__/chunkers.cpython-311.pyc +0 -0
- src/ingestion/__pycache__/cleaner.cpython-311.pyc +0 -0
- src/ingestion/__pycache__/readers.cpython-311.pyc +0 -0
- src/ingestion/chunkers.py +20 -0
- src/ingestion/cleaner.py +6 -0
- src/ingestion/ingest.py +83 -0
- src/ingestion/readers.py +18 -0
- src/llm/__pycache__/llm_client.cpython-311.pyc +0 -0
- src/llm/llm_client.py +56 -0
- src/pipeline/__pycache__/context_opt.cpython-311.pyc +0 -0
- src/pipeline/__pycache__/query_pipeline.cpython-311.pyc +0 -0
- src/pipeline/context_opt.py +55 -0
- src/pipeline/query_pipeline.py +97 -0
- src/reranker/__pycache__/cross_encoder.cpython-311.pyc +0 -0
- src/reranker/cross_encoder.py +19 -0
- src/retriever/__pycache__/hybrid_retriever.cpython-311.pyc +0 -0
- src/retriever/__pycache__/hyde.cpython-311.pyc +0 -0
- src/retriever/hybrid_retriever.py +71 -0
- src/retriever/hyde.py +21 -0
- src/ui/app.py +58 -0
src/app/__pycache__/main.cpython-311.pyc
ADDED
|
Binary file (920 Bytes). View file
|
|
|
src/app/api/__pycache__/routes.cpython-311.pyc
ADDED
|
Binary file (2.56 kB). View file
|
|
|
src/app/api/routes.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
from src.pipeline.query_pipeline import QueryPipeline
|
| 5 |
+
|
| 6 |
+
router = APIRouter()
|
| 7 |
+
_pipeline = None
|
| 8 |
+
|
| 9 |
+
def get_pipeline():
|
| 10 |
+
global _pipeline
|
| 11 |
+
if _pipeline is None:
|
| 12 |
+
_pipeline = QueryPipeline()
|
| 13 |
+
return _pipeline
|
| 14 |
+
|
| 15 |
+
class QueryRequest(BaseModel):
|
| 16 |
+
query: str
|
| 17 |
+
top_k_retrieval: Optional[int] = 20
|
| 18 |
+
top_k_rerank: Optional[int] = 5
|
| 19 |
+
use_hyde: Optional[bool] = False
|
| 20 |
+
|
| 21 |
+
class DocResponse(BaseModel):
|
| 22 |
+
content: str
|
| 23 |
+
score: float
|
| 24 |
+
|
| 25 |
+
class QueryResponse(BaseModel):
|
| 26 |
+
query: str
|
| 27 |
+
answer: str
|
| 28 |
+
context: List[tuple]
|
| 29 |
+
|
| 30 |
+
@router.post("/chat", response_model=QueryResponse)
|
| 31 |
+
async def chat(request: QueryRequest):
|
| 32 |
+
try:
|
| 33 |
+
pipe = get_pipeline()
|
| 34 |
+
|
| 35 |
+
result = pipe.run(
|
| 36 |
+
query=request.query,
|
| 37 |
+
top_k_retrieval=request.top_k_retrieval,
|
| 38 |
+
top_k_rerank=request.top_k_rerank
|
| 39 |
+
)
|
| 40 |
+
return result
|
| 41 |
+
except Exception as e:
|
| 42 |
+
raise HTTPException(status_code=500, detail=str(e))
|
src/app/main.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from src.app.api import routes
|
| 3 |
+
import uvicorn
|
| 4 |
+
|
| 5 |
+
app = FastAPI(title="Enterprise RAG Search API")
|
| 6 |
+
|
| 7 |
+
app.include_router(routes.router, prefix="/api/v1")
|
| 8 |
+
|
| 9 |
+
@app.get("/health")
|
| 10 |
+
def health():
|
| 11 |
+
return {"status": "ok"}
|
| 12 |
+
|
| 13 |
+
if __name__ == "__main__":
|
| 14 |
+
uvicorn.run("src.app.main:app", host="0.0.0.0", port=8000, reload=True)
|
src/embeddings/__pycache__/embedder.cpython-311.pyc
ADDED
|
Binary file (1.3 kB). View file
|
|
|
src/embeddings/embedder.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sentence_transformers import SentenceTransformer
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class Embedder:
|
| 5 |
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2", device: str = None):
|
| 6 |
+
if device is None:
|
| 7 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 8 |
+
else:
|
| 9 |
+
self.device = device
|
| 10 |
+
|
| 11 |
+
self.model = SentenceTransformer(model_name, device=self.device)
|
| 12 |
+
|
| 13 |
+
def embed(self, texts: list[str]):
|
| 14 |
+
return self.model.encode(texts, convert_to_numpy=True)
|
src/eval/__pycache__/hallucination.cpython-311.pyc
ADDED
|
Binary file (1.89 kB). View file
|
|
|
src/eval/__pycache__/relevancy.cpython-311.pyc
ADDED
|
Binary file (2.32 kB). View file
|
|
|
src/eval/__pycache__/retrieval_metrics.cpython-311.pyc
ADDED
|
Binary file (1.8 kB). View file
|
|
|
src/eval/hallucination.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.llm.llm_client import LLMClient
|
| 2 |
+
|
| 3 |
+
class HallucinationGrader:
|
| 4 |
+
def __init__(self, llm_client: LLMClient):
|
| 5 |
+
self.llm = llm_client
|
| 6 |
+
|
| 7 |
+
def grade(self, context: str, answer: str) -> dict:
|
| 8 |
+
"""
|
| 9 |
+
Returns hallucination score based on token overlap.
|
| 10 |
+
"""
|
| 11 |
+
# 1. Check for refusal
|
| 12 |
+
if "not enough information" in answer.lower():
|
| 13 |
+
return {"score": 0.0, "grounded": True}
|
| 14 |
+
|
| 15 |
+
# 2. Key Term Overlap
|
| 16 |
+
# Normalize and tokenize
|
| 17 |
+
def tokenize(text):
|
| 18 |
+
import re
|
| 19 |
+
text = text.lower()
|
| 20 |
+
tokens = re.findall(r'\w+', text)
|
| 21 |
+
# Remove stopwords
|
| 22 |
+
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'is', 'are', 'was', 'were'}
|
| 23 |
+
return set([t for t in tokens if t not in stop_words])
|
| 24 |
+
|
| 25 |
+
answer_tokens = tokenize(answer)
|
| 26 |
+
context_tokens = tokenize(context)
|
| 27 |
+
|
| 28 |
+
if not answer_tokens:
|
| 29 |
+
return {"score": 0.1, "grounded": True} # Default for empty answer
|
| 30 |
+
|
| 31 |
+
# Calculate overlap
|
| 32 |
+
intersection = answer_tokens.intersection(context_tokens)
|
| 33 |
+
overlap_ratio = len(intersection) / len(answer_tokens)
|
| 34 |
+
|
| 35 |
+
# User Rule: if overlap < 0.25 -> 1.0 (Hallucination)
|
| 36 |
+
# Else -> 0.1 (Grounded) -- User requested 0.1 specifically
|
| 37 |
+
|
| 38 |
+
if overlap_ratio < 0.25:
|
| 39 |
+
return {"score": 1.0, "grounded": False}
|
| 40 |
+
else:
|
| 41 |
+
return {"score": 0.1, "grounded": True}
|
src/eval/relevancy.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.llm.llm_client import LLMClient
|
| 2 |
+
|
| 3 |
+
class RelevancyGrader:
|
| 4 |
+
def __init__(self, llm_client: LLMClient):
|
| 5 |
+
self.llm = llm_client
|
| 6 |
+
|
| 7 |
+
def grade(self, query: str, answer: str) -> dict:
|
| 8 |
+
"""
|
| 9 |
+
Returns score (0-1) on whether the answer addresses the query.
|
| 10 |
+
"""
|
| 11 |
+
system_prompt = "You are a grader assessing if a generated answer is relevant to the user query."
|
| 12 |
+
user_prompt = f"""
|
| 13 |
+
User Query: {query}
|
| 14 |
+
Generated Answer: {answer}
|
| 15 |
+
|
| 16 |
+
Does the answer directly address the query?
|
| 17 |
+
Give a score between 0 and 1, and a boolean label (true/false).
|
| 18 |
+
Return JSON format: {{"score": 0.9, "relevant": true}}
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
messages = [
|
| 22 |
+
{"role": "system", "content": system_prompt},
|
| 23 |
+
{"role": "user", "content": user_prompt}
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
response = self.llm.chat(messages, response_format={"type": "json_object"})
|
| 28 |
+
import json
|
| 29 |
+
data = json.loads(response)
|
| 30 |
+
# print(f"DEBUG_RELEVANCY_RAW: {response}")
|
| 31 |
+
return data
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"DEBUG_RELEVANCY_ERROR: {e}")
|
| 34 |
+
print(f"DEBUG_RELEVANCY_RESPONSE_WAS: {locals().get('response', 'Not generated')}")
|
| 35 |
+
return {"score": 0.5, "relevant": False, "error": str(e)}
|
src/eval/retrieval_metrics.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
def mrr_score(relevant_doc_ids: List[str], retrieved_doc_ids: List[str]) -> float:
|
| 4 |
+
"""Calculates Mean Reciprocal Rank"""
|
| 5 |
+
for i, doc_id in enumerate(retrieved_doc_ids):
|
| 6 |
+
if doc_id in relevant_doc_ids:
|
| 7 |
+
return 1.0 / (i + 1)
|
| 8 |
+
return 0.0
|
| 9 |
+
|
| 10 |
+
def recall_at_k(relevant_doc_ids: List[str], retrieved_doc_ids: List[str], k: int) -> float:
|
| 11 |
+
"""Calculates Recall@K"""
|
| 12 |
+
retrieved_at_k = set(retrieved_doc_ids[:k])
|
| 13 |
+
relevant_set = set(relevant_doc_ids)
|
| 14 |
+
|
| 15 |
+
if not relevant_set:
|
| 16 |
+
return 0.0
|
| 17 |
+
|
| 18 |
+
hits = len(relevant_set.intersection(retrieved_at_k))
|
| 19 |
+
return hits / len(relevant_set)
|
| 20 |
+
|
| 21 |
+
def precision_at_k(relevant_doc_ids: List[str], retrieved_doc_ids: List[str], k: int) -> float:
|
| 22 |
+
"""Calculates Precision@K"""
|
| 23 |
+
retrieved_at_k = set(retrieved_doc_ids[:k])
|
| 24 |
+
relevant_set = set(relevant_doc_ids)
|
| 25 |
+
|
| 26 |
+
if not retrieved_at_k:
|
| 27 |
+
return 0.0
|
| 28 |
+
|
| 29 |
+
hits = len(relevant_set.intersection(retrieved_at_k))
|
| 30 |
+
return hits / len(retrieved_at_k)
|
src/indexer/__pycache__/bm25_index.cpython-311.pyc
ADDED
|
Binary file (2.94 kB). View file
|
|
|
src/indexer/__pycache__/faiss_index.cpython-311.pyc
ADDED
|
Binary file (2.15 kB). View file
|
|
|
src/indexer/bm25_index.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
from rank_bm25 import BM25Okapi
|
| 3 |
+
from typing import List
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
class BM25Index:
|
| 7 |
+
def __init__(self):
|
| 8 |
+
self.bm25 = None
|
| 9 |
+
self.corpus = []
|
| 10 |
+
|
| 11 |
+
def build(self, corpus: List[str]):
|
| 12 |
+
"""
|
| 13 |
+
Builds the BM25 index from a list of documents/chunks.
|
| 14 |
+
"""
|
| 15 |
+
self.corpus = corpus
|
| 16 |
+
tokenized_corpus = [doc.split(" ") for doc in corpus]
|
| 17 |
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
| 18 |
+
|
| 19 |
+
def search(self, query: str, top_k: int = 10):
|
| 20 |
+
if not self.bm25:
|
| 21 |
+
raise ValueError("Index not built.")
|
| 22 |
+
tokenized_query = query.split(" ")
|
| 23 |
+
scores = self.bm25.get_scores(tokenized_query)
|
| 24 |
+
top_n = self.bm25.get_top_n(tokenized_query, self.corpus, n=top_k)
|
| 25 |
+
return top_n, scores
|
| 26 |
+
|
| 27 |
+
def save(self, path: str):
|
| 28 |
+
with open(path, 'wb') as f:
|
| 29 |
+
pickle.dump((self.bm25, self.corpus), f)
|
| 30 |
+
|
| 31 |
+
def load(self, path: str):
|
| 32 |
+
with open(path, 'rb') as f:
|
| 33 |
+
self.bm25, self.corpus = pickle.load(f)
|
src/indexer/faiss_index.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import faiss
|
| 2 |
+
import numpy as np
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
class FaissIndex:
|
| 6 |
+
def __init__(self, dimension: int):
|
| 7 |
+
self.dimension = dimension
|
| 8 |
+
self.index = faiss.IndexFlatL2(dimension)
|
| 9 |
+
# Store metadata/mapping if needed, for now just simpler index
|
| 10 |
+
|
| 11 |
+
def add(self, embeddings: np.ndarray):
|
| 12 |
+
if embeddings.shape[1] != self.dimension:
|
| 13 |
+
raise ValueError(f"Embedding dimension mismatch. Expected {self.dimension}, got {embeddings.shape[1]}")
|
| 14 |
+
self.index.add(embeddings)
|
| 15 |
+
|
| 16 |
+
def search(self, query_embedding: np.ndarray, top_k: int = 10):
|
| 17 |
+
distances, indices = self.index.search(query_embedding, top_k)
|
| 18 |
+
return distances, indices
|
| 19 |
+
|
| 20 |
+
def save(self, path: str):
|
| 21 |
+
faiss.write_index(self.index, path)
|
| 22 |
+
|
| 23 |
+
def load(self, path: str):
|
| 24 |
+
self.index = faiss.read_index(path)
|
src/ingestion/__pycache__/chunkers.cpython-311.pyc
ADDED
|
Binary file (1.94 kB). View file
|
|
|
src/ingestion/__pycache__/cleaner.cpython-311.pyc
ADDED
|
Binary file (495 Bytes). View file
|
|
|
src/ingestion/__pycache__/readers.cpython-311.pyc
ADDED
|
Binary file (1.82 kB). View file
|
|
|
src/ingestion/chunkers.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
class Chunker:
|
| 4 |
+
def chunk(self, text: str) -> List[str]:
|
| 5 |
+
raise NotImplementedError
|
| 6 |
+
|
| 7 |
+
class SlidingWindowChunker(Chunker):
|
| 8 |
+
def __init__(self, chunk_size: int = 512, overlap: int = 50):
|
| 9 |
+
self.chunk_size = chunk_size
|
| 10 |
+
self.overlap = overlap
|
| 11 |
+
|
| 12 |
+
def chunk(self, text: str) -> List[str]:
|
| 13 |
+
words = text.split()
|
| 14 |
+
chunks = []
|
| 15 |
+
for i in range(0, len(words), self.chunk_size - self.overlap):
|
| 16 |
+
chunk_words = words[i : i + self.chunk_size]
|
| 17 |
+
chunks.append(" ".join(chunk_words))
|
| 18 |
+
if i + self.chunk_size >= len(words):
|
| 19 |
+
break
|
| 20 |
+
return chunks
|
src/ingestion/cleaner.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
def clean_text(text: str) -> str:
|
| 4 |
+
# Remove excessive whitespace
|
| 5 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
| 6 |
+
return text
|
src/ingestion/ingest.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
from src.ingestion.readers import get_reader
|
| 9 |
+
from src.ingestion.cleaner import clean_text
|
| 10 |
+
from src.ingestion.chunkers import SlidingWindowChunker
|
| 11 |
+
from src.embeddings.embedder import Embedder
|
| 12 |
+
from src.indexer.bm25_index import BM25Index
|
| 13 |
+
from src.indexer.faiss_index import FaissIndex
|
| 14 |
+
|
| 15 |
+
DATA_DIR = "data"
|
| 16 |
+
RAW_DIR = os.path.join(DATA_DIR, "raw")
|
| 17 |
+
INDEX_DIR = os.path.join(DATA_DIR, "index")
|
| 18 |
+
|
| 19 |
+
class IngestionPipeline:
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self.chunker = SlidingWindowChunker()
|
| 22 |
+
self.embedder = Embedder(model_name="all-MiniLM-L6-v2")
|
| 23 |
+
self.bm25_index = BM25Index()
|
| 24 |
+
# Dimension for all-MiniLM-L6-v2 is 384
|
| 25 |
+
self.faiss_index = FaissIndex(dimension=384)
|
| 26 |
+
|
| 27 |
+
def run(self):
|
| 28 |
+
print("Starting ingestion...")
|
| 29 |
+
files = glob.glob(os.path.join(RAW_DIR, "*.*"))
|
| 30 |
+
|
| 31 |
+
all_chunks = []
|
| 32 |
+
doc_map = [] # To map chunk index back to metadata/content if needed
|
| 33 |
+
|
| 34 |
+
# 1. Read, Clean, Chunk
|
| 35 |
+
print("Processing files...")
|
| 36 |
+
for file_path in tqdm(files):
|
| 37 |
+
path = Path(file_path)
|
| 38 |
+
try:
|
| 39 |
+
reader = get_reader(path)
|
| 40 |
+
raw_text = reader.read(path)
|
| 41 |
+
cleaned_text = clean_text(raw_text)
|
| 42 |
+
chunks = self.chunker.chunk(cleaned_text)
|
| 43 |
+
|
| 44 |
+
for chunk in chunks:
|
| 45 |
+
all_chunks.append(chunk)
|
| 46 |
+
doc_map.append({"source": str(path), "content": chunk})
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Error processing {path}: {e}")
|
| 49 |
+
|
| 50 |
+
print(f"Total chunks generated: {len(all_chunks)}")
|
| 51 |
+
|
| 52 |
+
# 2. Build BM25 Index
|
| 53 |
+
print("Building BM25 Index...")
|
| 54 |
+
self.bm25_index.build(all_chunks)
|
| 55 |
+
os.makedirs(INDEX_DIR, exist_ok=True)
|
| 56 |
+
self.bm25_index.save(os.path.join(INDEX_DIR, "bm25.pkl"))
|
| 57 |
+
|
| 58 |
+
# 3. Embed and Build FAISS Index
|
| 59 |
+
if not os.getenv("DISABLE_FAISS"):
|
| 60 |
+
print("Embedding chunks and building FAISS Index...")
|
| 61 |
+
batch_size = 32
|
| 62 |
+
for i in range(0, len(all_chunks), batch_size):
|
| 63 |
+
batch = all_chunks[i : i + batch_size]
|
| 64 |
+
embeddings = self.embedder.embed(batch)
|
| 65 |
+
self.faiss_index.add(embeddings)
|
| 66 |
+
|
| 67 |
+
self.faiss_index.save(os.path.join(INDEX_DIR, "faiss.index"))
|
| 68 |
+
else:
|
| 69 |
+
print("Skipping FAISS build due to DISABLE_FAISS environment variable.")
|
| 70 |
+
# Create a dummy file to satisfy file existence checks if any (though lazy loaded)
|
| 71 |
+
with open(os.path.join(INDEX_DIR, "faiss.index"), "w") as f:
|
| 72 |
+
f.write("dummy")
|
| 73 |
+
|
| 74 |
+
# Save doc_map (simple persistence for retrieval lookup)
|
| 75 |
+
import pickle
|
| 76 |
+
with open(os.path.join(INDEX_DIR, "doc_map.pkl"), "wb") as f:
|
| 77 |
+
pickle.dump(doc_map, f)
|
| 78 |
+
|
| 79 |
+
print("Ingestion complete.")
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
pipeline = IngestionPipeline()
|
| 83 |
+
pipeline.run()
|
src/ingestion/readers.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import List, Dict
|
| 3 |
+
|
| 4 |
+
class DocumentReader:
|
| 5 |
+
def read(self, file_path: Path) -> str:
|
| 6 |
+
raise NotImplementedError
|
| 7 |
+
|
| 8 |
+
class TextReader(DocumentReader):
|
| 9 |
+
def read(self, file_path: Path) -> str:
|
| 10 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 11 |
+
return f.read()
|
| 12 |
+
|
| 13 |
+
def get_reader(file_path: Path) -> DocumentReader:
|
| 14 |
+
ext = file_path.suffix.lower()
|
| 15 |
+
if ext in ['.txt', '.md']:
|
| 16 |
+
return TextReader()
|
| 17 |
+
# Add PDF/Docx support here later
|
| 18 |
+
raise ValueError(f"Unsupported file type: {ext}")
|
src/llm/__pycache__/llm_client.cpython-311.pyc
ADDED
|
Binary file (4.14 kB). View file
|
|
|
src/llm/llm_client.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import openai
|
| 3 |
+
from typing import List, Dict, Any
|
| 4 |
+
|
| 5 |
+
class LLMClient:
|
| 6 |
+
def chat(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
| 7 |
+
raise NotImplementedError
|
| 8 |
+
|
| 9 |
+
class OpenAIClient(LLMClient):
|
| 10 |
+
def __init__(self, api_key: str = None, model: str = "gpt-4o"):
|
| 11 |
+
self.client = openai.OpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY"))
|
| 12 |
+
self.model = model
|
| 13 |
+
|
| 14 |
+
def chat(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
| 15 |
+
response = self.client.chat.completions.create(
|
| 16 |
+
model=self.model,
|
| 17 |
+
messages=messages,
|
| 18 |
+
**kwargs
|
| 19 |
+
)
|
| 20 |
+
return response.choices[0].message.content
|
| 21 |
+
|
| 22 |
+
class VLLMClient(LLMClient):
|
| 23 |
+
def __init__(self, api_url: str = None, model: str = None):
|
| 24 |
+
self.api_url = api_url or os.getenv("VLLM_API_URL", "http://localhost:8000/v1")
|
| 25 |
+
self.model = model or os.getenv("VLLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
|
| 26 |
+
# vLLM is OpenAI compatible
|
| 27 |
+
self.client = openai.OpenAI(
|
| 28 |
+
base_url=self.api_url,
|
| 29 |
+
api_key="EMPTY"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def chat(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
| 33 |
+
response = self.client.chat.completions.create(
|
| 34 |
+
model=self.model,
|
| 35 |
+
messages=messages,
|
| 36 |
+
**kwargs
|
| 37 |
+
)
|
| 38 |
+
return response.choices[0].message.content
|
| 39 |
+
|
| 40 |
+
class GroqClient(LLMClient):
|
| 41 |
+
def __init__(self, api_key: str = None, model: str = "llama-3.3-70b-versatile"):
|
| 42 |
+
self.api_key = api_key or os.getenv("GROQ_API_KEY")
|
| 43 |
+
self.model = model
|
| 44 |
+
self.client = openai.OpenAI(
|
| 45 |
+
base_url="https://api.groq.com/openai/v1",
|
| 46 |
+
api_key=self.api_key
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def chat(self, messages: List[Dict[str, str]], **kwargs) -> str:
|
| 50 |
+
# Default behavior
|
| 51 |
+
response = self.client.chat.completions.create(
|
| 52 |
+
model=self.model,
|
| 53 |
+
messages=messages,
|
| 54 |
+
**kwargs
|
| 55 |
+
)
|
| 56 |
+
return response.choices[0].message.content
|
src/pipeline/__pycache__/context_opt.cpython-311.pyc
ADDED
|
Binary file (2.72 kB). View file
|
|
|
src/pipeline/__pycache__/query_pipeline.cpython-311.pyc
ADDED
|
Binary file (4.48 kB). View file
|
|
|
src/pipeline/context_opt.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 3 |
+
|
| 4 |
+
def maximal_marginal_relevance(query_embedding: np.ndarray, doc_embeddings: np.ndarray, lambda_mult: float = 0.5, top_k: int = 5):
|
| 5 |
+
"""
|
| 6 |
+
Selects docs that are relevant to query but diverse from each other.
|
| 7 |
+
"""
|
| 8 |
+
if len(doc_embeddings) == 0:
|
| 9 |
+
return []
|
| 10 |
+
|
| 11 |
+
# Simple MMR implementation
|
| 12 |
+
selected_indices = []
|
| 13 |
+
candidate_indices = list(range(len(doc_embeddings)))
|
| 14 |
+
|
| 15 |
+
for _ in range(top_k):
|
| 16 |
+
best_score = -np.inf
|
| 17 |
+
best_idx = -1
|
| 18 |
+
|
| 19 |
+
for idx in candidate_indices:
|
| 20 |
+
# Relevance
|
| 21 |
+
rel_score = cosine_similarity(query_embedding.reshape(1, -1), doc_embeddings[idx].reshape(1, -1))[0][0]
|
| 22 |
+
|
| 23 |
+
# Diversity (sim to already selected)
|
| 24 |
+
if not selected_indices:
|
| 25 |
+
div_score = 0
|
| 26 |
+
else:
|
| 27 |
+
sims = cosine_similarity(doc_embeddings[idx].reshape(1, -1), doc_embeddings[selected_indices])[0]
|
| 28 |
+
div_score = np.max(sims)
|
| 29 |
+
|
| 30 |
+
mmr_score = lambda_mult * rel_score - (1 - lambda_mult) * div_score
|
| 31 |
+
|
| 32 |
+
if mmr_score > best_score:
|
| 33 |
+
best_score = mmr_score
|
| 34 |
+
best_idx = idx
|
| 35 |
+
|
| 36 |
+
if best_idx != -1:
|
| 37 |
+
selected_indices.append(best_idx)
|
| 38 |
+
candidate_indices.remove(best_idx)
|
| 39 |
+
|
| 40 |
+
return selected_indices
|
| 41 |
+
|
| 42 |
+
def deduplicate_docs(docs: list[dict], threshold: float = 0.95) -> list[dict]:
|
| 43 |
+
"""
|
| 44 |
+
Remove near-duplicates based on content string similarity (simple)
|
| 45 |
+
or just exact match for now to be fast.
|
| 46 |
+
"""
|
| 47 |
+
seen = set()
|
| 48 |
+
unique_docs = []
|
| 49 |
+
for doc in docs:
|
| 50 |
+
# Assuming doc is a string or dict with 'content'
|
| 51 |
+
content = doc if isinstance(doc, str) else doc.get('content', '')
|
| 52 |
+
if content not in seen:
|
| 53 |
+
seen.add(content)
|
| 54 |
+
unique_docs.append(doc)
|
| 55 |
+
return unique_docs
|
src/pipeline/query_pipeline.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from src.retriever.hybrid_retriever import HybridRetriever
|
| 4 |
+
from src.retriever.hyde import HyDERetriever
|
| 5 |
+
from src.reranker.cross_encoder import Reranker
|
| 6 |
+
from src.llm.llm_client import OpenAIClient, VLLMClient, GroqClient
|
| 7 |
+
from src.embeddings.embedder import Embedder
|
| 8 |
+
from src.pipeline.context_opt import deduplicate_docs
|
| 9 |
+
|
| 10 |
+
class QueryPipeline:
|
| 11 |
+
def __init__(self, use_hyde: bool = False):
|
| 12 |
+
self.embedder = Embedder()
|
| 13 |
+
self.retriever = HybridRetriever(
|
| 14 |
+
bm25_path="data/index/bm25.pkl",
|
| 15 |
+
faiss_path="data/index/faiss.index",
|
| 16 |
+
doc_map_path="data/index/doc_map.pkl",
|
| 17 |
+
embedder=self.embedder
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# LLM Client Strategy
|
| 21 |
+
if os.getenv("GROQ_API_KEY"):
|
| 22 |
+
self.llm = GroqClient()
|
| 23 |
+
elif os.getenv("VLLM_API_URL"):
|
| 24 |
+
self.llm = VLLMClient()
|
| 25 |
+
else:
|
| 26 |
+
self.llm = OpenAIClient()
|
| 27 |
+
|
| 28 |
+
if use_hyde:
|
| 29 |
+
self.retriever = HyDERetriever(self.llm, self.retriever)
|
| 30 |
+
|
| 31 |
+
self.reranker = Reranker()
|
| 32 |
+
|
| 33 |
+
def run(self, query: str, top_k_retrieval: int = 20, top_k_rerank: int = 5):
|
| 34 |
+
# 1. Retrieve
|
| 35 |
+
print(f"Retrieving for query: {query}")
|
| 36 |
+
retrieved_docs = self.retriever.search(query, top_k=top_k_retrieval)
|
| 37 |
+
|
| 38 |
+
# 2. Deduplicate
|
| 39 |
+
unique_docs = deduplicate_docs(retrieved_docs)
|
| 40 |
+
|
| 41 |
+
# 3. Rerank
|
| 42 |
+
# Reranker expects strings
|
| 43 |
+
doc_contents = [d if isinstance(d, str) else d['content'] for d in unique_docs]
|
| 44 |
+
reranked = self.reranker.rerank(query, doc_contents, top_k=top_k_rerank)
|
| 45 |
+
|
| 46 |
+
# 4. Generate
|
| 47 |
+
|
| 48 |
+
# Retrieval Score Gate
|
| 49 |
+
RETRIEVAL_SCORE_THRESHOLD = -4.0
|
| 50 |
+
|
| 51 |
+
# reranked is list of (doc, score)
|
| 52 |
+
if not reranked or reranked[0][1] < RETRIEVAL_SCORE_THRESHOLD:
|
| 53 |
+
return {
|
| 54 |
+
"query": query,
|
| 55 |
+
"answer": "I do not have enough information in the provided documents to answer this question.",
|
| 56 |
+
"context": [],
|
| 57 |
+
"retrieval_score": reranked[0][1] if reranked else -99.9,
|
| 58 |
+
"hallucination_score": 0.0,
|
| 59 |
+
"groundedness": 1.0
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
context_text = "\n\n".join([doc for doc, score in reranked])
|
| 63 |
+
|
| 64 |
+
SYSTEM_PROMPT = """
|
| 65 |
+
You are an enterprise-grade question answering system.
|
| 66 |
+
|
| 67 |
+
Rules:
|
| 68 |
+
1. Answer strictly using ONLY the provided context.
|
| 69 |
+
2. DO NOT use prior knowledge or assumptions.
|
| 70 |
+
3. If the answer is not explicitly stated in the context, respond EXACTLY with:
|
| 71 |
+
"I do not have enough information in the provided documents to answer this question."
|
| 72 |
+
4. Do not add explanations, guesses, or external facts.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
user_prompt = f"""
|
| 76 |
+
{SYSTEM_PROMPT}
|
| 77 |
+
|
| 78 |
+
Context:
|
| 79 |
+
{context_text}
|
| 80 |
+
|
| 81 |
+
Question:
|
| 82 |
+
{query}
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
messages = [
|
| 86 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 87 |
+
{"role": "user", "content": user_prompt}
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
answer = self.llm.chat(messages)
|
| 91 |
+
|
| 92 |
+
return {
|
| 93 |
+
"query": query,
|
| 94 |
+
"answer": answer,
|
| 95 |
+
"context": reranked,
|
| 96 |
+
"retrieval_score": reranked[0][1]
|
| 97 |
+
}
|
src/reranker/__pycache__/cross_encoder.cpython-311.pyc
ADDED
|
Binary file (1.81 kB). View file
|
|
|
src/reranker/cross_encoder.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sentence_transformers import CrossEncoder
|
| 2 |
+
|
| 3 |
+
class Reranker:
|
| 4 |
+
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
| 5 |
+
self.model = CrossEncoder(model_name)
|
| 6 |
+
|
| 7 |
+
def rerank(self, query: str, docs: list[str], top_k: int = 5):
|
| 8 |
+
if not docs:
|
| 9 |
+
return []
|
| 10 |
+
|
| 11 |
+
pairs = [[query, doc] for doc in docs]
|
| 12 |
+
scores = self.model.predict(pairs).tolist()
|
| 13 |
+
|
| 14 |
+
# Combine docs with scores
|
| 15 |
+
doc_scores = list(zip(docs, scores))
|
| 16 |
+
# Sort by score descending
|
| 17 |
+
doc_scores.sort(key=lambda x: x[1], reverse=True)
|
| 18 |
+
|
| 19 |
+
return doc_scores[:top_k]
|
src/retriever/__pycache__/hybrid_retriever.cpython-311.pyc
ADDED
|
Binary file (4.55 kB). View file
|
|
|
src/retriever/__pycache__/hyde.cpython-311.pyc
ADDED
|
Binary file (1.81 kB). View file
|
|
|
src/retriever/hybrid_retriever.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
import numpy as np
|
| 3 |
+
from src.indexer.bm25_index import BM25Index
|
| 4 |
+
from src.indexer.faiss_index import FaissIndex
|
| 5 |
+
from src.embeddings.embedder import Embedder
|
| 6 |
+
|
| 7 |
+
class HybridRetriever:
|
| 8 |
+
def __init__(self, bm25_path: str, faiss_path: str, doc_map_path: str, embedder: Embedder):
|
| 9 |
+
self.bm25 = BM25Index()
|
| 10 |
+
self.bm25.load(bm25_path)
|
| 11 |
+
|
| 12 |
+
self.embedder = embedder
|
| 13 |
+
self.embedder = embedder
|
| 14 |
+
self.faiss = None
|
| 15 |
+
import os
|
| 16 |
+
if not os.getenv("DISABLE_FAISS"):
|
| 17 |
+
try:
|
| 18 |
+
self.faiss = FaissIndex(dimension=384) # adjust dimension if needed
|
| 19 |
+
self.faiss.load(faiss_path)
|
| 20 |
+
print("Successfully loaded FAISS index.")
|
| 21 |
+
except Exception as e:
|
| 22 |
+
print(f"WARNING: Could not load FAISS index ({e}). Running in BM25-only mode.")
|
| 23 |
+
else:
|
| 24 |
+
print("FAISS disabled via environment variable. Running in BM25-only mode.")
|
| 25 |
+
|
| 26 |
+
# Load doc map
|
| 27 |
+
import pickle
|
| 28 |
+
with open(doc_map_path, 'rb') as f:
|
| 29 |
+
self.doc_map = pickle.load(f)
|
| 30 |
+
|
| 31 |
+
def search(self, query: str, top_k: int = 10, alpha: float = 0.5) -> List[dict]:
|
| 32 |
+
"""
|
| 33 |
+
Hybrid search using BM25 and Dense embeddings.
|
| 34 |
+
alpha: weight for dense score (0 = pure BM25, 1 = pure Dense)
|
| 35 |
+
"""
|
| 36 |
+
# 1. BM25 Search
|
| 37 |
+
# We need to normalize scores to combine them properly, usually RRF is safer if scores are not calibrated
|
| 38 |
+
# For simplicity here, we get top N and use RRF
|
| 39 |
+
|
| 40 |
+
top_n = top_k * 2
|
| 41 |
+
|
| 42 |
+
# BM25
|
| 43 |
+
bm25_docs, bm25_scores = self.bm25.search(query, top_k=top_n)
|
| 44 |
+
|
| 45 |
+
scores = {}
|
| 46 |
+
|
| 47 |
+
# Process BM25
|
| 48 |
+
for rank, doc in enumerate(bm25_docs):
|
| 49 |
+
key = doc
|
| 50 |
+
scores[key] = scores.get(key, 0) + (1 / (60 + rank))
|
| 51 |
+
|
| 52 |
+
# Dense (Only if FAISS is loaded)
|
| 53 |
+
if self.faiss:
|
| 54 |
+
try:
|
| 55 |
+
query_emb = self.embedder.embed([query])
|
| 56 |
+
dense_dists, dense_indices = self.faiss.search(query_emb, top_k=top_n)
|
| 57 |
+
|
| 58 |
+
# Merge using Reciprocal Rank Fusion (RRF)
|
| 59 |
+
# Dense indices refer to doc_map
|
| 60 |
+
for rank, idx in enumerate(dense_indices[0]):
|
| 61 |
+
if idx == -1: continue
|
| 62 |
+
doc_data = self.doc_map[idx]
|
| 63 |
+
key = doc_data['content']
|
| 64 |
+
scores[key] = scores.get(key, 0) + (1 / (60 + rank))
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"Error during dense search: {e}")
|
| 67 |
+
|
| 68 |
+
# Sort by RRF score
|
| 69 |
+
sorted_docs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
| 70 |
+
|
| 71 |
+
return [doc for doc, score in sorted_docs[:top_k]]
|
src/retriever/hyde.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.llm.llm_client import LLMClient
|
| 2 |
+
|
| 3 |
+
class HyDERetriever:
|
| 4 |
+
def __init__(self, llm_client: LLMClient, base_retriever):
|
| 5 |
+
self.llm = llm_client
|
| 6 |
+
self.retriever = base_retriever
|
| 7 |
+
|
| 8 |
+
def generate_hypothetical_doc(self, query: str) -> str:
|
| 9 |
+
messages = [
|
| 10 |
+
{"role": "system", "content": "You are a helpful assistant. Write a hypothetical answer to the user's question. Do not include any explanation, just the answer."},
|
| 11 |
+
{"role": "user", "content": query}
|
| 12 |
+
]
|
| 13 |
+
return self.llm.chat(messages, temperature=0.7)
|
| 14 |
+
|
| 15 |
+
def search(self, query: str, top_k: int = 10):
|
| 16 |
+
# 1. Generate hypothetical doc
|
| 17 |
+
hypothetical_doc = self.generate_hypothetical_doc(query)
|
| 18 |
+
print(f"DEBUG: HyDE Doc: {hypothetical_doc[:100]}...")
|
| 19 |
+
|
| 20 |
+
# 2. Retrieve using the hypothetical doc as query
|
| 21 |
+
return self.retriever.search(hypothetical_doc, top_k=top_k)
|
src/ui/app.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import requests
|
| 3 |
+
import json
|
| 4 |
+
print("UI Starting up...")
|
| 5 |
+
|
| 6 |
+
st.set_page_config(page_title="Enterprise RAG Search", layout="wide")
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
API_URL = os.getenv("API_URL", "http://localhost:8000/api/v1/chat")
|
| 11 |
+
|
| 12 |
+
st.title("Enterprise RAG Search")
|
| 13 |
+
|
| 14 |
+
with st.sidebar:
|
| 15 |
+
st.header("Configuration")
|
| 16 |
+
top_k_retrieval = st.slider("Retrieval Top-K", 5, 50, 20)
|
| 17 |
+
top_k_rerank = st.slider("Rerank Top-K", 1, 10, 5)
|
| 18 |
+
# use_hyde = st.checkbox("Use HyDE", value=False)
|
| 19 |
+
|
| 20 |
+
query = st.chat_input("Enter your query...")
|
| 21 |
+
|
| 22 |
+
if query:
|
| 23 |
+
st.session_state.messages = st.session_state.get("messages", [])
|
| 24 |
+
st.session_state.messages.append({"role": "user", "content": query})
|
| 25 |
+
|
| 26 |
+
# s = requests.Session()
|
| 27 |
+
|
| 28 |
+
for msg in st.session_state.get("messages", []):
|
| 29 |
+
with st.chat_message(msg["role"]):
|
| 30 |
+
st.write(msg["content"])
|
| 31 |
+
|
| 32 |
+
if query:
|
| 33 |
+
with st.chat_message("assistant"):
|
| 34 |
+
with st.spinner("Searching..."):
|
| 35 |
+
try:
|
| 36 |
+
payload = {
|
| 37 |
+
"query": query,
|
| 38 |
+
"top_k_retrieval": top_k_retrieval,
|
| 39 |
+
"top_k_rerank": top_k_rerank,
|
| 40 |
+
# "use_hyde": use_hyde
|
| 41 |
+
}
|
| 42 |
+
response = requests.post(API_URL, json=payload)
|
| 43 |
+
response.raise_for_status()
|
| 44 |
+
data = response.json()
|
| 45 |
+
|
| 46 |
+
answer = data["answer"]
|
| 47 |
+
st.write(answer)
|
| 48 |
+
|
| 49 |
+
with st.expander("View Context"):
|
| 50 |
+
for i, (doc, score) in enumerate(data["context"]):
|
| 51 |
+
st.markdown(f"**Relevance Score:** {score:.4f}")
|
| 52 |
+
st.text(doc)
|
| 53 |
+
st.divider()
|
| 54 |
+
|
| 55 |
+
st.session_state.messages.append({"role": "assistant", "content": answer})
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
st.error(f"Error: {e}")
|