yuvis commited on
Commit
f4c70c8
·
verified ·
1 Parent(s): 2a3a1de

Upload folder using huggingface_hub

Browse files
Files changed (36) hide show
  1. src/app/__pycache__/main.cpython-311.pyc +0 -0
  2. src/app/api/__pycache__/routes.cpython-311.pyc +0 -0
  3. src/app/api/routes.py +42 -0
  4. src/app/main.py +14 -0
  5. src/embeddings/__pycache__/embedder.cpython-311.pyc +0 -0
  6. src/embeddings/embedder.py +14 -0
  7. src/eval/__pycache__/hallucination.cpython-311.pyc +0 -0
  8. src/eval/__pycache__/relevancy.cpython-311.pyc +0 -0
  9. src/eval/__pycache__/retrieval_metrics.cpython-311.pyc +0 -0
  10. src/eval/hallucination.py +41 -0
  11. src/eval/relevancy.py +35 -0
  12. src/eval/retrieval_metrics.py +30 -0
  13. src/indexer/__pycache__/bm25_index.cpython-311.pyc +0 -0
  14. src/indexer/__pycache__/faiss_index.cpython-311.pyc +0 -0
  15. src/indexer/bm25_index.py +33 -0
  16. src/indexer/faiss_index.py +24 -0
  17. src/ingestion/__pycache__/chunkers.cpython-311.pyc +0 -0
  18. src/ingestion/__pycache__/cleaner.cpython-311.pyc +0 -0
  19. src/ingestion/__pycache__/readers.cpython-311.pyc +0 -0
  20. src/ingestion/chunkers.py +20 -0
  21. src/ingestion/cleaner.py +6 -0
  22. src/ingestion/ingest.py +83 -0
  23. src/ingestion/readers.py +18 -0
  24. src/llm/__pycache__/llm_client.cpython-311.pyc +0 -0
  25. src/llm/llm_client.py +56 -0
  26. src/pipeline/__pycache__/context_opt.cpython-311.pyc +0 -0
  27. src/pipeline/__pycache__/query_pipeline.cpython-311.pyc +0 -0
  28. src/pipeline/context_opt.py +55 -0
  29. src/pipeline/query_pipeline.py +97 -0
  30. src/reranker/__pycache__/cross_encoder.cpython-311.pyc +0 -0
  31. src/reranker/cross_encoder.py +19 -0
  32. src/retriever/__pycache__/hybrid_retriever.cpython-311.pyc +0 -0
  33. src/retriever/__pycache__/hyde.cpython-311.pyc +0 -0
  34. src/retriever/hybrid_retriever.py +71 -0
  35. src/retriever/hyde.py +21 -0
  36. src/ui/app.py +58 -0
src/app/__pycache__/main.cpython-311.pyc ADDED
Binary file (920 Bytes). View file
 
src/app/api/__pycache__/routes.cpython-311.pyc ADDED
Binary file (2.56 kB). View file
 
src/app/api/routes.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List, Optional
4
+ from src.pipeline.query_pipeline import QueryPipeline
5
+
6
+ router = APIRouter()
7
+ _pipeline = None
8
+
9
+ def get_pipeline():
10
+ global _pipeline
11
+ if _pipeline is None:
12
+ _pipeline = QueryPipeline()
13
+ return _pipeline
14
+
15
+ class QueryRequest(BaseModel):
16
+ query: str
17
+ top_k_retrieval: Optional[int] = 20
18
+ top_k_rerank: Optional[int] = 5
19
+ use_hyde: Optional[bool] = False
20
+
21
+ class DocResponse(BaseModel):
22
+ content: str
23
+ score: float
24
+
25
+ class QueryResponse(BaseModel):
26
+ query: str
27
+ answer: str
28
+ context: List[tuple]
29
+
30
+ @router.post("/chat", response_model=QueryResponse)
31
+ async def chat(request: QueryRequest):
32
+ try:
33
+ pipe = get_pipeline()
34
+
35
+ result = pipe.run(
36
+ query=request.query,
37
+ top_k_retrieval=request.top_k_retrieval,
38
+ top_k_rerank=request.top_k_rerank
39
+ )
40
+ return result
41
+ except Exception as e:
42
+ raise HTTPException(status_code=500, detail=str(e))
src/app/main.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from src.app.api import routes
3
+ import uvicorn
4
+
5
+ app = FastAPI(title="Enterprise RAG Search API")
6
+
7
+ app.include_router(routes.router, prefix="/api/v1")
8
+
9
+ @app.get("/health")
10
+ def health():
11
+ return {"status": "ok"}
12
+
13
+ if __name__ == "__main__":
14
+ uvicorn.run("src.app.main:app", host="0.0.0.0", port=8000, reload=True)
src/embeddings/__pycache__/embedder.cpython-311.pyc ADDED
Binary file (1.3 kB). View file
 
src/embeddings/embedder.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import torch
3
+
4
+ class Embedder:
5
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2", device: str = None):
6
+ if device is None:
7
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ else:
9
+ self.device = device
10
+
11
+ self.model = SentenceTransformer(model_name, device=self.device)
12
+
13
+ def embed(self, texts: list[str]):
14
+ return self.model.encode(texts, convert_to_numpy=True)
src/eval/__pycache__/hallucination.cpython-311.pyc ADDED
Binary file (1.89 kB). View file
 
src/eval/__pycache__/relevancy.cpython-311.pyc ADDED
Binary file (2.32 kB). View file
 
src/eval/__pycache__/retrieval_metrics.cpython-311.pyc ADDED
Binary file (1.8 kB). View file
 
src/eval/hallucination.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.llm.llm_client import LLMClient
2
+
3
+ class HallucinationGrader:
4
+ def __init__(self, llm_client: LLMClient):
5
+ self.llm = llm_client
6
+
7
+ def grade(self, context: str, answer: str) -> dict:
8
+ """
9
+ Returns hallucination score based on token overlap.
10
+ """
11
+ # 1. Check for refusal
12
+ if "not enough information" in answer.lower():
13
+ return {"score": 0.0, "grounded": True}
14
+
15
+ # 2. Key Term Overlap
16
+ # Normalize and tokenize
17
+ def tokenize(text):
18
+ import re
19
+ text = text.lower()
20
+ tokens = re.findall(r'\w+', text)
21
+ # Remove stopwords
22
+ stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'is', 'are', 'was', 'were'}
23
+ return set([t for t in tokens if t not in stop_words])
24
+
25
+ answer_tokens = tokenize(answer)
26
+ context_tokens = tokenize(context)
27
+
28
+ if not answer_tokens:
29
+ return {"score": 0.1, "grounded": True} # Default for empty answer
30
+
31
+ # Calculate overlap
32
+ intersection = answer_tokens.intersection(context_tokens)
33
+ overlap_ratio = len(intersection) / len(answer_tokens)
34
+
35
+ # User Rule: if overlap < 0.25 -> 1.0 (Hallucination)
36
+ # Else -> 0.1 (Grounded) -- User requested 0.1 specifically
37
+
38
+ if overlap_ratio < 0.25:
39
+ return {"score": 1.0, "grounded": False}
40
+ else:
41
+ return {"score": 0.1, "grounded": True}
src/eval/relevancy.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.llm.llm_client import LLMClient
2
+
3
+ class RelevancyGrader:
4
+ def __init__(self, llm_client: LLMClient):
5
+ self.llm = llm_client
6
+
7
+ def grade(self, query: str, answer: str) -> dict:
8
+ """
9
+ Returns score (0-1) on whether the answer addresses the query.
10
+ """
11
+ system_prompt = "You are a grader assessing if a generated answer is relevant to the user query."
12
+ user_prompt = f"""
13
+ User Query: {query}
14
+ Generated Answer: {answer}
15
+
16
+ Does the answer directly address the query?
17
+ Give a score between 0 and 1, and a boolean label (true/false).
18
+ Return JSON format: {{"score": 0.9, "relevant": true}}
19
+ """
20
+
21
+ messages = [
22
+ {"role": "system", "content": system_prompt},
23
+ {"role": "user", "content": user_prompt}
24
+ ]
25
+
26
+ try:
27
+ response = self.llm.chat(messages, response_format={"type": "json_object"})
28
+ import json
29
+ data = json.loads(response)
30
+ # print(f"DEBUG_RELEVANCY_RAW: {response}")
31
+ return data
32
+ except Exception as e:
33
+ print(f"DEBUG_RELEVANCY_ERROR: {e}")
34
+ print(f"DEBUG_RELEVANCY_RESPONSE_WAS: {locals().get('response', 'Not generated')}")
35
+ return {"score": 0.5, "relevant": False, "error": str(e)}
src/eval/retrieval_metrics.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ def mrr_score(relevant_doc_ids: List[str], retrieved_doc_ids: List[str]) -> float:
4
+ """Calculates Mean Reciprocal Rank"""
5
+ for i, doc_id in enumerate(retrieved_doc_ids):
6
+ if doc_id in relevant_doc_ids:
7
+ return 1.0 / (i + 1)
8
+ return 0.0
9
+
10
+ def recall_at_k(relevant_doc_ids: List[str], retrieved_doc_ids: List[str], k: int) -> float:
11
+ """Calculates Recall@K"""
12
+ retrieved_at_k = set(retrieved_doc_ids[:k])
13
+ relevant_set = set(relevant_doc_ids)
14
+
15
+ if not relevant_set:
16
+ return 0.0
17
+
18
+ hits = len(relevant_set.intersection(retrieved_at_k))
19
+ return hits / len(relevant_set)
20
+
21
+ def precision_at_k(relevant_doc_ids: List[str], retrieved_doc_ids: List[str], k: int) -> float:
22
+ """Calculates Precision@K"""
23
+ retrieved_at_k = set(retrieved_doc_ids[:k])
24
+ relevant_set = set(relevant_doc_ids)
25
+
26
+ if not retrieved_at_k:
27
+ return 0.0
28
+
29
+ hits = len(relevant_set.intersection(retrieved_at_k))
30
+ return hits / len(retrieved_at_k)
src/indexer/__pycache__/bm25_index.cpython-311.pyc ADDED
Binary file (2.94 kB). View file
 
src/indexer/__pycache__/faiss_index.cpython-311.pyc ADDED
Binary file (2.15 kB). View file
 
src/indexer/bm25_index.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from rank_bm25 import BM25Okapi
3
+ from typing import List
4
+ import os
5
+
6
+ class BM25Index:
7
+ def __init__(self):
8
+ self.bm25 = None
9
+ self.corpus = []
10
+
11
+ def build(self, corpus: List[str]):
12
+ """
13
+ Builds the BM25 index from a list of documents/chunks.
14
+ """
15
+ self.corpus = corpus
16
+ tokenized_corpus = [doc.split(" ") for doc in corpus]
17
+ self.bm25 = BM25Okapi(tokenized_corpus)
18
+
19
+ def search(self, query: str, top_k: int = 10):
20
+ if not self.bm25:
21
+ raise ValueError("Index not built.")
22
+ tokenized_query = query.split(" ")
23
+ scores = self.bm25.get_scores(tokenized_query)
24
+ top_n = self.bm25.get_top_n(tokenized_query, self.corpus, n=top_k)
25
+ return top_n, scores
26
+
27
+ def save(self, path: str):
28
+ with open(path, 'wb') as f:
29
+ pickle.dump((self.bm25, self.corpus), f)
30
+
31
+ def load(self, path: str):
32
+ with open(path, 'rb') as f:
33
+ self.bm25, self.corpus = pickle.load(f)
src/indexer/faiss_index.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+ import os
4
+
5
+ class FaissIndex:
6
+ def __init__(self, dimension: int):
7
+ self.dimension = dimension
8
+ self.index = faiss.IndexFlatL2(dimension)
9
+ # Store metadata/mapping if needed, for now just simpler index
10
+
11
+ def add(self, embeddings: np.ndarray):
12
+ if embeddings.shape[1] != self.dimension:
13
+ raise ValueError(f"Embedding dimension mismatch. Expected {self.dimension}, got {embeddings.shape[1]}")
14
+ self.index.add(embeddings)
15
+
16
+ def search(self, query_embedding: np.ndarray, top_k: int = 10):
17
+ distances, indices = self.index.search(query_embedding, top_k)
18
+ return distances, indices
19
+
20
+ def save(self, path: str):
21
+ faiss.write_index(self.index, path)
22
+
23
+ def load(self, path: str):
24
+ self.index = faiss.read_index(path)
src/ingestion/__pycache__/chunkers.cpython-311.pyc ADDED
Binary file (1.94 kB). View file
 
src/ingestion/__pycache__/cleaner.cpython-311.pyc ADDED
Binary file (495 Bytes). View file
 
src/ingestion/__pycache__/readers.cpython-311.pyc ADDED
Binary file (1.82 kB). View file
 
src/ingestion/chunkers.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ class Chunker:
4
+ def chunk(self, text: str) -> List[str]:
5
+ raise NotImplementedError
6
+
7
+ class SlidingWindowChunker(Chunker):
8
+ def __init__(self, chunk_size: int = 512, overlap: int = 50):
9
+ self.chunk_size = chunk_size
10
+ self.overlap = overlap
11
+
12
+ def chunk(self, text: str) -> List[str]:
13
+ words = text.split()
14
+ chunks = []
15
+ for i in range(0, len(words), self.chunk_size - self.overlap):
16
+ chunk_words = words[i : i + self.chunk_size]
17
+ chunks.append(" ".join(chunk_words))
18
+ if i + self.chunk_size >= len(words):
19
+ break
20
+ return chunks
src/ingestion/cleaner.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def clean_text(text: str) -> str:
4
+ # Remove excessive whitespace
5
+ text = re.sub(r'\s+', ' ', text).strip()
6
+ return text
src/ingestion/ingest.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ from pathlib import Path
4
+ from typing import List
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+ from src.ingestion.readers import get_reader
9
+ from src.ingestion.cleaner import clean_text
10
+ from src.ingestion.chunkers import SlidingWindowChunker
11
+ from src.embeddings.embedder import Embedder
12
+ from src.indexer.bm25_index import BM25Index
13
+ from src.indexer.faiss_index import FaissIndex
14
+
15
+ DATA_DIR = "data"
16
+ RAW_DIR = os.path.join(DATA_DIR, "raw")
17
+ INDEX_DIR = os.path.join(DATA_DIR, "index")
18
+
19
+ class IngestionPipeline:
20
+ def __init__(self):
21
+ self.chunker = SlidingWindowChunker()
22
+ self.embedder = Embedder(model_name="all-MiniLM-L6-v2")
23
+ self.bm25_index = BM25Index()
24
+ # Dimension for all-MiniLM-L6-v2 is 384
25
+ self.faiss_index = FaissIndex(dimension=384)
26
+
27
+ def run(self):
28
+ print("Starting ingestion...")
29
+ files = glob.glob(os.path.join(RAW_DIR, "*.*"))
30
+
31
+ all_chunks = []
32
+ doc_map = [] # To map chunk index back to metadata/content if needed
33
+
34
+ # 1. Read, Clean, Chunk
35
+ print("Processing files...")
36
+ for file_path in tqdm(files):
37
+ path = Path(file_path)
38
+ try:
39
+ reader = get_reader(path)
40
+ raw_text = reader.read(path)
41
+ cleaned_text = clean_text(raw_text)
42
+ chunks = self.chunker.chunk(cleaned_text)
43
+
44
+ for chunk in chunks:
45
+ all_chunks.append(chunk)
46
+ doc_map.append({"source": str(path), "content": chunk})
47
+ except Exception as e:
48
+ print(f"Error processing {path}: {e}")
49
+
50
+ print(f"Total chunks generated: {len(all_chunks)}")
51
+
52
+ # 2. Build BM25 Index
53
+ print("Building BM25 Index...")
54
+ self.bm25_index.build(all_chunks)
55
+ os.makedirs(INDEX_DIR, exist_ok=True)
56
+ self.bm25_index.save(os.path.join(INDEX_DIR, "bm25.pkl"))
57
+
58
+ # 3. Embed and Build FAISS Index
59
+ if not os.getenv("DISABLE_FAISS"):
60
+ print("Embedding chunks and building FAISS Index...")
61
+ batch_size = 32
62
+ for i in range(0, len(all_chunks), batch_size):
63
+ batch = all_chunks[i : i + batch_size]
64
+ embeddings = self.embedder.embed(batch)
65
+ self.faiss_index.add(embeddings)
66
+
67
+ self.faiss_index.save(os.path.join(INDEX_DIR, "faiss.index"))
68
+ else:
69
+ print("Skipping FAISS build due to DISABLE_FAISS environment variable.")
70
+ # Create a dummy file to satisfy file existence checks if any (though lazy loaded)
71
+ with open(os.path.join(INDEX_DIR, "faiss.index"), "w") as f:
72
+ f.write("dummy")
73
+
74
+ # Save doc_map (simple persistence for retrieval lookup)
75
+ import pickle
76
+ with open(os.path.join(INDEX_DIR, "doc_map.pkl"), "wb") as f:
77
+ pickle.dump(doc_map, f)
78
+
79
+ print("Ingestion complete.")
80
+
81
+ if __name__ == "__main__":
82
+ pipeline = IngestionPipeline()
83
+ pipeline.run()
src/ingestion/readers.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import List, Dict
3
+
4
+ class DocumentReader:
5
+ def read(self, file_path: Path) -> str:
6
+ raise NotImplementedError
7
+
8
+ class TextReader(DocumentReader):
9
+ def read(self, file_path: Path) -> str:
10
+ with open(file_path, 'r', encoding='utf-8') as f:
11
+ return f.read()
12
+
13
+ def get_reader(file_path: Path) -> DocumentReader:
14
+ ext = file_path.suffix.lower()
15
+ if ext in ['.txt', '.md']:
16
+ return TextReader()
17
+ # Add PDF/Docx support here later
18
+ raise ValueError(f"Unsupported file type: {ext}")
src/llm/__pycache__/llm_client.cpython-311.pyc ADDED
Binary file (4.14 kB). View file
 
src/llm/llm_client.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import openai
3
+ from typing import List, Dict, Any
4
+
5
+ class LLMClient:
6
+ def chat(self, messages: List[Dict[str, str]], **kwargs) -> str:
7
+ raise NotImplementedError
8
+
9
+ class OpenAIClient(LLMClient):
10
+ def __init__(self, api_key: str = None, model: str = "gpt-4o"):
11
+ self.client = openai.OpenAI(api_key=api_key or os.getenv("OPENAI_API_KEY"))
12
+ self.model = model
13
+
14
+ def chat(self, messages: List[Dict[str, str]], **kwargs) -> str:
15
+ response = self.client.chat.completions.create(
16
+ model=self.model,
17
+ messages=messages,
18
+ **kwargs
19
+ )
20
+ return response.choices[0].message.content
21
+
22
+ class VLLMClient(LLMClient):
23
+ def __init__(self, api_url: str = None, model: str = None):
24
+ self.api_url = api_url or os.getenv("VLLM_API_URL", "http://localhost:8000/v1")
25
+ self.model = model or os.getenv("VLLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
26
+ # vLLM is OpenAI compatible
27
+ self.client = openai.OpenAI(
28
+ base_url=self.api_url,
29
+ api_key="EMPTY"
30
+ )
31
+
32
+ def chat(self, messages: List[Dict[str, str]], **kwargs) -> str:
33
+ response = self.client.chat.completions.create(
34
+ model=self.model,
35
+ messages=messages,
36
+ **kwargs
37
+ )
38
+ return response.choices[0].message.content
39
+
40
+ class GroqClient(LLMClient):
41
+ def __init__(self, api_key: str = None, model: str = "llama-3.3-70b-versatile"):
42
+ self.api_key = api_key or os.getenv("GROQ_API_KEY")
43
+ self.model = model
44
+ self.client = openai.OpenAI(
45
+ base_url="https://api.groq.com/openai/v1",
46
+ api_key=self.api_key
47
+ )
48
+
49
+ def chat(self, messages: List[Dict[str, str]], **kwargs) -> str:
50
+ # Default behavior
51
+ response = self.client.chat.completions.create(
52
+ model=self.model,
53
+ messages=messages,
54
+ **kwargs
55
+ )
56
+ return response.choices[0].message.content
src/pipeline/__pycache__/context_opt.cpython-311.pyc ADDED
Binary file (2.72 kB). View file
 
src/pipeline/__pycache__/query_pipeline.cpython-311.pyc ADDED
Binary file (4.48 kB). View file
 
src/pipeline/context_opt.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.metrics.pairwise import cosine_similarity
3
+
4
+ def maximal_marginal_relevance(query_embedding: np.ndarray, doc_embeddings: np.ndarray, lambda_mult: float = 0.5, top_k: int = 5):
5
+ """
6
+ Selects docs that are relevant to query but diverse from each other.
7
+ """
8
+ if len(doc_embeddings) == 0:
9
+ return []
10
+
11
+ # Simple MMR implementation
12
+ selected_indices = []
13
+ candidate_indices = list(range(len(doc_embeddings)))
14
+
15
+ for _ in range(top_k):
16
+ best_score = -np.inf
17
+ best_idx = -1
18
+
19
+ for idx in candidate_indices:
20
+ # Relevance
21
+ rel_score = cosine_similarity(query_embedding.reshape(1, -1), doc_embeddings[idx].reshape(1, -1))[0][0]
22
+
23
+ # Diversity (sim to already selected)
24
+ if not selected_indices:
25
+ div_score = 0
26
+ else:
27
+ sims = cosine_similarity(doc_embeddings[idx].reshape(1, -1), doc_embeddings[selected_indices])[0]
28
+ div_score = np.max(sims)
29
+
30
+ mmr_score = lambda_mult * rel_score - (1 - lambda_mult) * div_score
31
+
32
+ if mmr_score > best_score:
33
+ best_score = mmr_score
34
+ best_idx = idx
35
+
36
+ if best_idx != -1:
37
+ selected_indices.append(best_idx)
38
+ candidate_indices.remove(best_idx)
39
+
40
+ return selected_indices
41
+
42
+ def deduplicate_docs(docs: list[dict], threshold: float = 0.95) -> list[dict]:
43
+ """
44
+ Remove near-duplicates based on content string similarity (simple)
45
+ or just exact match for now to be fast.
46
+ """
47
+ seen = set()
48
+ unique_docs = []
49
+ for doc in docs:
50
+ # Assuming doc is a string or dict with 'content'
51
+ content = doc if isinstance(doc, str) else doc.get('content', '')
52
+ if content not in seen:
53
+ seen.add(content)
54
+ unique_docs.append(doc)
55
+ return unique_docs
src/pipeline/query_pipeline.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+ from src.retriever.hybrid_retriever import HybridRetriever
4
+ from src.retriever.hyde import HyDERetriever
5
+ from src.reranker.cross_encoder import Reranker
6
+ from src.llm.llm_client import OpenAIClient, VLLMClient, GroqClient
7
+ from src.embeddings.embedder import Embedder
8
+ from src.pipeline.context_opt import deduplicate_docs
9
+
10
+ class QueryPipeline:
11
+ def __init__(self, use_hyde: bool = False):
12
+ self.embedder = Embedder()
13
+ self.retriever = HybridRetriever(
14
+ bm25_path="data/index/bm25.pkl",
15
+ faiss_path="data/index/faiss.index",
16
+ doc_map_path="data/index/doc_map.pkl",
17
+ embedder=self.embedder
18
+ )
19
+
20
+ # LLM Client Strategy
21
+ if os.getenv("GROQ_API_KEY"):
22
+ self.llm = GroqClient()
23
+ elif os.getenv("VLLM_API_URL"):
24
+ self.llm = VLLMClient()
25
+ else:
26
+ self.llm = OpenAIClient()
27
+
28
+ if use_hyde:
29
+ self.retriever = HyDERetriever(self.llm, self.retriever)
30
+
31
+ self.reranker = Reranker()
32
+
33
+ def run(self, query: str, top_k_retrieval: int = 20, top_k_rerank: int = 5):
34
+ # 1. Retrieve
35
+ print(f"Retrieving for query: {query}")
36
+ retrieved_docs = self.retriever.search(query, top_k=top_k_retrieval)
37
+
38
+ # 2. Deduplicate
39
+ unique_docs = deduplicate_docs(retrieved_docs)
40
+
41
+ # 3. Rerank
42
+ # Reranker expects strings
43
+ doc_contents = [d if isinstance(d, str) else d['content'] for d in unique_docs]
44
+ reranked = self.reranker.rerank(query, doc_contents, top_k=top_k_rerank)
45
+
46
+ # 4. Generate
47
+
48
+ # Retrieval Score Gate
49
+ RETRIEVAL_SCORE_THRESHOLD = -4.0
50
+
51
+ # reranked is list of (doc, score)
52
+ if not reranked or reranked[0][1] < RETRIEVAL_SCORE_THRESHOLD:
53
+ return {
54
+ "query": query,
55
+ "answer": "I do not have enough information in the provided documents to answer this question.",
56
+ "context": [],
57
+ "retrieval_score": reranked[0][1] if reranked else -99.9,
58
+ "hallucination_score": 0.0,
59
+ "groundedness": 1.0
60
+ }
61
+
62
+ context_text = "\n\n".join([doc for doc, score in reranked])
63
+
64
+ SYSTEM_PROMPT = """
65
+ You are an enterprise-grade question answering system.
66
+
67
+ Rules:
68
+ 1. Answer strictly using ONLY the provided context.
69
+ 2. DO NOT use prior knowledge or assumptions.
70
+ 3. If the answer is not explicitly stated in the context, respond EXACTLY with:
71
+ "I do not have enough information in the provided documents to answer this question."
72
+ 4. Do not add explanations, guesses, or external facts.
73
+ """
74
+
75
+ user_prompt = f"""
76
+ {SYSTEM_PROMPT}
77
+
78
+ Context:
79
+ {context_text}
80
+
81
+ Question:
82
+ {query}
83
+ """
84
+
85
+ messages = [
86
+ {"role": "system", "content": SYSTEM_PROMPT},
87
+ {"role": "user", "content": user_prompt}
88
+ ]
89
+
90
+ answer = self.llm.chat(messages)
91
+
92
+ return {
93
+ "query": query,
94
+ "answer": answer,
95
+ "context": reranked,
96
+ "retrieval_score": reranked[0][1]
97
+ }
src/reranker/__pycache__/cross_encoder.cpython-311.pyc ADDED
Binary file (1.81 kB). View file
 
src/reranker/cross_encoder.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import CrossEncoder
2
+
3
+ class Reranker:
4
+ def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
5
+ self.model = CrossEncoder(model_name)
6
+
7
+ def rerank(self, query: str, docs: list[str], top_k: int = 5):
8
+ if not docs:
9
+ return []
10
+
11
+ pairs = [[query, doc] for doc in docs]
12
+ scores = self.model.predict(pairs).tolist()
13
+
14
+ # Combine docs with scores
15
+ doc_scores = list(zip(docs, scores))
16
+ # Sort by score descending
17
+ doc_scores.sort(key=lambda x: x[1], reverse=True)
18
+
19
+ return doc_scores[:top_k]
src/retriever/__pycache__/hybrid_retriever.cpython-311.pyc ADDED
Binary file (4.55 kB). View file
 
src/retriever/__pycache__/hyde.cpython-311.pyc ADDED
Binary file (1.81 kB). View file
 
src/retriever/hybrid_retriever.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ import numpy as np
3
+ from src.indexer.bm25_index import BM25Index
4
+ from src.indexer.faiss_index import FaissIndex
5
+ from src.embeddings.embedder import Embedder
6
+
7
+ class HybridRetriever:
8
+ def __init__(self, bm25_path: str, faiss_path: str, doc_map_path: str, embedder: Embedder):
9
+ self.bm25 = BM25Index()
10
+ self.bm25.load(bm25_path)
11
+
12
+ self.embedder = embedder
13
+ self.embedder = embedder
14
+ self.faiss = None
15
+ import os
16
+ if not os.getenv("DISABLE_FAISS"):
17
+ try:
18
+ self.faiss = FaissIndex(dimension=384) # adjust dimension if needed
19
+ self.faiss.load(faiss_path)
20
+ print("Successfully loaded FAISS index.")
21
+ except Exception as e:
22
+ print(f"WARNING: Could not load FAISS index ({e}). Running in BM25-only mode.")
23
+ else:
24
+ print("FAISS disabled via environment variable. Running in BM25-only mode.")
25
+
26
+ # Load doc map
27
+ import pickle
28
+ with open(doc_map_path, 'rb') as f:
29
+ self.doc_map = pickle.load(f)
30
+
31
+ def search(self, query: str, top_k: int = 10, alpha: float = 0.5) -> List[dict]:
32
+ """
33
+ Hybrid search using BM25 and Dense embeddings.
34
+ alpha: weight for dense score (0 = pure BM25, 1 = pure Dense)
35
+ """
36
+ # 1. BM25 Search
37
+ # We need to normalize scores to combine them properly, usually RRF is safer if scores are not calibrated
38
+ # For simplicity here, we get top N and use RRF
39
+
40
+ top_n = top_k * 2
41
+
42
+ # BM25
43
+ bm25_docs, bm25_scores = self.bm25.search(query, top_k=top_n)
44
+
45
+ scores = {}
46
+
47
+ # Process BM25
48
+ for rank, doc in enumerate(bm25_docs):
49
+ key = doc
50
+ scores[key] = scores.get(key, 0) + (1 / (60 + rank))
51
+
52
+ # Dense (Only if FAISS is loaded)
53
+ if self.faiss:
54
+ try:
55
+ query_emb = self.embedder.embed([query])
56
+ dense_dists, dense_indices = self.faiss.search(query_emb, top_k=top_n)
57
+
58
+ # Merge using Reciprocal Rank Fusion (RRF)
59
+ # Dense indices refer to doc_map
60
+ for rank, idx in enumerate(dense_indices[0]):
61
+ if idx == -1: continue
62
+ doc_data = self.doc_map[idx]
63
+ key = doc_data['content']
64
+ scores[key] = scores.get(key, 0) + (1 / (60 + rank))
65
+ except Exception as e:
66
+ print(f"Error during dense search: {e}")
67
+
68
+ # Sort by RRF score
69
+ sorted_docs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
70
+
71
+ return [doc for doc, score in sorted_docs[:top_k]]
src/retriever/hyde.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.llm.llm_client import LLMClient
2
+
3
+ class HyDERetriever:
4
+ def __init__(self, llm_client: LLMClient, base_retriever):
5
+ self.llm = llm_client
6
+ self.retriever = base_retriever
7
+
8
+ def generate_hypothetical_doc(self, query: str) -> str:
9
+ messages = [
10
+ {"role": "system", "content": "You are a helpful assistant. Write a hypothetical answer to the user's question. Do not include any explanation, just the answer."},
11
+ {"role": "user", "content": query}
12
+ ]
13
+ return self.llm.chat(messages, temperature=0.7)
14
+
15
+ def search(self, query: str, top_k: int = 10):
16
+ # 1. Generate hypothetical doc
17
+ hypothetical_doc = self.generate_hypothetical_doc(query)
18
+ print(f"DEBUG: HyDE Doc: {hypothetical_doc[:100]}...")
19
+
20
+ # 2. Retrieve using the hypothetical doc as query
21
+ return self.retriever.search(hypothetical_doc, top_k=top_k)
src/ui/app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import requests
3
+ import json
4
+ print("UI Starting up...")
5
+
6
+ st.set_page_config(page_title="Enterprise RAG Search", layout="wide")
7
+
8
+ import os
9
+
10
+ API_URL = os.getenv("API_URL", "http://localhost:8000/api/v1/chat")
11
+
12
+ st.title("Enterprise RAG Search")
13
+
14
+ with st.sidebar:
15
+ st.header("Configuration")
16
+ top_k_retrieval = st.slider("Retrieval Top-K", 5, 50, 20)
17
+ top_k_rerank = st.slider("Rerank Top-K", 1, 10, 5)
18
+ # use_hyde = st.checkbox("Use HyDE", value=False)
19
+
20
+ query = st.chat_input("Enter your query...")
21
+
22
+ if query:
23
+ st.session_state.messages = st.session_state.get("messages", [])
24
+ st.session_state.messages.append({"role": "user", "content": query})
25
+
26
+ # s = requests.Session()
27
+
28
+ for msg in st.session_state.get("messages", []):
29
+ with st.chat_message(msg["role"]):
30
+ st.write(msg["content"])
31
+
32
+ if query:
33
+ with st.chat_message("assistant"):
34
+ with st.spinner("Searching..."):
35
+ try:
36
+ payload = {
37
+ "query": query,
38
+ "top_k_retrieval": top_k_retrieval,
39
+ "top_k_rerank": top_k_rerank,
40
+ # "use_hyde": use_hyde
41
+ }
42
+ response = requests.post(API_URL, json=payload)
43
+ response.raise_for_status()
44
+ data = response.json()
45
+
46
+ answer = data["answer"]
47
+ st.write(answer)
48
+
49
+ with st.expander("View Context"):
50
+ for i, (doc, score) in enumerate(data["context"]):
51
+ st.markdown(f"**Relevance Score:** {score:.4f}")
52
+ st.text(doc)
53
+ st.divider()
54
+
55
+ st.session_state.messages.append({"role": "assistant", "content": answer})
56
+
57
+ except Exception as e:
58
+ st.error(f"Error: {e}")