Spaces:
Running
Running
File size: 13,897 Bytes
16c7408 b38ace2 42c5163 b38ace2 80c7ecd 90f053a b38ace2 90f053a b38ace2 90f053a b38ace2 90f053a b38ace2 96aed6c 90f053a 96aed6c 90f053a b38ace2 90f053a b38ace2 96aed6c 90f053a 96aed6c 90f053a 96aed6c 90f053a 80c7ecd 96aed6c 90f053a 96aed6c 90f053a 96aed6c 90f053a 80c7ecd 42c5163 96aed6c 42c5163 80c7ecd 90f053a 80c7ecd 96aed6c 80c7ecd 90f053a 96aed6c b38ace2 90f053a 96aed6c 90f053a b38ace2 96aed6c b38ace2 96aed6c b38ace2 96aed6c b38ace2 90f053a 42c5163 b38ace2 90f053a b38ace2 42c5163 b38ace2 42c5163 96aed6c b38ace2 90f053a b38ace2 96aed6c b38ace2 90f053a b38ace2 90f053a 96aed6c 90f053a b38ace2 90f053a b38ace2 90f053a 96aed6c 90f053a 96aed6c b38ace2 96aed6c b38ace2 96aed6c 90f053a 96aed6c 90f053a 96aed6c 90f053a b38ace2 96aed6c b38ace2 96aed6c 90f053a 96aed6c 90f053a b38ace2 90f053a b38ace2 96aed6c b38ace2 90f053a 96aed6c b38ace2 90f053a b38ace2 96aed6c b38ace2 90f053a b38ace2 90f053a 16c7408 b38ace2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 |
import streamlit as st
import pandas as pd
import numpy as np
import jieba
import requests
import os
import sys
import subprocess
from openai import OpenAI
from rank_bm25 import BM25Okapi
from sklearn.metrics.pairwise import cosine_similarity
# ================= 1. 全局配置与 CSS注入 =================
API_KEY = os.getenv("SILICONFLOW_API_KEY")
API_BASE = "https://api.siliconflow.cn/v1"
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B"
RERANK_MODEL = "Qwen/Qwen3-Reranker-4B"
GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2"
DATA_FILENAME = "comsol_embedded.parquet"
DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet"
st.set_page_config(
page_title="COMSOL Dark Expert",
page_icon="🌌",
layout="wide",
initial_sidebar_state="expanded"
)
# --- 注入自定义 CSS (保持之前的审美) ---
st.markdown("""
<style>
/* 1. 整体背景 - 深空黑 */
.stApp {
background-color: #050505;
background-image: radial-gradient(circle at 50% 0%, #1a1f35 0%, #050505 60%);
color: #e0e0e0;
font-family: 'Inter', system-ui, -apple-system, sans-serif;
}
/* 2. 隐藏默认组件 */
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
/* 3. 聊天气泡 */
[data-testid="stChatMessage"] {
background: rgba(255, 255, 255, 0.03);
border: 1px solid rgba(255, 255, 255, 0.08);
border-radius: 16px;
backdrop-filter: blur(12px);
box-shadow: 0 4px 20px rgba(0,0,0,0.2);
padding: 1.2rem;
}
/* 用户气泡 */
[data-testid="stChatMessage"][data-testid="user"] {
background: rgba(41, 181, 232, 0.1);
border-color: rgba(41, 181, 232, 0.2);
}
/* 4. 自定义标题栏 */
.custom-header {
border-bottom: 1px solid rgba(255,255,255,0.1);
padding-bottom: 1rem;
margin-bottom: 2rem;
display: flex;
align-items: center;
gap: 1rem;
}
.glitch-text {
font-size: 2rem;
font-weight: 800;
background: linear-gradient(120deg, #fff, #29B5E8);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
letter-spacing: -1px;
}
/* 5. 快捷按钮 */
div.stButton > button {
background: rgba(255,255,255,0.05);
color: #aaa;
border: 1px solid rgba(255,255,255,0.1);
border-radius: 20px;
padding: 0.5rem 1rem;
font-size: 0.85rem;
transition: all 0.3s;
width: 100%;
}
div.stButton > button:hover {
background: rgba(41, 181, 232, 0.2);
color: #fff;
border-color: #29B5E8;
transform: translateY(-2px);
}
/* 6. 输入框 */
.stChatInputContainer textarea {
background-color: #0f1115 !important;
border: 1px solid #333 !important;
color: white !important;
border-radius: 12px !important;
}
/* 7. Expander */
.streamlit-expanderHeader {
background-color: rgba(255,255,255,0.02);
border: 1px solid rgba(255,255,255,0.05);
border-radius: 8px;
color: #bbb;
}
</style>
""", unsafe_allow_html=True)
# ================= 2. 核心逻辑(数据与RAG) =================
if not API_KEY:
st.error("⚠️ 未检测到 API Key。请在 Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。")
st.stop()
def download_with_curl(url, output_path):
try:
cmd = [
"curl", "-L",
"-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
"-o", output_path,
"--fail",
url
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0: raise Exception(f"Curl failed: {result.stderr}")
return True
except Exception as e:
print(f"Curl download error: {e}")
return False
def get_data_file_path():
possible_paths = [
DATA_FILENAME, os.path.join("/app", DATA_FILENAME),
os.path.join("processed_data", DATA_FILENAME),
os.path.join("src", DATA_FILENAME),
os.path.join("..", DATA_FILENAME), "/tmp/" + DATA_FILENAME
]
for path in possible_paths:
if os.path.exists(path): return path
download_target = "/app/" + DATA_FILENAME
try: os.makedirs(os.path.dirname(download_target), exist_ok=True)
except: download_target = "/tmp/" + DATA_FILENAME
status_container = st.empty()
status_container.info("📡 正在接入神经元网络... (下载核心数据中)")
if download_with_curl(DATA_URL, download_target):
status_container.empty()
return download_target
try:
headers = {'User-Agent': 'Mozilla/5.0'}
r = requests.get(DATA_URL, headers=headers, stream=True)
r.raise_for_status()
with open(download_target, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
status_container.empty()
return download_target
except Exception as e:
st.error(f"❌ 数据链路中断。Error: {e}")
st.stop()
class FullRetriever:
def __init__(self, parquet_path):
try: self.df = pd.read_parquet(parquet_path)
except Exception as e: st.error(f"Memory Matrix Load Failed: {e}"); st.stop()
self.documents = self.df['content'].tolist()
self.embeddings = np.stack(self.df['embedding'].values)
self.bm25 = BM25Okapi([jieba.lcut(str(d).lower()) for d in self.documents])
self.client = OpenAI(base_url=API_BASE, api_key=API_KEY)
# Reranker 初始化移到这里,减少重复调用
self.rerank_headers = {"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"}
self.rerank_url = f"{API_BASE}/rerank"
def _get_emb(self, q):
try: return self.client.embeddings.create(model=EMBEDDING_MODEL, input=[q]).data[0].embedding
except: return [0.0] * 1024
def hybrid_search(self, query: str, top_k=5):
# 1. Vector
q_emb = self._get_emb(query)
vec_scores = cosine_similarity([q_emb], self.embeddings)[0]
vec_idx = np.argsort(vec_scores)[-100:][::-1]
# 2. Keyword
kw_idx = np.argsort(self.bm25.get_scores(jieba.lcut(query.lower())))[-100:][::-1]
# 3. RRF Fusion
fused = {}
for r, i in enumerate(vec_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1)
for r, i in enumerate(kw_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1)
c_idxs = [x[0] for x in sorted(fused.items(), key=lambda x:x[1], reverse=True)[:50]]
c_docs = [self.documents[i] for i in c_idxs]
# 4. Rerank
try:
payload = {"model": RERANK_MODEL, "query": query, "documents": c_docs, "top_n": top_k}
resp = requests.post(self.rerank_url, headers=self.rerank_headers, json=payload, timeout=10)
results = resp.json().get('results', [])
except:
results = [{"index": i, "relevance_score": 0.0} for i in range(len(c_docs))][:top_k]
final_res = []
context = ""
for i, item in enumerate(results):
orig_idx = c_idxs[item['index']]
row = self.df.iloc[orig_idx]
final_res.append({
"score": item['relevance_score'],
"filename": row['filename'],
"content": row['content']
})
context += f"[文档{i+1}]: {row['content']}\n\n"
return final_res, context
@st.cache_resource
def load_engine():
real_path = get_data_file_path()
return FullRetriever(real_path)
# ================= 3. UI 主程序 =================
def main():
st.markdown("""
<div class="custom-header">
<div style="font-size: 3rem;">🌌</div>
<div>
<div class="glitch-text">COMSOL DARK EXPERT</div>
<div style="color: #666; font-size: 0.9rem; letter-spacing: 1px;">
NEURAL SIMULATION ASSISTANT <span style="color:#29B5E8">V4.1 Fixed</span>
</div>
</div>
</div>
""", unsafe_allow_html=True)
retriever = load_engine()
with st.sidebar:
st.markdown("### ⚙️ 控制台")
top_k = st.slider("检索深度", 1, 10, 4)
temp = st.slider("发散度", 0.0, 1.0, 0.3)
st.markdown("---")
if st.button("🗑️ 清空记忆 (Clear)", use_container_width=True):
st.session_state.messages = []
st.session_state.current_refs = []
st.rerun()
if "messages" not in st.session_state: st.session_state.messages = []
if "current_refs" not in st.session_state: st.session_state.current_refs = []
col_chat, col_evidence = st.columns([0.65, 0.35], gap="large")
# ------------------ 处理输入源 ------------------
# 我们定义一个变量 user_input,不管它来自按钮还是输入框
user_input = None
with col_chat:
# 1. 如果历史为空,显示快捷按钮
if not st.session_state.messages:
st.markdown("##### 💡 初始化提问序列 (Starter Sequence)")
c1, c2, c3 = st.columns(3)
# 点击按钮直接赋值给 user_input
if c1.button("🌊 流固耦合接口设置"):
user_input = "怎么设置流固耦合接口?"
elif c2.button("⚡ 低频电磁场网格"):
user_input = "低频电磁场网格划分有哪些技巧?"
elif c3.button("📉 求解器不收敛"):
user_input = "求解器不收敛通常怎么解决?"
# 2. 渲染历史消息
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# 3. 处理底部输入框 (如果有按钮输入,这里会被跳过,因为 user_input 已经有值了)
if not user_input:
user_input = st.chat_input("输入指令或物理参数问题...")
# ------------------ 统一处理消息追加 ------------------
if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
# 强制刷新以立即在 UI 上显示用户的提问(对于按钮点击尤为重要)
st.rerun()
# ------------------ 统一触发生成 (修复的核心) ------------------
# 检查:如果有消息,且最后一条是 User 发的,说明需要 Assistant 回答
if st.session_state.messages and st.session_state.messages[-1]["role"] == "user":
# 获取最后一条用户消息
last_query = st.session_state.messages[-1]["content"]
with col_chat: # 确保在聊天栏显示
with st.spinner("🔍 正在扫描向量空间..."):
refs, context = retriever.hybrid_search(last_query, top_k=top_k)
st.session_state.current_refs = refs
system_prompt = f"""你是一个COMSOL高级仿真专家。请基于提供的文档回答问题。
要求:
1. 语气专业、客观,逻辑严密。
2. 涉及物理公式时,**必须**使用 LaTeX 格式(例如 $E = mc^2$)。
3. 涉及步骤或参数对比时,优先使用 Markdown 列表或表格。
参考文档:
{context}
"""
with st.chat_message("assistant"):
resp_cont = st.empty()
full_resp = ""
client = OpenAI(base_url=API_BASE, api_key=API_KEY)
try:
stream = client.chat.completions.create(
model=GEN_MODEL_NAME,
messages=[{"role": "system", "content": system_prompt}] + st.session_state.messages[-6:], # 除去当前的System
temperature=temp,
stream=True
)
for chunk in stream:
txt = chunk.choices[0].delta.content
if txt:
full_resp += txt
resp_cont.markdown(full_resp + " ▌")
resp_cont.markdown(full_resp)
st.session_state.messages.append({"role": "assistant", "content": full_resp})
except Exception as e:
st.error(f"Neural Generation Failed: {e}")
# ------------------ 渲染右侧证据栏 ------------------
with col_evidence:
st.markdown("### 📚 神经记忆 (Evidence)")
if st.session_state.current_refs:
for i, ref in enumerate(st.session_state.current_refs):
score = ref['score']
score_color = "#00ff41" if score > 0.6 else "#ffb700" if score > 0.4 else "#ff003c"
with st.expander(f"📄 Doc {i+1}: {ref['filename'][:20]}...", expanded=(i==0)):
st.markdown(f"""
<div style="margin-bottom:5px;">
<span style="color:#888;">Relevance:</span>
<span style="color:{score_color}; font-weight:bold;">{score:.4f}</span>
</div>
""", unsafe_allow_html=True)
st.code(ref['content'], language="text")
else:
st.info("等待输入指令以检索知识库...")
st.markdown("""
<div style="opacity:0.3; font-size:0.8rem; margin-top:20px;">
Waiting for query signal...<br>
Index Status: Ready<br>
Awaiting Input
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main() |