# # # import streamlit as st # # # import pandas as pd # # # import numpy as np # # # import jieba # # # import requests # # # import os # # # import sys # # # import subprocess # # # from openai import OpenAI # # # from rank_bm25 import BM25Okapi # # # from sklearn.metrics.pairwise import cosine_similarity # # # # ================= 1. 全局配置与 CSS注入 ================= # # # API_KEY = os.getenv("SILICONFLOW_API_KEY") # # # API_BASE = "https://api.siliconflow.cn/v1" # # # EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" # # # RERANK_MODEL = "Qwen/Qwen3-Reranker-4B" # # # GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2" # # # DATA_FILENAME = "comsol_embedded.parquet" # # # DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet" # # # st.set_page_config( # # # page_title="COMSOL Dark Expert", # # # page_icon="🌌", # # # layout="wide", # # # initial_sidebar_state="expanded" # # # ) # # # # --- 注入自定义 CSS (保持之前的审美) --- # # # st.markdown(""" # # # # # # """, unsafe_allow_html=True) # # # # ================= 2. 核心逻辑(数据与RAG) ================= # # # if not API_KEY: # # # st.error("⚠️ 未检测到 API Key。请在 Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。") # # # st.stop() # # # def download_with_curl(url, output_path): # # # try: # # # cmd = [ # # # "curl", "-L", # # # "-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", # # # "-o", output_path, # # # "--fail", # # # url # # # ] # # # result = subprocess.run(cmd, capture_output=True, text=True) # # # if result.returncode != 0: raise Exception(f"Curl failed: {result.stderr}") # # # return True # # # except Exception as e: # # # print(f"Curl download error: {e}") # # # return False # # # def get_data_file_path(): # # # possible_paths = [ # # # DATA_FILENAME, os.path.join("/app", DATA_FILENAME), # # # os.path.join("processed_data", DATA_FILENAME), # # # os.path.join("src", DATA_FILENAME), # # # os.path.join("..", DATA_FILENAME), "/tmp/" + DATA_FILENAME # # # ] # # # for path in possible_paths: # # # if os.path.exists(path): return path # # # download_target = "/app/" + DATA_FILENAME # # # try: os.makedirs(os.path.dirname(download_target), exist_ok=True) # # # except: download_target = "/tmp/" + DATA_FILENAME # # # status_container = st.empty() # # # status_container.info("📡 正在接入神经元网络... (下载核心数据中)") # # # if download_with_curl(DATA_URL, download_target): # # # status_container.empty() # # # return download_target # # # try: # # # headers = {'User-Agent': 'Mozilla/5.0'} # # # r = requests.get(DATA_URL, headers=headers, stream=True) # # # r.raise_for_status() # # # with open(download_target, 'wb') as f: # # # for chunk in r.iter_content(chunk_size=8192): f.write(chunk) # # # status_container.empty() # # # return download_target # # # except Exception as e: # # # st.error(f"❌ 数据链路中断。Error: {e}") # # # st.stop() # # # class FullRetriever: # # # def __init__(self, parquet_path): # # # try: self.df = pd.read_parquet(parquet_path) # # # except Exception as e: st.error(f"Memory Matrix Load Failed: {e}"); st.stop() # # # self.documents = self.df['content'].tolist() # # # self.embeddings = np.stack(self.df['embedding'].values) # # # self.bm25 = BM25Okapi([jieba.lcut(str(d).lower()) for d in self.documents]) # # # self.client = OpenAI(base_url=API_BASE, api_key=API_KEY) # # # # Reranker 初始化移到这里,减少重复调用 # # # self.rerank_headers = {"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"} # # # self.rerank_url = f"{API_BASE}/rerank" # # # def _get_emb(self, q): # # # try: return self.client.embeddings.create(model=EMBEDDING_MODEL, input=[q]).data[0].embedding # # # except: return [0.0] * 1024 # # # def hybrid_search(self, query: str, top_k=5): # # # # 1. Vector # # # q_emb = self._get_emb(query) # # # vec_scores = cosine_similarity([q_emb], self.embeddings)[0] # # # vec_idx = np.argsort(vec_scores)[-100:][::-1] # # # # 2. Keyword # # # kw_idx = np.argsort(self.bm25.get_scores(jieba.lcut(query.lower())))[-100:][::-1] # # # # 3. RRF Fusion # # # fused = {} # # # for r, i in enumerate(vec_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1) # # # for r, i in enumerate(kw_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1) # # # c_idxs = [x[0] for x in sorted(fused.items(), key=lambda x:x[1], reverse=True)[:50]] # # # c_docs = [self.documents[i] for i in c_idxs] # # # # 4. Rerank # # # try: # # # payload = {"model": RERANK_MODEL, "query": query, "documents": c_docs, "top_n": top_k} # # # resp = requests.post(self.rerank_url, headers=self.rerank_headers, json=payload, timeout=10) # # # results = resp.json().get('results', []) # # # except: # # # results = [{"index": i, "relevance_score": 0.0} for i in range(len(c_docs))][:top_k] # # # final_res = [] # # # context = "" # # # for i, item in enumerate(results): # # # orig_idx = c_idxs[item['index']] # # # row = self.df.iloc[orig_idx] # # # final_res.append({ # # # "score": item['relevance_score'], # # # "filename": row['filename'], # # # "content": row['content'] # # # }) # # # context += f"[文档{i+1}]: {row['content']}\n\n" # # # return final_res, context # # # @st.cache_resource # # # def load_engine(): # # # real_path = get_data_file_path() # # # return FullRetriever(real_path) # # # # ================= 3. UI 主程序 ================= # # # def main(): # # # st.markdown(""" # # #
# # #
🌌
# # #
# # #
COMSOL DARK EXPERT
# # #
# # # NEURAL SIMULATION ASSISTANT V4.1 Fixed # # #
# # #
# # #
# # # """, unsafe_allow_html=True) # # # retriever = load_engine() # # # with st.sidebar: # # # st.markdown("### ⚙️ 控制台") # # # top_k = st.slider("检索深度", 1, 10, 4) # # # temp = st.slider("发散度", 0.0, 1.0, 0.3) # # # st.markdown("---") # # # if st.button("🗑️ 清空记忆 (Clear)", use_container_width=True): # # # st.session_state.messages = [] # # # st.session_state.current_refs = [] # # # st.rerun() # # # if "messages" not in st.session_state: st.session_state.messages = [] # # # if "current_refs" not in st.session_state: st.session_state.current_refs = [] # # # col_chat, col_evidence = st.columns([0.65, 0.35], gap="large") # # # # ------------------ 处理输入源 ------------------ # # # # 我们定义一个变量 user_input,不管它来自按钮还是输入框 # # # user_input = None # # # with col_chat: # # # # 1. 如果历史为空,显示快捷按钮 # # # if not st.session_state.messages: # # # st.markdown("##### 💡 初始化提问序列 (Starter Sequence)") # # # c1, c2, c3 = st.columns(3) # # # # 点击按钮直接赋值给 user_input # # # if c1.button("🌊 流固耦合接口设置"): # # # user_input = "怎么设置流固耦合接口?" # # # elif c2.button("⚡ 低频电磁场网格"): # # # user_input = "低频电磁场网格划分有哪些技巧?" # # # elif c3.button("📉 求解器不收敛"): # # # user_input = "求解器不收敛通常怎么解决?" # # # # 2. 渲染历史消息 # # # for msg in st.session_state.messages: # # # with st.chat_message(msg["role"]): # # # st.markdown(msg["content"]) # # # # 3. 处理底部输入框 (如果有按钮输入,这里会被跳过,因为 user_input 已经有值了) # # # if not user_input: # # # user_input = st.chat_input("输入指令或物理参数问题...") # # # # ------------------ 统一处理消息追加 ------------------ # # # if user_input: # # # st.session_state.messages.append({"role": "user", "content": user_input}) # # # # 强制刷新以立即在 UI 上显示用户的提问(对于按钮点击尤为重要) # # # st.rerun() # # # # ------------------ 统一触发生成 (修复的核心) ------------------ # # # # 检查:如果有消息,且最后一条是 User 发的,说明需要 Assistant 回答 # # # if st.session_state.messages and st.session_state.messages[-1]["role"] == "user": # # # # 获取最后一条用户消息 # # # last_query = st.session_state.messages[-1]["content"] # # # with col_chat: # 确保在聊天栏显示 # # # with st.spinner("🔍 正在扫描向量空间..."): # # # refs, context = retriever.hybrid_search(last_query, top_k=top_k) # # # st.session_state.current_refs = refs # # # system_prompt = f"""你是一个COMSOL高级仿真专家。请基于提供的文档回答问题。 # # # 要求: # # # 1. 语气专业、客观,逻辑严密。 # # # 2. 涉及物理公式时,**必须**使用 LaTeX 格式(例如 $E = mc^2$)。 # # # 3. 涉及步骤或参数对比时,优先使用 Markdown 列表或表格。 # # # 参考文档: # # # {context} # # # """ # # # with st.chat_message("assistant"): # # # resp_cont = st.empty() # # # full_resp = "" # # # client = OpenAI(base_url=API_BASE, api_key=API_KEY) # # # try: # # # stream = client.chat.completions.create( # # # model=GEN_MODEL_NAME, # # # messages=[{"role": "system", "content": system_prompt}] + st.session_state.messages[-6:], # 除去当前的System # # # temperature=temp, # # # stream=True # # # ) # # # for chunk in stream: # # # txt = chunk.choices[0].delta.content # # # if txt: # # # full_resp += txt # # # resp_cont.markdown(full_resp + " ▌") # # # resp_cont.markdown(full_resp) # # # st.session_state.messages.append({"role": "assistant", "content": full_resp}) # # # except Exception as e: # # # st.error(f"Neural Generation Failed: {e}") # # # # ------------------ 渲染右侧证据栏 ------------------ # # # with col_evidence: # # # st.markdown("### 📚 神经记忆 (Evidence)") # # # if st.session_state.current_refs: # # # for i, ref in enumerate(st.session_state.current_refs): # # # score = ref['score'] # # # score_color = "#00ff41" if score > 0.6 else "#ffb700" if score > 0.4 else "#ff003c" # # # with st.expander(f"📄 Doc {i+1}: {ref['filename'][:20]}...", expanded=(i==0)): # # # st.markdown(f""" # # #
# # # Relevance: # # # {score:.4f} # # #
# # # """, unsafe_allow_html=True) # # # st.code(ref['content'], language="text") # # # else: # # # st.info("等待输入指令以检索知识库...") # # # st.markdown(""" # # #
# # # Waiting for query signal...
# # # Index Status: Ready
# # # Awaiting Input # # #
# # # """, unsafe_allow_html=True) # # # if __name__ == "__main__": # # # main() # # import streamlit as st # # import pandas as pd # # import numpy as np # # import jieba # # import requests # # import os # # import time # # import json # # import re # # import random # # import subprocess # # from openai import OpenAI # # from rank_bm25 import BM25Okapi # # from sklearn.metrics.pairwise import cosine_similarity # # from typing import List, Dict, Tuple, Any # # # ================= 1. 全局配置与样式 ================= # # # API 配置 (从 HF 环境变量获取) # # API_BASE = "https://api.siliconflow.cn/v1" # # API_KEY = os.getenv("SILICONFLOW_API_KEY") # # # 模型名称配置 # # EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" # # RERANK_MODEL = "Qwen/Qwen3-Reranker-4B" # # GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2" # # QE_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct" # # SUGGEST_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct" # # # 预置问题池 # # PRESET_QUESTIONS = [ # # "如何设置流固耦合接口?", # # "求解器不收敛怎么办?", # # "网格划分有哪些技巧?", # # "如何定义随时间变化的边界条件?", # # "计算结果如何导出数据?", # # "什么是完美匹配层 (PML)?", # # "低频电磁场仿真注意事项", # # "如何提高瞬态计算速度?", # # "参数化扫描如何设置?", # # "多物理场耦合的收敛性优化" # # ] # # # 数据文件配置 # # DATA_FILENAME = "comsol_embedded.parquet" # # DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet" # # # 页面配置 # # st.set_page_config( # # page_title="COMSOL RAG 策略控制台", # # page_icon="🎛️", # # layout="wide", # # initial_sidebar_state="expanded" # # ) # # # 自定义CSS样式 # # st.markdown(""" # # # # """, unsafe_allow_html=True) # # # ================= 2. 数据下载工具 (HF 适配) ================= # # def download_with_curl(url, output_path): # # """使用 curl 下载文件,增加鲁棒性""" # # try: # # cmd = [ # # "curl", "-L", # # "-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", # # "-o", output_path, # # "--fail", # # url # # ] # # result = subprocess.run(cmd, capture_output=True, text=True) # # if result.returncode != 0: # # print(f"Curl stderr: {result.stderr}") # # return False # # return True # # except Exception as e: # # print(f"Curl download error: {e}") # # return False # # def get_data_file_path(): # # """获取数据文件路径,如果不存在则自动下载""" # # # 优先检查本地可能存在的路径 # # possible_paths = [ # # DATA_FILENAME, # # os.path.join("/app", DATA_FILENAME), # # os.path.join("processed_data", DATA_FILENAME), # # os.path.join(os.getcwd(), DATA_FILENAME) # # ] # # for path in possible_paths: # # if os.path.exists(path): # # return path # # # 如果都没找到,准备下载 # # # HF Spaces 通常在 /home/user/app 下运行,直接下载到当前目录 # # download_target = os.path.join(os.getcwd(), DATA_FILENAME) # # status_container = st.empty() # # status_container.info("📡 正在接入神经元网络... (下载核心数据中,首次运行可能需要几十秒)") # # # 尝试 Curl 下载 # # if download_with_curl(DATA_URL, download_target): # # status_container.empty() # # return download_target # # # 降级尝试 Requests 下载 # # try: # # headers = {'User-Agent': 'Mozilla/5.0'} # # r = requests.get(DATA_URL, headers=headers, stream=True) # # r.raise_for_status() # # with open(download_target, 'wb') as f: # # for chunk in r.iter_content(chunk_size=8192): # # f.write(chunk) # # status_container.empty() # # return download_target # # except Exception as e: # # st.error(f"❌ 数据下载失败。Error: {e}") # # st.stop() # # # ================= 3. 核心 RAG 控制器 ================= # # class RAGController: # # """RAG系统控制器 - 实现策略矩阵""" # # def __init__(self): # # """初始化控制器""" # # if not API_KEY: # # st.error("⚠️ 未检测到 API Key。请在 Space Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。") # # st.stop() # # self.client = OpenAI(base_url=API_BASE, api_key=API_KEY) # # self.df = None # # self.documents = [] # # self.embeddings = None # # self.bm25 = None # # self.filenames = [] # # self._load_data() # # def _load_data(self): # # """加载COMSOL文档数据""" # # real_path = get_data_file_path() # # try: # # # 加载数据 # # self.df = pd.read_parquet(real_path) # # self.documents = self.df['content'].tolist() # # self.filenames = self.df['filename'].tolist() # # # 加载向量嵌入 # # self.embeddings = np.stack(self.df['embedding'].values) # # # 初始化BM25 # # tokenized_corpus = [jieba.lcut(str(doc).lower()) for doc in self.documents] # # self.bm25 = BM25Okapi(tokenized_corpus) # # st.success(f"✅ 成功加载 {len(self.documents)} 条文档") # # except Exception as e: # # st.error(f"❌ 数据加载失败: {str(e)}") # # st.stop() # # def get_embedding(self, text: str) -> List[float]: # # """获取文本向量嵌入""" # # try: # # resp = self.client.embeddings.create( # # model=EMBEDDING_MODEL, # # input=[text] # # ) # # return resp.data[0].embedding # # except Exception as e: # # st.warning(f"向量获取失败: {e}") # # return [0.0] * 2560 # Qwen3-Embedding-4B dimension fallback # # def expand_query(self, query: str) -> Tuple[str, float]: # # """查询扩展 - 使用LLM优化查询""" # # prompt = f"""你是COMSOL仿真专家。请将用户的口语化问题改写为专业的检索查询。 # # 要求: # # 1. 补充COMSOL专业术语(物理场、模块、边界条件等) # # 2. 保持问题核心意图不变 # # 3. 输出简洁,仅返回改写后的查询 # # 用户问题: {query} # # 专业查询:""" # # try: # # start_time = time.time() # # resp = self.client.chat.completions.create( # # model=QE_MODEL_NAME, # # messages=[{"role": "user", "content": prompt}], # # temperature=0.3 # # ) # # expanded = resp.choices[0].message.content.strip() # # elapsed = time.time() - start_time # # return expanded, elapsed # # except Exception as e: # # print(f"QE Error: {e}") # # return query, 0 # # def vector_search(self, query: str, top_k: int = 100) -> List[Tuple[int, float]]: # # """向量检索""" # # q_vec = self.get_embedding(query) # # similarities = cosine_similarity([q_vec], self.embeddings)[0] # # top_indices = np.argsort(similarities)[-top_k:][::-1] # # return [(idx, similarities[idx]) for idx in top_indices] # # def bm25_search(self, query: str, top_k: int = 100) -> List[Tuple[int, float]]: # # """BM25关键词检索""" # # tokenized_query = jieba.lcut(query.lower()) # # scores = self.bm25.get_scores(tokenized_query) # # top_indices = np.argsort(scores)[-top_k:][::-1] # # return [(idx, scores[idx]) for idx in top_indices] # # def reciprocal_rank_fusion(self, vector_results: List[Tuple[int, float]], # # bm25_results: List[Tuple[int, float]], k: int = 60) -> Dict[int, float]: # # """RRF融合算法""" # # scores = {} # # for rank, (idx, score) in enumerate(vector_results): # # scores[idx] = scores.get(idx, 0) + 1.0 / (k + rank + 1) # # for rank, (idx, score) in enumerate(bm25_results): # # scores[idx] = scores.get(idx, 0) + 1.0 / (k + rank + 1) # # return scores # # def rerank_documents(self, query: str, documents: List[Dict], top_n: int) -> Tuple[List[Dict], float]: # # """使用重排序模型""" # # if not documents: return [], 0 # # url = f"{API_BASE}/rerank" # # headers = { # # "Authorization": f"Bearer {API_KEY}", # # "Content-Type": "application/json" # # } # # # 截断文档内容以符合 Context Window # # docs_content = [doc["content"][:2048] for doc in documents] # # payload = { # # "model": RERANK_MODEL, # # "query": query, # # "documents": docs_content, # # "top_n": top_n # # } # # try: # # start_time = time.time() # # response = requests.post(url, headers=headers, json=payload, timeout=20) # # elapsed = time.time() - start_time # # if response.status_code == 200: # # results = response.json().get("results", []) # # reranked_docs = [] # # for result in results: # # original_doc = documents[result["index"]] # # original_doc["rerank_score"] = result["relevance_score"] # # original_doc["final_score"] = result["relevance_score"] # # reranked_docs.append(original_doc) # # return reranked_docs, elapsed # # else: # # print(f"Rerank API Error: {response.text}") # # return documents[:top_n], elapsed # # except Exception as e: # # print(f"Rerank Exception: {e}") # # return documents[:top_n], 0 # # def execute_strategy(self, query: str, config: Dict[str, Any]) -> Dict[str, Any]: # # """执行策略矩阵""" # # start_time = time.time() # # result = { # # 'original_query': query, # # 'final_query': query, # # 'documents': [], # # 'steps': [], # # 'metrics': {'qe_time': 0, 'retrieval_time': 0, 'rerank_time': 0, 'total_time': 0}, # # 'strategy_tags': [] # # } # # # 1. 查询扩展 # # if config['use_qe']: # # expanded_q, qe_time = self.expand_query(query) # # result['final_query'] = expanded_q # # result['metrics']['qe_time'] = qe_time # # result['steps'].append(f"🧠 查询扩展 ({qe_time:.2f}s): {query} → **{expanded_q}**") # # result['strategy_tags'].append("QE") # # # 2. 检索 # # retrieval_start = time.time() # # query_to_search = result['final_query'] # # if config['strategy'] == 'Vector': # # results = self.vector_search(query_to_search) # # result['steps'].append(f"🔍 向量检索: 找到 {len(results)} 个候选") # # result['strategy_tags'].append("Vector") # # elif config['strategy'] == 'BM25': # # results = self.bm25_search(query_to_search) # # result['steps'].append(f"🔍 BM25检索: 找到 {len(results)} 个候选") # # result['strategy_tags'].append("BM25") # # elif config['strategy'] == 'Hybrid': # # vec_results = self.vector_search(query_to_search) # # bm25_results = self.bm25_search(query_to_search) # # fused_scores = self.reciprocal_rank_fusion(vec_results, bm25_results) # # results = sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) # # results = [(idx, score) for idx, score in results] # # result['steps'].append(f"🔍 混合检索: Vector + BM25 → {len(results)} 个融合候选") # # result['strategy_tags'].extend(["Vector", "BM25"]) # # result['metrics']['retrieval_time'] = time.time() - retrieval_start # # # 3. 构建候选列表 # # recall_k = config['top_k'] * 3 if config['use_rerank'] else config['top_k'] # # top_results = results[:recall_k] # # documents = [] # # for idx, score in top_results: # # documents.append({ # # 'content': self.documents[idx], # # 'filename': self.filenames[idx], # # 'retrieval_score': score, # # 'final_score': score, # # 'type': 'retrieval' # # }) # # # 4. 重排序 # # if config['use_rerank']: # # reranked_docs, rerank_time = self.rerank_documents( # # result['final_query'], documents, config['top_k'] # # ) # # result['documents'] = reranked_docs # # result['metrics']['rerank_time'] = rerank_time # # result['steps'].append(f"⚖️ 重排序 ({rerank_time:.2f}s): 精选 Top-{config['top_k']}") # # result['strategy_tags'].append("Rerank") # # else: # # result['documents'] = documents[:config['top_k']] # # result['metrics']['total_time'] = time.time() - start_time # # result['steps'].append(f"⏱️ 总耗时: {result['metrics']['total_time']:.2f}s") # # return result # # def generate_suggestions(controller, query: str, answer: str) -> List[str]: # # """生成3个后续引导问题""" # # prompt = f"""基于以下技术问答,预测用户可能感兴趣的3个后续COMSOL专业问题。 # # 用户问题:{query} # # 专家回答:{answer[:800]}... # # 要求: # # 1. 问题简短(15字以内)。 # # 2. 紧扣当前话题。 # # 3. 严格输出 JSON 字符串数组格式,例如:["问题1", "问题2", "问题3"]。 # # 4. 不要包含任何 Markdown 标记。 # # """ # # try: # # resp = controller.client.chat.completions.create( # # model=SUGGEST_MODEL_NAME, # # messages=[{"role": "user", "content": prompt}], # # temperature=0.5 # # ) # # content = resp.choices[0].message.content.strip() # # match = re.search(r'\[.*\]', content, re.DOTALL) # # if match: # # sugs = json.loads(match.group()) # # return sugs[:3] # # return [] # # except Exception as e: # # print(f"Suggestion Error: {e}") # # return [] # # def generate_answer(controller, query: str, documents: List[Dict], history: List[Dict], max_rounds: int) -> str: # # """流式生成回答""" # # if not documents: # # return "抱歉,没有找到相关的文档来回答您的问题。" # # context_text = "\n\n".join([f"[文档{i+1}] {doc['content'][:800]}..." for i, doc in enumerate(documents)]) # # system_prompt = f"""你是一个COMSOL Multiphysics仿真专家。请基于提供的文档回答用户问题。 # # 要求: # # 1. 语气专业,使用COMSOL术语。 # # 2. 物理公式使用 LaTeX(如 $E=mc^2$)。 # # 3. 如果文档信息不足,请如实告知,不要编造。 # # 【参考文档】: # # {context_text} # # """ # # # 构建历史记录 # # keep_messages = max_rounds * 2 # # history_to_send = history[:-1][-keep_messages:] if keep_messages > 0 else [] # # api_messages = [{"role": "system", "content": system_prompt}] + history_to_send + [{"role": "user", "content": query}] # # try: # # response = controller.client.chat.completions.create( # # model=GEN_MODEL_NAME, # # messages=api_messages, # # temperature=0.3, # # stream=True # # ) # # answer = "" # # placeholder = st.empty() # # for chunk in response: # # if chunk.choices[0].delta.content: # # answer += chunk.choices[0].delta.content # # placeholder.markdown(answer + "▌") # # placeholder.markdown(answer) # # return answer # # except Exception as e: # # return f"生成遇到错误: {e}" # # # ================= 4. 初始化与组件渲染 ================= # # @st.cache_resource(show_spinner="🚀 正在初始化 RAG 引擎...") # # def initialize_controller(): # # return RAGController() # # def render_strategy_matrix(): # # st.markdown('

🎯 策略矩阵配置

', unsafe_allow_html=True) # # st.markdown("""
# #

⚙️ 参数调节:控制检索片段数量和模型记忆深度。

# #
""", unsafe_allow_html=True) # # col1, col2 = st.columns(2) # # with col1: # # use_qe = st.toggle("🔄 查询扩展 (QE)", value=False) # # use_rerank = st.toggle("⚖️ 深度重排序 (Rerank)", value=True) # # max_history_rounds = st.slider("🧠 记忆轮数", 0, 50, 10, help="发给模型的对话历史轮数") # # with col2: # # strategy = st.radio("🔍 检索策略", ["Vector", "BM25", "Hybrid"], index=2) # # top_k = st.slider("📊 检索数量", 1, 50, 10, help="从知识库召回的片段数量") # # return {'use_qe': use_qe, 'strategy': strategy, 'use_rerank': use_rerank, 'top_k': top_k, 'max_history_rounds': max_history_rounds} # # def render_metrics(metrics): # # st.markdown("### 📊 性能指标") # # cols = st.columns(4) # # with cols[0]: st.metric("查询扩展", f"{metrics['qe_time']:.2f}s" if metrics['qe_time']>0 else "N/A", delta="QE" if metrics['qe_time']>0 else None) # # with cols[1]: st.metric("检索耗时", f"{metrics['retrieval_time']:.2f}s") # # with cols[2]: st.metric("重排序", f"{metrics['rerank_time']:.2f}s" if metrics['rerank_time']>0 else "N/A", delta="Rerank" if metrics['rerank_time']>0 else None) # # with cols[3]: st.metric("总耗时", f"{metrics['total_time']:.2f}s", delta="⚡") # # def render_documents(documents, strategy_tags): # # st.markdown("### 📄 检索结果") # # if not documents: # # st.warning("未找到相关文档") # # return # # tags_html = "".join([f'{t}' for t in strategy_tags]) # Simplified for brevity, use full logic if copying # # # Manual mapping for safety # # html_tags = "" # # for tag in strategy_tags: # # cls = "tag-vec" if tag=="Vector" else "tag-bm25" if tag=="BM25" else "tag-qe" if tag=="QE" else "tag-rerank" # # html_tags += f'{tag}' # # st.markdown(f"**策略组合:** {html_tags}", unsafe_allow_html=True) # # for i, doc in enumerate(documents): # # score = doc.get('final_score', 0) # # with st.expander(f"📄 文档 {i+1} | Score: {score:.4f} | {doc['filename'][:40]}...", expanded=i<2): # # st.code(doc['content'], language="markdown") # # # ================= 5. 主程序 ================= # # def main(): # # # 状态初始化 # # if "messages" not in st.session_state: st.session_state.messages = [] # # if "last_result" not in st.session_state: st.session_state.last_result = None # # if "suggestions" not in st.session_state: st.session_state.suggestions = random.sample(PRESET_QUESTIONS, 3) # # if "prompt_trigger" not in st.session_state: st.session_state.prompt_trigger = None # # # 加载控制器 # # controller = initialize_controller() # # # 侧边栏 # # with st.sidebar: # # config = render_strategy_matrix() # # st.markdown("---") # # if st.button("🗑️ 清空当前对话", use_container_width=True): # # st.session_state.messages = [] # # st.session_state.last_result = None # # st.rerun() # # # 主界面布局 # # main_col, debug_col = st.columns([0.6, 0.4], gap="large") # # with main_col: # # st.markdown("### 💬 智能仿真问答") # # # 1. 历史消息 # # for msg in st.session_state.messages: # # with st.chat_message(msg["role"]): # # st.markdown(msg["content"]) # # # 2. 建议区 # # if st.session_state.suggestions: # # st.markdown("##### 💡 您可能想问:") # # cols = st.columns(3) # # for i, sug in enumerate(st.session_state.suggestions): # # if cols[i].button(sug, use_container_width=True, key=f"sug_{i}"): # # st.session_state.prompt_trigger = sug # # st.rerun() # # # 3. 输入处理 # # user_input = None # # if st.session_state.prompt_trigger: # # user_input = st.session_state.prompt_trigger # # st.session_state.prompt_trigger = None # # else: # # user_input = st.chat_input("请输入您关于 COMSOL 的问题...") # # # 4. 执行逻辑 # # if user_input: # # st.session_state.messages.append({"role": "user", "content": user_input}) # # with st.chat_message("user"): st.markdown(user_input) # # # 检索 # # with st.spinner("🔍 检索知识库中..."): # # result = controller.execute_strategy(user_input, config) # # st.session_state.last_result = result # # # 生成 # # with st.chat_message("assistant"): # # answer = generate_answer( # # controller, user_input, result['documents'], # # st.session_state.messages, config['max_history_rounds'] # # ) # # st.session_state.messages.append({"role": "assistant", "content": answer}) # # # 生成新建议 # # new_sugs = generate_suggestions(controller, user_input, answer) # # st.session_state.suggestions = new_sugs if new_sugs else random.sample(PRESET_QUESTIONS, 3) # # st.rerun() # # with debug_col: # # st.markdown("### 🔍 系统调试视图") # # if st.session_state.last_result: # # res = st.session_state.last_result # # st.info(f"当前查询: {res.get('original_query', 'N/A')}") # # render_metrics(res['metrics']) # # with st.expander("🔧 检索链路详情", expanded=True): # # for step in res['steps']: st.markdown(f"- {step}") # # render_documents(res['documents'], res['strategy_tags']) # # else: # # st.info("等待交互...") # # if __name__ == "__main__": # # main() # import streamlit as st # import pandas as pd # import numpy as np # import jieba # import requests # import os # import time # import json # import re # import random # import subprocess # from openai import OpenAI # from rank_bm25 import BM25Okapi # from sklearn.metrics.pairwise import cosine_similarity # from typing import List, Dict, Tuple, Any # # ================= 1. 全局配置与样式 ================= # # API 配置 (从 HF 环境变量获取) # API_BASE = "https://api.siliconflow.cn/v1" # API_KEY = os.getenv("SILICONFLOW_API_KEY") # # 模型名称配置 # EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" # RERANK_MODEL = "Qwen/Qwen3-Reranker-4B" # GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2" # # QE_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct" # # SUGGEST_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct" # QE_MODEL_NAME = "MiniMaxAI/MiniMax-M2" # SUGGEST_MODEL_NAME = "MiniMaxAI/MiniMax-M2" # # 预置问题池 # PRESET_QUESTIONS = [ # "如何设置流固耦合接口?", # "求解器不收敛怎么办?", # "网格划分有哪些技巧?", # "如何定义随时间变化的边界条件?", # "计算结果如何导出数据?", # "什么是完美匹配层 (PML)?", # "低频电磁场仿真注意事项", # "如何提高瞬态计算速度?", # "参数化扫描如何设置?", # "多物理场耦合的收敛性优化" # ] # # 数据文件配置 # DATA_FILENAME = "comsol_embedded.parquet" # DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet" # # 页面配置 # st.set_page_config( # page_title="COMSOL RAG 策略控制台", # page_icon="🎛️", # layout="wide", # initial_sidebar_state="expanded" # ) # # 自定义CSS样式 (适配深色/浅色模式) # st.markdown(""" # # """, unsafe_allow_html=True) # # ================= 2. 数据下载工具 (HF 适配) ================= # def download_with_curl(url, output_path): # """使用 curl 下载文件,增加鲁棒性""" # try: # cmd = [ # "curl", "-L", # "-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", # "-o", output_path, # "--fail", # url # ] # result = subprocess.run(cmd, capture_output=True, text=True) # if result.returncode != 0: # print(f"Curl stderr: {result.stderr}") # return False # return True # except Exception as e: # print(f"Curl download error: {e}") # return False # def get_data_file_path(): # """获取数据文件路径,如果不存在则自动下载""" # # 优先检查本地可能存在的路径 # possible_paths = [ # DATA_FILENAME, # os.path.join("/app", DATA_FILENAME), # os.path.join("processed_data", DATA_FILENAME), # os.path.join(os.getcwd(), DATA_FILENAME) # ] # for path in possible_paths: # if os.path.exists(path): # return path # # 如果都没找到,准备下载 # # HF Spaces 通常在 /home/user/app 下运行,直接下载到当前目录 # download_target = os.path.join(os.getcwd(), DATA_FILENAME) # status_container = st.empty() # status_container.info("📡 正在接入神经元网络... (下载核心数据中,首次运行可能需要几十秒)") # # 尝试 Curl 下载 # if download_with_curl(DATA_URL, download_target): # status_container.empty() # return download_target # # 降级尝试 Requests 下载 # try: # headers = {'User-Agent': 'Mozilla/5.0'} # r = requests.get(DATA_URL, headers=headers, stream=True) # r.raise_for_status() # with open(download_target, 'wb') as f: # for chunk in r.iter_content(chunk_size=8192): # f.write(chunk) # status_container.empty() # return download_target # except Exception as e: # st.error(f"❌ 数据下载失败。Error: {e}") # st.stop() # # ================= 3. 核心 RAG 控制器 ================= # class RAGController: # """RAG系统控制器 - 实现策略矩阵""" # def __init__(self): # """初始化控制器""" # if not API_KEY: # st.error("⚠️ 未检测到 API Key。请在 Space Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。") # st.stop() # self.client = OpenAI(base_url=API_BASE, api_key=API_KEY) # self.df = None # self.documents = [] # self.embeddings = None # self.bm25 = None # self.filenames = [] # self._load_data() # def _load_data(self): # """加载COMSOL文档数据""" # real_path = get_data_file_path() # try: # # 加载数据 # self.df = pd.read_parquet(real_path) # self.documents = self.df['content'].tolist() # self.filenames = self.df['filename'].tolist() # # 加载向量嵌入 # self.embeddings = np.stack(self.df['embedding'].values) # # 初始化BM25 # tokenized_corpus = [jieba.lcut(str(doc).lower()) for doc in self.documents] # self.bm25 = BM25Okapi(tokenized_corpus) # st.success(f"✅ 成功加载 {len(self.documents)} 条文档") # except Exception as e: # st.error(f"❌ 数据加载失败: {str(e)}") # st.stop() # def get_embedding(self, text: str) -> List[float]: # """获取文本向量嵌入""" # try: # resp = self.client.embeddings.create( # model=EMBEDDING_MODEL, # input=[text] # ) # return resp.data[0].embedding # except Exception as e: # st.warning(f"向量获取失败: {e}") # return [0.0] * 2560 # Qwen3-Embedding-4B dimension fallback # def expand_query(self, query: str) -> Tuple[str, float]: # """查询扩展 - 使用LLM优化查询""" # prompt = f"""你是COMSOL仿真专家。请将用户的口语化问题改写为专业的检索查询。 # 要求: # 1. 补充COMSOL专业术语(物理场、模块、边界条件等) # 2. 保持问题核心意图不变 # 3. 输出简洁,仅返回改写后的查询 # 用户问题: {query} # 专业查询:""" # try: # start_time = time.time() # resp = self.client.chat.completions.create( # model=QE_MODEL_NAME, # messages=[{"role": "user", "content": prompt}], # temperature=0.3 # ) # expanded = resp.choices[0].message.content.strip() # elapsed = time.time() - start_time # return expanded, elapsed # except Exception as e: # print(f"QE Error: {e}") # return query, 0 # def vector_search(self, query: str, top_k: int = 100) -> List[Tuple[int, float]]: # """向量检索""" # q_vec = self.get_embedding(query) # similarities = cosine_similarity([q_vec], self.embeddings)[0] # top_indices = np.argsort(similarities)[-top_k:][::-1] # return [(idx, similarities[idx]) for idx in top_indices] # def bm25_search(self, query: str, top_k: int = 100) -> List[Tuple[int, float]]: # """BM25关键词检索""" # tokenized_query = jieba.lcut(query.lower()) # scores = self.bm25.get_scores(tokenized_query) # top_indices = np.argsort(scores)[-top_k:][::-1] # return [(idx, scores[idx]) for idx in top_indices] # def reciprocal_rank_fusion(self, vector_results: List[Tuple[int, float]], # bm25_results: List[Tuple[int, float]], k: int = 60) -> Dict[int, float]: # """RRF融合算法""" # scores = {} # for rank, (idx, score) in enumerate(vector_results): # scores[idx] = scores.get(idx, 0) + 1.0 / (k + rank + 1) # for rank, (idx, score) in enumerate(bm25_results): # scores[idx] = scores.get(idx, 0) + 1.0 / (k + rank + 1) # return scores # def rerank_documents(self, query: str, documents: List[Dict], top_n: int) -> Tuple[List[Dict], float]: # """使用重排序模型""" # if not documents: return [], 0 # url = f"{API_BASE}/rerank" # headers = { # "Authorization": f"Bearer {API_KEY}", # "Content-Type": "application/json" # } # # 截断文档内容以符合 Context Window # docs_content = [doc["content"][:2048] for doc in documents] # payload = { # "model": RERANK_MODEL, # "query": query, # "documents": docs_content, # "top_n": top_n # } # try: # start_time = time.time() # response = requests.post(url, headers=headers, json=payload, timeout=20) # elapsed = time.time() - start_time # if response.status_code == 200: # results = response.json().get("results", []) # reranked_docs = [] # for result in results: # original_doc = documents[result["index"]] # original_doc["rerank_score"] = result["relevance_score"] # original_doc["final_score"] = result["relevance_score"] # reranked_docs.append(original_doc) # return reranked_docs, elapsed # else: # print(f"Rerank API Error: {response.text}") # return documents[:top_n], elapsed # except Exception as e: # print(f"Rerank Exception: {e}") # return documents[:top_n], 0 # def execute_strategy(self, query: str, config: Dict[str, Any]) -> Dict[str, Any]: # """执行策略矩阵""" # start_time = time.time() # result = { # 'original_query': query, # 'final_query': query, # 'documents': [], # 'steps': [], # 'metrics': {'qe_time': 0, 'retrieval_time': 0, 'rerank_time': 0, 'total_time': 0}, # 'strategy_tags': [] # } # # 1. 查询扩展 # if config['use_qe']: # expanded_q, qe_time = self.expand_query(query) # result['final_query'] = expanded_q # result['metrics']['qe_time'] = qe_time # result['steps'].append(f"🧠 查询扩展 ({qe_time:.2f}s): {query} → **{expanded_q}**") # result['strategy_tags'].append("QE") # # 2. 检索 # retrieval_start = time.time() # query_to_search = result['final_query'] # if config['strategy'] == 'Vector': # results = self.vector_search(query_to_search) # result['steps'].append(f"🔍 向量检索: 找到 {len(results)} 个候选") # result['strategy_tags'].append("Vector") # elif config['strategy'] == 'BM25': # results = self.bm25_search(query_to_search) # result['steps'].append(f"🔍 BM25检索: 找到 {len(results)} 个候选") # result['strategy_tags'].append("BM25") # elif config['strategy'] == 'Hybrid': # vec_results = self.vector_search(query_to_search) # bm25_results = self.bm25_search(query_to_search) # fused_scores = self.reciprocal_rank_fusion(vec_results, bm25_results) # results = sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) # results = [(idx, score) for idx, score in results] # result['steps'].append(f"🔍 混合检索: Vector + BM25 → {len(results)} 个融合候选") # result['strategy_tags'].extend(["Vector", "BM25"]) # result['metrics']['retrieval_time'] = time.time() - retrieval_start # # 3. 构建候选列表 # recall_k = config['top_k'] * 3 if config['use_rerank'] else config['top_k'] # top_results = results[:recall_k] # documents = [] # for idx, score in top_results: # documents.append({ # 'content': self.documents[idx], # 'filename': self.filenames[idx], # 'retrieval_score': score, # 'final_score': score, # 'type': 'retrieval' # }) # # 4. 重排序 # if config['use_rerank']: # reranked_docs, rerank_time = self.rerank_documents( # result['final_query'], documents, config['top_k'] # ) # result['documents'] = reranked_docs # result['metrics']['rerank_time'] = rerank_time # result['steps'].append(f"⚖️ 重排序 ({rerank_time:.2f}s): 精选 Top-{config['top_k']}") # result['strategy_tags'].append("Rerank") # else: # result['documents'] = documents[:config['top_k']] # result['metrics']['total_time'] = time.time() - start_time # result['steps'].append(f"⏱️ 总耗时: {result['metrics']['total_time']:.2f}s") # return result # def generate_suggestions(controller, query: str, answer: str) -> List[str]: # """生成3个后续引导问题""" # prompt = f"""基于以下技术问答,预测用户可能感兴趣的3个后续COMSOL专业问题。 # 用户问题:{query} # 专家回答:{answer[:800]}... # 要求: # 1. 问题简短(15字以内)。 # 2. 紧扣当前话题。 # 3. 严格输出 JSON 字符串数组格式,例如:["问题1", "问题2", "问题3"]。 # 4. 不要包含任何 Markdown 标记。 # """ # try: # resp = controller.client.chat.completions.create( # model=SUGGEST_MODEL_NAME, # messages=[{"role": "user", "content": prompt}], # temperature=0.5 # ) # content = resp.choices[0].message.content.strip() # match = re.search(r'\[.*\]', content, re.DOTALL) # if match: # sugs = json.loads(match.group()) # return sugs[:3] # return [] # except Exception as e: # print(f"Suggestion Error: {e}") # return [] # def generate_answer(controller, query: str, documents: List[Dict], history: List[Dict], max_rounds: int) -> str: # """流式生成回答""" # if not documents: # return "抱歉,没有找到相关的文档来回答您的问题。" # context_text = "\n\n".join([f"[文档{i+1}] {doc['content'][:800]}..." for i, doc in enumerate(documents)]) # system_prompt = f"""你是一个COMSOL Multiphysics仿真专家。请基于提供的文档回答用户问题。 # 要求: # 1. 语气专业,使用COMSOL术语。 # 2. 物理公式使用 LaTeX(如 $E=mc^2$)。 # 3. 如果文档信息不足,请如实告知,不要编造。 # 【参考文档】: # {context_text} # """ # # 构建历史记录 # keep_messages = max_rounds * 2 # history_to_send = history[:-1][-keep_messages:] if keep_messages > 0 else [] # api_messages = [{"role": "system", "content": system_prompt}] + history_to_send + [{"role": "user", "content": query}] # try: # response = controller.client.chat.completions.create( # model=GEN_MODEL_NAME, # messages=api_messages, # temperature=0.3, # stream=True # ) # answer = "" # placeholder = st.empty() # for chunk in response: # if chunk.choices[0].delta.content: # answer += chunk.choices[0].delta.content # placeholder.markdown(answer + "▌") # placeholder.markdown(answer) # return answer # except Exception as e: # return f"生成遇到错误: {e}" # # ================= 4. 初始化与组件渲染 ================= # @st.cache_resource(show_spinner="🚀 正在初始化 RAG 引擎...") # def initialize_controller(): # return RAGController() # def render_strategy_matrix(): # st.markdown('

🎯 策略矩阵配置

', unsafe_allow_html=True) # st.markdown("""
#

⚙️ 参数调节:控制检索片段数量和模型记忆深度。

#
""", unsafe_allow_html=True) # col1, col2 = st.columns(2) # with col1: # use_qe = st.toggle("🔄 查询扩展 (QE)", value=False) # use_rerank = st.toggle("⚖️ 深度重排序 (Rerank)", value=True) # max_history_rounds = st.slider("🧠 记忆轮数", 0, 50, 10, help="发给模型的对话历史轮数") # with col2: # strategy = st.radio("🔍 检索策略", ["Vector", "BM25", "Hybrid"], index=2) # top_k = st.slider("📊 检索数量", 1, 50, 10, help="从知识库召回的片段数量") # return {'use_qe': use_qe, 'strategy': strategy, 'use_rerank': use_rerank, 'top_k': top_k, 'max_history_rounds': max_history_rounds} # def render_metrics(metrics): # st.markdown("### 📊 性能指标") # cols = st.columns(4) # with cols[0]: st.metric("查询扩展", f"{metrics['qe_time']:.2f}s" if metrics['qe_time']>0 else "N/A", delta="QE" if metrics['qe_time']>0 else None) # with cols[1]: st.metric("检索耗时", f"{metrics['retrieval_time']:.2f}s") # with cols[2]: st.metric("重排序", f"{metrics['rerank_time']:.2f}s" if metrics['rerank_time']>0 else "N/A", delta="Rerank" if metrics['rerank_time']>0 else None) # with cols[3]: st.metric("总耗时", f"{metrics['total_time']:.2f}s", delta="⚡") # def render_documents(documents, strategy_tags): # st.markdown("### 📄 检索结果") # if not documents: # st.warning("未找到相关文档") # return # tags_html = "".join([f'{t}' for t in strategy_tags]) # Simplified for brevity, use full logic if copying # # Manual mapping for safety # html_tags = "" # for tag in strategy_tags: # cls = "tag-vec" if tag=="Vector" else "tag-bm25" if tag=="BM25" else "tag-qe" if tag=="QE" else "tag-rerank" # html_tags += f'{tag}' # st.markdown(f"**策略组合:** {html_tags}", unsafe_allow_html=True) # for i, doc in enumerate(documents): # score = doc.get('final_score', 0) # with st.expander(f"📄 文档 {i+1} | Score: {score:.4f} | {doc['filename'][:40]}...", expanded=i<2): # st.code(doc['content'], language="markdown") # # ================= 5. 主程序 ================= # def main(): # # 状态初始化 # if "messages" not in st.session_state: st.session_state.messages = [] # if "last_result" not in st.session_state: st.session_state.last_result = None # if "suggestions" not in st.session_state: st.session_state.suggestions = random.sample(PRESET_QUESTIONS, 3) # if "prompt_trigger" not in st.session_state: st.session_state.prompt_trigger = None # # 加载控制器 # controller = initialize_controller() # # 侧边栏 # with st.sidebar: # config = render_strategy_matrix() # st.markdown("---") # if st.button("🗑️ 清空当前对话", use_container_width=True): # st.session_state.messages = [] # st.session_state.last_result = None # st.rerun() # # 主界面布局 # main_col, debug_col = st.columns([0.6, 0.4], gap="large") # with main_col: # st.markdown("### 💬 智能仿真问答") # # 1. 历史消息 # for msg in st.session_state.messages: # with st.chat_message(msg["role"]): # st.markdown(msg["content"]) # # 2. 建议区 # if st.session_state.suggestions: # st.markdown("##### 💡 您可能想问:") # cols = st.columns(3) # for i, sug in enumerate(st.session_state.suggestions): # if cols[i].button(sug, use_container_width=True, key=f"sug_{i}"): # st.session_state.prompt_trigger = sug # st.rerun() # # 3. 输入处理 # user_input = None # if st.session_state.prompt_trigger: # user_input = st.session_state.prompt_trigger # st.session_state.prompt_trigger = None # else: # user_input = st.chat_input("请输入您关于 COMSOL 的问题...") # # 4. 执行逻辑 # if user_input: # st.session_state.messages.append({"role": "user", "content": user_input}) # with st.chat_message("user"): st.markdown(user_input) # # 检索 # with st.spinner("🔍 检索知识库中..."): # result = controller.execute_strategy(user_input, config) # st.session_state.last_result = result # # 生成 # with st.chat_message("assistant"): # answer = generate_answer( # controller, user_input, result['documents'], # st.session_state.messages, config['max_history_rounds'] # ) # st.session_state.messages.append({"role": "assistant", "content": answer}) # # 生成新建议 # new_sugs = generate_suggestions(controller, user_input, answer) # st.session_state.suggestions = new_sugs if new_sugs else random.sample(PRESET_QUESTIONS, 3) # st.rerun() # with debug_col: # st.markdown("### 🔍 系统调试视图") # if st.session_state.last_result: # res = st.session_state.last_result # st.info(f"当前查询: {res.get('original_query', 'N/A')}") # render_metrics(res['metrics']) # with st.expander("🔧 检索链路详情", expanded=True): # for step in res['steps']: st.markdown(f"- {step}") # render_documents(res['documents'], res['strategy_tags']) # else: # st.info("等待交互...") # if __name__ == "__main__": # main() import streamlit as st import pandas as pd import numpy as np import jieba import requests import os import time import json import re import random import subprocess import logging import psutil from openai import OpenAI from rank_bm25 import BM25Okapi from sklearn.metrics.pairwise import cosine_similarity from typing import List, Dict, Tuple, Any # ================= 0. 日志与内存监控配置 ================= # 配置日志格式 - 日志会显示在 HF Space 的 "Logs" 标签页 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def log_memory(): """记录当前内存占用情况 (单位: MB)""" process = psutil.Process(os.getpid()) mem_info = process.memory_info() res_mem = mem_info.rss / (1024 * 1024) logger.info(f"💾 Current Memory Usage: {res_mem:.2f} MB") return res_mem # ================= 1. 全局配置与样式 ================= # API 配置 (从 HF 环境变量获取) API_BASE = "https://api.siliconflow.cn/v1" API_KEY = os.getenv("SILICONFLOW_API_KEY") # 模型名称配置 EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" RERANK_MODEL = "Qwen/Qwen3-Reranker-4B" GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2" QE_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct" SUGGEST_MODEL_NAME = "Qwen/Qwen3-Next-80B-A3B-Instruct" # 预置问题池 PRESET_QUESTIONS = [ "如何设置流固耦合接口?", "求解器不收敛怎么办?", "网格划分有哪些技巧?", "如何定义随时间变化的边界条件?", "计算结果如何导出数据?", "什么是完美匹配层 (PML)?", "低频电磁场仿真注意事项", "如何提高瞬态计算速度?", "参数化扫描如何设置?", "多物理场耦合的收敛性优化" ] # 数据文件配置 DATA_FILENAME = "comsol_embedded.parquet" DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet" # 页面配置 st.set_page_config( page_title="COMSOL RAG 策略控制台", page_icon="🎛️", layout="wide", initial_sidebar_state="expanded" ) # 自定义CSS样式 (适配深色/浅色模式) st.markdown(""" """, unsafe_allow_html=True) # ================= 2. 数据下载工具 (HF 适配) ================= def download_with_curl(url, output_path): """使用 curl 下载文件,增加鲁棒性""" try: cmd = [ "curl", "-L", "-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", "-o", output_path, "--fail", url ] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: logger.warning(f"Curl stderr: {result.stderr}") return False return True except Exception as e: logger.error(f"Curl download error: {e}") return False def get_data_file_path(): """获取数据文件路径,如果不存在则自动下载""" # 优先检查本地可能存在的路径 possible_paths = [ DATA_FILENAME, os.path.join("/app", DATA_FILENAME), os.path.join("processed_data", DATA_FILENAME), os.path.join(os.getcwd(), DATA_FILENAME) ] for path in possible_paths: if os.path.exists(path): return path # 如果都没找到,准备下载 # HF Spaces 通常在 /home/user/app 下运行,直接下载到当前目录 download_target = os.path.join(os.getcwd(), DATA_FILENAME) status_container = st.empty() status_container.info("📡 正在接入神经元网络... (下载核心数据中,首次运行可能需要几十秒)") # 尝试 Curl 下载 if download_with_curl(DATA_URL, download_target): status_container.empty() return download_target # 降级尝试 Requests 下载 try: headers = {'User-Agent': 'Mozilla/5.0'} r = requests.get(DATA_URL, headers=headers, stream=True) r.raise_for_status() with open(download_target, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) status_container.empty() return download_target except Exception as e: st.error(f"❌ 数据下载失败。Error: {e}") st.stop() # ================= 3. 核心 RAG 控制器 ================= class RAGController: """RAG系统控制器 - 实现策略矩阵""" def __init__(self): """初始化控制器""" if not API_KEY: st.error("⚠️ 未检测到 API Key。请在 Space Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。") st.stop() self.client = OpenAI(base_url=API_BASE, api_key=API_KEY) self.df = None self.documents = [] self.embeddings = None self.bm25 = None self.filenames = [] self._load_data() def _load_data(self): """加载COMSOL文档数据""" real_path = get_data_file_path() try: logger.info("🚀 Initializing RAG Controller...") log_memory() # 加载数据 logger.info(f"📂 Loading parquet from: {real_path}") self.df = pd.read_parquet(real_path) self.documents = self.df['content'].tolist() self.filenames = self.df['filename'].tolist() logger.info(f"✅ Dataframe loaded. Shape: {self.df.shape}") log_memory() # 加载向量嵌入 logger.info("🧠 Stacking embeddings matrix...") self.embeddings = np.stack(self.df['embedding'].values) logger.info(f"✅ Embeddings stacked. Shape: {self.embeddings.shape}") log_memory() # 初始化BM25 logger.info("🔍 Initializing BM25 index (Tokenizing)...") tokenized_corpus = [jieba.lcut(str(doc).lower()) for doc in self.documents] self.bm25 = BM25Okapi(tokenized_corpus) logger.info("✅ BM25 Index ready.") log_memory() st.success(f"✅ 成功加载 {len(self.documents)} 条文档") except Exception as e: logger.error(f"❌ Critical error during data load: {str(e)}", exc_info=True) st.error(f"❌ 数据加载失败: {str(e)}") st.stop() def get_embedding(self, text: str) -> List[float]: """获取文本向量嵌入""" try: resp = self.client.embeddings.create( model=EMBEDDING_MODEL, input=[text] ) return resp.data[0].embedding except Exception as e: st.warning(f"向量获取失败: {e}") return [0.0] * 2560 # Qwen3-Embedding-4B dimension fallback def expand_query(self, query: str) -> Tuple[str, float]: """查询扩展 - 使用LLM优化查询""" prompt = f"""你是COMSOL仿真专家。请将用户的口语化问题改写为专业的检索查询。 要求: 1. 补充COMSOL专业术语(物理场、模块、边界条件等) 2. 保持问题核心意图不变 3. 输出简洁,仅返回改写后的查询 用户问题: {query} 专业查询:""" try: start_time = time.time() resp = self.client.chat.completions.create( model=QE_MODEL_NAME, messages=[{"role": "user", "content": prompt}], temperature=0.3 ) expanded = resp.choices[0].message.content.strip() elapsed = time.time() - start_time logger.info(f"🔧 QE completed in {elapsed:.2f}s") return expanded, elapsed except Exception as e: logger.error(f"❌ QE Error: {e}") return query, 0 def vector_search(self, query: str, top_k: int = 100) -> List[Tuple[int, float]]: """向量检索""" q_vec = self.get_embedding(query) similarities = cosine_similarity([q_vec], self.embeddings)[0] top_indices = np.argsort(similarities)[-top_k:][::-1] return [(idx, similarities[idx]) for idx in top_indices] def bm25_search(self, query: str, top_k: int = 100) -> List[Tuple[int, float]]: """BM25关键词检索""" tokenized_query = jieba.lcut(query.lower()) scores = self.bm25.get_scores(tokenized_query) top_indices = np.argsort(scores)[-top_k:][::-1] return [(idx, scores[idx]) for idx in top_indices] def reciprocal_rank_fusion(self, vector_results: List[Tuple[int, float]], bm25_results: List[Tuple[int, float]], k: int = 60) -> Dict[int, float]: """RRF融合算法""" scores = {} for rank, (idx, score) in enumerate(vector_results): scores[idx] = scores.get(idx, 0) + 1.0 / (k + rank + 1) for rank, (idx, score) in enumerate(bm25_results): scores[idx] = scores.get(idx, 0) + 1.0 / (k + rank + 1) return scores def rerank_documents(self, query: str, documents: List[Dict], top_n: int) -> Tuple[List[Dict], float]: """使用重排序模型""" if not documents: return [], 0 url = f"{API_BASE}/rerank" headers = { "Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json" } # 截断文档内容以符合 Context Window docs_content = [doc["content"][:2048] for doc in documents] payload = { "model": RERANK_MODEL, "query": query, "documents": docs_content, "top_n": top_n } try: start_time = time.time() # 设置 timeout 为 15 秒,防止长时间挂起导致 WebSocket 断开 response = requests.post(url, headers=headers, json=payload, timeout=15) elapsed = time.time() - start_time if response.status_code == 200: results = response.json().get("results", []) reranked_docs = [] for result in results: original_doc = documents[result["index"]] original_doc["rerank_score"] = result["relevance_score"] original_doc["final_score"] = result["relevance_score"] reranked_docs.append(original_doc) return reranked_docs, elapsed else: logger.warning(f"Rerank API Error: {response.text}") return documents[:top_n], elapsed except requests.exceptions.Timeout: logger.warning("⚠️ Rerank API timed out, falling back to original order.") return documents[:top_n], 0 except Exception as e: logger.error(f"❌ Rerank error: {e}") return documents[:top_n], 0 def execute_strategy(self, query: str, config: Dict[str, Any]) -> Dict[str, Any]: """执行策略矩阵""" start_time = time.time() result = { 'original_query': query, 'final_query': query, 'documents': [], 'steps': [], 'metrics': {'qe_time': 0, 'retrieval_time': 0, 'rerank_time': 0, 'total_time': 0}, 'strategy_tags': [] } # 1. 查询扩展 if config['use_qe']: expanded_q, qe_time = self.expand_query(query) result['final_query'] = expanded_q result['metrics']['qe_time'] = qe_time result['steps'].append(f"🧠 查询扩展 ({qe_time:.2f}s): {query} → **{expanded_q}**") result['strategy_tags'].append("QE") # 2. 检索 retrieval_start = time.time() query_to_search = result['final_query'] if config['strategy'] == 'Vector': results = self.vector_search(query_to_search) result['steps'].append(f"🔍 向量检索: 找到 {len(results)} 个候选") result['strategy_tags'].append("Vector") elif config['strategy'] == 'BM25': results = self.bm25_search(query_to_search) result['steps'].append(f"🔍 BM25检索: 找到 {len(results)} 个候选") result['strategy_tags'].append("BM25") elif config['strategy'] == 'Hybrid': vec_results = self.vector_search(query_to_search) bm25_results = self.bm25_search(query_to_search) fused_scores = self.reciprocal_rank_fusion(vec_results, bm25_results) results = sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) results = [(idx, score) for idx, score in results] result['steps'].append(f"🔍 混合检索: Vector + BM25 → {len(results)} 个融合候选") result['strategy_tags'].extend(["Vector", "BM25"]) result['metrics']['retrieval_time'] = time.time() - retrieval_start # 3. 构建候选列表 recall_k = config['top_k'] * 3 if config['use_rerank'] else config['top_k'] top_results = results[:recall_k] documents = [] for idx, score in top_results: documents.append({ 'content': self.documents[idx], 'filename': self.filenames[idx], 'retrieval_score': score, 'final_score': score, 'type': 'retrieval' }) # 4. 重排序 if config['use_rerank']: reranked_docs, rerank_time = self.rerank_documents( result['final_query'], documents, config['top_k'] ) result['documents'] = reranked_docs result['metrics']['rerank_time'] = rerank_time result['steps'].append(f"⚖️ 重排序 ({rerank_time:.2f}s): 精选 Top-{config['top_k']}") result['strategy_tags'].append("Rerank") else: result['documents'] = documents[:config['top_k']] result['metrics']['total_time'] = time.time() - start_time result['steps'].append(f"⏱️ 总耗时: {result['metrics']['total_time']:.2f}s") return result def generate_suggestions(controller, query: str, answer: str) -> List[str]: """生成3个后续引导问题""" prompt = f"""基于以下技术问答,预测用户可能感兴趣的3个后续COMSOL专业问题。 用户问题:{query} 专家回答:{answer[:800]}... 要求: 1. 问题简短(15字以内)。 2. 紧扣当前话题。 3. 严格输出 JSON 字符串数组格式,例如:["问题1", "问题2", "问题3"]。 4. 不要包含任何 Markdown 标记。 """ try: resp = controller.client.chat.completions.create( model=SUGGEST_MODEL_NAME, messages=[{"role": "user", "content": prompt}], temperature=0.5 ) content = resp.choices[0].message.content.strip() match = re.search(r'\[.*\]', content, re.DOTALL) if match: sugs = json.loads(match.group()) return sugs[:3] return [] except Exception as e: logger.error(f"Suggestion Error: {e}") return [] def generate_answer(controller, query: str, documents: List[Dict], history: List[Dict], max_rounds: int) -> str: """流式生成回答""" if not documents: return "抱歉,没有找到相关的文档来回答您的问题。" context_text = "\n\n".join([f"[文档{i+1}] {doc['content'][:800]}..." for i, doc in enumerate(documents)]) system_prompt = f"""你是一个COMSOL Multiphysics仿真专家。请基于提供的文档回答用户问题。 要求: 1. 语气专业,使用COMSOL术语。 2. 物理公式使用 LaTeX(如 $E=mc^2$)。 3. 如果文档信息不足,请如实告知,不要编造。 【参考文档】: {context_text} """ # 构建历史记录 keep_messages = max_rounds * 2 history_to_send = history[:-1][-keep_messages:] if keep_messages > 0 else [] api_messages = [{"role": "system", "content": system_prompt}] + history_to_send + [{"role": "user", "content": query}] try: response = controller.client.chat.completions.create( model=GEN_MODEL_NAME, messages=api_messages, temperature=0.3, stream=True ) answer = "" placeholder = st.empty() for chunk in response: if chunk.choices[0].delta.content: answer += chunk.choices[0].delta.content placeholder.markdown(answer + "▌") placeholder.markdown(answer) return answer except Exception as e: return f"生成遇到错误: {e}" # ================= 4. 初始化与组件渲染 ================= @st.cache_resource(show_spinner="🚀 正在初始化 RAG 引擎...") def initialize_controller(): return RAGController() def render_strategy_matrix(): st.markdown('

🎯 策略矩阵配置

', unsafe_allow_html=True) st.markdown("""

⚙️ 参数调节:控制检索片段数量和模型记忆深度。

""", unsafe_allow_html=True) col1, col2 = st.columns(2) with col1: use_qe = st.toggle("🔄 查询扩展 (QE)", value=False) use_rerank = st.toggle("⚖️ 深度重排序 (Rerank)", value=True) max_history_rounds = st.slider("🧠 记忆轮数", 0, 50, 10, help="发给模型的对话历史轮数") with col2: strategy = st.radio("🔍 检索策略", ["Vector", "BM25", "Hybrid"], index=2) top_k = st.slider("📊 检索数量", 1, 50, 10, help="从知识库召回的片段数量") return {'use_qe': use_qe, 'strategy': strategy, 'use_rerank': use_rerank, 'top_k': top_k, 'max_history_rounds': max_history_rounds} def render_metrics(metrics): st.markdown("### 📊 性能指标") cols = st.columns(4) with cols[0]: st.metric("查询扩展", f"{metrics['qe_time']:.2f}s" if metrics['qe_time']>0 else "N/A", delta="QE" if metrics['qe_time']>0 else None) with cols[1]: st.metric("检索耗时", f"{metrics['retrieval_time']:.2f}s") with cols[2]: st.metric("重排序", f"{metrics['rerank_time']:.2f}s" if metrics['rerank_time']>0 else "N/A", delta="Rerank" if metrics['rerank_time']>0 else None) with cols[3]: st.metric("总耗时", f"{metrics['total_time']:.2f}s", delta="⚡") def render_documents(documents, strategy_tags): st.markdown("### 📄 检索结果") if not documents: st.warning("未找到相关文档") return tags_html = "".join([f'{t}' for t in strategy_tags]) # Simplified for brevity, use full logic if copying # Manual mapping for safety html_tags = "" for tag in strategy_tags: cls = "tag-vec" if tag=="Vector" else "tag-bm25" if tag=="BM25" else "tag-qe" if tag=="QE" else "tag-rerank" html_tags += f'{tag}' st.markdown(f"**策略组合:** {html_tags}", unsafe_allow_html=True) for i, doc in enumerate(documents): score = doc.get('final_score', 0) with st.expander(f"📄 文档 {i+1} | Score: {score:.4f} | {doc['filename'][:40]}...", expanded=i<2): st.code(doc['content'], language="markdown") # ================= 5. 主程序 ================= def main(): # 状态初始化 if "messages" not in st.session_state: st.session_state.messages = [] if "last_result" not in st.session_state: st.session_state.last_result = None if "suggestions" not in st.session_state: st.session_state.suggestions = random.sample(PRESET_QUESTIONS, 3) if "prompt_trigger" not in st.session_state: st.session_state.prompt_trigger = None # 加载控制器 controller = initialize_controller() # 侧边栏 with st.sidebar: config = render_strategy_matrix() st.markdown("---") if st.button("🗑️ 清空当前对话", use_container_width=True): st.session_state.messages = [] st.session_state.last_result = None st.rerun() # 主界面布局 main_col, debug_col = st.columns([0.6, 0.4], gap="large") with main_col: st.markdown("### 💬 智能仿真问答") # 1. 历史消息 for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.markdown(msg["content"]) # 2. 建议区 if st.session_state.suggestions: st.markdown("##### 💡 您可能想问:") cols = st.columns(3) for i, sug in enumerate(st.session_state.suggestions): if cols[i].button(sug, use_container_width=True, key=f"sug_{i}"): logger.info(f"🔘 Triggered by button: {sug}") st.session_state.prompt_trigger = sug st.rerun() # 3. 输入处理 user_input = None if st.session_state.prompt_trigger: user_input = st.session_state.prompt_trigger st.session_state.prompt_trigger = None # 立即清除,防止重复触发 logger.info(f"🔘 Triggered by button: {user_input}") else: user_input = st.chat_input("请输入您关于 COMSOL 的问题...") if user_input: logger.info(f"⌨️ Triggered by chat input: {user_input}") # 4. 执行逻辑 if user_input: st.session_state.messages.append({"role": "user", "content": user_input}) with st.chat_message("user"): st.markdown(user_input) # 检索 with st.spinner("🔍 检索知识库中..."): logger.info(f"🔎 Starting retrieval for: {user_input[:50]}...") result = controller.execute_strategy(user_input, config) st.session_state.last_result = result logger.info(f"✅ Retrieval done in {result['metrics']['total_time']:.2f}s") # 生成 with st.chat_message("assistant"): logger.info("🤖 Generating answer...") answer = generate_answer( controller, user_input, result['documents'], st.session_state.messages, config['max_history_rounds'] ) st.session_state.messages.append({"role": "assistant", "content": answer}) # 生成新建议 (不需要 rerun,Streamlit 会在脚本结束后自动刷新 UI) logger.info("✨ Generating follow-up questions...") new_sugs = generate_suggestions(controller, user_input, answer) st.session_state.suggestions = new_sugs if new_sugs else random.sample(PRESET_QUESTIONS, 3) logger.info(f"✅ Response completed in {result['metrics']['total_time']:.2f}s") with debug_col: st.markdown("### 🔍 系统调试视图") if st.session_state.last_result: res = st.session_state.last_result st.info(f"当前查询: {res.get('original_query', 'N/A')}") render_metrics(res['metrics']) with st.expander("🔧 检索链路详情", expanded=True): for step in res['steps']: st.markdown(f"- {step}") render_documents(res['documents'], res['strategy_tags']) else: st.info("等待交互...") if __name__ == "__main__": main()