package main import ( "context" "errors" "strings" "testing" "time" ) type fakeChatTool struct { name string description string enabled bool } func (t fakeChatTool) Name() string { return t.name } func (t fakeChatTool) Description() string { return t.description } func (t fakeChatTool) Enabled() bool { return t.enabled } func (t fakeChatTool) Enrich(context.Context, *OpenAIProfile, []ChatMessage, string, func(chatSSEFrame)) ([]ChatMessage, error) { return nil, nil } func TestNormalizeToolRouterConfigDefaults(t *testing.T) { cfg := &Config{ToolRouter: ToolRouterConfig{Enabled: true}} changed, err := normalizeToolRouterConfig(cfg) if err != nil { t.Fatal(err) } if !changed { t.Fatal("expected defaults to change config") } if cfg.ToolRouter.Timeout != defaultToolRouterTimeout { t.Fatalf("timeout = %d", cfg.ToolRouter.Timeout) } if cfg.ToolRouter.MaxTokens != defaultToolRouterMaxTokens { t.Fatalf("max_tokens = %d", cfg.ToolRouter.MaxTokens) } if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) == "" { t.Fatal("system prompt should be defaulted") } if len(cfg.ToolRouter.Tools) != 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[1].Name != "search" || cfg.ToolRouter.Tools[2].Name != "sql" || !cfg.ToolRouter.Tools[0].Enabled || !cfg.ToolRouter.Tools[1].Enabled || !cfg.ToolRouter.Tools[2].Enabled { t.Fatalf("unexpected tools: %#v", cfg.ToolRouter.Tools) } } func TestNormalizeToolRouterConfigAddsTimeBeforeSQL(t *testing.T) { cfg := &Config{ToolRouter: ToolRouterConfig{ Enabled: true, Timeout: 1, MaxTokens: 1, SystemPrompt: "route", Tools: []ToolRouteConfig{ {Name: "search", Enabled: true}, {Name: "sql", Enabled: true}, }, }} changed, err := normalizeToolRouterConfig(cfg) if err != nil { t.Fatal(err) } if !changed { t.Fatal("expected time tool to be added") } if len(cfg.ToolRouter.Tools) < 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[2].Name != "sql" { t.Fatalf("unexpected tool order: %#v", cfg.ToolRouter.Tools) } } func TestNormalizeToolRouterConfigDuplicateTools(t *testing.T) { cfg := &Config{ToolRouter: ToolRouterConfig{ Enabled: true, Timeout: 1, MaxTokens: 1, SystemPrompt: "route", Tools: []ToolRouteConfig{ {Name: "sql", Enabled: true}, {Name: " SQL ", Enabled: true}, }, }} _, err := normalizeToolRouterConfig(cfg) if err == nil { t.Fatal("expected duplicate tool error") } } func TestParseToolRoutingDecision(t *testing.T) { decision, err := parseToolRoutingDecision("```json\n{\"tools\":[{\"name\":\" SQL \",\"reason\":\" 需要查库 \"}],\"reason\":\" 总原因 \"}\n```") if err != nil { t.Fatal(err) } if len(decision.Tools) != 1 || decision.Tools[0].Name != "sql" || decision.Tools[0].Reason != "需要查库" { t.Fatalf("unexpected decision: %#v", decision) } if decision.Reason != "总原因" { t.Fatalf("reason = %q", decision.Reason) } if _, err := parseToolRoutingDecision("not json"); err == nil { t.Fatal("expected malformed JSON error") } } func TestFilterToolSelections(t *testing.T) { tools := map[string]ChatTool{ "time": fakeChatTool{name: "time", enabled: true}, "sql": fakeChatTool{name: "sql", enabled: true}, "search": fakeChatTool{name: "search", enabled: true}, } decision := ToolRoutingDecision{Tools: []ToolSelection{ {Name: "unknown", Reason: "ignore"}, {Name: "search", Reason: "second in config"}, {Name: "sql", Reason: "third in config"}, {Name: "time", Reason: "first in config"}, {Name: "sql", Reason: "duplicate"}, }} selected := filterToolSelections(decision, tools, []ToolRouteConfig{{Name: "time"}, {Name: "search"}, {Name: "sql"}}) if len(selected) != 3 { t.Fatalf("selected length = %d", len(selected)) } if selected[0].Name != "time" || selected[0].Reason != "first in config" { t.Fatalf("first selection = %#v", selected[0]) } if selected[1].Name != "search" { t.Fatalf("second selection = %#v", selected[1]) } if selected[2].Name != "sql" || selected[2].Reason != "third in config" { t.Fatalf("third selection = %#v", selected[2]) } } func TestEnsureTimeSelectionForRelativeSearch(t *testing.T) { tools := map[string]ChatTool{ "time": fakeChatTool{name: "time", enabled: true}, "search": fakeChatTool{name: "search", enabled: true}, } selected := ensureTimeSelectionForRelativeQuery( []ToolSelection{{Name: "search", Reason: "查询历史事件"}}, tools, []ToolRouteConfig{{Name: "time"}, {Name: "search"}, {Name: "sql"}}, "历史上的今天都发生了什么?", ) if len(selected) != 2 || selected[0].Name != "time" || selected[1].Name != "search" { t.Fatalf("unexpected selected tools: %#v", selected) } if !strings.Contains(selected[0].Reason, "相对日期") { t.Fatalf("unexpected time reason: %#v", selected[0]) } } func TestEnsureTimeSelectionSkipsOrdinarySearch(t *testing.T) { tools := map[string]ChatTool{ "time": fakeChatTool{name: "time", enabled: true}, "search": fakeChatTool{name: "search", enabled: true}, } selected := ensureTimeSelectionForRelativeQuery( []ToolSelection{{Name: "search", Reason: "查询资料"}}, tools, []ToolRouteConfig{{Name: "time"}, {Name: "search"}}, "查一下 Go 语言官网", ) if len(selected) != 1 || selected[0].Name != "search" { t.Fatalf("unexpected selected tools: %#v", selected) } } func TestRunTimeToolAddsHiddenDateRanges(t *testing.T) { messages := []ChatMessage{{Role: "user", Content: "本月有什么日程安排"}} withTime, err := runTimeTool(context.Background(), messages, "需要日期范围", func(chatSSEFrame) {}) if err != nil { t.Fatal(err) } if len(withTime) != 2 || !withTime[0].Hidden || withTime[0].Role != "system" { t.Fatalf("unexpected messages: %#v", withTime) } for _, want := range []string{"时间工具结果", "本月", "start=", "end_exclusive=", "半开区间"} { if !strings.Contains(withTime[0].Content, want) { t.Fatalf("time context missing %q:\n%s", want, withTime[0].Content) } } } func TestBuildToolRouterPrompt(t *testing.T) { cfg := &ToolRouterConfig{ SystemPrompt: "router", Tools: []ToolRouteConfig{ {Name: "time", Enabled: true}, {Name: "sql", Enabled: true, Description: "configured sql"}, {Name: "search", Enabled: true}, }, } tools := map[string]ChatTool{ "time": fakeChatTool{name: "time", description: "fallback time", enabled: true}, "sql": fakeChatTool{name: "sql", description: "fallback sql", enabled: true}, "search": fakeChatTool{name: "search", description: "fallback search", enabled: true}, } prompt := buildToolRouterPrompt(cfg, []ChatMessage{{Role: "user", Content: "查一下订单"}}, tools) for _, want := range []string{"router", "name: time", "fallback time", "name: sql", "configured sql", "name: search", "fallback search", "最新用户问题:查一下订单"} { if !strings.Contains(prompt, want) { t.Fatalf("prompt missing %q:\n%s", want, prompt) } } } func TestRouteToolsUsesConfiguredRouterProfileAndTimeout(t *testing.T) { ai := &OpenAIState{ profiles: map[string]*OpenAIProfile{ "chat": {Config: OpenAIConfig{Name: "chat"}}, "router": {Config: OpenAIConfig{Name: "router"}}, }, order: []string{"chat", "router"}, activeName: "chat", } state := &ToolRouterState{cfg: &ToolRouterConfig{ OpenAIName: "router", Timeout: 7, MaxTokens: 123, SystemPrompt: "router prompt", Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}}, }, ai: ai} state.complete = func(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, maxTokens int, timeout time.Duration) (string, error) { if profile.Config.Name != "router" { t.Fatalf("profile = %s", profile.Config.Name) } if maxTokens != 123 { t.Fatalf("maxTokens = %d", maxTokens) } if timeout != 7*time.Second { t.Fatalf("timeout = %s", timeout) } return `{"tools":[],"reason":"无需工具"}`, nil } decision, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}}) if err != nil { t.Fatal(err) } if len(decision.Tools) != 0 || decision.Reason != "无需工具" { t.Fatalf("unexpected decision: %#v", decision) } } func TestRouteToolsCompleterError(t *testing.T) { ai := &OpenAIState{profiles: map[string]*OpenAIProfile{"chat": {Config: OpenAIConfig{Name: "chat"}}}, activeName: "chat"} state := &ToolRouterState{cfg: &ToolRouterConfig{Timeout: 1, MaxTokens: 1, SystemPrompt: "router", Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}}}, ai: ai} state.complete = func(context.Context, *OpenAIProfile, []ChatMessage, int, time.Duration) (string, error) { return "", errors.New("boom") } _, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}}) if err == nil { t.Fatal("expected completer error") } }