import streamlit as st import pandas as pd import numpy as np import jieba import requests import os import sys import subprocess from openai import OpenAI from rank_bm25 import BM25Okapi from sklearn.metrics.pairwise import cosine_similarity # ================= 1. 全局配置与 CSS注入 ================= API_KEY = os.getenv("SILICONFLOW_API_KEY") API_BASE = "https://api.siliconflow.cn/v1" EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" RERANK_MODEL = "Qwen/Qwen3-Reranker-4B" GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2" DATA_FILENAME = "comsol_embedded.parquet" DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet" st.set_page_config( page_title="COMSOL Dark Expert", page_icon="🌌", layout="wide", initial_sidebar_state="expanded" ) # --- 注入自定义 CSS (保持之前的审美) --- st.markdown(""" """, unsafe_allow_html=True) # ================= 2. 核心逻辑(数据与RAG) ================= if not API_KEY: st.error("⚠️ 未检测到 API Key。请在 Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。") st.stop() def download_with_curl(url, output_path): try: cmd = [ "curl", "-L", "-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", "-o", output_path, "--fail", url ] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: raise Exception(f"Curl failed: {result.stderr}") return True except Exception as e: print(f"Curl download error: {e}") return False def get_data_file_path(): possible_paths = [ DATA_FILENAME, os.path.join("/app", DATA_FILENAME), os.path.join("processed_data", DATA_FILENAME), os.path.join("src", DATA_FILENAME), os.path.join("..", DATA_FILENAME), "/tmp/" + DATA_FILENAME ] for path in possible_paths: if os.path.exists(path): return path download_target = "/app/" + DATA_FILENAME try: os.makedirs(os.path.dirname(download_target), exist_ok=True) except: download_target = "/tmp/" + DATA_FILENAME status_container = st.empty() status_container.info("📡 正在接入神经元网络... (下载核心数据中)") if download_with_curl(DATA_URL, download_target): status_container.empty() return download_target try: headers = {'User-Agent': 'Mozilla/5.0'} r = requests.get(DATA_URL, headers=headers, stream=True) r.raise_for_status() with open(download_target, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) status_container.empty() return download_target except Exception as e: st.error(f"❌ 数据链路中断。Error: {e}") st.stop() class FullRetriever: def __init__(self, parquet_path): try: self.df = pd.read_parquet(parquet_path) except Exception as e: st.error(f"Memory Matrix Load Failed: {e}"); st.stop() self.documents = self.df['content'].tolist() self.embeddings = np.stack(self.df['embedding'].values) self.bm25 = BM25Okapi([jieba.lcut(str(d).lower()) for d in self.documents]) self.client = OpenAI(base_url=API_BASE, api_key=API_KEY) # Reranker 初始化移到这里,减少重复调用 self.rerank_headers = {"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"} self.rerank_url = f"{API_BASE}/rerank" def _get_emb(self, q): try: return self.client.embeddings.create(model=EMBEDDING_MODEL, input=[q]).data[0].embedding except: return [0.0] * 1024 def hybrid_search(self, query: str, top_k=5): # 1. Vector q_emb = self._get_emb(query) vec_scores = cosine_similarity([q_emb], self.embeddings)[0] vec_idx = np.argsort(vec_scores)[-100:][::-1] # 2. Keyword kw_idx = np.argsort(self.bm25.get_scores(jieba.lcut(query.lower())))[-100:][::-1] # 3. RRF Fusion fused = {} for r, i in enumerate(vec_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1) for r, i in enumerate(kw_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1) c_idxs = [x[0] for x in sorted(fused.items(), key=lambda x:x[1], reverse=True)[:50]] c_docs = [self.documents[i] for i in c_idxs] # 4. Rerank try: payload = {"model": RERANK_MODEL, "query": query, "documents": c_docs, "top_n": top_k} resp = requests.post(self.rerank_url, headers=self.rerank_headers, json=payload, timeout=10) results = resp.json().get('results', []) except: results = [{"index": i, "relevance_score": 0.0} for i in range(len(c_docs))][:top_k] final_res = [] context = "" for i, item in enumerate(results): orig_idx = c_idxs[item['index']] row = self.df.iloc[orig_idx] final_res.append({ "score": item['relevance_score'], "filename": row['filename'], "content": row['content'] }) context += f"[文档{i+1}]: {row['content']}\n\n" return final_res, context @st.cache_resource def load_engine(): real_path = get_data_file_path() return FullRetriever(real_path) # ================= 3. UI 主程序 ================= def main(): st.markdown("""