Spaces:
Running
Running
| import os, glob | |
| import chromadb | |
| from typing import List | |
| from sentence_transformers import SentenceTransformer, models | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from utils.constants import CHROMA_DIR, DOCS_DIR, COLLECTION, EMB_MODEL_NAME | |
| from utils.helpers import to_safe_items | |
| def get_embedder(): | |
| word = models.Transformer(EMB_MODEL_NAME) | |
| pooling = models.Pooling(word.get_word_embedding_dimension()) | |
| return SentenceTransformer(modules=[word, pooling]) | |
| def get_chroma(): | |
| client = chromadb.PersistentClient(path=CHROMA_DIR) | |
| col = client.get_or_create_collection(COLLECTION, metadata={"hnsw:space":"cosine"}) | |
| return client, col | |
| def embed_texts(model, texts: List[str]): | |
| return model.encode(texts, convert_to_numpy=True).tolist() | |
| def seed_index(col, model, root_folder: str) -> int: | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=1100, chunk_overlap=150) | |
| paths = [] | |
| for ext in ("**/*.txt","**/*.md"): | |
| paths += glob.glob(os.path.join(root_folder, ext), recursive=True) | |
| ids, docs, metas = [], [], [] | |
| for p in sorted(paths): | |
| title = os.path.splitext(os.path.basename(p))[0] | |
| with open(p, "r", encoding="utf-8") as f: | |
| txt = f.read() | |
| # unique id uses relative path from DOCS_DIR to avoid duplicates across specialties | |
| rel = os.path.relpath(p, root_folder).replace(os.sep, "_") | |
| chunks = splitter.split_text(txt) | |
| for i, ch in enumerate(chunks): | |
| ids.append(f"{rel}-{i}") | |
| docs.append(ch) | |
| metas.append({"title": title, "source": p}) | |
| if not docs: | |
| return 0 | |
| embs = embed_texts(model, docs) | |
| try: | |
| col.add(ids=ids, documents=docs, metadatas=metas, embeddings=embs) | |
| except Exception: | |
| try: col.delete(ids=ids) | |
| except Exception: pass | |
| col.add(ids=ids, documents=docs, metadatas=metas, embeddings=embs) | |
| return len(docs) | |
| def retrieve(col, model, query: str, k: int = 5): | |
| q_emb = embed_texts(model, [query])[0] | |
| res = col.query(query_embeddings=[q_emb], n_results=k, include=["documents","metadatas","distances"]) | |
| return to_safe_items(res) | |