Spaces:
Running
Running
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import jieba | |
| import requests | |
| import os | |
| import sys | |
| import subprocess | |
| from openai import OpenAI | |
| from rank_bm25 import BM25Okapi | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| # ================= 1. 全局配置与 CSS注入 ================= | |
| API_KEY = os.getenv("SILICONFLOW_API_KEY") | |
| API_BASE = "https://api.siliconflow.cn/v1" | |
| EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B" | |
| RERANK_MODEL = "Qwen/Qwen3-Reranker-4B" | |
| GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2" | |
| DATA_FILENAME = "comsol_embedded.parquet" | |
| DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet" | |
| st.set_page_config( | |
| page_title="COMSOL Dark Expert", | |
| page_icon="🌌", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # --- 注入自定义 CSS (保持之前的审美) --- | |
| st.markdown(""" | |
| <style> | |
| /* 1. 整体背景 - 深空黑 */ | |
| .stApp { | |
| background-color: #050505; | |
| background-image: radial-gradient(circle at 50% 0%, #1a1f35 0%, #050505 60%); | |
| color: #e0e0e0; | |
| font-family: 'Inter', system-ui, -apple-system, sans-serif; | |
| } | |
| /* 2. 隐藏默认组件 */ | |
| #MainMenu {visibility: hidden;} | |
| footer {visibility: hidden;} | |
| header {visibility: hidden;} | |
| /* 3. 聊天气泡 */ | |
| [data-testid="stChatMessage"] { | |
| background: rgba(255, 255, 255, 0.03); | |
| border: 1px solid rgba(255, 255, 255, 0.08); | |
| border-radius: 16px; | |
| backdrop-filter: blur(12px); | |
| box-shadow: 0 4px 20px rgba(0,0,0,0.2); | |
| padding: 1.2rem; | |
| } | |
| /* 用户气泡 */ | |
| [data-testid="stChatMessage"][data-testid="user"] { | |
| background: rgba(41, 181, 232, 0.1); | |
| border-color: rgba(41, 181, 232, 0.2); | |
| } | |
| /* 4. 自定义标题栏 */ | |
| .custom-header { | |
| border-bottom: 1px solid rgba(255,255,255,0.1); | |
| padding-bottom: 1rem; | |
| margin-bottom: 2rem; | |
| display: flex; | |
| align-items: center; | |
| gap: 1rem; | |
| } | |
| .glitch-text { | |
| font-size: 2rem; | |
| font-weight: 800; | |
| background: linear-gradient(120deg, #fff, #29B5E8); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| letter-spacing: -1px; | |
| } | |
| /* 5. 快捷按钮 */ | |
| div.stButton > button { | |
| background: rgba(255,255,255,0.05); | |
| color: #aaa; | |
| border: 1px solid rgba(255,255,255,0.1); | |
| border-radius: 20px; | |
| padding: 0.5rem 1rem; | |
| font-size: 0.85rem; | |
| transition: all 0.3s; | |
| width: 100%; | |
| } | |
| div.stButton > button:hover { | |
| background: rgba(41, 181, 232, 0.2); | |
| color: #fff; | |
| border-color: #29B5E8; | |
| transform: translateY(-2px); | |
| } | |
| /* 6. 输入框 */ | |
| .stChatInputContainer textarea { | |
| background-color: #0f1115 !important; | |
| border: 1px solid #333 !important; | |
| color: white !important; | |
| border-radius: 12px !important; | |
| } | |
| /* 7. Expander */ | |
| .streamlit-expanderHeader { | |
| background-color: rgba(255,255,255,0.02); | |
| border: 1px solid rgba(255,255,255,0.05); | |
| border-radius: 8px; | |
| color: #bbb; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ================= 2. 核心逻辑(数据与RAG) ================= | |
| if not API_KEY: | |
| st.error("⚠️ 未检测到 API Key。请在 Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。") | |
| st.stop() | |
| def download_with_curl(url, output_path): | |
| try: | |
| cmd = [ | |
| "curl", "-L", | |
| "-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", | |
| "-o", output_path, | |
| "--fail", | |
| url | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| if result.returncode != 0: raise Exception(f"Curl failed: {result.stderr}") | |
| return True | |
| except Exception as e: | |
| print(f"Curl download error: {e}") | |
| return False | |
| def get_data_file_path(): | |
| possible_paths = [ | |
| DATA_FILENAME, os.path.join("/app", DATA_FILENAME), | |
| os.path.join("processed_data", DATA_FILENAME), | |
| os.path.join("src", DATA_FILENAME), | |
| os.path.join("..", DATA_FILENAME), "/tmp/" + DATA_FILENAME | |
| ] | |
| for path in possible_paths: | |
| if os.path.exists(path): return path | |
| download_target = "/app/" + DATA_FILENAME | |
| try: os.makedirs(os.path.dirname(download_target), exist_ok=True) | |
| except: download_target = "/tmp/" + DATA_FILENAME | |
| status_container = st.empty() | |
| status_container.info("📡 正在接入神经元网络... (下载核心数据中)") | |
| if download_with_curl(DATA_URL, download_target): | |
| status_container.empty() | |
| return download_target | |
| try: | |
| headers = {'User-Agent': 'Mozilla/5.0'} | |
| r = requests.get(DATA_URL, headers=headers, stream=True) | |
| r.raise_for_status() | |
| with open(download_target, 'wb') as f: | |
| for chunk in r.iter_content(chunk_size=8192): f.write(chunk) | |
| status_container.empty() | |
| return download_target | |
| except Exception as e: | |
| st.error(f"❌ 数据链路中断。Error: {e}") | |
| st.stop() | |
| class FullRetriever: | |
| def __init__(self, parquet_path): | |
| try: self.df = pd.read_parquet(parquet_path) | |
| except Exception as e: st.error(f"Memory Matrix Load Failed: {e}"); st.stop() | |
| self.documents = self.df['content'].tolist() | |
| self.embeddings = np.stack(self.df['embedding'].values) | |
| self.bm25 = BM25Okapi([jieba.lcut(str(d).lower()) for d in self.documents]) | |
| self.client = OpenAI(base_url=API_BASE, api_key=API_KEY) | |
| # Reranker 初始化移到这里,减少重复调用 | |
| self.rerank_headers = {"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"} | |
| self.rerank_url = f"{API_BASE}/rerank" | |
| def _get_emb(self, q): | |
| try: return self.client.embeddings.create(model=EMBEDDING_MODEL, input=[q]).data[0].embedding | |
| except: return [0.0] * 1024 | |
| def hybrid_search(self, query: str, top_k=5): | |
| # 1. Vector | |
| q_emb = self._get_emb(query) | |
| vec_scores = cosine_similarity([q_emb], self.embeddings)[0] | |
| vec_idx = np.argsort(vec_scores)[-100:][::-1] | |
| # 2. Keyword | |
| kw_idx = np.argsort(self.bm25.get_scores(jieba.lcut(query.lower())))[-100:][::-1] | |
| # 3. RRF Fusion | |
| fused = {} | |
| for r, i in enumerate(vec_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1) | |
| for r, i in enumerate(kw_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1) | |
| c_idxs = [x[0] for x in sorted(fused.items(), key=lambda x:x[1], reverse=True)[:50]] | |
| c_docs = [self.documents[i] for i in c_idxs] | |
| # 4. Rerank | |
| try: | |
| payload = {"model": RERANK_MODEL, "query": query, "documents": c_docs, "top_n": top_k} | |
| resp = requests.post(self.rerank_url, headers=self.rerank_headers, json=payload, timeout=10) | |
| results = resp.json().get('results', []) | |
| except: | |
| results = [{"index": i, "relevance_score": 0.0} for i in range(len(c_docs))][:top_k] | |
| final_res = [] | |
| context = "" | |
| for i, item in enumerate(results): | |
| orig_idx = c_idxs[item['index']] | |
| row = self.df.iloc[orig_idx] | |
| final_res.append({ | |
| "score": item['relevance_score'], | |
| "filename": row['filename'], | |
| "content": row['content'] | |
| }) | |
| context += f"[文档{i+1}]: {row['content']}\n\n" | |
| return final_res, context | |
| def load_engine(): | |
| real_path = get_data_file_path() | |
| return FullRetriever(real_path) | |
| # ================= 3. UI 主程序 ================= | |
| def main(): | |
| st.markdown(""" | |
| <div class="custom-header"> | |
| <div style="font-size: 3rem;">🌌</div> | |
| <div> | |
| <div class="glitch-text">COMSOL DARK EXPERT</div> | |
| <div style="color: #666; font-size: 0.9rem; letter-spacing: 1px;"> | |
| NEURAL SIMULATION ASSISTANT <span style="color:#29B5E8">V4.1 Fixed</span> | |
| </div> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| retriever = load_engine() | |
| with st.sidebar: | |
| st.markdown("### ⚙️ 控制台") | |
| top_k = st.slider("检索深度", 1, 10, 4) | |
| temp = st.slider("发散度", 0.0, 1.0, 0.3) | |
| st.markdown("---") | |
| if st.button("🗑️ 清空记忆 (Clear)", use_container_width=True): | |
| st.session_state.messages = [] | |
| st.session_state.current_refs = [] | |
| st.rerun() | |
| if "messages" not in st.session_state: st.session_state.messages = [] | |
| if "current_refs" not in st.session_state: st.session_state.current_refs = [] | |
| col_chat, col_evidence = st.columns([0.65, 0.35], gap="large") | |
| # ------------------ 处理输入源 ------------------ | |
| # 我们定义一个变量 user_input,不管它来自按钮还是输入框 | |
| user_input = None | |
| with col_chat: | |
| # 1. 如果历史为空,显示快捷按钮 | |
| if not st.session_state.messages: | |
| st.markdown("##### 💡 初始化提问序列 (Starter Sequence)") | |
| c1, c2, c3 = st.columns(3) | |
| # 点击按钮直接赋值给 user_input | |
| if c1.button("🌊 流固耦合接口设置"): | |
| user_input = "怎么设置流固耦合接口?" | |
| elif c2.button("⚡ 低频电磁场网格"): | |
| user_input = "低频电磁场网格划分有哪些技巧?" | |
| elif c3.button("📉 求解器不收敛"): | |
| user_input = "求解器不收敛通常怎么解决?" | |
| # 2. 渲染历史消息 | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg["role"]): | |
| st.markdown(msg["content"]) | |
| # 3. 处理底部输入框 (如果有按钮输入,这里会被跳过,因为 user_input 已经有值了) | |
| if not user_input: | |
| user_input = st.chat_input("输入指令或物理参数问题...") | |
| # ------------------ 统一处理消息追加 ------------------ | |
| if user_input: | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| # 强制刷新以立即在 UI 上显示用户的提问(对于按钮点击尤为重要) | |
| st.rerun() | |
| # ------------------ 统一触发生成 (修复的核心) ------------------ | |
| # 检查:如果有消息,且最后一条是 User 发的,说明需要 Assistant 回答 | |
| if st.session_state.messages and st.session_state.messages[-1]["role"] == "user": | |
| # 获取最后一条用户消息 | |
| last_query = st.session_state.messages[-1]["content"] | |
| with col_chat: # 确保在聊天栏显示 | |
| with st.spinner("🔍 正在扫描向量空间..."): | |
| refs, context = retriever.hybrid_search(last_query, top_k=top_k) | |
| st.session_state.current_refs = refs | |
| system_prompt = f"""你是一个COMSOL高级仿真专家。请基于提供的文档回答问题。 | |
| 要求: | |
| 1. 语气专业、客观,逻辑严密。 | |
| 2. 涉及物理公式时,**必须**使用 LaTeX 格式(例如 $E = mc^2$)。 | |
| 3. 涉及步骤或参数对比时,优先使用 Markdown 列表或表格。 | |
| 参考文档: | |
| {context} | |
| """ | |
| with st.chat_message("assistant"): | |
| resp_cont = st.empty() | |
| full_resp = "" | |
| client = OpenAI(base_url=API_BASE, api_key=API_KEY) | |
| try: | |
| stream = client.chat.completions.create( | |
| model=GEN_MODEL_NAME, | |
| messages=[{"role": "system", "content": system_prompt}] + st.session_state.messages[-6:], # 除去当前的System | |
| temperature=temp, | |
| stream=True | |
| ) | |
| for chunk in stream: | |
| txt = chunk.choices[0].delta.content | |
| if txt: | |
| full_resp += txt | |
| resp_cont.markdown(full_resp + " ▌") | |
| resp_cont.markdown(full_resp) | |
| st.session_state.messages.append({"role": "assistant", "content": full_resp}) | |
| except Exception as e: | |
| st.error(f"Neural Generation Failed: {e}") | |
| # ------------------ 渲染右侧证据栏 ------------------ | |
| with col_evidence: | |
| st.markdown("### 📚 神经记忆 (Evidence)") | |
| if st.session_state.current_refs: | |
| for i, ref in enumerate(st.session_state.current_refs): | |
| score = ref['score'] | |
| score_color = "#00ff41" if score > 0.6 else "#ffb700" if score > 0.4 else "#ff003c" | |
| with st.expander(f"📄 Doc {i+1}: {ref['filename'][:20]}...", expanded=(i==0)): | |
| st.markdown(f""" | |
| <div style="margin-bottom:5px;"> | |
| <span style="color:#888;">Relevance:</span> | |
| <span style="color:{score_color}; font-weight:bold;">{score:.4f}</span> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.code(ref['content'], language="text") | |
| else: | |
| st.info("等待输入指令以检索知识库...") | |
| st.markdown(""" | |
| <div style="opacity:0.3; font-size:0.8rem; margin-top:20px;"> | |
| Waiting for query signal...<br> | |
| Index Status: Ready<br> | |
| Awaiting Input | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if __name__ == "__main__": | |
| main() |