Spaces:
Running
Running
File size: 2,226 Bytes
b171cab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import os, glob
import chromadb
from typing import List
from sentence_transformers import SentenceTransformer, models
from langchain_text_splitters import RecursiveCharacterTextSplitter
from utils.constants import CHROMA_DIR, DOCS_DIR, COLLECTION, EMB_MODEL_NAME
from utils.helpers import to_safe_items
def get_embedder():
word = models.Transformer(EMB_MODEL_NAME)
pooling = models.Pooling(word.get_word_embedding_dimension())
return SentenceTransformer(modules=[word, pooling])
def get_chroma():
client = chromadb.PersistentClient(path=CHROMA_DIR)
col = client.get_or_create_collection(COLLECTION, metadata={"hnsw:space":"cosine"})
return client, col
def embed_texts(model, texts: List[str]):
return model.encode(texts, convert_to_numpy=True).tolist()
def seed_index(col, model, root_folder: str) -> int:
splitter = RecursiveCharacterTextSplitter(chunk_size=1100, chunk_overlap=150)
paths = []
for ext in ("**/*.txt","**/*.md"):
paths += glob.glob(os.path.join(root_folder, ext), recursive=True)
ids, docs, metas = [], [], []
for p in sorted(paths):
title = os.path.splitext(os.path.basename(p))[0]
with open(p, "r", encoding="utf-8") as f:
txt = f.read()
# unique id uses relative path from DOCS_DIR to avoid duplicates across specialties
rel = os.path.relpath(p, root_folder).replace(os.sep, "_")
chunks = splitter.split_text(txt)
for i, ch in enumerate(chunks):
ids.append(f"{rel}-{i}")
docs.append(ch)
metas.append({"title": title, "source": p})
if not docs:
return 0
embs = embed_texts(model, docs)
try:
col.add(ids=ids, documents=docs, metadatas=metas, embeddings=embs)
except Exception:
try: col.delete(ids=ids)
except Exception: pass
col.add(ids=ids, documents=docs, metadatas=metas, embeddings=embs)
return len(docs)
def retrieve(col, model, query: str, k: int = 5):
q_emb = embed_texts(model, [query])[0]
res = col.query(query_embeddings=[q_emb], n_results=k, include=["documents","metadatas","distances"])
return to_safe_items(res)
|