更新工具链
This commit is contained in:
+101
-162
@@ -6,21 +6,10 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
type fakeChatTool struct {
|
||||
name string
|
||||
description string
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func (t fakeChatTool) Name() string { return t.name }
|
||||
func (t fakeChatTool) Description() string { return t.description }
|
||||
func (t fakeChatTool) Enabled() bool { return t.enabled }
|
||||
func (t fakeChatTool) Enrich(context.Context, *OpenAIProfile, []ChatMessage, string, func(chatSSEFrame)) ([]ChatMessage, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestNormalizeToolRouterConfigDefaults(t *testing.T) {
|
||||
cfg := &Config{ToolRouter: ToolRouterConfig{Enabled: true}}
|
||||
changed, err := normalizeToolRouterConfig(cfg)
|
||||
@@ -49,7 +38,7 @@ func TestNormalizeToolRouterConfigAddsTimeBeforeSQL(t *testing.T) {
|
||||
Enabled: true,
|
||||
Timeout: 1,
|
||||
MaxTokens: 1,
|
||||
SystemPrompt: "route",
|
||||
SystemPrompt: "tools",
|
||||
Tools: []ToolRouteConfig{
|
||||
{Name: "search", Enabled: true},
|
||||
{Name: "sql", Enabled: true},
|
||||
@@ -72,7 +61,7 @@ func TestNormalizeToolRouterConfigDuplicateTools(t *testing.T) {
|
||||
Enabled: true,
|
||||
Timeout: 1,
|
||||
MaxTokens: 1,
|
||||
SystemPrompt: "route",
|
||||
SystemPrompt: "tools",
|
||||
Tools: []ToolRouteConfig{
|
||||
{Name: "sql", Enabled: true},
|
||||
{Name: " SQL ", Enabled: true},
|
||||
@@ -84,169 +73,119 @@ func TestNormalizeToolRouterConfigDuplicateTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolRoutingDecision(t *testing.T) {
|
||||
decision, err := parseToolRoutingDecision("```json\n{\"tools\":[{\"name\":\" SQL \",\"reason\":\" 需要查库 \"}],\"reason\":\" 总原因 \"}\n```")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(decision.Tools) != 1 || decision.Tools[0].Name != "sql" || decision.Tools[0].Reason != "需要查库" {
|
||||
t.Fatalf("unexpected decision: %#v", decision)
|
||||
}
|
||||
if decision.Reason != "总原因" {
|
||||
t.Fatalf("reason = %q", decision.Reason)
|
||||
}
|
||||
func TestAvailableAgentToolsUsesConfigOrderAndEnabled(t *testing.T) {
|
||||
oldRouter := toolRouterState
|
||||
oldSearch := searchState
|
||||
oldSQL := sqlState
|
||||
defer func() {
|
||||
toolRouterState = oldRouter
|
||||
searchState = oldSearch
|
||||
sqlState = oldSQL
|
||||
}()
|
||||
|
||||
if _, err := parseToolRoutingDecision("not json"); err == nil {
|
||||
t.Fatal("expected malformed JSON error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterToolSelections(t *testing.T) {
|
||||
tools := map[string]ChatTool{
|
||||
"time": fakeChatTool{name: "time", enabled: true},
|
||||
"sql": fakeChatTool{name: "sql", enabled: true},
|
||||
"search": fakeChatTool{name: "search", enabled: true},
|
||||
}
|
||||
decision := ToolRoutingDecision{Tools: []ToolSelection{
|
||||
{Name: "unknown", Reason: "ignore"},
|
||||
{Name: "search", Reason: "second in config"},
|
||||
{Name: "sql", Reason: "third in config"},
|
||||
{Name: "time", Reason: "first in config"},
|
||||
{Name: "sql", Reason: "duplicate"},
|
||||
}}
|
||||
selected := filterToolSelections(decision, tools, []ToolRouteConfig{{Name: "time"}, {Name: "search"}, {Name: "sql"}})
|
||||
if len(selected) != 3 {
|
||||
t.Fatalf("selected length = %d", len(selected))
|
||||
}
|
||||
if selected[0].Name != "time" || selected[0].Reason != "first in config" {
|
||||
t.Fatalf("first selection = %#v", selected[0])
|
||||
}
|
||||
if selected[1].Name != "search" {
|
||||
t.Fatalf("second selection = %#v", selected[1])
|
||||
}
|
||||
if selected[2].Name != "sql" || selected[2].Reason != "third in config" {
|
||||
t.Fatalf("third selection = %#v", selected[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureTimeSelectionForRelativeSearch(t *testing.T) {
|
||||
tools := map[string]ChatTool{
|
||||
"time": fakeChatTool{name: "time", enabled: true},
|
||||
"search": fakeChatTool{name: "search", enabled: true},
|
||||
}
|
||||
selected := ensureTimeSelectionForRelativeQuery(
|
||||
[]ToolSelection{{Name: "search", Reason: "查询历史事件"}},
|
||||
tools,
|
||||
[]ToolRouteConfig{{Name: "time"}, {Name: "search"}, {Name: "sql"}},
|
||||
"历史上的今天都发生了什么?",
|
||||
)
|
||||
if len(selected) != 2 || selected[0].Name != "time" || selected[1].Name != "search" {
|
||||
t.Fatalf("unexpected selected tools: %#v", selected)
|
||||
}
|
||||
if !strings.Contains(selected[0].Reason, "相对日期") {
|
||||
t.Fatalf("unexpected time reason: %#v", selected[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureTimeSelectionSkipsOrdinarySearch(t *testing.T) {
|
||||
tools := map[string]ChatTool{
|
||||
"time": fakeChatTool{name: "time", enabled: true},
|
||||
"search": fakeChatTool{name: "search", enabled: true},
|
||||
}
|
||||
selected := ensureTimeSelectionForRelativeQuery(
|
||||
[]ToolSelection{{Name: "search", Reason: "查询资料"}},
|
||||
tools,
|
||||
[]ToolRouteConfig{{Name: "time"}, {Name: "search"}},
|
||||
"查一下 Go 语言官网",
|
||||
)
|
||||
if len(selected) != 1 || selected[0].Name != "search" {
|
||||
t.Fatalf("unexpected selected tools: %#v", selected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunTimeToolAddsHiddenDateRanges(t *testing.T) {
|
||||
messages := []ChatMessage{{Role: "user", Content: "本月有什么日程安排"}}
|
||||
withTime, err := runTimeTool(context.Background(), messages, "需要日期范围", func(chatSSEFrame) {})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(withTime) != 2 || !withTime[0].Hidden || withTime[0].Role != "system" {
|
||||
t.Fatalf("unexpected messages: %#v", withTime)
|
||||
}
|
||||
for _, want := range []string{"时间工具结果", "本月", "start=", "end_exclusive=", "半开区间"} {
|
||||
if !strings.Contains(withTime[0].Content, want) {
|
||||
t.Fatalf("time context missing %q:\n%s", want, withTime[0].Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolRouterPrompt(t *testing.T) {
|
||||
cfg := &ToolRouterConfig{
|
||||
SystemPrompt: "router",
|
||||
toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
|
||||
Enabled: true,
|
||||
Tools: []ToolRouteConfig{
|
||||
{Name: "time", Enabled: true},
|
||||
{Name: "sql", Enabled: true, Description: "configured sql"},
|
||||
{Name: "search", Enabled: true},
|
||||
{Name: "time", Enabled: true, Description: "custom time"},
|
||||
{Name: "sql", Enabled: false},
|
||||
},
|
||||
}}
|
||||
searchState = nil
|
||||
sqlState = nil
|
||||
|
||||
tools := availableAgentTools(&OpenAIProfile{}, nil)
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("tools length = %d", len(tools))
|
||||
}
|
||||
tools := map[string]ChatTool{
|
||||
"time": fakeChatTool{name: "time", description: "fallback time", enabled: true},
|
||||
"sql": fakeChatTool{name: "sql", description: "fallback sql", enabled: true},
|
||||
"search": fakeChatTool{name: "search", description: "fallback search", enabled: true},
|
||||
if tools[0].name != "time" {
|
||||
t.Fatalf("tool name = %s", tools[0].name)
|
||||
}
|
||||
prompt := buildToolRouterPrompt(cfg, []ChatMessage{{Role: "user", Content: "查一下订单"}}, tools)
|
||||
for _, want := range []string{"router", "name: time", "fallback time", "name: sql", "configured sql", "name: search", "fallback search", "最新用户问题:查一下订单"} {
|
||||
if !strings.Contains(prompt, want) {
|
||||
t.Fatalf("prompt missing %q:\n%s", want, prompt)
|
||||
}
|
||||
if tools[0].definition.Function == nil || tools[0].definition.Function.Description != "custom time" {
|
||||
t.Fatalf("unexpected definition: %#v", tools[0].definition)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteToolsUsesConfiguredRouterProfileAndTimeout(t *testing.T) {
|
||||
ai := &OpenAIState{
|
||||
profiles: map[string]*OpenAIProfile{
|
||||
"chat": {Config: OpenAIConfig{Name: "chat"}},
|
||||
"router": {Config: OpenAIConfig{Name: "router"}},
|
||||
},
|
||||
order: []string{"chat", "router"},
|
||||
activeName: "chat",
|
||||
func TestRunAgentToolLoopAppendsToolMessages(t *testing.T) {
|
||||
oldRouter := toolRouterState
|
||||
defer func() { toolRouterState = oldRouter }()
|
||||
|
||||
calls := 0
|
||||
toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
|
||||
Enabled: true,
|
||||
Timeout: 1,
|
||||
MaxTokens: 128,
|
||||
SystemPrompt: "use tools",
|
||||
Tools: []ToolRouteConfig{{Name: "time", Enabled: true}},
|
||||
}}
|
||||
toolRouterState.complete = func(ctx context.Context, profile *OpenAIProfile, req model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) {
|
||||
calls++
|
||||
if req.ToolChoice != model.ToolChoiceStringTypeAuto {
|
||||
t.Fatalf("tool choice = %#v", req.ToolChoice)
|
||||
}
|
||||
if len(req.Tools) != 1 || req.Tools[0].Function == nil || req.Tools[0].Function.Name != "time" {
|
||||
t.Fatalf("unexpected tools: %#v", req.Tools)
|
||||
}
|
||||
if calls == 1 {
|
||||
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{ToolCalls: []*model.ToolCall{{ID: "call_1", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "time", Arguments: `{"reason":"需要当前日期"}`}}}}}}}, nil
|
||||
}
|
||||
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{Content: stringContent("done")}}}}, nil
|
||||
}
|
||||
state := &ToolRouterState{cfg: &ToolRouterConfig{
|
||||
OpenAIName: "router",
|
||||
Timeout: 7,
|
||||
MaxTokens: 123,
|
||||
SystemPrompt: "router prompt",
|
||||
Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}},
|
||||
}, ai: ai}
|
||||
state.complete = func(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, maxTokens int, timeout time.Duration) (string, error) {
|
||||
if profile.Config.Name != "router" {
|
||||
t.Fatalf("profile = %s", profile.Config.Name)
|
||||
}
|
||||
if maxTokens != 123 {
|
||||
t.Fatalf("maxTokens = %d", maxTokens)
|
||||
}
|
||||
if timeout != 7*time.Second {
|
||||
t.Fatalf("timeout = %s", timeout)
|
||||
}
|
||||
return `{"tools":[],"reason":"无需工具"}`, nil
|
||||
}
|
||||
decision, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}})
|
||||
|
||||
messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "test"}}, []ChatMessage{{Role: "user", Content: "今天几号"}}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(decision.Tools) != 0 || decision.Reason != "无需工具" {
|
||||
t.Fatalf("unexpected decision: %#v", decision)
|
||||
if calls != 2 {
|
||||
t.Fatalf("calls = %d", calls)
|
||||
}
|
||||
if len(messages) < 4 {
|
||||
t.Fatalf("expected system/user/assistant/tool messages, got %d", len(messages))
|
||||
}
|
||||
last := messages[len(messages)-1]
|
||||
if last.Role != model.ChatMessageRoleTool || last.ToolCallID != "call_1" {
|
||||
t.Fatalf("unexpected last message: %#v", last)
|
||||
}
|
||||
if last.Content == nil || last.Content.StringValue == nil || !strings.Contains(*last.Content.StringValue, "时间工具结果") {
|
||||
t.Fatalf("unexpected tool content: %#v", last.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteToolsCompleterError(t *testing.T) {
|
||||
ai := &OpenAIState{profiles: map[string]*OpenAIProfile{"chat": {Config: OpenAIConfig{Name: "chat"}}}, activeName: "chat"}
|
||||
state := &ToolRouterState{cfg: &ToolRouterConfig{Timeout: 1, MaxTokens: 1, SystemPrompt: "router", Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}}}, ai: ai}
|
||||
state.complete = func(context.Context, *OpenAIProfile, []ChatMessage, int, time.Duration) (string, error) {
|
||||
return "", errors.New("boom")
|
||||
func TestExecuteAgentToolCallUnknownAndError(t *testing.T) {
|
||||
unknown := executeAgentToolCall(context.Background(), &model.ToolCall{ID: "1", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "missing"}}, map[string]agentTool{}, nil)
|
||||
if !strings.Contains(unknown, "未知工具") {
|
||||
t.Fatalf("unknown result = %q", unknown)
|
||||
}
|
||||
_, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}})
|
||||
if err == nil {
|
||||
t.Fatal("expected completer error")
|
||||
|
||||
failed := executeAgentToolCall(context.Background(), &model.ToolCall{ID: "2", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "boom"}}, map[string]agentTool{
|
||||
"boom": {name: "boom", execute: func(context.Context, string) (string, error) { return "", errors.New("bad args") }},
|
||||
}, nil)
|
||||
if !strings.Contains(failed, "bad args") {
|
||||
t.Fatalf("failed result = %q", failed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunAgentToolLoopMaxIterations(t *testing.T) {
|
||||
oldRouter := toolRouterState
|
||||
defer func() { toolRouterState = oldRouter }()
|
||||
|
||||
toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
|
||||
Enabled: true,
|
||||
Timeout: 1,
|
||||
MaxTokens: 128,
|
||||
SystemPrompt: "use tools",
|
||||
Tools: []ToolRouteConfig{{Name: "time", Enabled: true}},
|
||||
}}
|
||||
toolRouterState.complete = func(context.Context, *OpenAIProfile, model.CreateChatCompletionRequest, time.Duration) (model.ChatCompletionResponse, error) {
|
||||
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{ToolCalls: []*model.ToolCall{{ID: "loop", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "time", Arguments: `{}`}}}}}}}, nil
|
||||
}
|
||||
|
||||
messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "test"}}, []ChatMessage{{Role: "user", Content: "今天"}}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
last := messages[len(messages)-1]
|
||||
if last.Role != model.ChatMessageRoleSystem || last.Content == nil || last.Content.StringValue == nil || !strings.Contains(*last.Content.StringValue, "工具调用轮数已达到上限") {
|
||||
t.Fatalf("unexpected last message: %#v", last)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user