File size: 2,226 Bytes
b171cab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os, glob
import chromadb
from typing import List
from sentence_transformers import SentenceTransformer, models
from langchain_text_splitters import RecursiveCharacterTextSplitter
from utils.constants import CHROMA_DIR, DOCS_DIR, COLLECTION, EMB_MODEL_NAME
from utils.helpers import to_safe_items

def get_embedder():
    word = models.Transformer(EMB_MODEL_NAME)
    pooling = models.Pooling(word.get_word_embedding_dimension())
    return SentenceTransformer(modules=[word, pooling])

def get_chroma():
    client = chromadb.PersistentClient(path=CHROMA_DIR)
    col = client.get_or_create_collection(COLLECTION, metadata={"hnsw:space":"cosine"})
    return client, col

def embed_texts(model, texts: List[str]):
    return model.encode(texts, convert_to_numpy=True).tolist()

def seed_index(col, model, root_folder: str) -> int:
    splitter = RecursiveCharacterTextSplitter(chunk_size=1100, chunk_overlap=150)
    paths = []
    for ext in ("**/*.txt","**/*.md"):
        paths += glob.glob(os.path.join(root_folder, ext), recursive=True)
    ids, docs, metas = [], [], []
    for p in sorted(paths):
        title = os.path.splitext(os.path.basename(p))[0]
        with open(p, "r", encoding="utf-8") as f:
            txt = f.read()
        # unique id uses relative path from DOCS_DIR to avoid duplicates across specialties
        rel = os.path.relpath(p, root_folder).replace(os.sep, "_")
        chunks = splitter.split_text(txt)
        for i, ch in enumerate(chunks):
            ids.append(f"{rel}-{i}")
            docs.append(ch)
            metas.append({"title": title, "source": p})
    if not docs:
        return 0
    embs = embed_texts(model, docs)
    try:
        col.add(ids=ids, documents=docs, metadatas=metas, embeddings=embs)
    except Exception:
        try: col.delete(ids=ids)
        except Exception: pass
        col.add(ids=ids, documents=docs, metadatas=metas, embeddings=embs)
    return len(docs)

def retrieve(col, model, query: str, k: int = 5):
    q_emb = embed_texts(model, [query])[0]
    res = col.query(query_embeddings=[q_emb], n_results=k, include=["documents","metadatas","distances"])
    return to_safe_items(res)