支持工具链

This commit is contained in:
2026-06-10 12:07:07 +08:00
parent 1e793ce814
commit fe2477dd97
9 changed files with 1632 additions and 545 deletions
+217
View File
@@ -0,0 +1,217 @@
package main
import (
"context"
"errors"
"strings"
"testing"
"time"
)
type fakeChatTool struct {
name string
description string
enabled bool
}
func (t fakeChatTool) Name() string { return t.name }
func (t fakeChatTool) Description() string { return t.description }
func (t fakeChatTool) Enabled() bool { return t.enabled }
func (t fakeChatTool) Enrich(context.Context, *OpenAIProfile, []ChatMessage, string, func(chatSSEFrame)) ([]ChatMessage, error) {
return nil, nil
}
func TestNormalizeToolRouterConfigDefaults(t *testing.T) {
cfg := &Config{ToolRouter: ToolRouterConfig{Enabled: true}}
changed, err := normalizeToolRouterConfig(cfg)
if err != nil {
t.Fatal(err)
}
if !changed {
t.Fatal("expected defaults to change config")
}
if cfg.ToolRouter.Timeout != defaultToolRouterTimeout {
t.Fatalf("timeout = %d", cfg.ToolRouter.Timeout)
}
if cfg.ToolRouter.MaxTokens != defaultToolRouterMaxTokens {
t.Fatalf("max_tokens = %d", cfg.ToolRouter.MaxTokens)
}
if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) == "" {
t.Fatal("system prompt should be defaulted")
}
if len(cfg.ToolRouter.Tools) != 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[1].Name != "search" || cfg.ToolRouter.Tools[2].Name != "sql" || !cfg.ToolRouter.Tools[0].Enabled || !cfg.ToolRouter.Tools[1].Enabled || !cfg.ToolRouter.Tools[2].Enabled {
t.Fatalf("unexpected tools: %#v", cfg.ToolRouter.Tools)
}
}
func TestNormalizeToolRouterConfigAddsTimeBeforeSQL(t *testing.T) {
cfg := &Config{ToolRouter: ToolRouterConfig{
Enabled: true,
Timeout: 1,
MaxTokens: 1,
SystemPrompt: "route",
Tools: []ToolRouteConfig{
{Name: "search", Enabled: true},
{Name: "sql", Enabled: true},
},
}}
changed, err := normalizeToolRouterConfig(cfg)
if err != nil {
t.Fatal(err)
}
if !changed {
t.Fatal("expected time tool to be added")
}
if len(cfg.ToolRouter.Tools) < 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[2].Name != "sql" {
t.Fatalf("unexpected tool order: %#v", cfg.ToolRouter.Tools)
}
}
func TestNormalizeToolRouterConfigDuplicateTools(t *testing.T) {
cfg := &Config{ToolRouter: ToolRouterConfig{
Enabled: true,
Timeout: 1,
MaxTokens: 1,
SystemPrompt: "route",
Tools: []ToolRouteConfig{
{Name: "sql", Enabled: true},
{Name: " SQL ", Enabled: true},
},
}}
_, err := normalizeToolRouterConfig(cfg)
if err == nil {
t.Fatal("expected duplicate tool error")
}
}
func TestParseToolRoutingDecision(t *testing.T) {
decision, err := parseToolRoutingDecision("```json\n{\"tools\":[{\"name\":\" SQL \",\"reason\":\" 需要查库 \"}],\"reason\":\" 总原因 \"}\n```")
if err != nil {
t.Fatal(err)
}
if len(decision.Tools) != 1 || decision.Tools[0].Name != "sql" || decision.Tools[0].Reason != "需要查库" {
t.Fatalf("unexpected decision: %#v", decision)
}
if decision.Reason != "总原因" {
t.Fatalf("reason = %q", decision.Reason)
}
if _, err := parseToolRoutingDecision("not json"); err == nil {
t.Fatal("expected malformed JSON error")
}
}
func TestFilterToolSelections(t *testing.T) {
tools := map[string]ChatTool{
"time": fakeChatTool{name: "time", enabled: true},
"sql": fakeChatTool{name: "sql", enabled: true},
"search": fakeChatTool{name: "search", enabled: true},
}
decision := ToolRoutingDecision{Tools: []ToolSelection{
{Name: "unknown", Reason: "ignore"},
{Name: "search", Reason: "second in config"},
{Name: "sql", Reason: "third in config"},
{Name: "time", Reason: "first in config"},
{Name: "sql", Reason: "duplicate"},
}}
selected := filterToolSelections(decision, tools, []ToolRouteConfig{{Name: "time"}, {Name: "search"}, {Name: "sql"}})
if len(selected) != 3 {
t.Fatalf("selected length = %d", len(selected))
}
if selected[0].Name != "time" || selected[0].Reason != "first in config" {
t.Fatalf("first selection = %#v", selected[0])
}
if selected[1].Name != "search" {
t.Fatalf("second selection = %#v", selected[1])
}
if selected[2].Name != "sql" || selected[2].Reason != "third in config" {
t.Fatalf("third selection = %#v", selected[2])
}
}
func TestRunTimeToolAddsHiddenDateRanges(t *testing.T) {
messages := []ChatMessage{{Role: "user", Content: "本月有什么日程安排"}}
withTime, err := runTimeTool(context.Background(), messages, "需要日期范围", func(chatSSEFrame) {})
if err != nil {
t.Fatal(err)
}
if len(withTime) != 2 || !withTime[0].Hidden || withTime[0].Role != "system" {
t.Fatalf("unexpected messages: %#v", withTime)
}
for _, want := range []string{"时间工具结果", "本月", "start=", "end_exclusive=", "半开区间"} {
if !strings.Contains(withTime[0].Content, want) {
t.Fatalf("time context missing %q:\n%s", want, withTime[0].Content)
}
}
}
func TestBuildToolRouterPrompt(t *testing.T) {
cfg := &ToolRouterConfig{
SystemPrompt: "router",
Tools: []ToolRouteConfig{
{Name: "time", Enabled: true},
{Name: "sql", Enabled: true, Description: "configured sql"},
{Name: "search", Enabled: true},
},
}
tools := map[string]ChatTool{
"time": fakeChatTool{name: "time", description: "fallback time", enabled: true},
"sql": fakeChatTool{name: "sql", description: "fallback sql", enabled: true},
"search": fakeChatTool{name: "search", description: "fallback search", enabled: true},
}
prompt := buildToolRouterPrompt(cfg, []ChatMessage{{Role: "user", Content: "查一下订单"}}, tools)
for _, want := range []string{"router", "name: time", "fallback time", "name: sql", "configured sql", "name: search", "fallback search", "最新用户问题:查一下订单"} {
if !strings.Contains(prompt, want) {
t.Fatalf("prompt missing %q:\n%s", want, prompt)
}
}
}
func TestRouteToolsUsesConfiguredRouterProfileAndTimeout(t *testing.T) {
ai := &OpenAIState{
profiles: map[string]*OpenAIProfile{
"chat": {Config: OpenAIConfig{Name: "chat"}},
"router": {Config: OpenAIConfig{Name: "router"}},
},
order: []string{"chat", "router"},
activeName: "chat",
}
state := &ToolRouterState{cfg: &ToolRouterConfig{
OpenAIName: "router",
Timeout: 7,
MaxTokens: 123,
SystemPrompt: "router prompt",
Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}},
}, ai: ai}
state.complete = func(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, maxTokens int, timeout time.Duration) (string, error) {
if profile.Config.Name != "router" {
t.Fatalf("profile = %s", profile.Config.Name)
}
if maxTokens != 123 {
t.Fatalf("maxTokens = %d", maxTokens)
}
if timeout != 7*time.Second {
t.Fatalf("timeout = %s", timeout)
}
return `{"tools":[],"reason":"无需工具"}`, nil
}
decision, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}})
if err != nil {
t.Fatal(err)
}
if len(decision.Tools) != 0 || decision.Reason != "无需工具" {
t.Fatalf("unexpected decision: %#v", decision)
}
}
func TestRouteToolsCompleterError(t *testing.T) {
ai := &OpenAIState{profiles: map[string]*OpenAIProfile{"chat": {Config: OpenAIConfig{Name: "chat"}}}, activeName: "chat"}
state := &ToolRouterState{cfg: &ToolRouterConfig{Timeout: 1, MaxTokens: 1, SystemPrompt: "router", Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}}}, ai: ai}
state.complete = func(context.Context, *OpenAIProfile, []ChatMessage, int, time.Duration) (string, error) {
return "", errors.New("boom")
}
_, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}})
if err == nil {
t.Fatal("expected completer error")
}
}