package main import ( "context" "errors" "strings" "testing" "time" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" ) func TestNormalizeToolRouterConfigDefaults(t *testing.T) { cfg := &Config{ToolRouter: ToolRouterConfig{Enabled: true}} changed, err := normalizeToolRouterConfig(cfg) if err != nil { t.Fatal(err) } if !changed { t.Fatal("expected defaults to change config") } if cfg.ToolRouter.Timeout != defaultToolRouterTimeout { t.Fatalf("timeout = %d", cfg.ToolRouter.Timeout) } if cfg.ToolRouter.MaxTokens != defaultToolRouterMaxTokens { t.Fatalf("max_tokens = %d", cfg.ToolRouter.MaxTokens) } if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) == "" { t.Fatal("system prompt should be defaulted") } if len(cfg.ToolRouter.Tools) != 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[1].Name != "search" || cfg.ToolRouter.Tools[2].Name != "sql" || !cfg.ToolRouter.Tools[0].Enabled || !cfg.ToolRouter.Tools[1].Enabled || !cfg.ToolRouter.Tools[2].Enabled { t.Fatalf("unexpected tools: %#v", cfg.ToolRouter.Tools) } } func TestNormalizeToolRouterConfigAddsTimeBeforeSQL(t *testing.T) { cfg := &Config{ToolRouter: ToolRouterConfig{ Enabled: true, Timeout: 1, MaxTokens: 1, SystemPrompt: "tools", Tools: []ToolRouteConfig{ {Name: "search", Enabled: true}, {Name: "sql", Enabled: true}, }, }} changed, err := normalizeToolRouterConfig(cfg) if err != nil { t.Fatal(err) } if !changed { t.Fatal("expected time tool to be added") } if len(cfg.ToolRouter.Tools) < 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[2].Name != "sql" { t.Fatalf("unexpected tool order: %#v", cfg.ToolRouter.Tools) } } func TestNormalizeToolRouterConfigDuplicateTools(t *testing.T) { cfg := &Config{ToolRouter: ToolRouterConfig{ Enabled: true, Timeout: 1, MaxTokens: 1, SystemPrompt: "tools", Tools: []ToolRouteConfig{ {Name: "sql", Enabled: true}, {Name: " SQL ", Enabled: true}, }, }} _, err := normalizeToolRouterConfig(cfg) if err == nil { t.Fatal("expected duplicate tool error") } } func TestAvailableAgentToolsUsesConfigOrderAndEnabled(t *testing.T) { oldRouter := toolRouterState oldSearch := searchState oldSQL := sqlState defer func() { toolRouterState = oldRouter searchState = oldSearch sqlState = oldSQL }() toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{ Enabled: true, Tools: []ToolRouteConfig{ {Name: "search", Enabled: true}, {Name: "time", Enabled: true, Description: "custom time"}, {Name: "sql", Enabled: false}, }, }} searchState = nil sqlState = nil tools := availableAgentTools(&OpenAIProfile{}, nil) if len(tools) != 1 { t.Fatalf("tools length = %d", len(tools)) } if tools[0].name != "time" { t.Fatalf("tool name = %s", tools[0].name) } if tools[0].definition.Function == nil || tools[0].definition.Function.Description != "custom time" { t.Fatalf("unexpected definition: %#v", tools[0].definition) } } func TestRunAgentToolLoopAppendsToolMessages(t *testing.T) { oldRouter := toolRouterState defer func() { toolRouterState = oldRouter }() calls := 0 toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{ Enabled: true, Timeout: 1, MaxTokens: 128, SystemPrompt: "use tools", Tools: []ToolRouteConfig{{Name: "time", Enabled: true}}, }} toolRouterState.complete = func(ctx context.Context, profile *OpenAIProfile, req model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) { calls++ if req.ToolChoice != model.ToolChoiceStringTypeAuto { t.Fatalf("tool choice = %#v", req.ToolChoice) } if len(req.Tools) != 1 || req.Tools[0].Function == nil || req.Tools[0].Function.Name != "time" { t.Fatalf("unexpected tools: %#v", req.Tools) } if calls == 1 { return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{ToolCalls: []*model.ToolCall{{ID: "call_1", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "time", Arguments: `{"reason":"需要当前日期"}`}}}}}}}, nil } return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{Content: stringContent("done")}}}}, nil } messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "test"}}, []ChatMessage{{Role: "user", Content: "今天几号"}}, nil) if err != nil { t.Fatal(err) } if calls != 2 { t.Fatalf("calls = %d", calls) } if len(messages) < 4 { t.Fatalf("expected system/user/assistant/tool messages, got %d", len(messages)) } last := messages[len(messages)-1] if last.Role != model.ChatMessageRoleTool || last.ToolCallID != "call_1" { t.Fatalf("unexpected last message: %#v", last) } if last.Content == nil || last.Content.StringValue == nil || !strings.Contains(*last.Content.StringValue, "时间工具结果") { t.Fatalf("unexpected tool content: %#v", last.Content) } } func TestExecuteAgentToolCallUnknownAndError(t *testing.T) { unknown := executeAgentToolCall(context.Background(), &model.ToolCall{ID: "1", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "missing"}}, map[string]agentTool{}, nil) if !strings.Contains(unknown, "未知工具") { t.Fatalf("unknown result = %q", unknown) } failed := executeAgentToolCall(context.Background(), &model.ToolCall{ID: "2", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "boom"}}, map[string]agentTool{ "boom": {name: "boom", execute: func(context.Context, string) (string, error) { return "", errors.New("bad args") }}, }, nil) if !strings.Contains(failed, "bad args") { t.Fatalf("failed result = %q", failed) } } func TestRunAgentToolLoopMaxIterations(t *testing.T) { oldRouter := toolRouterState defer func() { toolRouterState = oldRouter }() toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{ Enabled: true, Timeout: 1, MaxTokens: 128, SystemPrompt: "use tools", Tools: []ToolRouteConfig{{Name: "time", Enabled: true}}, }} toolRouterState.complete = func(context.Context, *OpenAIProfile, model.CreateChatCompletionRequest, time.Duration) (model.ChatCompletionResponse, error) { return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{ToolCalls: []*model.ToolCall{{ID: "loop", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "time", Arguments: `{}`}}}}}}}, nil } messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "test"}}, []ChatMessage{{Role: "user", Content: "今天"}}, nil) if err != nil { t.Fatal(err) } last := messages[len(messages)-1] if last.Role != model.ChatMessageRoleSystem || last.Content == nil || last.Content.StringValue == nil || !strings.Contains(*last.Content.StringValue, "工具调用轮数已达到上限") { t.Fatalf("unexpected last message: %#v", last) } }