comsol-rag-expert / src /streamlit_app.py
leezhuuu's picture
Update src/streamlit_app.py
96aed6c verified
import streamlit as st
import pandas as pd
import numpy as np
import jieba
import requests
import os
import sys
import subprocess
from openai import OpenAI
from rank_bm25 import BM25Okapi
from sklearn.metrics.pairwise import cosine_similarity
# ================= 1. 全局配置与 CSS注入 =================
API_KEY = os.getenv("SILICONFLOW_API_KEY")
API_BASE = "https://api.siliconflow.cn/v1"
EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-4B"
RERANK_MODEL = "Qwen/Qwen3-Reranker-4B"
GEN_MODEL_NAME = "MiniMaxAI/MiniMax-M2"
DATA_FILENAME = "comsol_embedded.parquet"
DATA_URL = "https://share.leezhu.cn/graduation_design_data/comsol_embedded.parquet"
st.set_page_config(
page_title="COMSOL Dark Expert",
page_icon="🌌",
layout="wide",
initial_sidebar_state="expanded"
)
# --- 注入自定义 CSS (保持之前的审美) ---
st.markdown("""
<style>
/* 1. 整体背景 - 深空黑 */
.stApp {
background-color: #050505;
background-image: radial-gradient(circle at 50% 0%, #1a1f35 0%, #050505 60%);
color: #e0e0e0;
font-family: 'Inter', system-ui, -apple-system, sans-serif;
}
/* 2. 隐藏默认组件 */
#MainMenu {visibility: hidden;}
footer {visibility: hidden;}
header {visibility: hidden;}
/* 3. 聊天气泡 */
[data-testid="stChatMessage"] {
background: rgba(255, 255, 255, 0.03);
border: 1px solid rgba(255, 255, 255, 0.08);
border-radius: 16px;
backdrop-filter: blur(12px);
box-shadow: 0 4px 20px rgba(0,0,0,0.2);
padding: 1.2rem;
}
/* 用户气泡 */
[data-testid="stChatMessage"][data-testid="user"] {
background: rgba(41, 181, 232, 0.1);
border-color: rgba(41, 181, 232, 0.2);
}
/* 4. 自定义标题栏 */
.custom-header {
border-bottom: 1px solid rgba(255,255,255,0.1);
padding-bottom: 1rem;
margin-bottom: 2rem;
display: flex;
align-items: center;
gap: 1rem;
}
.glitch-text {
font-size: 2rem;
font-weight: 800;
background: linear-gradient(120deg, #fff, #29B5E8);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
letter-spacing: -1px;
}
/* 5. 快捷按钮 */
div.stButton > button {
background: rgba(255,255,255,0.05);
color: #aaa;
border: 1px solid rgba(255,255,255,0.1);
border-radius: 20px;
padding: 0.5rem 1rem;
font-size: 0.85rem;
transition: all 0.3s;
width: 100%;
}
div.stButton > button:hover {
background: rgba(41, 181, 232, 0.2);
color: #fff;
border-color: #29B5E8;
transform: translateY(-2px);
}
/* 6. 输入框 */
.stChatInputContainer textarea {
background-color: #0f1115 !important;
border: 1px solid #333 !important;
color: white !important;
border-radius: 12px !important;
}
/* 7. Expander */
.streamlit-expanderHeader {
background-color: rgba(255,255,255,0.02);
border: 1px solid rgba(255,255,255,0.05);
border-radius: 8px;
color: #bbb;
}
</style>
""", unsafe_allow_html=True)
# ================= 2. 核心逻辑(数据与RAG) =================
if not API_KEY:
st.error("⚠️ 未检测到 API Key。请在 Settings -> Secrets 中配置 `SILICONFLOW_API_KEY`。")
st.stop()
def download_with_curl(url, output_path):
try:
cmd = [
"curl", "-L",
"-A", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
"-o", output_path,
"--fail",
url
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0: raise Exception(f"Curl failed: {result.stderr}")
return True
except Exception as e:
print(f"Curl download error: {e}")
return False
def get_data_file_path():
possible_paths = [
DATA_FILENAME, os.path.join("/app", DATA_FILENAME),
os.path.join("processed_data", DATA_FILENAME),
os.path.join("src", DATA_FILENAME),
os.path.join("..", DATA_FILENAME), "/tmp/" + DATA_FILENAME
]
for path in possible_paths:
if os.path.exists(path): return path
download_target = "/app/" + DATA_FILENAME
try: os.makedirs(os.path.dirname(download_target), exist_ok=True)
except: download_target = "/tmp/" + DATA_FILENAME
status_container = st.empty()
status_container.info("📡 正在接入神经元网络... (下载核心数据中)")
if download_with_curl(DATA_URL, download_target):
status_container.empty()
return download_target
try:
headers = {'User-Agent': 'Mozilla/5.0'}
r = requests.get(DATA_URL, headers=headers, stream=True)
r.raise_for_status()
with open(download_target, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192): f.write(chunk)
status_container.empty()
return download_target
except Exception as e:
st.error(f"❌ 数据链路中断。Error: {e}")
st.stop()
class FullRetriever:
def __init__(self, parquet_path):
try: self.df = pd.read_parquet(parquet_path)
except Exception as e: st.error(f"Memory Matrix Load Failed: {e}"); st.stop()
self.documents = self.df['content'].tolist()
self.embeddings = np.stack(self.df['embedding'].values)
self.bm25 = BM25Okapi([jieba.lcut(str(d).lower()) for d in self.documents])
self.client = OpenAI(base_url=API_BASE, api_key=API_KEY)
# Reranker 初始化移到这里,减少重复调用
self.rerank_headers = {"Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}"}
self.rerank_url = f"{API_BASE}/rerank"
def _get_emb(self, q):
try: return self.client.embeddings.create(model=EMBEDDING_MODEL, input=[q]).data[0].embedding
except: return [0.0] * 1024
def hybrid_search(self, query: str, top_k=5):
# 1. Vector
q_emb = self._get_emb(query)
vec_scores = cosine_similarity([q_emb], self.embeddings)[0]
vec_idx = np.argsort(vec_scores)[-100:][::-1]
# 2. Keyword
kw_idx = np.argsort(self.bm25.get_scores(jieba.lcut(query.lower())))[-100:][::-1]
# 3. RRF Fusion
fused = {}
for r, i in enumerate(vec_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1)
for r, i in enumerate(kw_idx): fused[i] = fused.get(i, 0) + 1/(60+r+1)
c_idxs = [x[0] for x in sorted(fused.items(), key=lambda x:x[1], reverse=True)[:50]]
c_docs = [self.documents[i] for i in c_idxs]
# 4. Rerank
try:
payload = {"model": RERANK_MODEL, "query": query, "documents": c_docs, "top_n": top_k}
resp = requests.post(self.rerank_url, headers=self.rerank_headers, json=payload, timeout=10)
results = resp.json().get('results', [])
except:
results = [{"index": i, "relevance_score": 0.0} for i in range(len(c_docs))][:top_k]
final_res = []
context = ""
for i, item in enumerate(results):
orig_idx = c_idxs[item['index']]
row = self.df.iloc[orig_idx]
final_res.append({
"score": item['relevance_score'],
"filename": row['filename'],
"content": row['content']
})
context += f"[文档{i+1}]: {row['content']}\n\n"
return final_res, context
@st.cache_resource
def load_engine():
real_path = get_data_file_path()
return FullRetriever(real_path)
# ================= 3. UI 主程序 =================
def main():
st.markdown("""
<div class="custom-header">
<div style="font-size: 3rem;">🌌</div>
<div>
<div class="glitch-text">COMSOL DARK EXPERT</div>
<div style="color: #666; font-size: 0.9rem; letter-spacing: 1px;">
NEURAL SIMULATION ASSISTANT <span style="color:#29B5E8">V4.1 Fixed</span>
</div>
</div>
</div>
""", unsafe_allow_html=True)
retriever = load_engine()
with st.sidebar:
st.markdown("### ⚙️ 控制台")
top_k = st.slider("检索深度", 1, 10, 4)
temp = st.slider("发散度", 0.0, 1.0, 0.3)
st.markdown("---")
if st.button("🗑️ 清空记忆 (Clear)", use_container_width=True):
st.session_state.messages = []
st.session_state.current_refs = []
st.rerun()
if "messages" not in st.session_state: st.session_state.messages = []
if "current_refs" not in st.session_state: st.session_state.current_refs = []
col_chat, col_evidence = st.columns([0.65, 0.35], gap="large")
# ------------------ 处理输入源 ------------------
# 我们定义一个变量 user_input,不管它来自按钮还是输入框
user_input = None
with col_chat:
# 1. 如果历史为空,显示快捷按钮
if not st.session_state.messages:
st.markdown("##### 💡 初始化提问序列 (Starter Sequence)")
c1, c2, c3 = st.columns(3)
# 点击按钮直接赋值给 user_input
if c1.button("🌊 流固耦合接口设置"):
user_input = "怎么设置流固耦合接口?"
elif c2.button("⚡ 低频电磁场网格"):
user_input = "低频电磁场网格划分有哪些技巧?"
elif c3.button("📉 求解器不收敛"):
user_input = "求解器不收敛通常怎么解决?"
# 2. 渲染历史消息
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# 3. 处理底部输入框 (如果有按钮输入,这里会被跳过,因为 user_input 已经有值了)
if not user_input:
user_input = st.chat_input("输入指令或物理参数问题...")
# ------------------ 统一处理消息追加 ------------------
if user_input:
st.session_state.messages.append({"role": "user", "content": user_input})
# 强制刷新以立即在 UI 上显示用户的提问(对于按钮点击尤为重要)
st.rerun()
# ------------------ 统一触发生成 (修复的核心) ------------------
# 检查:如果有消息,且最后一条是 User 发的,说明需要 Assistant 回答
if st.session_state.messages and st.session_state.messages[-1]["role"] == "user":
# 获取最后一条用户消息
last_query = st.session_state.messages[-1]["content"]
with col_chat: # 确保在聊天栏显示
with st.spinner("🔍 正在扫描向量空间..."):
refs, context = retriever.hybrid_search(last_query, top_k=top_k)
st.session_state.current_refs = refs
system_prompt = f"""你是一个COMSOL高级仿真专家。请基于提供的文档回答问题。
要求:
1. 语气专业、客观,逻辑严密。
2. 涉及物理公式时,**必须**使用 LaTeX 格式(例如 $E = mc^2$)。
3. 涉及步骤或参数对比时,优先使用 Markdown 列表或表格。
参考文档:
{context}
"""
with st.chat_message("assistant"):
resp_cont = st.empty()
full_resp = ""
client = OpenAI(base_url=API_BASE, api_key=API_KEY)
try:
stream = client.chat.completions.create(
model=GEN_MODEL_NAME,
messages=[{"role": "system", "content": system_prompt}] + st.session_state.messages[-6:], # 除去当前的System
temperature=temp,
stream=True
)
for chunk in stream:
txt = chunk.choices[0].delta.content
if txt:
full_resp += txt
resp_cont.markdown(full_resp + " ▌")
resp_cont.markdown(full_resp)
st.session_state.messages.append({"role": "assistant", "content": full_resp})
except Exception as e:
st.error(f"Neural Generation Failed: {e}")
# ------------------ 渲染右侧证据栏 ------------------
with col_evidence:
st.markdown("### 📚 神经记忆 (Evidence)")
if st.session_state.current_refs:
for i, ref in enumerate(st.session_state.current_refs):
score = ref['score']
score_color = "#00ff41" if score > 0.6 else "#ffb700" if score > 0.4 else "#ff003c"
with st.expander(f"📄 Doc {i+1}: {ref['filename'][:20]}...", expanded=(i==0)):
st.markdown(f"""
<div style="margin-bottom:5px;">
<span style="color:#888;">Relevance:</span>
<span style="color:{score_color}; font-weight:bold;">{score:.4f}</span>
</div>
""", unsafe_allow_html=True)
st.code(ref['content'], language="text")
else:
st.info("等待输入指令以检索知识库...")
st.markdown("""
<div style="opacity:0.3; font-size:0.8rem; margin-top:20px;">
Waiting for query signal...<br>
Index Status: Ready<br>
Awaiting Input
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()