From 1e793ce814e8c4119c4274fb7aa6573d547871e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E6=96=87=E5=B3=B0?= Date: Tue, 9 Jun 2026 20:55:16 +0800 Subject: [PATCH] up --- main.go | 109 +++++++++++++++++++++++++++++++------------- templates/chat.html | 94 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 168 insertions(+), 35 deletions(-) diff --git a/main.go b/main.go index 1aecd70..f9a27ee 100644 --- a/main.go +++ b/main.go @@ -630,6 +630,17 @@ var ( store *ConvStore ) +type chatSSEFrame struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Message string `json:"message,omitempty"` + Tool string `json:"tool,omitempty"` + Stage string `json:"stage,omitempty"` + Status string `json:"status,omitempty"` + Data map[string]any `json:"data,omitempty"` + Error string `json:"error,omitempty"` +} + // ─── 路由 ───────────────────────────────────────────────── func indexHandler(c *gin.Context) { @@ -739,19 +750,46 @@ func chatHandler(c *gin.Context) { return } + // SSE 头先写出,后续插件/模型过程都通过 trace 事件实时展示。 + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return + } + emit := func(frame chatSSEFrame) { + writeSSEJSON(c.Writer, frame) + flusher.Flush() + } + emitTrace := func(tool, stage, status, message string, data map[string]any) { + emit(chatSSEFrame{Type: "trace", Tool: tool, Stage: stage, Status: status, Message: message, Data: data}) + } + emitError := func(err error) { + emit(chatSSEFrame{Type: "error", Error: err.Error()}) + } + + // 超时 context + timeout := time.Duration(profile.Config.Timeout) * time.Second + ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) + defer cancel() + chatMessages := req.Messages if sqlState != nil && sqlState.Enabled() { - withSQL, err := enrichMessagesWithSQL(c.Request.Context(), profile, chatMessages) + withSQL, err := enrichMessagesWithSQL(ctx, profile, chatMessages, emit) if err != nil { fmt.Fprintln(os.Stderr, "SQL 查询插件调用失败:", err) + emitTrace("sql", "error", "error", "数据库查询插件调用失败,将继续普通回答", map[string]any{"error": err.Error()}) } else { chatMessages = withSQL } } if req.WebSearch { - withSearch, err := enrichMessagesWithSearch(c.Request.Context(), chatMessages) + withSearch, err := enrichMessagesWithSearch(ctx, chatMessages, emit) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + emitError(err) return } chatMessages = withSearch @@ -760,39 +798,22 @@ func chatHandler(c *gin.Context) { // 构建 ark 消息列表 messages, err := buildArkMessages(chatMessages) if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + emitError(err) return } - // SSE 头 - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Writer.WriteHeader(http.StatusOK) - flusher, ok := c.Writer.(http.Flusher) - if !ok { - c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器不支持流式响应"}) - return - } - - // 超时 context - timeout := time.Duration(profile.Config.Timeout) * time.Second - ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) - defer cancel() - - // 发起流式请求(使用 CreateChatCompletionStream) + emitTrace("model", "request", "running", "正在调用模型生成回答", nil) stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{ Model: profile.Config.Model, Messages: messages, MaxTokens: intPtr(4096), }.WithStream(true)) if err != nil { - fmt.Fprintf(c.Writer, "data: {\"error\":%s}\n\n", toJSON(err.Error())) - flusher.Flush() + emitError(err) return } defer stream.Close() + emitTrace("model", "stream", "running", "模型已开始输出", nil) var full strings.Builder for { @@ -803,21 +824,20 @@ func chatHandler(c *gin.Context) { fmt.Fprintln(os.Stderr, "保存对话失败:", err) } } + emitTrace("model", "stream", "success", "回答生成完成", nil) fmt.Fprintf(c.Writer, "data: [DONE]\n\n") flusher.Flush() return } if err != nil { - fmt.Fprintf(c.Writer, "data: {\"error\":%s}\n\n", toJSON(err.Error())) - flusher.Flush() + emitError(err) return } if len(resp.Choices) > 0 { delta := resp.Choices[0].Delta.Content if delta != "" { full.WriteString(delta) - fmt.Fprintf(c.Writer, "data: %s\n\n", toSSE(delta)) - flusher.Flush() + emit(chatSSEFrame{Type: "delta", Text: delta}) } } } @@ -858,7 +878,7 @@ type duckDuckGoResponse struct { } `json:"Infobox"` } -func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage) ([]ChatMessage, error) { +func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) { searchConfig := searchState.ActiveProfile() if !searchConfig.Enabled { return nil, errors.New("联网搜索未启用,请先在 config.yaml 中配置 search.enabled") @@ -869,14 +889,19 @@ func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage) ([]Ch return nil, errors.New("联网搜索需要输入文本问题") } + emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "running", Message: "正在联网搜索", Data: map[string]any{"provider": searchConfig.Provider}}) results, err := webSearch(ctx, searchConfig, query) if err != nil { + emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索失败", Data: map[string]any{"error": err.Error()}}) return nil, err } if len(results) == 0 { - return nil, errors.New("未搜索到相关网页结果") + err := errors.New("未搜索到相关网页结果") + emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "error", Message: err.Error()}) + return nil, err } + emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "success", Message: fmt.Sprintf("联网搜索完成,找到 %d 条结果", len(results)), Data: map[string]any{"provider": searchConfig.Provider, "count": len(results)}}) searchContext := buildSearchContext(searchConfig, query, results) withSearch := make([]ChatMessage, 0, len(messages)+1) withSearch = append(withSearch, ChatMessage{Role: "system", Content: searchContext, Hidden: true}) @@ -904,40 +929,54 @@ type sqlGenerationResult struct { Reason string `json:"reason"` } -func enrichMessagesWithSQL(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage) ([]ChatMessage, error) { +func enrichMessagesWithSQL(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) { query := latestUserQuery(messages) if query == "" { return messages, nil } + emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "running", Message: "正在判断是否需要查询数据库"}) activate, reason, err := classifySQLActivation(ctx, profile, messages) if err != nil { + emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "error", Message: "数据库查询判断失败", Data: map[string]any{"error": err.Error()}}) return messages, err } if !activate { + emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "success", Message: "判断结果:本轮无需查询数据库", Data: map[string]any{"activate": false, "reason": reason}}) return messages, nil } + emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "success", Message: "判断结果:需要查询数据库", Data: map[string]any{"activate": true, "reason": reason}}) + emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "running", Message: "正在读取数据库结构"}) schemaContext, err := sqlState.SchemaContext(ctx) if err != nil { + emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "error", Message: "数据库结构读取失败", Data: map[string]any{"error": err.Error()}}) return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil } + emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "success", Message: "数据库结构读取完成"}) + emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "running", Message: "正在生成只读 SQL"}) generated, err := generateSQLForUserQuery(ctx, profile, query, schemaContext) if err != nil { + emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "SQL 生成失败", Data: map[string]any{"error": err.Error()}}) return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil } generated.Database = strings.TrimSpace(generated.Database) generated.SQL = strings.TrimSpace(generated.SQL) if generated.SQL == "" { err := fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason) + emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "模型未生成可执行 SQL", Data: map[string]any{"reason": generated.Reason}}) return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil } + emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "running", Message: "正在执行数据库查询", Data: map[string]any{"database": generated.Database}}) result, err := sqlState.ExecuteReadOnly(ctx, generated.Database, generated.SQL) if err != nil { + emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "error", Message: "数据库查询失败", Data: map[string]any{"error": err.Error()}}) return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil } + emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "success", Message: "已生成只读 SQL", Data: map[string]any{"database": generated.Database, "sql": generated.SQL, "reason": generated.Reason}}) + emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "success", Message: fmt.Sprintf("数据库查询完成,返回 %d 行", len(result.Rows)), Data: map[string]any{"database": result.Database, "rows": len(result.Rows), "columns": len(result.Columns), "truncated": result.Truncated, "max_rows": result.MaxRows}}) contextText := sqlquery.BuildResultContext(query, generated.SQL, result) if strings.TrimSpace(reason) != "" { contextText += "\n激活原因:" + reason @@ -1480,6 +1519,14 @@ func contains(items []string, target string) bool { func intPtr(i int) *int { return &i } +func writeSSEJSON(w io.Writer, frame chatSSEFrame) { + data, err := json.Marshal(frame) + if err != nil { + data, _ = json.Marshal(chatSSEFrame{Type: "error", Error: "序列化流事件失败"}) + } + fmt.Fprintf(w, "data: %s\n\n", data) +} + func toJSON(s string) string { b, _ := json.Marshal(s) return string(b) diff --git a/templates/chat.html b/templates/chat.html index bcad0a2..90f57a3 100644 --- a/templates/chat.html +++ b/templates/chat.html @@ -259,6 +259,40 @@ } @keyframes blink { 50% { opacity: 0; } } + /* 执行过程 */ + .trace-panel { + display: flex; + flex-direction: column; + gap: 5px; + margin-bottom: 8px; + color: var(--text-dim); + font-size: 0.78rem; + line-height: 1.45; + white-space: normal; + } + .trace-panel:empty { display: none; } + .trace-item { + border-left: 2px solid var(--accent-border); + padding-left: 8px; + } + .trace-item.running { color: var(--text-dim); } + .trace-item.success { opacity: .86; } + .trace-item.error { + color: var(--danger); + border-left-color: var(--danger); + } + .trace-detail { + margin-top: 4px; + font-family: Consolas, 'Fira Code', monospace; + background: var(--code-bg); + border: 1px solid var(--border); + border-radius: 6px; + padding: 6px 8px; + white-space: pre-wrap; + overflow-x: auto; + } + .answer-text { display: inline; } + /* 错误消息 */ .error-msg { color: var(--danger); @@ -830,13 +864,49 @@ function addAIBubble() { const bub = document.createElement('div'); bub.className = 'bubble typing-cursor'; + const trace = document.createElement('div'); + trace.className = 'trace-panel'; + const txt = document.createElement('span'); + txt.className = 'answer-text'; + bub.appendChild(trace); bub.appendChild(txt); row.appendChild(av); row.appendChild(bub); msgBox.appendChild(row); scrollToBottom(); - return { bub, txt }; + return { bub, txt, trace }; +} + +function appendTrace(aiBubble, frame) { + if (!aiBubble.trace) return; + const item = document.createElement('div'); + item.className = `trace-item ${frame.status || ''}`; + const label = frame.message || [frame.tool, frame.stage, frame.status].filter(Boolean).join(' '); + item.textContent = label; + + const data = frame.data || {}; + const details = []; + if (data.sql) details.push(data.sql); + const stats = []; + if (data.database) stats.push(`数据库: ${data.database}`); + if (typeof data.rows === 'number') stats.push(`行数: ${data.rows}`); + if (typeof data.columns === 'number') stats.push(`列数: ${data.columns}`); + if (typeof data.count === 'number') stats.push(`结果数: ${data.count}`); + if (data.truncated) stats.push(`已截断,最多 ${data.max_rows || ''} 行`); + if (data.reason) stats.push(`原因: ${data.reason}`); + if (data.error) stats.push(`错误: ${data.error}`); + if (stats.length) details.push(stats.join(' | ')); + + if (details.length) { + const detail = document.createElement('div'); + detail.className = 'trace-detail'; + detail.textContent = details.join('\n'); + item.appendChild(detail); + } + + aiBubble.trace.appendChild(item); + scrollToBottom(); } async function streamChat(messages, aiBubble, webSearch = false) { @@ -880,13 +950,29 @@ async function streamChat(messages, aiBubble, webSearch = false) { } try { const parsed = JSON.parse(raw); - if (parsed && typeof parsed === 'object' && parsed.error) { - throw new Error(parsed.error); - } if (typeof parsed === 'string') { full += parsed; txtEl.innerHTML = renderMarkdown(full); scrollToBottom(); + continue; + } + if (parsed && typeof parsed === 'object') { + if (parsed.type === 'error' || parsed.error) { + throw new Error(parsed.error || parsed.message || '流式响应错误'); + } + if (parsed.type === 'delta') { + const delta = parsed.text || ''; + if (delta) { + full += delta; + txtEl.innerHTML = renderMarkdown(full); + scrollToBottom(); + } + continue; + } + if (parsed.type === 'trace') { + appendTrace(aiBubble, parsed); + continue; + } } } catch (e) { if (e instanceof SyntaxError) continue;