Files
aichat/main_test.go
T
2026-06-11 18:04:47 +08:00

192 lines
6.9 KiB
Go

package main
import (
"context"
"errors"
"strings"
"testing"
"time"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
)
func TestNormalizeToolRouterConfigDefaults(t *testing.T) {
cfg := &Config{ToolRouter: ToolRouterConfig{Enabled: true}}
changed, err := normalizeToolRouterConfig(cfg)
if err != nil {
t.Fatal(err)
}
if !changed {
t.Fatal("expected defaults to change config")
}
if cfg.ToolRouter.Timeout != defaultToolRouterTimeout {
t.Fatalf("timeout = %d", cfg.ToolRouter.Timeout)
}
if cfg.ToolRouter.MaxTokens != defaultToolRouterMaxTokens {
t.Fatalf("max_tokens = %d", cfg.ToolRouter.MaxTokens)
}
if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) == "" {
t.Fatal("system prompt should be defaulted")
}
if len(cfg.ToolRouter.Tools) != 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[1].Name != "search" || cfg.ToolRouter.Tools[2].Name != "sql" || !cfg.ToolRouter.Tools[0].Enabled || !cfg.ToolRouter.Tools[1].Enabled || !cfg.ToolRouter.Tools[2].Enabled {
t.Fatalf("unexpected tools: %#v", cfg.ToolRouter.Tools)
}
}
func TestNormalizeToolRouterConfigAddsTimeBeforeSQL(t *testing.T) {
cfg := &Config{ToolRouter: ToolRouterConfig{
Enabled: true,
Timeout: 1,
MaxTokens: 1,
SystemPrompt: "tools",
Tools: []ToolRouteConfig{
{Name: "search", Enabled: true},
{Name: "sql", Enabled: true},
},
}}
changed, err := normalizeToolRouterConfig(cfg)
if err != nil {
t.Fatal(err)
}
if !changed {
t.Fatal("expected time tool to be added")
}
if len(cfg.ToolRouter.Tools) < 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[2].Name != "sql" {
t.Fatalf("unexpected tool order: %#v", cfg.ToolRouter.Tools)
}
}
func TestNormalizeToolRouterConfigDuplicateTools(t *testing.T) {
cfg := &Config{ToolRouter: ToolRouterConfig{
Enabled: true,
Timeout: 1,
MaxTokens: 1,
SystemPrompt: "tools",
Tools: []ToolRouteConfig{
{Name: "sql", Enabled: true},
{Name: " SQL ", Enabled: true},
},
}}
_, err := normalizeToolRouterConfig(cfg)
if err == nil {
t.Fatal("expected duplicate tool error")
}
}
func TestAvailableAgentToolsUsesConfigOrderAndEnabled(t *testing.T) {
oldRouter := toolRouterState
oldSearch := searchState
oldSQL := sqlState
defer func() {
toolRouterState = oldRouter
searchState = oldSearch
sqlState = oldSQL
}()
toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
Enabled: true,
Tools: []ToolRouteConfig{
{Name: "search", Enabled: true},
{Name: "time", Enabled: true, Description: "custom time"},
{Name: "sql", Enabled: false},
},
}}
searchState = nil
sqlState = nil
tools := availableAgentTools(&OpenAIProfile{}, nil)
if len(tools) != 1 {
t.Fatalf("tools length = %d", len(tools))
}
if tools[0].name != "time" {
t.Fatalf("tool name = %s", tools[0].name)
}
if tools[0].definition.Function == nil || tools[0].definition.Function.Description != "custom time" {
t.Fatalf("unexpected definition: %#v", tools[0].definition)
}
}
func TestRunAgentToolLoopAppendsToolMessages(t *testing.T) {
oldRouter := toolRouterState
defer func() { toolRouterState = oldRouter }()
calls := 0
toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
Enabled: true,
Timeout: 1,
MaxTokens: 128,
SystemPrompt: "use tools",
Tools: []ToolRouteConfig{{Name: "time", Enabled: true}},
}}
toolRouterState.complete = func(ctx context.Context, profile *OpenAIProfile, req model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) {
calls++
if req.ToolChoice != model.ToolChoiceStringTypeAuto {
t.Fatalf("tool choice = %#v", req.ToolChoice)
}
if len(req.Tools) != 1 || req.Tools[0].Function == nil || req.Tools[0].Function.Name != "time" {
t.Fatalf("unexpected tools: %#v", req.Tools)
}
if calls == 1 {
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{ToolCalls: []*model.ToolCall{{ID: "call_1", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "time", Arguments: `{"reason":"需要当前日期"}`}}}}}}}, nil
}
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{Content: stringContent("done")}}}}, nil
}
messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "test"}}, []ChatMessage{{Role: "user", Content: "今天几号"}}, nil)
if err != nil {
t.Fatal(err)
}
if calls != 2 {
t.Fatalf("calls = %d", calls)
}
if len(messages) < 4 {
t.Fatalf("expected system/user/assistant/tool messages, got %d", len(messages))
}
last := messages[len(messages)-1]
if last.Role != model.ChatMessageRoleTool || last.ToolCallID != "call_1" {
t.Fatalf("unexpected last message: %#v", last)
}
if last.Content == nil || last.Content.StringValue == nil || !strings.Contains(*last.Content.StringValue, "时间工具结果") {
t.Fatalf("unexpected tool content: %#v", last.Content)
}
}
func TestExecuteAgentToolCallUnknownAndError(t *testing.T) {
unknown := executeAgentToolCall(context.Background(), &model.ToolCall{ID: "1", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "missing"}}, map[string]agentTool{}, nil)
if !strings.Contains(unknown, "未知工具") {
t.Fatalf("unknown result = %q", unknown)
}
failed := executeAgentToolCall(context.Background(), &model.ToolCall{ID: "2", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "boom"}}, map[string]agentTool{
"boom": {name: "boom", execute: func(context.Context, string) (string, error) { return "", errors.New("bad args") }},
}, nil)
if !strings.Contains(failed, "bad args") {
t.Fatalf("failed result = %q", failed)
}
}
func TestRunAgentToolLoopMaxIterations(t *testing.T) {
oldRouter := toolRouterState
defer func() { toolRouterState = oldRouter }()
toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
Enabled: true,
Timeout: 1,
MaxTokens: 128,
SystemPrompt: "use tools",
Tools: []ToolRouteConfig{{Name: "time", Enabled: true}},
}}
toolRouterState.complete = func(context.Context, *OpenAIProfile, model.CreateChatCompletionRequest, time.Duration) (model.ChatCompletionResponse, error) {
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{ToolCalls: []*model.ToolCall{{ID: "loop", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "time", Arguments: `{}`}}}}}}}, nil
}
messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "test"}}, []ChatMessage{{Role: "user", Content: "今天"}}, nil)
if err != nil {
t.Fatal(err)
}
last := messages[len(messages)-1]
if last.Role != model.ChatMessageRoleSystem || last.Content == nil || last.Content.StringValue == nil || !strings.Contains(*last.Content.StringValue, "工具调用轮数已达到上限") {
t.Fatalf("unexpected last message: %#v", last)
}
}