This commit is contained in:
2026-06-09 20:55:16 +08:00
parent 721caccc58
commit 1e793ce814
2 changed files with 168 additions and 35 deletions
+78 -31
View File
@@ -630,6 +630,17 @@ var (
store *ConvStore
)
type chatSSEFrame struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Message string `json:"message,omitempty"`
Tool string `json:"tool,omitempty"`
Stage string `json:"stage,omitempty"`
Status string `json:"status,omitempty"`
Data map[string]any `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// ─── 路由 ─────────────────────────────────────────────────
func indexHandler(c *gin.Context) {
@@ -739,19 +750,46 @@ func chatHandler(c *gin.Context) {
return
}
// SSE 头先写出,后续插件/模型过程都通过 trace 事件实时展示。
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return
}
emit := func(frame chatSSEFrame) {
writeSSEJSON(c.Writer, frame)
flusher.Flush()
}
emitTrace := func(tool, stage, status, message string, data map[string]any) {
emit(chatSSEFrame{Type: "trace", Tool: tool, Stage: stage, Status: status, Message: message, Data: data})
}
emitError := func(err error) {
emit(chatSSEFrame{Type: "error", Error: err.Error()})
}
// 超时 context
timeout := time.Duration(profile.Config.Timeout) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
chatMessages := req.Messages
if sqlState != nil && sqlState.Enabled() {
withSQL, err := enrichMessagesWithSQL(c.Request.Context(), profile, chatMessages)
withSQL, err := enrichMessagesWithSQL(ctx, profile, chatMessages, emit)
if err != nil {
fmt.Fprintln(os.Stderr, "SQL 查询插件调用失败:", err)
emitTrace("sql", "error", "error", "数据库查询插件调用失败,将继续普通回答", map[string]any{"error": err.Error()})
} else {
chatMessages = withSQL
}
}
if req.WebSearch {
withSearch, err := enrichMessagesWithSearch(c.Request.Context(), chatMessages)
withSearch, err := enrichMessagesWithSearch(ctx, chatMessages, emit)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
emitError(err)
return
}
chatMessages = withSearch
@@ -760,39 +798,22 @@ func chatHandler(c *gin.Context) {
// 构建 ark 消息列表
messages, err := buildArkMessages(chatMessages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
emitError(err)
return
}
// SSE 头
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器不支持流式响应"})
return
}
// 超时 context
timeout := time.Duration(profile.Config.Timeout) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// 发起流式请求(使用 CreateChatCompletionStream
emitTrace("model", "request", "running", "正在调用模型生成回答", nil)
stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
Model: profile.Config.Model,
Messages: messages,
MaxTokens: intPtr(4096),
}.WithStream(true))
if err != nil {
fmt.Fprintf(c.Writer, "data: {\"error\":%s}\n\n", toJSON(err.Error()))
flusher.Flush()
emitError(err)
return
}
defer stream.Close()
emitTrace("model", "stream", "running", "模型已开始输出", nil)
var full strings.Builder
for {
@@ -803,21 +824,20 @@ func chatHandler(c *gin.Context) {
fmt.Fprintln(os.Stderr, "保存对话失败:", err)
}
}
emitTrace("model", "stream", "success", "回答生成完成", nil)
fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
return
}
if err != nil {
fmt.Fprintf(c.Writer, "data: {\"error\":%s}\n\n", toJSON(err.Error()))
flusher.Flush()
emitError(err)
return
}
if len(resp.Choices) > 0 {
delta := resp.Choices[0].Delta.Content
if delta != "" {
full.WriteString(delta)
fmt.Fprintf(c.Writer, "data: %s\n\n", toSSE(delta))
flusher.Flush()
emit(chatSSEFrame{Type: "delta", Text: delta})
}
}
}
@@ -858,7 +878,7 @@ type duckDuckGoResponse struct {
} `json:"Infobox"`
}
func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage) ([]ChatMessage, error) {
func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) {
searchConfig := searchState.ActiveProfile()
if !searchConfig.Enabled {
return nil, errors.New("联网搜索未启用,请先在 config.yaml 中配置 search.enabled")
@@ -869,14 +889,19 @@ func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage) ([]Ch
return nil, errors.New("联网搜索需要输入文本问题")
}
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "running", Message: "正在联网搜索", Data: map[string]any{"provider": searchConfig.Provider}})
results, err := webSearch(ctx, searchConfig, query)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索失败", Data: map[string]any{"error": err.Error()}})
return nil, err
}
if len(results) == 0 {
return nil, errors.New("未搜索到相关网页结果")
err := errors.New("未搜索到相关网页结果")
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "error", Message: err.Error()})
return nil, err
}
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "success", Message: fmt.Sprintf("联网搜索完成,找到 %d 条结果", len(results)), Data: map[string]any{"provider": searchConfig.Provider, "count": len(results)}})
searchContext := buildSearchContext(searchConfig, query, results)
withSearch := make([]ChatMessage, 0, len(messages)+1)
withSearch = append(withSearch, ChatMessage{Role: "system", Content: searchContext, Hidden: true})
@@ -904,40 +929,54 @@ type sqlGenerationResult struct {
Reason string `json:"reason"`
}
func enrichMessagesWithSQL(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage) ([]ChatMessage, error) {
func enrichMessagesWithSQL(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) {
query := latestUserQuery(messages)
if query == "" {
return messages, nil
}
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "running", Message: "正在判断是否需要查询数据库"})
activate, reason, err := classifySQLActivation(ctx, profile, messages)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "error", Message: "数据库查询判断失败", Data: map[string]any{"error": err.Error()}})
return messages, err
}
if !activate {
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "success", Message: "判断结果:本轮无需查询数据库", Data: map[string]any{"activate": false, "reason": reason}})
return messages, nil
}
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "success", Message: "判断结果:需要查询数据库", Data: map[string]any{"activate": true, "reason": reason}})
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "running", Message: "正在读取数据库结构"})
schemaContext, err := sqlState.SchemaContext(ctx)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "error", Message: "数据库结构读取失败", Data: map[string]any{"error": err.Error()}})
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "success", Message: "数据库结构读取完成"})
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "running", Message: "正在生成只读 SQL"})
generated, err := generateSQLForUserQuery(ctx, profile, query, schemaContext)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "SQL 生成失败", Data: map[string]any{"error": err.Error()}})
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
generated.Database = strings.TrimSpace(generated.Database)
generated.SQL = strings.TrimSpace(generated.SQL)
if generated.SQL == "" {
err := fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason)
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "模型未生成可执行 SQL", Data: map[string]any{"reason": generated.Reason}})
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "running", Message: "正在执行数据库查询", Data: map[string]any{"database": generated.Database}})
result, err := sqlState.ExecuteReadOnly(ctx, generated.Database, generated.SQL)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "error", Message: "数据库查询失败", Data: map[string]any{"error": err.Error()}})
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "success", Message: "已生成只读 SQL", Data: map[string]any{"database": generated.Database, "sql": generated.SQL, "reason": generated.Reason}})
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "success", Message: fmt.Sprintf("数据库查询完成,返回 %d 行", len(result.Rows)), Data: map[string]any{"database": result.Database, "rows": len(result.Rows), "columns": len(result.Columns), "truncated": result.Truncated, "max_rows": result.MaxRows}})
contextText := sqlquery.BuildResultContext(query, generated.SQL, result)
if strings.TrimSpace(reason) != "" {
contextText += "\n激活原因:" + reason
@@ -1480,6 +1519,14 @@ func contains(items []string, target string) bool {
func intPtr(i int) *int { return &i }
func writeSSEJSON(w io.Writer, frame chatSSEFrame) {
data, err := json.Marshal(frame)
if err != nil {
data, _ = json.Marshal(chatSSEFrame{Type: "error", Error: "序列化流事件失败"})
}
fmt.Fprintf(w, "data: %s\n\n", data)
}
func toJSON(s string) string {
b, _ := json.Marshal(s)
return string(b)
+90 -4
View File
@@ -259,6 +259,40 @@
}
@keyframes blink { 50% { opacity: 0; } }
/* 执行过程 */
.trace-panel {
display: flex;
flex-direction: column;
gap: 5px;
margin-bottom: 8px;
color: var(--text-dim);
font-size: 0.78rem;
line-height: 1.45;
white-space: normal;
}
.trace-panel:empty { display: none; }
.trace-item {
border-left: 2px solid var(--accent-border);
padding-left: 8px;
}
.trace-item.running { color: var(--text-dim); }
.trace-item.success { opacity: .86; }
.trace-item.error {
color: var(--danger);
border-left-color: var(--danger);
}
.trace-detail {
margin-top: 4px;
font-family: Consolas, 'Fira Code', monospace;
background: var(--code-bg);
border: 1px solid var(--border);
border-radius: 6px;
padding: 6px 8px;
white-space: pre-wrap;
overflow-x: auto;
}
.answer-text { display: inline; }
/* 错误消息 */
.error-msg {
color: var(--danger);
@@ -830,13 +864,49 @@ function addAIBubble() {
const bub = document.createElement('div');
bub.className = 'bubble typing-cursor';
const trace = document.createElement('div');
trace.className = 'trace-panel';
const txt = document.createElement('span');
txt.className = 'answer-text';
bub.appendChild(trace);
bub.appendChild(txt);
row.appendChild(av);
row.appendChild(bub);
msgBox.appendChild(row);
scrollToBottom();
return { bub, txt };
return { bub, txt, trace };
}
function appendTrace(aiBubble, frame) {
if (!aiBubble.trace) return;
const item = document.createElement('div');
item.className = `trace-item ${frame.status || ''}`;
const label = frame.message || [frame.tool, frame.stage, frame.status].filter(Boolean).join(' ');
item.textContent = label;
const data = frame.data || {};
const details = [];
if (data.sql) details.push(data.sql);
const stats = [];
if (data.database) stats.push(`数据库: ${data.database}`);
if (typeof data.rows === 'number') stats.push(`行数: ${data.rows}`);
if (typeof data.columns === 'number') stats.push(`列数: ${data.columns}`);
if (typeof data.count === 'number') stats.push(`结果数: ${data.count}`);
if (data.truncated) stats.push(`已截断,最多 ${data.max_rows || ''}`);
if (data.reason) stats.push(`原因: ${data.reason}`);
if (data.error) stats.push(`错误: ${data.error}`);
if (stats.length) details.push(stats.join(' '));
if (details.length) {
const detail = document.createElement('div');
detail.className = 'trace-detail';
detail.textContent = details.join('\n');
item.appendChild(detail);
}
aiBubble.trace.appendChild(item);
scrollToBottom();
}
async function streamChat(messages, aiBubble, webSearch = false) {
@@ -880,13 +950,29 @@ async function streamChat(messages, aiBubble, webSearch = false) {
}
try {
const parsed = JSON.parse(raw);
if (parsed && typeof parsed === 'object' && parsed.error) {
throw new Error(parsed.error);
}
if (typeof parsed === 'string') {
full += parsed;
txtEl.innerHTML = renderMarkdown(full);
scrollToBottom();
continue;
}
if (parsed && typeof parsed === 'object') {
if (parsed.type === 'error' || parsed.error) {
throw new Error(parsed.error || parsed.message || '流式响应错误');
}
if (parsed.type === 'delta') {
const delta = parsed.text || '';
if (delta) {
full += delta;
txtEl.innerHTML = renderMarkdown(full);
scrollToBottom();
}
continue;
}
if (parsed.type === 'trace') {
appendTrace(aiBubble, parsed);
continue;
}
}
} catch (e) {
if (e instanceof SyntaxError) continue;