365 lines
14 KiB
Go
365 lines
14 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
|
)
|
|
|
|
func TestNormalizeToolRouterConfigDefaults(t *testing.T) {
|
|
cfg := &Config{ToolRouter: ToolRouterConfig{Enabled: true}}
|
|
changed, err := normalizeToolRouterConfig(cfg)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !changed {
|
|
t.Fatal("expected defaults to change config")
|
|
}
|
|
if cfg.ToolRouter.Timeout != defaultToolRouterTimeout {
|
|
t.Fatalf("timeout = %d", cfg.ToolRouter.Timeout)
|
|
}
|
|
if cfg.ToolRouter.MaxTokens != defaultToolRouterMaxTokens {
|
|
t.Fatalf("max_tokens = %d", cfg.ToolRouter.MaxTokens)
|
|
}
|
|
if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) == "" {
|
|
t.Fatal("system prompt should be defaulted")
|
|
}
|
|
if len(cfg.ToolRouter.Tools) != 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[1].Name != "search" || cfg.ToolRouter.Tools[2].Name != "sql" || !cfg.ToolRouter.Tools[0].Enabled || !cfg.ToolRouter.Tools[1].Enabled || !cfg.ToolRouter.Tools[2].Enabled {
|
|
t.Fatalf("unexpected tools: %#v", cfg.ToolRouter.Tools)
|
|
}
|
|
}
|
|
|
|
func TestNormalizeToolRouterConfigAddsTimeBeforeSQL(t *testing.T) {
|
|
cfg := &Config{ToolRouter: ToolRouterConfig{
|
|
Enabled: true,
|
|
Timeout: 1,
|
|
MaxTokens: 1,
|
|
SystemPrompt: "tools",
|
|
Tools: []ToolRouteConfig{
|
|
{Name: "search", Enabled: true},
|
|
{Name: "sql", Enabled: true},
|
|
},
|
|
}}
|
|
changed, err := normalizeToolRouterConfig(cfg)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !changed {
|
|
t.Fatal("expected time tool to be added")
|
|
}
|
|
if len(cfg.ToolRouter.Tools) < 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[2].Name != "sql" {
|
|
t.Fatalf("unexpected tool order: %#v", cfg.ToolRouter.Tools)
|
|
}
|
|
}
|
|
|
|
func TestNormalizeToolRouterConfigDuplicateTools(t *testing.T) {
|
|
cfg := &Config{ToolRouter: ToolRouterConfig{
|
|
Enabled: true,
|
|
Timeout: 1,
|
|
MaxTokens: 1,
|
|
SystemPrompt: "tools",
|
|
Tools: []ToolRouteConfig{
|
|
{Name: "sql", Enabled: true},
|
|
{Name: " SQL ", Enabled: true},
|
|
},
|
|
}}
|
|
_, err := normalizeToolRouterConfig(cfg)
|
|
if err == nil {
|
|
t.Fatal("expected duplicate tool error")
|
|
}
|
|
}
|
|
|
|
func TestAvailableAgentToolsUsesConfigOrderAndEnabled(t *testing.T) {
|
|
oldRouter := toolRouterState
|
|
oldSearch := searchState
|
|
oldSQL := sqlState
|
|
defer func() {
|
|
toolRouterState = oldRouter
|
|
searchState = oldSearch
|
|
sqlState = oldSQL
|
|
}()
|
|
|
|
toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
|
|
Enabled: true,
|
|
Tools: []ToolRouteConfig{
|
|
{Name: "search", Enabled: true},
|
|
{Name: "time", Enabled: true, Description: "custom time"},
|
|
{Name: "sql", Enabled: false},
|
|
},
|
|
}}
|
|
searchState = nil
|
|
sqlState = nil
|
|
|
|
tools := availableAgentTools(&OpenAIProfile{}, nil)
|
|
if len(tools) != 1 {
|
|
t.Fatalf("tools length = %d", len(tools))
|
|
}
|
|
if tools[0].name != "time" {
|
|
t.Fatalf("tool name = %s", tools[0].name)
|
|
}
|
|
if tools[0].definition.Function == nil || tools[0].definition.Function.Description != "custom time" {
|
|
t.Fatalf("unexpected definition: %#v", tools[0].definition)
|
|
}
|
|
}
|
|
|
|
func TestRunAgentToolLoopAppendsToolMessages(t *testing.T) {
|
|
oldRouter := toolRouterState
|
|
defer func() { toolRouterState = oldRouter }()
|
|
|
|
calls := 0
|
|
toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
|
|
Enabled: true,
|
|
Timeout: 1,
|
|
MaxTokens: 128,
|
|
SystemPrompt: "use tools",
|
|
Tools: []ToolRouteConfig{{Name: "time", Enabled: true}},
|
|
}}
|
|
toolRouterState.complete = func(ctx context.Context, profile *OpenAIProfile, req model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) {
|
|
calls++
|
|
if req.ToolChoice != model.ToolChoiceStringTypeAuto {
|
|
t.Fatalf("tool choice = %#v", req.ToolChoice)
|
|
}
|
|
if len(req.Tools) != 1 || req.Tools[0].Function == nil || req.Tools[0].Function.Name != "time" {
|
|
t.Fatalf("unexpected tools: %#v", req.Tools)
|
|
}
|
|
if calls == 1 {
|
|
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{ToolCalls: []*model.ToolCall{{ID: "call_1", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "time", Arguments: `{"reason":"需要当前日期"}`}}}}}}}, nil
|
|
}
|
|
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{Content: stringContent("done")}}}}, nil
|
|
}
|
|
|
|
messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "test"}}, []ChatMessage{{Role: "user", Content: "今天几号"}}, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if calls != 2 {
|
|
t.Fatalf("calls = %d", calls)
|
|
}
|
|
if len(messages) < 4 {
|
|
t.Fatalf("expected system/user/assistant/tool messages, got %d", len(messages))
|
|
}
|
|
last := messages[len(messages)-1]
|
|
if last.Role != model.ChatMessageRoleTool || last.ToolCallID != "call_1" {
|
|
t.Fatalf("unexpected last message: %#v", last)
|
|
}
|
|
if last.Content == nil || last.Content.StringValue == nil || !strings.Contains(*last.Content.StringValue, "时间工具结果") {
|
|
t.Fatalf("unexpected tool content: %#v", last.Content)
|
|
}
|
|
}
|
|
|
|
func TestExecuteAgentToolCallUnknownAndError(t *testing.T) {
|
|
unknown := executeAgentToolCall(context.Background(), &model.ToolCall{ID: "1", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "missing"}}, map[string]agentTool{}, nil)
|
|
if !strings.Contains(unknown, "未知工具") {
|
|
t.Fatalf("unknown result = %q", unknown)
|
|
}
|
|
|
|
failed := executeAgentToolCall(context.Background(), &model.ToolCall{ID: "2", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "boom"}}, map[string]agentTool{
|
|
"boom": {name: "boom", execute: func(context.Context, string) (string, error) { return "", errors.New("bad args") }},
|
|
}, nil)
|
|
if !strings.Contains(failed, "bad args") {
|
|
t.Fatalf("failed result = %q", failed)
|
|
}
|
|
}
|
|
|
|
func TestRunAgentToolLoopMaxIterations(t *testing.T) {
|
|
oldRouter := toolRouterState
|
|
defer func() { toolRouterState = oldRouter }()
|
|
|
|
toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
|
|
Enabled: true,
|
|
Timeout: 1,
|
|
MaxTokens: 128,
|
|
SystemPrompt: "use tools",
|
|
Tools: []ToolRouteConfig{{Name: "time", Enabled: true}},
|
|
}}
|
|
toolRouterState.complete = func(context.Context, *OpenAIProfile, model.CreateChatCompletionRequest, time.Duration) (model.ChatCompletionResponse, error) {
|
|
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{ToolCalls: []*model.ToolCall{{ID: "loop", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "time", Arguments: `{}`}}}}}}}, nil
|
|
}
|
|
|
|
messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "test"}}, []ChatMessage{{Role: "user", Content: "今天"}}, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
last := messages[len(messages)-1]
|
|
if last.Role != model.ChatMessageRoleSystem || last.Content == nil || last.Content.StringValue == nil || !strings.Contains(*last.Content.StringValue, "工具调用轮数已达到上限") {
|
|
t.Fatalf("unexpected last message: %#v", last)
|
|
}
|
|
}
|
|
|
|
func TestBuildArkMessageImageTextOrder(t *testing.T) {
|
|
msg, err := buildArkMessage(ChatMessage{Role: "user", Content: "请描述图片", ImageURL: "data:image/png;base64,aGVsbG8="})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if msg.Content == nil || len(msg.Content.ListValue) != 2 {
|
|
t.Fatalf("unexpected content: %#v", msg.Content)
|
|
}
|
|
if msg.Content.ListValue[0].Type != model.ChatCompletionMessageContentPartTypeText || msg.Content.ListValue[0].Text != "请描述图片" {
|
|
t.Fatalf("first part should be text: %#v", msg.Content.ListValue[0])
|
|
}
|
|
if msg.Content.ListValue[1].Type != model.ChatCompletionMessageContentPartTypeImageURL || msg.Content.ListValue[1].ImageURL == nil {
|
|
t.Fatalf("second part should be image: %#v", msg.Content.ListValue[1])
|
|
}
|
|
}
|
|
|
|
func TestBuildArkMessageImageOnly(t *testing.T) {
|
|
msg, err := buildArkMessage(ChatMessage{Role: "user", ImageURL: "data:image/png;base64,aGVsbG8="})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if msg.Content == nil || len(msg.Content.ListValue) != 1 || msg.Content.ListValue[0].Type != model.ChatCompletionMessageContentPartTypeImageURL {
|
|
t.Fatalf("unexpected content: %#v", msg.Content)
|
|
}
|
|
}
|
|
|
|
func TestThinkTagParserSingleChunk(t *testing.T) {
|
|
parser := &thinkTagParser{}
|
|
visible, reasoning := parser.Accept("hello <think>abc</think> world")
|
|
flushVisible, flushReasoning := parser.Flush()
|
|
visible += flushVisible
|
|
reasoning += flushReasoning
|
|
if visible != "hello world" || reasoning != "abc" {
|
|
t.Fatalf("visible=%q reasoning=%q", visible, reasoning)
|
|
}
|
|
}
|
|
|
|
func TestThinkTagParserAcrossChunks(t *testing.T) {
|
|
parser := &thinkTagParser{}
|
|
var visible, reasoning string
|
|
for _, chunk := range []string{"hello <thi", "nk>abc</thi", "nk> world"} {
|
|
v, r := parser.Accept(chunk)
|
|
visible += v
|
|
reasoning += r
|
|
}
|
|
v, r := parser.Flush()
|
|
visible += v
|
|
reasoning += r
|
|
if visible != "hello world" || reasoning != "abc" {
|
|
t.Fatalf("visible=%q reasoning=%q", visible, reasoning)
|
|
}
|
|
}
|
|
|
|
func TestThinkTagParserUnclosedThink(t *testing.T) {
|
|
parser := &thinkTagParser{}
|
|
visible, reasoning := parser.Accept("answer <think>still thinking")
|
|
v, r := parser.Flush()
|
|
visible += v
|
|
reasoning += r
|
|
if visible != "answer " || reasoning != "still thinking" {
|
|
t.Fatalf("visible=%q reasoning=%q", visible, reasoning)
|
|
}
|
|
}
|
|
|
|
func TestShouldParseThinkTags(t *testing.T) {
|
|
if !shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: "http://127.0.0.1:11434/v1"}}) {
|
|
t.Fatal("expected local ollama to parse think tags")
|
|
}
|
|
if shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: defaultOpenAIBaseURL}}) {
|
|
t.Fatal("expected remote profile not to parse think tags by default")
|
|
}
|
|
falseValue := false
|
|
if shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: "http://127.0.0.1:11434/v1", ParseThinkTags: &falseValue}}) {
|
|
t.Fatal("explicit false should disable think parsing")
|
|
}
|
|
trueValue := true
|
|
if !shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: defaultOpenAIBaseURL, ParseThinkTags: &trueValue}}) {
|
|
t.Fatal("explicit true should enable think parsing")
|
|
}
|
|
}
|
|
|
|
func TestBuildToolDecisionMessagesRemovesImages(t *testing.T) {
|
|
messages, err := buildToolDecisionMessages([]ChatMessage{{Role: "user", Content: "描述这张图", ImageURL: "data:image/png;base64,aGVsbG8="}})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(messages) != 1 || messages[0].Content == nil || messages[0].Content.StringValue == nil {
|
|
t.Fatalf("unexpected messages: %#v", messages)
|
|
}
|
|
if !strings.Contains(*messages[0].Content.StringValue, "工具判断阶段不读取图片内容") {
|
|
t.Fatalf("missing image placeholder: %q", *messages[0].Content.StringValue)
|
|
}
|
|
if messages[0].Content.ListValue != nil {
|
|
t.Fatalf("decision message should be text-only: %#v", messages[0].Content)
|
|
}
|
|
}
|
|
|
|
func TestRunAgentToolLoopImageUsesTextOnlyDecisionMessages(t *testing.T) {
|
|
oldRouter := toolRouterState
|
|
defer func() { toolRouterState = oldRouter }()
|
|
|
|
toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
|
|
Enabled: true,
|
|
Timeout: 1,
|
|
MaxTokens: 128,
|
|
SystemPrompt: "use tools",
|
|
Tools: []ToolRouteConfig{{Name: "time", Enabled: true}},
|
|
}}
|
|
toolRouterState.complete = func(ctx context.Context, profile *OpenAIProfile, req model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) {
|
|
for _, msg := range req.Messages {
|
|
if msg.Content != nil && len(msg.Content.ListValue) > 0 {
|
|
t.Fatalf("tool decision should not receive multimodal content: %#v", msg.Content)
|
|
}
|
|
}
|
|
joined := ""
|
|
for _, msg := range req.Messages {
|
|
if msg.Content != nil && msg.Content.StringValue != nil {
|
|
joined += *msg.Content.StringValue
|
|
}
|
|
}
|
|
if !strings.Contains(joined, "工具判断阶段不读取图片内容") {
|
|
t.Fatalf("missing placeholder in decision messages: %q", joined)
|
|
}
|
|
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{Content: stringContent("no tool")}}}}, nil
|
|
}
|
|
|
|
messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "chat"}}, []ChatMessage{{Role: "user", Content: "描述这张图", ImageURL: "data:image/png;base64,aGVsbG8="}}, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
foundImage := false
|
|
for _, msg := range messages {
|
|
if msg.Content != nil && len(msg.Content.ListValue) > 0 {
|
|
foundImage = true
|
|
}
|
|
}
|
|
if !foundImage {
|
|
t.Fatalf("final messages should retain image: %#v", messages)
|
|
}
|
|
}
|
|
|
|
func TestRunAgentToolLoopUsesConfiguredRouterProfile(t *testing.T) {
|
|
oldRouter := toolRouterState
|
|
defer func() { toolRouterState = oldRouter }()
|
|
|
|
ai, err := NewOpenAIState([]OpenAIConfig{
|
|
{Name: "chat", APIKey: "key", BaseURL: defaultOpenAIBaseURL, Model: "chat-model", Timeout: 1, Active: true},
|
|
{Name: "router", APIKey: "key", BaseURL: defaultOpenAIBaseURL, Model: "router-model", Timeout: 1},
|
|
})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
toolRouterState = &ToolRouterState{ai: ai, cfg: &ToolRouterConfig{
|
|
Enabled: true,
|
|
OpenAIName: "router",
|
|
Timeout: 1,
|
|
MaxTokens: 128,
|
|
SystemPrompt: "use tools",
|
|
Tools: []ToolRouteConfig{{Name: "time", Enabled: true}},
|
|
}}
|
|
toolRouterState.complete = func(ctx context.Context, profile *OpenAIProfile, req model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) {
|
|
if profile.Config.Name != "router" || req.Model != "router-model" {
|
|
t.Fatalf("router profile not used: profile=%s model=%s", profile.Config.Name, req.Model)
|
|
}
|
|
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{Content: stringContent("no tool")}}}}, nil
|
|
}
|
|
|
|
_, err = runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Name: "chat", Model: "chat-model"}}, []ChatMessage{{Role: "user", Content: "今天"}}, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|