From d1324dc2f2d64840ed3187d0d02af3f1b36f55fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E6=96=87=E5=B3=B0?= Date: Thu, 11 Jun 2026 18:04:47 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E5=B7=A5=E5=85=B7=E9=93=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agents/search/search.go | 54 +++ agents/search/search_test.go | 39 +++ agents/sql/sql_query.go | 124 +++++++ agents/sql/sql_query_test.go | 40 +++ agents/time/time.go | 47 ++- agents/time/time_test.go | 16 + main.go | 649 +++++++++++++---------------------- main_test.go | 263 ++++++-------- templates/chat.html | 56 ++- 9 files changed, 718 insertions(+), 570 deletions(-) diff --git a/agents/search/search.go b/agents/search/search.go index 51a9bc9..c1334ad 100644 --- a/agents/search/search.go +++ b/agents/search/search.go @@ -14,10 +14,12 @@ import ( "sync" "time" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "gopkg.in/yaml.v3" ) const ( + ToolName = "search" defaultActivationPrompt = `判断用户问题是否需要联网搜索。 当问题涉及实时信息、新闻、价格、当前版本、近期事件、政策、网页资料核验,或用户明确要求“查一下/搜索/联网/最新”时调用 search。 当用户询问“历史上的今天”、某日期历史事件、需要按当前日期动态确定查询词的常识资料时,也应调用 search;如果联网无结果,主模型会回退到自身知识库回答并说明来源。 @@ -91,6 +93,11 @@ type ListResponse struct { Profiles []ProfileConfig `json:"profiles"` } +type ToolArgs struct { + Query string `json:"query"` + Reason string `json:"reason"` +} + type braveSearchResponse struct { Web struct { Results []Result `json:"results"` @@ -208,6 +215,53 @@ func (s *State) ActivationPrompt() string { return strings.TrimSpace(s.cfg.ActivationPrompt) } +func (s *State) ToolDefinition(description string) *model.Tool { + description = strings.TrimSpace(description) + if description == "" { + description = s.ActivationPrompt() + } + return &model.Tool{ + Type: model.ToolTypeFunction, + Function: &model.FunctionDefinition{ + Name: ToolName, + Description: description, + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "要联网搜索的关键词。若问题包含相对日期,应先调用 time 工具后使用绝对日期改写查询词。", + }, + "reason": map[string]any{ + "type": "string", + "description": "调用联网搜索的原因。", + }, + }, + "required": []string{"query"}, + }, + }, + } +} + +func (s *State) ExecuteTool(ctx context.Context, args string) (string, error) { + var parsed ToolArgs + if err := json.Unmarshal([]byte(strings.TrimSpace(args)), &parsed); err != nil { + return "", fmt.Errorf("解析搜索工具参数失败: %w", err) + } + query := strings.TrimSpace(parsed.Query) + if query == "" { + return "", errors.New("搜索关键词不能为空") + } + results, profile, err := s.Search(ctx, query) + if err != nil { + return BuildErrorContext(query, err), nil + } + if len(results) == 0 { + return BuildFallbackContext(profile, query, parsed.Reason, errors.New("未搜索到相关网页结果")), nil + } + return BuildResultContext(profile, query, results, parsed.Reason), nil +} + func (s *State) ActiveProfile() ProfileConfig { s.mu.RLock() defer s.mu.RUnlock() diff --git a/agents/search/search_test.go b/agents/search/search_test.go index 52481d9..76769c8 100644 --- a/agents/search/search_test.go +++ b/agents/search/search_test.go @@ -121,3 +121,42 @@ func TestLoadConfigWritesLegacyProfiles(t *testing.T) { t.Fatal(err) } } + +func TestToolDefinitionAndExecuteTool(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("q") != "golang" { + t.Fatalf("query = %s", r.URL.RawQuery) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"Heading":"Go","Abstract":"Go language","AbstractURL":"https://go.dev"}`)) + })) + defer server.Close() + state, err := NewState(&Config{Enabled: true, Profiles: ProfileConfigs{{Name: "ddg", Active: true, Enabled: true, Provider: "duckduckgo", BaseURL: server.URL, Count: 1, Timeout: 1}}}) + if err != nil { + t.Fatal(err) + } + definition := state.ToolDefinition("custom search") + if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom search" { + t.Fatalf("unexpected definition: %#v", definition) + } + text, err := state.ExecuteTool(context.Background(), `{"query":"golang","reason":"测试搜索"}`) + if err != nil { + t.Fatal(err) + } + for _, want := range []string{"联网搜索", "golang", "Go", "https://go.dev", "测试搜索"} { + if !strings.Contains(text, want) { + t.Fatalf("tool result missing %q:\n%s", want, text) + } + } +} + +func TestExecuteToolRejectsEmptyQuery(t *testing.T) { + state, err := NewState(&Config{Enabled: true, Profiles: ProfileConfigs{{Name: "ddg", Active: true, Enabled: true, Provider: "duckduckgo", BaseURL: defaultBaseURL, Count: 1, Timeout: 1}}}) + if err != nil { + t.Fatal(err) + } + _, err = state.ExecuteTool(context.Background(), `{"query":" "}`) + if err == nil || !strings.Contains(err.Error(), "不能为空") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/agents/sql/sql_query.go b/agents/sql/sql_query.go index cd4aba0..d8b5c70 100644 --- a/agents/sql/sql_query.go +++ b/agents/sql/sql_query.go @@ -3,6 +3,7 @@ package sqlquery import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "os" @@ -15,11 +16,13 @@ import ( "unicode/utf8" _ "github.com/go-sql-driver/mysql" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "gopkg.in/yaml.v3" _ "modernc.org/sqlite" ) const ( + ToolName = "sql" defaultActivationPrompt = `判断用户问题是否需要查询业务数据库。 仅当用户询问数据库表、记录、字段、时间、状态、内容、统计、最近/最早/某时间范围内的数据时返回 activate=true。 当用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,必须返回 activate=true,并说明应查询 tab_calendar_events 表。 @@ -80,6 +83,19 @@ type QueryResult struct { MaxRows int `json:"max_rows"` } +type ToolArgs struct { + Question string `json:"question"` + Reason string `json:"reason"` +} + +type SQLGenerator func(ctx context.Context, prompt string, maxTokens int) (string, error) + +type GenerationResult struct { + Database string `json:"database"` + SQL string `json:"sql"` + Reason string `json:"reason"` +} + func LoadConfig(path string) (*Config, error) { if _, err := os.Stat(path); err != nil { if !os.IsNotExist(err) { @@ -180,6 +196,73 @@ func (s *State) ActivationPrompt() string { return strings.TrimSpace(s.cfg.ActivationPrompt) } +func (s *State) ToolDefinition(description string) *model.Tool { + description = strings.TrimSpace(description) + if description == "" { + description = s.ActivationPrompt() + } + return &model.Tool{ + Type: model.ToolTypeFunction, + Function: &model.FunctionDefinition{ + Name: ToolName, + Description: description, + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "需要查询数据库的问题。若已有 time 工具结果,应使用其中的绝对日期范围解释相对时间。", + }, + "reason": map[string]any{ + "type": "string", + "description": "调用数据库查询工具的原因。", + }, + }, + "required": []string{"question"}, + }, + }, + } +} + +func (s *State) ExecuteTool(ctx context.Context, args string, generator SQLGenerator) (string, error) { + if generator == nil { + return "", errors.New("SQL 生成器未配置") + } + var parsed ToolArgs + if err := json.Unmarshal([]byte(strings.TrimSpace(args)), &parsed); err != nil { + return "", fmt.Errorf("解析 SQL 工具参数失败: %w", err) + } + question := strings.TrimSpace(parsed.Question) + if question == "" { + return "", errors.New("数据库查询问题不能为空") + } + schemaContext, err := s.SchemaContext(ctx) + if err != nil { + return BuildErrorContext(question, err), nil + } + generated, err := GenerateSQL(ctx, generator, question, schemaContext) + if err != nil { + return BuildErrorContext(question, err), nil + } + generated.Database = strings.TrimSpace(generated.Database) + generated.SQL = strings.TrimSpace(generated.SQL) + if generated.SQL == "" { + return BuildErrorContext(question, fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason)), nil + } + result, err := s.ExecuteReadOnly(ctx, generated.Database, generated.SQL) + if err != nil { + return BuildErrorContext(question, err), nil + } + contextText := BuildResultContext(question, generated.SQL, result) + if strings.TrimSpace(parsed.Reason) != "" { + contextText += "\n调用原因:" + strings.TrimSpace(parsed.Reason) + } + if strings.TrimSpace(generated.Reason) != "" { + contextText += "\nSQL 生成原因:" + strings.TrimSpace(generated.Reason) + } + return contextText, nil +} + func (s *State) DefaultDatabase() string { if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.DefaultDatabase) == "" { return defaultDatabaseName @@ -220,6 +303,37 @@ func (s *State) SchemaContext(ctx context.Context) (string, error) { return text, nil } +func GenerateSQL(ctx context.Context, generator SQLGenerator, userQuery string, schemaContext string) (*GenerationResult, error) { + prompt := BuildSQLGenerationPrompt(userQuery, schemaContext) + text, err := generator(ctx, prompt, 1024) + if err != nil { + return nil, err + } + var generated GenerationResult + if err := json.Unmarshal([]byte(extractJSONObject(text)), &generated); err != nil { + return nil, fmt.Errorf("解析 SQL 生成结果失败: %w", err) + } + return &generated, nil +} + +func BuildSQLGenerationPrompt(userQuery string, schemaContext string) string { + return fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题、工具结果上下文和数据库 schema 生成一条只读 SQL。 +要求: +- 只能返回 JSON,不要使用 Markdown。 +- JSON 格式:{"database":"数据库名称","sql":"SELECT ... LIMIT N","reason":"生成原因"} +- 只能生成 SELECT 或 WITH 查询,禁止 INSERT/UPDATE/DELETE/DROP/ALTER/CREATE 等任何修改语句。 +- 必须只使用 schema 中出现的数据库、表和字段。 +- 如果工具结果中包含“时间工具结果”,必须使用其中的绝对日期范围解释用户问题里的今天、明天、昨天、本周、本月、本年、最近等相对时间。 +- 用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,优先查询 tab_calendar_events 表;如果 schema 中没有该表,再返回无法根据已知表结构生成查询。 +- 查询日程表时,涉及日期范围必须使用半开区间:时间字段 >= start AND 时间字段 < end_exclusive;时间字段必须从 schema 中选择真实存在的字段。 +- 必须添加 LIMIT,且 LIMIT 不超过插件配置的 max_rows。 +- 如果无法根据 schema 回答,返回 {"database":"","sql":"","reason":"无法根据已知表结构生成查询"}。 + +%s + +用户问题:%s`, schemaContext, userQuery) +} + func (s *State) ExecuteReadOnly(ctx context.Context, databaseName string, query string) (*QueryResult, error) { if !s.Enabled() { return nil, errors.New("SQL 查询插件未启用") @@ -588,6 +702,16 @@ func (d *database) rejectExcludedTables(query string) error { return nil } +func extractJSONObject(text string) string { + text = strings.TrimSpace(text) + start := strings.Index(text, "{") + end := strings.LastIndex(text, "}") + if start >= 0 && end > start { + return text[start : end+1] + } + return text +} + func scanRows(rows *sql.Rows, cfg DatabaseConfig, query string) (*QueryResult, error) { defer rows.Close() columns, err := rows.Columns() diff --git a/agents/sql/sql_query_test.go b/agents/sql/sql_query_test.go index e144988..bc1a3f5 100644 --- a/agents/sql/sql_query_test.go +++ b/agents/sql/sql_query_test.go @@ -1,6 +1,7 @@ package sqlquery import ( + "context" "strings" "testing" ) @@ -43,3 +44,42 @@ func TestValidateReadOnlySQLRejectsUnsafeStatements(t *testing.T) { } } } + +func TestToolDefinition(t *testing.T) { + state := &State{} + definition := state.ToolDefinition("custom sql") + if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom sql" { + t.Fatalf("unexpected definition: %#v", definition) + } +} + +func TestBuildSQLGenerationPromptIncludesSafetyRules(t *testing.T) { + prompt := BuildSQLGenerationPrompt("本月有什么日程安排", "schema: tab_calendar_events(start_time)") + for _, want := range []string{"只读 SQL", "SELECT", "WITH", "半开区间", "tab_calendar_events", "max_rows", "本月有什么日程安排"} { + if !strings.Contains(prompt, want) { + t.Fatalf("prompt missing %q:\n%s", want, prompt) + } + } +} + +func TestGenerateSQLParsesJSONAndRejectsMalformed(t *testing.T) { + generated, err := GenerateSQL(context.Background(), func(ctx context.Context, prompt string, maxTokens int) (string, error) { + if !strings.Contains(prompt, "schema") || maxTokens != 1024 { + t.Fatalf("unexpected prompt/maxTokens: %s / %d", prompt, maxTokens) + } + return `{"database":"default","sql":"SELECT * FROM events LIMIT 1","reason":"ok"}`, nil + }, "查事件", "schema") + if err != nil { + t.Fatal(err) + } + if generated.Database != "default" || !strings.Contains(generated.SQL, "SELECT") || generated.Reason != "ok" { + t.Fatalf("unexpected generated SQL: %#v", generated) + } + + _, err = GenerateSQL(context.Background(), func(context.Context, string, int) (string, error) { + return "not json", nil + }, "查事件", "schema") + if err == nil { + t.Fatal("expected malformed JSON error") + } +} diff --git a/agents/time/time.go b/agents/time/time.go index ce4c41e..8ea1b36 100644 --- a/agents/time/time.go +++ b/agents/time/time.go @@ -1,12 +1,22 @@ package timeagent import ( + "encoding/json" "fmt" "strings" "time" + + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" ) -const ActivationPrompt = "提供当前日期、时间和常用时间范围。当用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近、历史上的今天、日程安排等相对时间表达时,应先调用此工具;如果后续还需要联网搜索或查数据库,可继续调用 search 或 sql。" +const ( + ToolName = "time" + ActivationPrompt = "提供当前日期、时间和常用时间范围。当用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近、历史上的今天、日程安排等相对时间表达时,应先调用此工具;如果后续还需要联网搜索或查数据库,可继续调用 search 或 sql。" +) + +type ToolArgs struct { + Reason string `json:"reason"` +} type Range struct { Start time.Time @@ -23,6 +33,39 @@ type Context struct { ThisYear Range } +func ToolDefinition(description string) *model.Tool { + description = strings.TrimSpace(description) + if description == "" { + description = ActivationPrompt + } + return &model.Tool{ + Type: model.ToolTypeFunction, + Function: &model.FunctionDefinition{ + Name: ToolName, + Description: description, + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "reason": map[string]any{ + "type": "string", + "description": "调用时间工具的原因。", + }, + }, + }, + }, + } +} + +func ExecuteTool(args string, now time.Time) (string, error) { + var parsed ToolArgs + if strings.TrimSpace(args) != "" { + if err := json.Unmarshal([]byte(args), &parsed); err != nil { + return "", fmt.Errorf("解析时间工具参数失败: %w", err) + } + } + return BuildContext(Resolve(now), parsed.Reason), nil +} + func Resolve(now time.Time) Context { loc := now.Location() today := Range{Start: startOfDay(now), End: startOfDay(now).AddDate(0, 0, 1)} @@ -47,7 +90,7 @@ func Resolve(now time.Time) Context { func BuildContext(ctx Context, routeReason string) string { var b strings.Builder - b.WriteString("时间工具结果。后续工具必须优先使用这里的绝对日期解释用户问题中的相对时间,不要自行猜测当前日期。\n") + b.WriteString("时间工具结果。请优先使用这里的绝对日期解释用户问题中的相对时间,不要自行猜测当前日期。\n") fmt.Fprintf(&b, "当前本地日期时间:%s\n", ctx.Now.Format("2006-01-02 15:04:05 MST")) fmt.Fprintf(&b, "今天:%s\n", FormatSQLRange(ctx.Today)) fmt.Fprintf(&b, "明天:%s\n", FormatSQLRange(ctx.Tomorrow)) diff --git a/agents/time/time_test.go b/agents/time/time_test.go index 45dbc9f..a4f8061 100644 --- a/agents/time/time_test.go +++ b/agents/time/time_test.go @@ -29,3 +29,19 @@ func TestBuildContextIncludesSQLHints(t *testing.T) { } } } + +func TestToolDefinitionAndExecuteTool(t *testing.T) { + definition := ToolDefinition("custom description") + if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom description" { + t.Fatalf("unexpected definition: %#v", definition) + } + text, err := ExecuteTool(`{"reason":"测试原因"}`, time.Date(2026, 6, 10, 13, 14, 15, 0, time.UTC)) + if err != nil { + t.Fatal(err) + } + for _, want := range []string{"时间工具结果", "2026-06-10", "本月", "start=", "end_exclusive=", "测试原因"} { + if !strings.Contains(text, want) { + t.Fatalf("tool result missing %q:\n%s", want, text) + } + } +} diff --git a/main.go b/main.go index 3f4543a..fe77e8a 100644 --- a/main.go +++ b/main.go @@ -37,16 +37,12 @@ const ( defaultOpenAITimeout = 120 defaultToolRouterTimeout = 30 defaultToolRouterMaxTokens = 512 - defaultToolRouterSystemText = `你是工具路由器。根据用户最新问题和可用工具列表,判断本轮是否需要调用一个或多个工具。 -只能返回 JSON,不要使用 Markdown。 -JSON 格式:{"tools":[{"name":"工具名称","reason":"..."}],"reason":"..."} -工具名称必须来自“可用工具”列表。 -可以选择多个工具,工具会按配置顺序依次执行;后面的工具可以使用前面工具写入的上下文。 -如果用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近等相对时间,且还需要调用 search 或 sql,必须同时选择 time,并让 time 排在这些工具之前。 -例如“历史上的今天都发生了什么”应选择 time 和 search:先获取今天的绝对日期,再搜索当天历史事件;如果联网无结果,主模型会回退到自身知识库回答并说明来源。 -例如“本月有什么日程安排”应选择 time 和 sql:先获取本月绝对日期范围,再查询日程表。 -如果无需工具,返回 {"tools":[],"reason":"..."}。 -只选择确实必要的工具。` + defaultToolRouterSystemText = `你可以按需直接调用可用工具来回答用户问题。 +如果用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近等相对时间,且后续需要搜索或查询数据库,应先调用 time 获取绝对日期范围。 +需要实时网页资料、新闻、当前版本、近期事件、网页核验或用户明确要求联网时,调用 search。 +需要查询本地业务数据、日程、会议、待办、记录、统计或时间范围内数据时,调用 sql。 +工具结果优先于模型内置知识;工具失败时必须如实说明,不要编造结果。 +只调用确实必要的工具。` ) type OpenAIConfig struct { @@ -286,6 +282,11 @@ func normalizeOpenAIConfigs(cfg *Config) (bool, error) { return changed, nil } +func isLegacyToolRouterPrompt(prompt string) bool { + prompt = strings.TrimSpace(prompt) + return strings.Contains(prompt, "工具路由器") || strings.Contains(prompt, "route_tools") || strings.Contains(prompt, `"tools":[`) +} + func normalizeToolRouterConfig(cfg *Config) (bool, error) { changed := false defaults := defaultToolRouterConfig() @@ -298,11 +299,12 @@ func normalizeToolRouterConfig(cfg *Config) (bool, error) { cfg.ToolRouter.MaxTokens = defaultToolRouterMaxTokens changed = true } - if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) == "" { + systemPrompt := strings.TrimSpace(cfg.ToolRouter.SystemPrompt) + if systemPrompt == "" || isLegacyToolRouterPrompt(systemPrompt) { cfg.ToolRouter.SystemPrompt = defaultToolRouterSystemText changed = true - } else if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) != cfg.ToolRouter.SystemPrompt { - cfg.ToolRouter.SystemPrompt = strings.TrimSpace(cfg.ToolRouter.SystemPrompt) + } else if systemPrompt != cfg.ToolRouter.SystemPrompt { + cfg.ToolRouter.SystemPrompt = systemPrompt changed = true } if len(cfg.ToolRouter.Tools) == 0 { @@ -432,12 +434,12 @@ type openAIListResponse struct { Profiles []OpenAIConfig `json:"profiles"` } -type toolTextCompleter func(context.Context, *OpenAIProfile, []ChatMessage, int, time.Duration) (string, error) +type chatCompleter func(context.Context, *OpenAIProfile, model.CreateChatCompletionRequest, time.Duration) (model.ChatCompletionResponse, error) type ToolRouterState struct { cfg *ToolRouterConfig ai *OpenAIState - complete toolTextCompleter + complete chatCompleter } func NewToolRouterState(config *ToolRouterConfig, ai *OpenAIState) (*ToolRouterState, error) { @@ -453,7 +455,7 @@ func NewToolRouterState(config *ToolRouterConfig, ai *OpenAIState) (*ToolRouterS return nil, fmt.Errorf("tool_router.openai_name 配置无效: %w", err) } } - return &ToolRouterState{cfg: config, ai: ai, complete: completeTextWithTimeout}, nil + return &ToolRouterState{cfg: config, ai: ai, complete: completeChatWithTimeout}, nil } func NewOpenAIState(configs []OpenAIConfig) (*OpenAIState, error) { @@ -796,20 +798,17 @@ func chatHandler(c *gin.Context) { usage := newTokenUsageTracker() ctx = contextWithTokenUsage(ctx, usage) - chatMessages := req.Messages - withTools, err := enrichMessagesWithRoutedTools(ctx, profile, chatMessages, emit) + // 用 Function Calling 工具循环替代旧的路由+隐藏上下文机制 + messages, err := runAgentToolLoop(ctx, profile, req.Messages, emit) if err != nil { - fmt.Fprintln(os.Stderr, "工具路由调用失败:", err) - } else { - chatMessages = withTools + fmt.Fprintln(os.Stderr, "Agent 工具循环失败:", err) + messages, err = buildArkMessages(req.Messages) + if err != nil { + emitError(err) + return + } } - // 构建 ark 消息列表 - messages, err := buildArkMessages(chatMessages) - if err != nil { - emitError(err) - return - } - promptTokens := estimateChatMessagesTokens(chatMessages) + promptTokens := estimateChatMessagesTokens(req.Messages) emitTrace("model", "request", "running", "正在调用模型生成回答", nil) stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{ @@ -885,21 +884,16 @@ func chatHandler(c *gin.Context) { stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond) emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats}) } + // 思考过程 reasoning_content 单独事件推送 + if resp.Choices[0].Delta.ReasoningContent != nil && *resp.Choices[0].Delta.ReasoningContent != "" { + emit(chatSSEFrame{Type: "reasoning", Text: *resp.Choices[0].Delta.ReasoningContent}) + } } } } // ─── 辅助函数 ───────────────────────────────────────────── -func latestUserQuery(messages []ChatMessage) string { - for i := len(messages) - 1; i >= 0; i-- { - if messages[i].Role == "user" { - return strings.TrimSpace(messages[i].Content) - } - } - return "" -} - func estimateChatMessagesTokens(messages []ChatMessage) int { total := 0 for _, msg := range messages { @@ -951,380 +945,214 @@ func tokensPerSecond(tokens int, start time.Time) float64 { return float64(tokens) / elapsed } -type ToolSelection struct { - Name string `json:"name"` - Reason string `json:"reason"` +type agentTool struct { + name string + definition *model.Tool + execute func(context.Context, string) (string, error) } -type ToolRoutingDecision struct { - Tools []ToolSelection `json:"tools"` - Reason string `json:"reason"` -} +func (t agentTool) Name() string { return t.name } -type ChatTool interface { - Name() string - Description() string - Enabled() bool - Enrich(context.Context, *OpenAIProfile, []ChatMessage, string, func(chatSSEFrame)) ([]ChatMessage, error) -} +const maxAgentToolIterations = 6 -type TimeChatTool struct{} - -func (t TimeChatTool) Name() string { return "time" } - -func (t TimeChatTool) Description() string { - return timeagent.ActivationPrompt -} - -func (t TimeChatTool) Enabled() bool { return true } - -func (t TimeChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) { - return runTimeTool(ctx, messages, routeReason, emit) -} - -type SQLChatTool struct { - state *sqlquery.State -} - -func (t SQLChatTool) Name() string { return "sql" } - -func (t SQLChatTool) Description() string { - if t.state == nil { - return "" - } - return t.state.ActivationPrompt() -} - -func (t SQLChatTool) Enabled() bool { return t.state != nil && t.state.Enabled() } - -func (t SQLChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) { - return runSQLTool(ctx, t.state, profile, messages, routeReason, emit) -} - -type SearchChatTool struct { - state *searchagent.State -} - -func (t SearchChatTool) Name() string { return "search" } - -func (t SearchChatTool) Description() string { - if t.state == nil { - return "" - } - return t.state.ActivationPrompt() -} - -func (t SearchChatTool) Enabled() bool { return t.state != nil && t.state.Enabled() } - -func (t SearchChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) { - return runSearchTool(ctx, t.state, messages, routeReason, emit) -} - -type sqlGenerationResult struct { - Database string `json:"database"` - SQL string `json:"sql"` - Reason string `json:"reason"` -} - -func runSQLTool(ctx context.Context, state *sqlquery.State, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) { - query := latestUserQuery(messages) - if query == "" { - return messages, nil - } - - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "running", Message: "正在读取数据库结构"}) - schemaContext, err := state.SchemaContext(ctx) - if err != nil { - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "error", Message: "数据库结构读取失败", Data: map[string]any{"error": err.Error()}}) - return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil - } - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "success", Message: "数据库结构读取完成"}) - - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "running", Message: "正在生成只读 SQL"}) - generated, err := generateSQLForUserQuery(ctx, profile, query, schemaContext) - if err != nil { - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "SQL 生成失败", Data: map[string]any{"error": err.Error()}}) - return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil - } - generated.Database = strings.TrimSpace(generated.Database) - generated.SQL = strings.TrimSpace(generated.SQL) - if generated.SQL == "" { - err := fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason) - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "模型未生成可执行 SQL", Data: map[string]any{"reason": generated.Reason}}) - return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil - } - - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "running", Message: "正在执行数据库查询", Data: map[string]any{"database": generated.Database}}) - result, err := state.ExecuteReadOnly(ctx, generated.Database, generated.SQL) - if err != nil { - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "error", Message: "数据库查询失败", Data: map[string]any{"error": err.Error()}}) - return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil - } - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "success", Message: "已生成只读 SQL", Data: map[string]any{"database": generated.Database, "sql": generated.SQL, "reason": generated.Reason}}) - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "success", Message: fmt.Sprintf("数据库查询完成,返回 %d 行", len(result.Rows)), Data: map[string]any{"database": result.Database, "rows": len(result.Rows), "columns": len(result.Columns), "truncated": result.Truncated, "max_rows": result.MaxRows}}) - contextText := sqlquery.BuildResultContext(query, generated.SQL, result) - if strings.TrimSpace(routeReason) != "" { - contextText += "\n激活原因:" + routeReason - } - return prependHiddenContext(messages, contextText), nil -} - -func prependHiddenContext(messages []ChatMessage, content string) []ChatMessage { - withContext := make([]ChatMessage, 0, len(messages)+1) - withContext = append(withContext, ChatMessage{Role: "system", Content: content, Hidden: true}) - withContext = append(withContext, messages...) - return withContext -} - -func runTimeTool(ctx context.Context, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) { - _ = ctx - resolved := timeagent.Resolve(time.Now()) - emit(chatSSEFrame{Type: "trace", Tool: "time", Stage: "resolve", Status: "success", Message: "已获取当前时间上下文", Data: map[string]any{ - "today": timeagent.FormatDate(resolved.Now), - "this_month": fmt.Sprintf("%s 至 %s", timeagent.FormatDate(resolved.ThisMonth.Start), timeagent.FormatDate(resolved.ThisMonth.End.AddDate(0, 0, -1))), - }}) - return prependHiddenContext(messages, timeagent.BuildContext(resolved, routeReason)), nil -} - -func runSearchTool(ctx context.Context, state *searchagent.State, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) { - query := latestUserQuery(messages) - if query == "" { - return messages, nil - } - if state == nil || !state.Enabled() { - err := errors.New("联网搜索未启用") - emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索未启用", Data: map[string]any{"error": err.Error()}}) - return prependHiddenContext(messages, searchagent.BuildErrorContext(query, err)), nil - } - active := state.ActiveProfile() - emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "running", Message: "正在联网搜索", Data: map[string]any{"provider": active.Provider}}) - results, profile, err := state.Search(ctx, query) - if err != nil { - emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索失败", Data: map[string]any{"error": err.Error()}}) - return prependHiddenContext(messages, searchagent.BuildErrorContext(query, err)), nil - } - if len(results) == 0 { - err := errors.New("未搜索到相关网页结果") - emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "warning", Message: "未搜索到相关网页结果,将使用模型知识库回答"}) - return prependHiddenContext(messages, searchagent.BuildFallbackContext(profile, query, routeReason, err)), nil - } - emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "success", Message: fmt.Sprintf("联网搜索完成,找到 %d 条结果", len(results)), Data: map[string]any{"provider": profile.Provider, "count": len(results)}}) - return prependHiddenContext(messages, searchagent.BuildResultContext(profile, query, results, routeReason)), nil -} - -func enrichMessagesWithRoutedTools(ctx context.Context, chatProfile *OpenAIProfile, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) { +func availableAgentTools(profile *OpenAIProfile, emit func(chatSSEFrame)) []agentTool { if toolRouterState == nil || toolRouterState.cfg == nil || !toolRouterState.cfg.Enabled { - return messages, nil + return nil } - if latestUserQuery(messages) == "" { - return messages, nil - } - tools := availableChatTools(toolRouterState.cfg) - if len(tools) == 0 { - return messages, nil - } - - emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "running", Message: "正在进行工具路由"}) - decision, err := routeTools(ctx, toolRouterState, chatProfile, messages, tools) - if err != nil { - emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "error", Message: "工具路由失败,将继续普通回答", Data: map[string]any{"error": err.Error()}}) - return messages, err - } - selected := filterToolSelections(decision, tools, toolRouterState.cfg.Tools) - selected = ensureTimeSelectionForRelativeQuery(selected, tools, toolRouterState.cfg.Tools, latestUserQuery(messages)) - if len(selected) == 0 { - emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "success", Message: "工具路由结果:无需调用工具", Data: map[string]any{"reason": decision.Reason}}) - return messages, nil - } - - names := make([]string, 0, len(selected)) - for _, item := range selected { - names = append(names, item.Name) - } - emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "success", Message: "工具路由结果:将调用 " + strings.Join(names, ", "), Data: map[string]any{"tools": names, "reason": decision.Reason}}) - - current := messages - for _, item := range selected { - tool := tools[item.Name] - next, err := tool.Enrich(ctx, chatProfile, current, firstNonEmpty(item.Reason, decision.Reason), emit) - if err != nil { - emit(chatSSEFrame{Type: "trace", Tool: item.Name, Stage: "error", Status: "error", Message: "工具调用失败,将继续普通回答", Data: map[string]any{"error": err.Error()}}) - continue - } - current = next - } - return current, nil -} - -func availableChatTools(config *ToolRouterConfig) map[string]ChatTool { - configured := map[string]ToolRouteConfig{} - for _, item := range config.Tools { - configured[item.Name] = item - } - registered := []ChatTool{ - TimeChatTool{}, - SearchChatTool{state: searchState}, - SQLChatTool{state: sqlState}, - } - available := map[string]ChatTool{} - for _, tool := range registered { - name := tool.Name() - item, ok := configured[name] - if !ok || !item.Enabled || !tool.Enabled() { - continue - } - available[name] = tool - } - return available -} - -func routeTools(ctx context.Context, state *ToolRouterState, chatProfile *OpenAIProfile, messages []ChatMessage, tools map[string]ChatTool) (ToolRoutingDecision, error) { - routerProfile := chatProfile - if strings.TrimSpace(state.cfg.OpenAIName) != "" { - profile, err := state.ai.GetProfile(state.cfg.OpenAIName) - if err != nil { - return ToolRoutingDecision{}, err - } - routerProfile = profile - } - prompt := buildToolRouterPrompt(state.cfg, messages, tools) - text, err := state.complete(ctx, routerProfile, []ChatMessage{{Role: "system", Content: prompt}}, state.cfg.MaxTokens, time.Duration(state.cfg.Timeout)*time.Second) - if err != nil { - return ToolRoutingDecision{}, err - } - return parseToolRoutingDecision(text) -} - -func buildToolRouterPrompt(config *ToolRouterConfig, messages []ChatMessage, tools map[string]ChatTool) string { - query := latestUserQuery(messages) - var b strings.Builder - b.WriteString(strings.TrimSpace(config.SystemPrompt)) - b.WriteString("\n\n可用工具:\n") - for _, item := range config.Tools { - tool, ok := tools[item.Name] - if !ok { + tools := make([]agentTool, 0, len(toolRouterState.cfg.Tools)) + for _, item := range toolRouterState.cfg.Tools { + if !item.Enabled { continue } description := strings.TrimSpace(item.Description) - if description == "" { - description = tool.Description() - } - fmt.Fprintf(&b, "- name: %s\n description: %s\n", item.Name, description) - } - fmt.Fprintf(&b, "\n最新用户问题:%s", query) - return b.String() -} - -func parseToolRoutingDecision(text string) (ToolRoutingDecision, error) { - var decision ToolRoutingDecision - if err := json.Unmarshal([]byte(extractJSONObject(text)), &decision); err != nil { - return decision, fmt.Errorf("解析工具路由结果失败: %w", err) - } - for i := range decision.Tools { - decision.Tools[i].Name = strings.ToLower(strings.TrimSpace(decision.Tools[i].Name)) - decision.Tools[i].Reason = strings.TrimSpace(decision.Tools[i].Reason) - } - decision.Reason = strings.TrimSpace(decision.Reason) - return decision, nil -} - -func filterToolSelections(decision ToolRoutingDecision, tools map[string]ChatTool, order []ToolRouteConfig) []ToolSelection { - selected := map[string]ToolSelection{} - for _, item := range decision.Tools { - if item.Name == "" { - continue - } - if _, ok := tools[item.Name]; !ok { - continue - } - if _, ok := selected[item.Name]; !ok { - selected[item.Name] = item + switch item.Name { + case timeagent.ToolName: + tools = append(tools, agentTool{ + name: timeagent.ToolName, + definition: timeagent.ToolDefinition(description), + execute: func(ctx context.Context, args string) (string, error) { + result, err := timeagent.ExecuteTool(args, time.Now()) + if err == nil && emit != nil { + emit(chatSSEFrame{Type: "trace", Tool: timeagent.ToolName, Stage: "resolve", Status: "success", Message: "已获取当前时间上下文"}) + } + return result, err + }, + }) + case searchagent.ToolName: + if searchState == nil || !searchState.Enabled() { + continue + } + tools = append(tools, agentTool{ + name: searchagent.ToolName, + definition: searchState.ToolDefinition(description), + execute: func(ctx context.Context, args string) (string, error) { + if emit != nil { + emit(chatSSEFrame{Type: "trace", Tool: searchagent.ToolName, Stage: "request", Status: "running", Message: "正在联网搜索"}) + } + result, err := searchState.ExecuteTool(ctx, args) + if emit != nil { + status := "success" + message := "联网搜索完成" + if err != nil { + status = "error" + message = "联网搜索失败" + } + emit(chatSSEFrame{Type: "trace", Tool: searchagent.ToolName, Stage: "results", Status: status, Message: message}) + } + return result, err + }, + }) + case sqlquery.ToolName: + if sqlState == nil || !sqlState.Enabled() { + continue + } + tools = append(tools, agentTool{ + name: sqlquery.ToolName, + definition: sqlState.ToolDefinition(description), + execute: func(ctx context.Context, args string) (string, error) { + if emit != nil { + emit(chatSSEFrame{Type: "trace", Tool: sqlquery.ToolName, Stage: "execute", Status: "running", Message: "正在查询数据库"}) + } + generator := func(ctx context.Context, prompt string, maxTokens int) (string, error) { + return completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, maxTokens) + } + result, err := sqlState.ExecuteTool(ctx, args, generator) + if emit != nil { + status := "success" + message := "数据库查询完成" + if err != nil { + status = "error" + message = "数据库查询失败" + } + emit(chatSSEFrame{Type: "trace", Tool: sqlquery.ToolName, Stage: "execute", Status: status, Message: message}) + } + return result, err + }, + }) } } - return orderToolSelections(selected, order) + return tools } -func ensureTimeSelectionForRelativeQuery(selected []ToolSelection, tools map[string]ChatTool, order []ToolRouteConfig, query string) []ToolSelection { - if !containsRelativeTime(query) || hasToolSelection(selected, "time") || (!hasToolSelection(selected, "search") && !hasToolSelection(selected, "sql")) { - return selected - } - if _, ok := tools["time"]; !ok { - return selected - } - withTime := make(map[string]ToolSelection, len(selected)+1) - for _, item := range selected { - withTime[item.Name] = item - } - withTime["time"] = ToolSelection{Name: "time", Reason: "问题包含相对日期,需要先获取当前日期"} - return orderToolSelections(withTime, order) -} - -func containsRelativeTime(query string) bool { - query = strings.TrimSpace(query) - if query == "" { - return false - } - for _, keyword := range []string{"今天", "今日", "明天", "昨天", "本周", "这周", "本月", "这个月", "本年", "今年", "最近", "历史上的今天"} { - if strings.Contains(query, keyword) { - return true - } - } - return false -} - -func hasToolSelection(selected []ToolSelection, name string) bool { - for _, item := range selected { - if item.Name == name { - return true - } - } - return false -} - -func orderToolSelections(selected map[string]ToolSelection, order []ToolRouteConfig) []ToolSelection { - result := make([]ToolSelection, 0, len(selected)) - for _, item := range order { - if selection, ok := selected[item.Name]; ok { - result = append(result, selection) - } - } - return result -} - -func firstNonEmpty(items ...string) string { - for _, item := range items { - if strings.TrimSpace(item) != "" { - return strings.TrimSpace(item) - } - } - return "" -} - -func generateSQLForUserQuery(ctx context.Context, profile *OpenAIProfile, userQuery string, schemaContext string) (*sqlGenerationResult, error) { - prompt := fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题、隐藏上下文和数据库 schema 生成一条只读 SQL。 -要求: -- 只能返回 JSON,不要使用 Markdown。 -- JSON 格式:{"database":"数据库名称","sql":"SELECT ... LIMIT N","reason":"生成原因"} -- 只能生成 SELECT 或 WITH 查询,禁止 INSERT/UPDATE/DELETE/DROP/ALTER/CREATE 等任何修改语句。 -- 必须只使用 schema 中出现的数据库、表和字段。 -- 如果隐藏上下文中包含“时间工具结果”,必须使用其中的绝对日期范围解释用户问题里的今天、明天、昨天、本周、本月、本年、最近等相对时间。 -- 用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,优先查询 tab_calendar_events 表;如果 schema 中没有该表,再返回无法根据已知表结构生成查询。 -- 查询日程表时,涉及日期范围必须使用半开区间:时间字段 >= start AND 时间字段 < end_exclusive;时间字段必须从 schema 中选择真实存在的字段。 -- 必须添加 LIMIT,且 LIMIT 不超过插件配置的 max_rows。 -- 如果无法根据 schema 回答,返回 {"database":"","sql":"","reason":"无法根据已知表结构生成查询"}。 - -%s - -用户问题:%s`, schemaContext, userQuery) - text, err := completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, 1024) +func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, emit func(chatSSEFrame)) ([]*model.ChatCompletionMessage, error) { + messages, err := buildArkMessages(chatMessages) if err != nil { return nil, err } - var generated sqlGenerationResult - if err := json.Unmarshal([]byte(extractJSONObject(text)), &generated); err != nil { - return nil, fmt.Errorf("解析 SQL 生成结果失败: %w", err) + tools := availableAgentTools(profile, emit) + if len(tools) == 0 { + return messages, nil } - return &generated, nil + toolByName := make(map[string]agentTool, len(tools)) + definitions := make([]*model.Tool, 0, len(tools)) + availableNames := make([]string, 0, len(tools)) + toolDescriptions := make([]string, 0, len(tools)) + for _, tool := range tools { + toolByName[tool.name] = tool + definitions = append(definitions, tool.definition) + availableNames = append(availableNames, tool.name) + if tool.definition != nil && tool.definition.Function != nil { + toolDescriptions = append(toolDescriptions, fmt.Sprintf("%s: %s", tool.name, tool.definition.Function.Description)) + } + } + if emit != nil { + emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "prepare", Status: "success", Message: "已准备可用工具", Data: map[string]any{"tools": availableNames, "tool_descriptions": toolDescriptions}}) + } + if prompt := strings.TrimSpace(toolRouterState.cfg.SystemPrompt); prompt != "" { + messages = append([]*model.ChatCompletionMessage{{Role: model.ChatMessageRoleSystem, Content: stringContent(prompt)}}, messages...) + } + for i := 0; i < maxAgentToolIterations; i++ { + if emit != nil { + emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "running", Message: fmt.Sprintf("正在进行第 %d 轮工具判断", i+1), Data: map[string]any{"iteration": i + 1, "max_iterations": maxAgentToolIterations, "tools": availableNames}}) + } + resp, err := toolRouterState.complete(ctx, profile, model.CreateChatCompletionRequest{ + Model: profile.Config.Model, + Messages: messages, + MaxTokens: intPtr(toolRouterState.cfg.MaxTokens), + Tools: definitions, + ToolChoice: model.ToolChoiceStringTypeAuto, + ParallelToolCalls: boolPtr(false), + }, time.Duration(toolRouterState.cfg.Timeout)*time.Second) + if err != nil { + return messages, err + } + if tracker := tokenUsageFromContext(ctx); tracker != nil { + tracker.addTool(resp.Usage.PromptTokens, resp.Usage.CompletionTokens) + } + if len(resp.Choices) == 0 { + return messages, nil + } + choice := resp.Choices[0] + decisionPreview := chatMessageContentString(choice.Message.Content) + if emit != nil { + emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "decision", Status: "success", Message: "工具判断响应已返回", Data: map[string]any{"iteration": i + 1, "finish_reason": string(choice.FinishReason), "content_preview": truncateString(decisionPreview, 800)}}) + } + calls := choice.Message.ToolCalls + if len(calls) == 0 && choice.Message.FunctionCall != nil { + calls = []*model.ToolCall{{ID: "legacy_function_call", Type: model.ToolTypeFunction, Function: *choice.Message.FunctionCall}} + } + if len(calls) == 0 { + if emit != nil { + emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "success", Message: "模型未请求工具,进入回答生成"}) + } + return messages, nil + } + callNames := make([]string, 0, len(calls)) + for _, call := range calls { + if call != nil { + callNames = append(callNames, call.Function.Name) + } + } + if emit != nil { + emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "tool_calls", Status: "running", Message: fmt.Sprintf("模型请求调用 %d 个工具", len(calls)), Data: map[string]any{"tools": callNames, "iteration": i + 1}}) + } + messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleAssistant, ToolCalls: calls, Content: choice.Message.Content}) + for _, call := range calls { + result := executeAgentToolCall(ctx, call, toolByName, emit) + messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleTool, ToolCallID: call.ID, Content: stringContent(result)}) + } + } + messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent("工具调用轮数已达到上限。请基于已有工具结果回答,并说明可能未完成全部工具调用。")}) + return messages, nil +} + +func executeAgentToolCall(ctx context.Context, call *model.ToolCall, tools map[string]agentTool, emit func(chatSSEFrame)) string { + if call == nil || call.Type != model.ToolTypeFunction { + result := "工具调用无效:仅支持 function 类型工具。" + if emit != nil { + emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "execute", Status: "error", Message: result}) + } + return result + } + toolName := call.Function.Name + if emit != nil { + emit(chatSSEFrame{Type: "trace", Tool: toolName, Stage: "arguments", Status: "running", Message: "准备执行工具", Data: map[string]any{"tool_call_id": call.ID, "arguments": call.Function.Arguments}}) + } + tool, ok := tools[toolName] + if !ok { + result := fmt.Sprintf("工具调用失败:未知工具 %s。", toolName) + if emit != nil { + emit(chatSSEFrame{Type: "trace", Tool: toolName, Stage: "execute", Status: "error", Message: result}) + } + return result + } + started := time.Now() + result, err := tool.execute(ctx, call.Function.Arguments) + durationMs := time.Since(started).Milliseconds() + if err != nil { + message := fmt.Sprintf("工具 %s 执行失败:%v", tool.name, err) + if emit != nil { + emit(chatSSEFrame{Type: "trace", Tool: tool.name, Stage: "execute", Status: "error", Message: "工具执行失败", Data: map[string]any{"tool_call_id": call.ID, "duration_ms": durationMs, "error": err.Error()}}) + } + return message + } + if strings.TrimSpace(result) == "" { + result = fmt.Sprintf("工具 %s 执行完成,但没有返回内容。", tool.name) + } + if emit != nil { + emit(chatSSEFrame{Type: "trace", Tool: tool.name, Stage: "result", Status: "success", Message: "工具执行完成", Data: map[string]any{"tool_call_id": call.ID, "duration_ms": durationMs, "result_preview": truncateString(result, 1200)}}) + } + return result } func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) { @@ -1370,14 +1198,10 @@ func completeTextWithTimeout(ctx context.Context, profile *OpenAIProfile, chatMe } } -func extractJSONObject(text string) string { - text = strings.TrimSpace(text) - start := strings.Index(text, "{") - end := strings.LastIndex(text, "}") - if start >= 0 && end > start { - return text[start : end+1] - } - return text +func completeChatWithTimeout(ctx context.Context, profile *OpenAIProfile, request model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) { + completionCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return profile.Client.CreateChatCompletion(completionCtx, request.WithStream(false)) } func newUUID() string { @@ -1590,6 +1414,17 @@ func textPart(text string) *model.ChatCompletionMessageContentPart { } } +func stringContent(text string) *model.ChatCompletionMessageContent { + return &model.ChatCompletionMessageContent{StringValue: &text} +} + +func chatMessageContentString(content *model.ChatCompletionMessageContent) string { + if content == nil || content.StringValue == nil { + return "" + } + return *content.StringValue +} + func normalizeImageURL(raw string) (string, error) { raw = strings.TrimSpace(raw) if raw == "" { @@ -1651,6 +1486,16 @@ func contains(items []string, target string) bool { func intPtr(i int) *int { return &i } +func boolPtr(v bool) *bool { return &v } + +func truncateString(text string, maxRunes int) string { + runes := []rune(strings.TrimSpace(text)) + if maxRunes <= 0 || len(runes) <= maxRunes { + return string(runes) + } + return string(runes[:maxRunes]) + "..." +} + func writeSSEJSON(w io.Writer, frame chatSSEFrame) { data, err := json.Marshal(frame) if err != nil { diff --git a/main_test.go b/main_test.go index 10c3514..cefebeb 100644 --- a/main_test.go +++ b/main_test.go @@ -6,21 +6,10 @@ import ( "strings" "testing" "time" + + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" ) -type fakeChatTool struct { - name string - description string - enabled bool -} - -func (t fakeChatTool) Name() string { return t.name } -func (t fakeChatTool) Description() string { return t.description } -func (t fakeChatTool) Enabled() bool { return t.enabled } -func (t fakeChatTool) Enrich(context.Context, *OpenAIProfile, []ChatMessage, string, func(chatSSEFrame)) ([]ChatMessage, error) { - return nil, nil -} - func TestNormalizeToolRouterConfigDefaults(t *testing.T) { cfg := &Config{ToolRouter: ToolRouterConfig{Enabled: true}} changed, err := normalizeToolRouterConfig(cfg) @@ -49,7 +38,7 @@ func TestNormalizeToolRouterConfigAddsTimeBeforeSQL(t *testing.T) { Enabled: true, Timeout: 1, MaxTokens: 1, - SystemPrompt: "route", + SystemPrompt: "tools", Tools: []ToolRouteConfig{ {Name: "search", Enabled: true}, {Name: "sql", Enabled: true}, @@ -72,7 +61,7 @@ func TestNormalizeToolRouterConfigDuplicateTools(t *testing.T) { Enabled: true, Timeout: 1, MaxTokens: 1, - SystemPrompt: "route", + SystemPrompt: "tools", Tools: []ToolRouteConfig{ {Name: "sql", Enabled: true}, {Name: " SQL ", Enabled: true}, @@ -84,169 +73,119 @@ func TestNormalizeToolRouterConfigDuplicateTools(t *testing.T) { } } -func TestParseToolRoutingDecision(t *testing.T) { - decision, err := parseToolRoutingDecision("```json\n{\"tools\":[{\"name\":\" SQL \",\"reason\":\" 需要查库 \"}],\"reason\":\" 总原因 \"}\n```") - if err != nil { - t.Fatal(err) - } - if len(decision.Tools) != 1 || decision.Tools[0].Name != "sql" || decision.Tools[0].Reason != "需要查库" { - t.Fatalf("unexpected decision: %#v", decision) - } - if decision.Reason != "总原因" { - t.Fatalf("reason = %q", decision.Reason) - } +func TestAvailableAgentToolsUsesConfigOrderAndEnabled(t *testing.T) { + oldRouter := toolRouterState + oldSearch := searchState + oldSQL := sqlState + defer func() { + toolRouterState = oldRouter + searchState = oldSearch + sqlState = oldSQL + }() - if _, err := parseToolRoutingDecision("not json"); err == nil { - t.Fatal("expected malformed JSON error") - } -} - -func TestFilterToolSelections(t *testing.T) { - tools := map[string]ChatTool{ - "time": fakeChatTool{name: "time", enabled: true}, - "sql": fakeChatTool{name: "sql", enabled: true}, - "search": fakeChatTool{name: "search", enabled: true}, - } - decision := ToolRoutingDecision{Tools: []ToolSelection{ - {Name: "unknown", Reason: "ignore"}, - {Name: "search", Reason: "second in config"}, - {Name: "sql", Reason: "third in config"}, - {Name: "time", Reason: "first in config"}, - {Name: "sql", Reason: "duplicate"}, - }} - selected := filterToolSelections(decision, tools, []ToolRouteConfig{{Name: "time"}, {Name: "search"}, {Name: "sql"}}) - if len(selected) != 3 { - t.Fatalf("selected length = %d", len(selected)) - } - if selected[0].Name != "time" || selected[0].Reason != "first in config" { - t.Fatalf("first selection = %#v", selected[0]) - } - if selected[1].Name != "search" { - t.Fatalf("second selection = %#v", selected[1]) - } - if selected[2].Name != "sql" || selected[2].Reason != "third in config" { - t.Fatalf("third selection = %#v", selected[2]) - } -} - -func TestEnsureTimeSelectionForRelativeSearch(t *testing.T) { - tools := map[string]ChatTool{ - "time": fakeChatTool{name: "time", enabled: true}, - "search": fakeChatTool{name: "search", enabled: true}, - } - selected := ensureTimeSelectionForRelativeQuery( - []ToolSelection{{Name: "search", Reason: "查询历史事件"}}, - tools, - []ToolRouteConfig{{Name: "time"}, {Name: "search"}, {Name: "sql"}}, - "历史上的今天都发生了什么?", - ) - if len(selected) != 2 || selected[0].Name != "time" || selected[1].Name != "search" { - t.Fatalf("unexpected selected tools: %#v", selected) - } - if !strings.Contains(selected[0].Reason, "相对日期") { - t.Fatalf("unexpected time reason: %#v", selected[0]) - } -} - -func TestEnsureTimeSelectionSkipsOrdinarySearch(t *testing.T) { - tools := map[string]ChatTool{ - "time": fakeChatTool{name: "time", enabled: true}, - "search": fakeChatTool{name: "search", enabled: true}, - } - selected := ensureTimeSelectionForRelativeQuery( - []ToolSelection{{Name: "search", Reason: "查询资料"}}, - tools, - []ToolRouteConfig{{Name: "time"}, {Name: "search"}}, - "查一下 Go 语言官网", - ) - if len(selected) != 1 || selected[0].Name != "search" { - t.Fatalf("unexpected selected tools: %#v", selected) - } -} - -func TestRunTimeToolAddsHiddenDateRanges(t *testing.T) { - messages := []ChatMessage{{Role: "user", Content: "本月有什么日程安排"}} - withTime, err := runTimeTool(context.Background(), messages, "需要日期范围", func(chatSSEFrame) {}) - if err != nil { - t.Fatal(err) - } - if len(withTime) != 2 || !withTime[0].Hidden || withTime[0].Role != "system" { - t.Fatalf("unexpected messages: %#v", withTime) - } - for _, want := range []string{"时间工具结果", "本月", "start=", "end_exclusive=", "半开区间"} { - if !strings.Contains(withTime[0].Content, want) { - t.Fatalf("time context missing %q:\n%s", want, withTime[0].Content) - } - } -} - -func TestBuildToolRouterPrompt(t *testing.T) { - cfg := &ToolRouterConfig{ - SystemPrompt: "router", + toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{ + Enabled: true, Tools: []ToolRouteConfig{ - {Name: "time", Enabled: true}, - {Name: "sql", Enabled: true, Description: "configured sql"}, {Name: "search", Enabled: true}, + {Name: "time", Enabled: true, Description: "custom time"}, + {Name: "sql", Enabled: false}, }, + }} + searchState = nil + sqlState = nil + + tools := availableAgentTools(&OpenAIProfile{}, nil) + if len(tools) != 1 { + t.Fatalf("tools length = %d", len(tools)) } - tools := map[string]ChatTool{ - "time": fakeChatTool{name: "time", description: "fallback time", enabled: true}, - "sql": fakeChatTool{name: "sql", description: "fallback sql", enabled: true}, - "search": fakeChatTool{name: "search", description: "fallback search", enabled: true}, + if tools[0].name != "time" { + t.Fatalf("tool name = %s", tools[0].name) } - prompt := buildToolRouterPrompt(cfg, []ChatMessage{{Role: "user", Content: "查一下订单"}}, tools) - for _, want := range []string{"router", "name: time", "fallback time", "name: sql", "configured sql", "name: search", "fallback search", "最新用户问题:查一下订单"} { - if !strings.Contains(prompt, want) { - t.Fatalf("prompt missing %q:\n%s", want, prompt) - } + if tools[0].definition.Function == nil || tools[0].definition.Function.Description != "custom time" { + t.Fatalf("unexpected definition: %#v", tools[0].definition) } } -func TestRouteToolsUsesConfiguredRouterProfileAndTimeout(t *testing.T) { - ai := &OpenAIState{ - profiles: map[string]*OpenAIProfile{ - "chat": {Config: OpenAIConfig{Name: "chat"}}, - "router": {Config: OpenAIConfig{Name: "router"}}, - }, - order: []string{"chat", "router"}, - activeName: "chat", +func TestRunAgentToolLoopAppendsToolMessages(t *testing.T) { + oldRouter := toolRouterState + defer func() { toolRouterState = oldRouter }() + + calls := 0 + toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{ + Enabled: true, + Timeout: 1, + MaxTokens: 128, + SystemPrompt: "use tools", + Tools: []ToolRouteConfig{{Name: "time", Enabled: true}}, + }} + toolRouterState.complete = func(ctx context.Context, profile *OpenAIProfile, req model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) { + calls++ + if req.ToolChoice != model.ToolChoiceStringTypeAuto { + t.Fatalf("tool choice = %#v", req.ToolChoice) + } + if len(req.Tools) != 1 || req.Tools[0].Function == nil || req.Tools[0].Function.Name != "time" { + t.Fatalf("unexpected tools: %#v", req.Tools) + } + if calls == 1 { + return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{ToolCalls: []*model.ToolCall{{ID: "call_1", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "time", Arguments: `{"reason":"需要当前日期"}`}}}}}}}, nil + } + return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{Content: stringContent("done")}}}}, nil } - state := &ToolRouterState{cfg: &ToolRouterConfig{ - OpenAIName: "router", - Timeout: 7, - MaxTokens: 123, - SystemPrompt: "router prompt", - Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}}, - }, ai: ai} - state.complete = func(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, maxTokens int, timeout time.Duration) (string, error) { - if profile.Config.Name != "router" { - t.Fatalf("profile = %s", profile.Config.Name) - } - if maxTokens != 123 { - t.Fatalf("maxTokens = %d", maxTokens) - } - if timeout != 7*time.Second { - t.Fatalf("timeout = %s", timeout) - } - return `{"tools":[],"reason":"无需工具"}`, nil - } - decision, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}}) + + messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "test"}}, []ChatMessage{{Role: "user", Content: "今天几号"}}, nil) if err != nil { t.Fatal(err) } - if len(decision.Tools) != 0 || decision.Reason != "无需工具" { - t.Fatalf("unexpected decision: %#v", decision) + if calls != 2 { + t.Fatalf("calls = %d", calls) + } + if len(messages) < 4 { + t.Fatalf("expected system/user/assistant/tool messages, got %d", len(messages)) + } + last := messages[len(messages)-1] + if last.Role != model.ChatMessageRoleTool || last.ToolCallID != "call_1" { + t.Fatalf("unexpected last message: %#v", last) + } + if last.Content == nil || last.Content.StringValue == nil || !strings.Contains(*last.Content.StringValue, "时间工具结果") { + t.Fatalf("unexpected tool content: %#v", last.Content) } } -func TestRouteToolsCompleterError(t *testing.T) { - ai := &OpenAIState{profiles: map[string]*OpenAIProfile{"chat": {Config: OpenAIConfig{Name: "chat"}}}, activeName: "chat"} - state := &ToolRouterState{cfg: &ToolRouterConfig{Timeout: 1, MaxTokens: 1, SystemPrompt: "router", Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}}}, ai: ai} - state.complete = func(context.Context, *OpenAIProfile, []ChatMessage, int, time.Duration) (string, error) { - return "", errors.New("boom") +func TestExecuteAgentToolCallUnknownAndError(t *testing.T) { + unknown := executeAgentToolCall(context.Background(), &model.ToolCall{ID: "1", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "missing"}}, map[string]agentTool{}, nil) + if !strings.Contains(unknown, "未知工具") { + t.Fatalf("unknown result = %q", unknown) } - _, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}}) - if err == nil { - t.Fatal("expected completer error") + + failed := executeAgentToolCall(context.Background(), &model.ToolCall{ID: "2", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "boom"}}, map[string]agentTool{ + "boom": {name: "boom", execute: func(context.Context, string) (string, error) { return "", errors.New("bad args") }}, + }, nil) + if !strings.Contains(failed, "bad args") { + t.Fatalf("failed result = %q", failed) + } +} + +func TestRunAgentToolLoopMaxIterations(t *testing.T) { + oldRouter := toolRouterState + defer func() { toolRouterState = oldRouter }() + + toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{ + Enabled: true, + Timeout: 1, + MaxTokens: 128, + SystemPrompt: "use tools", + Tools: []ToolRouteConfig{{Name: "time", Enabled: true}}, + }} + toolRouterState.complete = func(context.Context, *OpenAIProfile, model.CreateChatCompletionRequest, time.Duration) (model.ChatCompletionResponse, error) { + return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{ToolCalls: []*model.ToolCall{{ID: "loop", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "time", Arguments: `{}`}}}}}}}, nil + } + + messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "test"}}, []ChatMessage{{Role: "user", Content: "今天"}}, nil) + if err != nil { + t.Fatal(err) + } + last := messages[len(messages)-1] + if last.Role != model.ChatMessageRoleSystem || last.Content == nil || last.Content.StringValue == nil || !strings.Contains(*last.Content.StringValue, "工具调用轮数已达到上限") { + t.Fatalf("unexpected last message: %#v", last) } } diff --git a/templates/chat.html b/templates/chat.html index a750d33..6eb2e98 100644 --- a/templates/chat.html +++ b/templates/chat.html @@ -265,6 +265,30 @@ white-space: normal; } .trace-panel:empty { display: none; } + .reasoning-panel { + display: none; + margin-bottom: 8px; + border: 1px solid var(--border); + border-radius: 8px; + background: var(--surface2); + color: var(--text-dim); + font-size: 0.78rem; + white-space: normal; + overflow: hidden; + } + .reasoning-panel.show { display: block; } + .reasoning-title { + padding: 6px 9px; + border-bottom: 1px solid var(--border); + font-weight: 600; + } + .reasoning-content { + padding: 7px 9px; + white-space: pre-wrap; + max-height: 220px; + overflow-y: auto; + font-family: Consolas, 'Fira Code', monospace; + } .trace-item { border-left: 2px solid var(--accent-border); padding-left: 8px; @@ -859,18 +883,23 @@ function addAIBubble() { const trace = document.createElement('div'); trace.className = 'trace-panel'; + const reasoning = document.createElement('div'); + reasoning.className = 'reasoning-panel'; + reasoning.innerHTML = '
思考过程(模型返回)
'; + const txt = document.createElement('span'); txt.className = 'answer-text'; const stats = document.createElement('div'); stats.className = 'token-stats'; bub.appendChild(trace); + bub.appendChild(reasoning); bub.appendChild(txt); bub.appendChild(stats); row.appendChild(av); row.appendChild(bub); msgBox.appendChild(row); scrollToBottom(); - return { bub, txt, trace, stats }; + return { bub, txt, trace, reasoning, stats }; } function formatTokenStats(stats) { @@ -903,18 +932,24 @@ function appendTrace(aiBubble, frame) { if (!aiBubble.trace) return; const item = document.createElement('div'); item.className = `trace-item ${frame.status || ''}`; + const prefix = [frame.tool, frame.stage].filter(Boolean).join('/'); const label = frame.message || [frame.tool, frame.stage, frame.status].filter(Boolean).join(' '); - item.textContent = label; + item.textContent = prefix ? `${prefix}:${label}` : label; const data = frame.data || {}; const details = []; - if (data.sql) details.push(data.sql); + if (data.arguments) details.push(`参数:\n${data.arguments}`); + if (data.sql) details.push(`SQL:\n${data.sql}`); + if (data.result_preview) details.push(`结果预览:\n${data.result_preview}`); const stats = []; + if (typeof data.iteration === 'number') stats.push(`轮次: ${data.iteration}${data.max_iterations ? '/' + data.max_iterations : ''}`); + if (data.tool_call_id) stats.push(`调用 ID: ${data.tool_call_id}`); if (data.database) stats.push(`数据库: ${data.database}`); if (typeof data.rows === 'number') stats.push(`行数: ${data.rows}`); if (typeof data.columns === 'number') stats.push(`列数: ${data.columns}`); if (typeof data.count === 'number') stats.push(`结果数: ${data.count}`); if (Array.isArray(data.tools) && data.tools.length) stats.push(`工具: ${data.tools.join(', ')}`); + if (typeof data.duration_ms === 'number') stats.push(`耗时: ${data.duration_ms}ms`); if (data.truncated) stats.push(`已截断,最多 ${data.max_rows || ''} 行`); if (data.reason) stats.push(`原因: ${data.reason}`); if (data.error) stats.push(`错误: ${data.error}`); @@ -923,7 +958,7 @@ function appendTrace(aiBubble, frame) { if (details.length) { const detail = document.createElement('div'); detail.className = 'trace-detail'; - detail.textContent = details.join('\n'); + detail.textContent = details.join('\n\n'); item.appendChild(detail); } @@ -931,6 +966,15 @@ function appendTrace(aiBubble, frame) { scrollToBottom(); } +function appendReasoning(aiBubble, text) { + if (!aiBubble.reasoning || !text) return; + aiBubble.reasoning.classList.add('show'); + const content = aiBubble.reasoning.querySelector('.reasoning-content'); + content.textContent += text; + content.scrollTop = content.scrollHeight; + scrollToBottom(); +} + async function streamChat(messages, aiBubble) { const txtEl = aiBubble.txt; let full = ''; @@ -996,6 +1040,10 @@ async function streamChat(messages, aiBubble) { scrollToBottom(); continue; } + if (parsed.type === 'reasoning') { + appendReasoning(aiBubble, parsed.text || ''); + continue; + } if (parsed.type === 'trace') { appendTrace(aiBubble, parsed); continue;