up
This commit is contained in:
@@ -630,6 +630,17 @@ var (
|
||||
store *ConvStore
|
||||
)
|
||||
|
||||
type chatSSEFrame struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Tool string `json:"tool,omitempty"`
|
||||
Stage string `json:"stage,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Data map[string]any `json:"data,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ─── 路由 ─────────────────────────────────────────────────
|
||||
|
||||
func indexHandler(c *gin.Context) {
|
||||
@@ -739,19 +750,46 @@ func chatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// SSE 头先写出,后续插件/模型过程都通过 trace 事件实时展示。
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
emit := func(frame chatSSEFrame) {
|
||||
writeSSEJSON(c.Writer, frame)
|
||||
flusher.Flush()
|
||||
}
|
||||
emitTrace := func(tool, stage, status, message string, data map[string]any) {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: tool, Stage: stage, Status: status, Message: message, Data: data})
|
||||
}
|
||||
emitError := func(err error) {
|
||||
emit(chatSSEFrame{Type: "error", Error: err.Error()})
|
||||
}
|
||||
|
||||
// 超时 context
|
||||
timeout := time.Duration(profile.Config.Timeout) * time.Second
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
chatMessages := req.Messages
|
||||
if sqlState != nil && sqlState.Enabled() {
|
||||
withSQL, err := enrichMessagesWithSQL(c.Request.Context(), profile, chatMessages)
|
||||
withSQL, err := enrichMessagesWithSQL(ctx, profile, chatMessages, emit)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "SQL 查询插件调用失败:", err)
|
||||
emitTrace("sql", "error", "error", "数据库查询插件调用失败,将继续普通回答", map[string]any{"error": err.Error()})
|
||||
} else {
|
||||
chatMessages = withSQL
|
||||
}
|
||||
}
|
||||
if req.WebSearch {
|
||||
withSearch, err := enrichMessagesWithSearch(c.Request.Context(), chatMessages)
|
||||
withSearch, err := enrichMessagesWithSearch(ctx, chatMessages, emit)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
emitError(err)
|
||||
return
|
||||
}
|
||||
chatMessages = withSearch
|
||||
@@ -760,39 +798,22 @@ func chatHandler(c *gin.Context) {
|
||||
// 构建 ark 消息列表
|
||||
messages, err := buildArkMessages(chatMessages)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
emitError(err)
|
||||
return
|
||||
}
|
||||
|
||||
// SSE 头
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器不支持流式响应"})
|
||||
return
|
||||
}
|
||||
|
||||
// 超时 context
|
||||
timeout := time.Duration(profile.Config.Timeout) * time.Second
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
// 发起流式请求(使用 CreateChatCompletionStream)
|
||||
emitTrace("model", "request", "running", "正在调用模型生成回答", nil)
|
||||
stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
|
||||
Model: profile.Config.Model,
|
||||
Messages: messages,
|
||||
MaxTokens: intPtr(4096),
|
||||
}.WithStream(true))
|
||||
if err != nil {
|
||||
fmt.Fprintf(c.Writer, "data: {\"error\":%s}\n\n", toJSON(err.Error()))
|
||||
flusher.Flush()
|
||||
emitError(err)
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
emitTrace("model", "stream", "running", "模型已开始输出", nil)
|
||||
|
||||
var full strings.Builder
|
||||
for {
|
||||
@@ -803,21 +824,20 @@ func chatHandler(c *gin.Context) {
|
||||
fmt.Fprintln(os.Stderr, "保存对话失败:", err)
|
||||
}
|
||||
}
|
||||
emitTrace("model", "stream", "success", "回答生成完成", nil)
|
||||
fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
fmt.Fprintf(c.Writer, "data: {\"error\":%s}\n\n", toJSON(err.Error()))
|
||||
flusher.Flush()
|
||||
emitError(err)
|
||||
return
|
||||
}
|
||||
if len(resp.Choices) > 0 {
|
||||
delta := resp.Choices[0].Delta.Content
|
||||
if delta != "" {
|
||||
full.WriteString(delta)
|
||||
fmt.Fprintf(c.Writer, "data: %s\n\n", toSSE(delta))
|
||||
flusher.Flush()
|
||||
emit(chatSSEFrame{Type: "delta", Text: delta})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -858,7 +878,7 @@ type duckDuckGoResponse struct {
|
||||
} `json:"Infobox"`
|
||||
}
|
||||
|
||||
func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage) ([]ChatMessage, error) {
|
||||
func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
||||
searchConfig := searchState.ActiveProfile()
|
||||
if !searchConfig.Enabled {
|
||||
return nil, errors.New("联网搜索未启用,请先在 config.yaml 中配置 search.enabled")
|
||||
@@ -869,14 +889,19 @@ func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage) ([]Ch
|
||||
return nil, errors.New("联网搜索需要输入文本问题")
|
||||
}
|
||||
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "running", Message: "正在联网搜索", Data: map[string]any{"provider": searchConfig.Provider}})
|
||||
results, err := webSearch(ctx, searchConfig, query)
|
||||
if err != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索失败", Data: map[string]any{"error": err.Error()}})
|
||||
return nil, err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return nil, errors.New("未搜索到相关网页结果")
|
||||
err := errors.New("未搜索到相关网页结果")
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "error", Message: err.Error()})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "success", Message: fmt.Sprintf("联网搜索完成,找到 %d 条结果", len(results)), Data: map[string]any{"provider": searchConfig.Provider, "count": len(results)}})
|
||||
searchContext := buildSearchContext(searchConfig, query, results)
|
||||
withSearch := make([]ChatMessage, 0, len(messages)+1)
|
||||
withSearch = append(withSearch, ChatMessage{Role: "system", Content: searchContext, Hidden: true})
|
||||
@@ -904,40 +929,54 @@ type sqlGenerationResult struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
func enrichMessagesWithSQL(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage) ([]ChatMessage, error) {
|
||||
func enrichMessagesWithSQL(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
||||
query := latestUserQuery(messages)
|
||||
if query == "" {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "running", Message: "正在判断是否需要查询数据库"})
|
||||
activate, reason, err := classifySQLActivation(ctx, profile, messages)
|
||||
if err != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "error", Message: "数据库查询判断失败", Data: map[string]any{"error": err.Error()}})
|
||||
return messages, err
|
||||
}
|
||||
if !activate {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "success", Message: "判断结果:本轮无需查询数据库", Data: map[string]any{"activate": false, "reason": reason}})
|
||||
return messages, nil
|
||||
}
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "success", Message: "判断结果:需要查询数据库", Data: map[string]any{"activate": true, "reason": reason}})
|
||||
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "running", Message: "正在读取数据库结构"})
|
||||
schemaContext, err := sqlState.SchemaContext(ctx)
|
||||
if err != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "error", Message: "数据库结构读取失败", Data: map[string]any{"error": err.Error()}})
|
||||
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
||||
}
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "success", Message: "数据库结构读取完成"})
|
||||
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "running", Message: "正在生成只读 SQL"})
|
||||
generated, err := generateSQLForUserQuery(ctx, profile, query, schemaContext)
|
||||
if err != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "SQL 生成失败", Data: map[string]any{"error": err.Error()}})
|
||||
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
||||
}
|
||||
generated.Database = strings.TrimSpace(generated.Database)
|
||||
generated.SQL = strings.TrimSpace(generated.SQL)
|
||||
if generated.SQL == "" {
|
||||
err := fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason)
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "模型未生成可执行 SQL", Data: map[string]any{"reason": generated.Reason}})
|
||||
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
||||
}
|
||||
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "running", Message: "正在执行数据库查询", Data: map[string]any{"database": generated.Database}})
|
||||
result, err := sqlState.ExecuteReadOnly(ctx, generated.Database, generated.SQL)
|
||||
if err != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "error", Message: "数据库查询失败", Data: map[string]any{"error": err.Error()}})
|
||||
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
||||
}
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "success", Message: "已生成只读 SQL", Data: map[string]any{"database": generated.Database, "sql": generated.SQL, "reason": generated.Reason}})
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "success", Message: fmt.Sprintf("数据库查询完成,返回 %d 行", len(result.Rows)), Data: map[string]any{"database": result.Database, "rows": len(result.Rows), "columns": len(result.Columns), "truncated": result.Truncated, "max_rows": result.MaxRows}})
|
||||
contextText := sqlquery.BuildResultContext(query, generated.SQL, result)
|
||||
if strings.TrimSpace(reason) != "" {
|
||||
contextText += "\n激活原因:" + reason
|
||||
@@ -1480,6 +1519,14 @@ func contains(items []string, target string) bool {
|
||||
|
||||
func intPtr(i int) *int { return &i }
|
||||
|
||||
func writeSSEJSON(w io.Writer, frame chatSSEFrame) {
|
||||
data, err := json.Marshal(frame)
|
||||
if err != nil {
|
||||
data, _ = json.Marshal(chatSSEFrame{Type: "error", Error: "序列化流事件失败"})
|
||||
}
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
}
|
||||
|
||||
func toJSON(s string) string {
|
||||
b, _ := json.Marshal(s)
|
||||
return string(b)
|
||||
|
||||
Reference in New Issue
Block a user