This commit is contained in:
2026-06-09 20:55:16 +08:00
parent 721caccc58
commit 1e793ce814
2 changed files with 168 additions and 35 deletions
+78 -31
View File
@@ -630,6 +630,17 @@ var (
store *ConvStore
)
type chatSSEFrame struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Message string `json:"message,omitempty"`
Tool string `json:"tool,omitempty"`
Stage string `json:"stage,omitempty"`
Status string `json:"status,omitempty"`
Data map[string]any `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}
// ─── 路由 ─────────────────────────────────────────────────
func indexHandler(c *gin.Context) {
@@ -739,19 +750,46 @@ func chatHandler(c *gin.Context) {
return
}
// SSE 头先写出,后续插件/模型过程都通过 trace 事件实时展示。
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return
}
emit := func(frame chatSSEFrame) {
writeSSEJSON(c.Writer, frame)
flusher.Flush()
}
emitTrace := func(tool, stage, status, message string, data map[string]any) {
emit(chatSSEFrame{Type: "trace", Tool: tool, Stage: stage, Status: status, Message: message, Data: data})
}
emitError := func(err error) {
emit(chatSSEFrame{Type: "error", Error: err.Error()})
}
// 超时 context
timeout := time.Duration(profile.Config.Timeout) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
chatMessages := req.Messages
if sqlState != nil && sqlState.Enabled() {
withSQL, err := enrichMessagesWithSQL(c.Request.Context(), profile, chatMessages)
withSQL, err := enrichMessagesWithSQL(ctx, profile, chatMessages, emit)
if err != nil {
fmt.Fprintln(os.Stderr, "SQL 查询插件调用失败:", err)
emitTrace("sql", "error", "error", "数据库查询插件调用失败,将继续普通回答", map[string]any{"error": err.Error()})
} else {
chatMessages = withSQL
}
}
if req.WebSearch {
withSearch, err := enrichMessagesWithSearch(c.Request.Context(), chatMessages)
withSearch, err := enrichMessagesWithSearch(ctx, chatMessages, emit)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
emitError(err)
return
}
chatMessages = withSearch
@@ -760,39 +798,22 @@ func chatHandler(c *gin.Context) {
// 构建 ark 消息列表
messages, err := buildArkMessages(chatMessages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
emitError(err)
return
}
// SSE 头
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器不支持流式响应"})
return
}
// 超时 context
timeout := time.Duration(profile.Config.Timeout) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// 发起流式请求(使用 CreateChatCompletionStream
emitTrace("model", "request", "running", "正在调用模型生成回答", nil)
stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
Model: profile.Config.Model,
Messages: messages,
MaxTokens: intPtr(4096),
}.WithStream(true))
if err != nil {
fmt.Fprintf(c.Writer, "data: {\"error\":%s}\n\n", toJSON(err.Error()))
flusher.Flush()
emitError(err)
return
}
defer stream.Close()
emitTrace("model", "stream", "running", "模型已开始输出", nil)
var full strings.Builder
for {
@@ -803,21 +824,20 @@ func chatHandler(c *gin.Context) {
fmt.Fprintln(os.Stderr, "保存对话失败:", err)
}
}
emitTrace("model", "stream", "success", "回答生成完成", nil)
fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
return
}
if err != nil {
fmt.Fprintf(c.Writer, "data: {\"error\":%s}\n\n", toJSON(err.Error()))
flusher.Flush()
emitError(err)
return
}
if len(resp.Choices) > 0 {
delta := resp.Choices[0].Delta.Content
if delta != "" {
full.WriteString(delta)
fmt.Fprintf(c.Writer, "data: %s\n\n", toSSE(delta))
flusher.Flush()
emit(chatSSEFrame{Type: "delta", Text: delta})
}
}
}
@@ -858,7 +878,7 @@ type duckDuckGoResponse struct {
} `json:"Infobox"`
}
func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage) ([]ChatMessage, error) {
func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) {
searchConfig := searchState.ActiveProfile()
if !searchConfig.Enabled {
return nil, errors.New("联网搜索未启用,请先在 config.yaml 中配置 search.enabled")
@@ -869,14 +889,19 @@ func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage) ([]Ch
return nil, errors.New("联网搜索需要输入文本问题")
}
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "running", Message: "正在联网搜索", Data: map[string]any{"provider": searchConfig.Provider}})
results, err := webSearch(ctx, searchConfig, query)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索失败", Data: map[string]any{"error": err.Error()}})
return nil, err
}
if len(results) == 0 {
return nil, errors.New("未搜索到相关网页结果")
err := errors.New("未搜索到相关网页结果")
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "error", Message: err.Error()})
return nil, err
}
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "success", Message: fmt.Sprintf("联网搜索完成,找到 %d 条结果", len(results)), Data: map[string]any{"provider": searchConfig.Provider, "count": len(results)}})
searchContext := buildSearchContext(searchConfig, query, results)
withSearch := make([]ChatMessage, 0, len(messages)+1)
withSearch = append(withSearch, ChatMessage{Role: "system", Content: searchContext, Hidden: true})
@@ -904,40 +929,54 @@ type sqlGenerationResult struct {
Reason string `json:"reason"`
}
func enrichMessagesWithSQL(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage) ([]ChatMessage, error) {
func enrichMessagesWithSQL(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) {
query := latestUserQuery(messages)
if query == "" {
return messages, nil
}
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "running", Message: "正在判断是否需要查询数据库"})
activate, reason, err := classifySQLActivation(ctx, profile, messages)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "error", Message: "数据库查询判断失败", Data: map[string]any{"error": err.Error()}})
return messages, err
}
if !activate {
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "success", Message: "判断结果:本轮无需查询数据库", Data: map[string]any{"activate": false, "reason": reason}})
return messages, nil
}
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "success", Message: "判断结果:需要查询数据库", Data: map[string]any{"activate": true, "reason": reason}})
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "running", Message: "正在读取数据库结构"})
schemaContext, err := sqlState.SchemaContext(ctx)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "error", Message: "数据库结构读取失败", Data: map[string]any{"error": err.Error()}})
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "success", Message: "数据库结构读取完成"})
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "running", Message: "正在生成只读 SQL"})
generated, err := generateSQLForUserQuery(ctx, profile, query, schemaContext)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "SQL 生成失败", Data: map[string]any{"error": err.Error()}})
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
generated.Database = strings.TrimSpace(generated.Database)
generated.SQL = strings.TrimSpace(generated.SQL)
if generated.SQL == "" {
err := fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason)
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "模型未生成可执行 SQL", Data: map[string]any{"reason": generated.Reason}})
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "running", Message: "正在执行数据库查询", Data: map[string]any{"database": generated.Database}})
result, err := sqlState.ExecuteReadOnly(ctx, generated.Database, generated.SQL)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "error", Message: "数据库查询失败", Data: map[string]any{"error": err.Error()}})
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "success", Message: "已生成只读 SQL", Data: map[string]any{"database": generated.Database, "sql": generated.SQL, "reason": generated.Reason}})
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "success", Message: fmt.Sprintf("数据库查询完成,返回 %d 行", len(result.Rows)), Data: map[string]any{"database": result.Database, "rows": len(result.Rows), "columns": len(result.Columns), "truncated": result.Truncated, "max_rows": result.MaxRows}})
contextText := sqlquery.BuildResultContext(query, generated.SQL, result)
if strings.TrimSpace(reason) != "" {
contextText += "\n激活原因:" + reason
@@ -1480,6 +1519,14 @@ func contains(items []string, target string) bool {
func intPtr(i int) *int { return &i }
func writeSSEJSON(w io.Writer, frame chatSSEFrame) {
data, err := json.Marshal(frame)
if err != nil {
data, _ = json.Marshal(chatSSEFrame{Type: "error", Error: "序列化流事件失败"})
}
fmt.Fprintf(w, "data: %s\n\n", data)
}
func toJSON(s string) string {
b, _ := json.Marshal(s)
return string(b)