diff --git a/main.go b/main.go index fe77e8a..d8e9c09 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,8 @@ package main import ( + "bufio" + "bytes" "context" "crypto/rand" "encoding/base64" @@ -46,12 +48,13 @@ const ( ) type OpenAIConfig struct { - Name string `yaml:"name" json:"name"` - Active bool `yaml:"active,omitempty" json:"active"` - APIKey string `yaml:"api_key" json:"-"` - BaseURL string `yaml:"base_url" json:"base_url"` - Model string `yaml:"model" json:"model"` - Timeout int `yaml:"timeout" json:"timeout"` + Name string `yaml:"name" json:"name"` + Active bool `yaml:"active,omitempty" json:"active"` + APIKey string `yaml:"api_key" json:"-"` + BaseURL string `yaml:"base_url" json:"base_url"` + Model string `yaml:"model" json:"model"` + Timeout int `yaml:"timeout" json:"timeout"` + ParseThinkTags *bool `yaml:"parse_think_tags,omitempty" json:"parse_think_tags,omitempty"` } type OpenAIConfigs []OpenAIConfig @@ -559,6 +562,44 @@ func publicOpenAIConfig(profile *OpenAIProfile, active bool) OpenAIConfig { return config } +func (s *ToolRouterState) RouterProfile(fallback *OpenAIProfile) *OpenAIProfile { + if s == nil || s.cfg == nil || s.ai == nil { + return fallback + } + name := strings.TrimSpace(s.cfg.OpenAIName) + if name == "" { + return fallback + } + profile, err := s.ai.GetProfile(name) + if err != nil { + return fallback + } + return profile +} + +func isOllamaProfile(profile *OpenAIProfile) bool { + if profile == nil { + return false + } + u, err := url.Parse(strings.TrimSpace(profile.Config.BaseURL)) + if err != nil { + return strings.Contains(profile.Config.BaseURL, ":11434") + } + host := strings.ToLower(u.Hostname()) + port := u.Port() + return port == "11434" && (host == "127.0.0.1" || host == "localhost" || host == "::1") +} + +func shouldParseThinkTags(profile *OpenAIProfile) bool { + if profile == nil { + return false + } + if profile.Config.ParseThinkTags != nil { + return *profile.Config.ParseThinkTags + } + return isOllamaProfile(profile) +} + // ─── 全局变量 ───────────────────────────────────────────── var ( @@ -810,6 +851,21 @@ func chatHandler(c *gin.Context) { } promptTokens := estimateChatMessagesTokens(req.Messages) + if isOllamaProfile(profile) && hasImageMessage(req.Messages) { + emitTrace("model", "request", "running", "正在通过 Ollama 原生接口调用视觉模型", nil) + err = streamOllamaChat(ctx, profile, messages, promptTokens, usage, emit, func(content string) { + if req.ConversationID != "" { + if err := saveConversationMessages(req.ConversationID, req.Messages, content); err != nil { + fmt.Fprintln(os.Stderr, "保存对话失败:", err) + } + } + }) + if err != nil { + emitError(err) + } + return + } + emitTrace("model", "request", "running", "正在调用模型生成回答", nil) stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{ Model: profile.Config.Model, @@ -829,9 +885,56 @@ func chatHandler(c *gin.Context) { windowStarted := streamStarted windowTokens := 0 peakTokensPerSecond := 0.0 + parseThinkTags := shouldParseThinkTags(profile) + thinkParser := &thinkTagParser{} + emitDelta := func(delta string) { + if delta == "" { + return + } + now := time.Now() + deltaTokens := estimateTokenCount(delta) + windowTokens += deltaTokens + windowElapsed := now.Sub(windowStarted).Seconds() + if windowElapsed >= 1 { + windowSpeed := float64(windowTokens) / windowElapsed + if windowSpeed > peakTokensPerSecond { + peakTokensPerSecond = windowSpeed + } + windowStarted = now + windowTokens = 0 + } else if peakTokensPerSecond == 0 && windowElapsed > 0.25 { + peakTokensPerSecond = float64(windowTokens) / windowElapsed + } + full.WriteString(delta) + completionTokens += deltaTokens + usage.setModel(promptTokens, completionTokens) + stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond) + emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats}) + } + emitModelContent := func(delta string) { + if delta == "" { + return + } + if !parseThinkTags { + emitDelta(delta) + return + } + visible, reasoning := thinkParser.Accept(delta) + if reasoning != "" { + emit(chatSSEFrame{Type: "reasoning", Text: reasoning}) + } + emitDelta(visible) + } for { resp, err := stream.Recv() if errors.Is(err, io.EOF) { + if parseThinkTags { + visible, reasoning := thinkParser.Flush() + if reasoning != "" { + emit(chatSSEFrame{Type: "reasoning", Text: reasoning}) + } + emitDelta(visible) + } usage.setModel(promptTokens, completionTokens) if windowTokens > 0 { windowElapsed := time.Since(windowStarted).Seconds() @@ -862,28 +965,7 @@ func chatHandler(c *gin.Context) { return } if len(resp.Choices) > 0 { - delta := resp.Choices[0].Delta.Content - if delta != "" { - now := time.Now() - deltaTokens := estimateTokenCount(delta) - windowTokens += deltaTokens - windowElapsed := now.Sub(windowStarted).Seconds() - if windowElapsed >= 1 { - windowSpeed := float64(windowTokens) / windowElapsed - if windowSpeed > peakTokensPerSecond { - peakTokensPerSecond = windowSpeed - } - windowStarted = now - windowTokens = 0 - } else if peakTokensPerSecond == 0 && windowElapsed > 0.25 { - peakTokensPerSecond = float64(windowTokens) / windowElapsed - } - full.WriteString(delta) - completionTokens += deltaTokens - usage.setModel(promptTokens, completionTokens) - stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond) - emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats}) - } + emitModelContent(resp.Choices[0].Delta.Content) // 思考过程 reasoning_content 单独事件推送 if resp.Choices[0].Delta.ReasoningContent != nil && *resp.Choices[0].Delta.ReasoningContent != "" { emit(chatSSEFrame{Type: "reasoning", Text: *resp.Choices[0].Delta.ReasoningContent}) @@ -1035,13 +1117,27 @@ func availableAgentTools(profile *OpenAIProfile, emit func(chatSSEFrame)) []agen } func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, emit func(chatSSEFrame)) ([]*model.ChatCompletionMessage, error) { - messages, err := buildArkMessages(chatMessages) + finalMessages, err := buildArkMessages(chatMessages) if err != nil { return nil, err } - tools := availableAgentTools(profile, emit) + routerProfile := profile + if toolRouterState != nil { + routerProfile = toolRouterState.RouterProfile(profile) + } + tools := availableAgentTools(routerProfile, emit) if len(tools) == 0 { - return messages, nil + return finalMessages, nil + } + decisionMessages := append([]*model.ChatCompletionMessage(nil), finalMessages...) + if hasImageMessage(chatMessages) { + decisionMessages, err = buildToolDecisionMessages(chatMessages) + if err != nil { + return nil, err + } + if emit != nil { + emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "prepare", Status: "success", Message: "检测到图片输入,工具判断阶段将使用纯文本上下文"}) + } } toolByName := make(map[string]agentTool, len(tools)) definitions := make([]*model.Tool, 0, len(tools)) @@ -1059,28 +1155,30 @@ func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "prepare", Status: "success", Message: "已准备可用工具", Data: map[string]any{"tools": availableNames, "tool_descriptions": toolDescriptions}}) } if prompt := strings.TrimSpace(toolRouterState.cfg.SystemPrompt); prompt != "" { - messages = append([]*model.ChatCompletionMessage{{Role: model.ChatMessageRoleSystem, Content: stringContent(prompt)}}, messages...) + systemMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent(prompt)} + finalMessages = append([]*model.ChatCompletionMessage{systemMessage}, finalMessages...) + decisionMessages = append([]*model.ChatCompletionMessage{systemMessage}, decisionMessages...) } for i := 0; i < maxAgentToolIterations; i++ { if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "running", Message: fmt.Sprintf("正在进行第 %d 轮工具判断", i+1), Data: map[string]any{"iteration": i + 1, "max_iterations": maxAgentToolIterations, "tools": availableNames}}) } - resp, err := toolRouterState.complete(ctx, profile, model.CreateChatCompletionRequest{ - Model: profile.Config.Model, - Messages: messages, + resp, err := toolRouterState.complete(ctx, routerProfile, model.CreateChatCompletionRequest{ + Model: routerProfile.Config.Model, + Messages: decisionMessages, MaxTokens: intPtr(toolRouterState.cfg.MaxTokens), Tools: definitions, ToolChoice: model.ToolChoiceStringTypeAuto, ParallelToolCalls: boolPtr(false), }, time.Duration(toolRouterState.cfg.Timeout)*time.Second) if err != nil { - return messages, err + return finalMessages, err } if tracker := tokenUsageFromContext(ctx); tracker != nil { tracker.addTool(resp.Usage.PromptTokens, resp.Usage.CompletionTokens) } if len(resp.Choices) == 0 { - return messages, nil + return finalMessages, nil } choice := resp.Choices[0] decisionPreview := chatMessageContentString(choice.Message.Content) @@ -1095,7 +1193,7 @@ func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "success", Message: "模型未请求工具,进入回答生成"}) } - return messages, nil + return finalMessages, nil } callNames := make([]string, 0, len(calls)) for _, call := range calls { @@ -1106,14 +1204,89 @@ func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "tool_calls", Status: "running", Message: fmt.Sprintf("模型请求调用 %d 个工具", len(calls)), Data: map[string]any{"tools": callNames, "iteration": i + 1}}) } - messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleAssistant, ToolCalls: calls, Content: choice.Message.Content}) + assistantMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleAssistant, ToolCalls: calls, Content: choice.Message.Content} + finalMessages = append(finalMessages, assistantMessage) + decisionMessages = append(decisionMessages, assistantMessage) for _, call := range calls { result := executeAgentToolCall(ctx, call, toolByName, emit) - messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleTool, ToolCallID: call.ID, Content: stringContent(result)}) + toolMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleTool, ToolCallID: call.ID, Content: stringContent(result)} + finalMessages = append(finalMessages, toolMessage) + decisionMessages = append(decisionMessages, toolMessage) } } - messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent("工具调用轮数已达到上限。请基于已有工具结果回答,并说明可能未完成全部工具调用。")}) - return messages, nil + limitMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent("工具调用轮数已达到上限。请基于已有工具结果回答,并说明可能未完成全部工具调用。")} + finalMessages = append(finalMessages, limitMessage) + return finalMessages, nil +} + +type thinkTagParser struct { + inThink bool + buffer string +} + +const ( + thinkOpenTag = "" + thinkCloseTag = "" +) + +func (p *thinkTagParser) Accept(delta string) (visible string, reasoning string) { + p.buffer += delta + for p.buffer != "" { + if p.inThink { + idx := strings.Index(p.buffer, thinkCloseTag) + if idx >= 0 { + reasoning += p.buffer[:idx] + p.buffer = p.buffer[idx+len(thinkCloseTag):] + p.inThink = false + continue + } + keep := tagPrefixSuffixLen(p.buffer, thinkCloseTag) + if len(p.buffer) > keep { + reasoning += p.buffer[:len(p.buffer)-keep] + p.buffer = p.buffer[len(p.buffer)-keep:] + } + return visible, reasoning + } + + idx := strings.Index(p.buffer, thinkOpenTag) + if idx >= 0 { + visible += p.buffer[:idx] + p.buffer = p.buffer[idx+len(thinkOpenTag):] + p.inThink = true + continue + } + keep := tagPrefixSuffixLen(p.buffer, thinkOpenTag) + if len(p.buffer) > keep { + visible += p.buffer[:len(p.buffer)-keep] + p.buffer = p.buffer[len(p.buffer)-keep:] + } + return visible, reasoning + } + return visible, reasoning +} + +func (p *thinkTagParser) Flush() (visible string, reasoning string) { + if p.inThink { + reasoning = p.buffer + } else { + visible = p.buffer + } + p.buffer = "" + p.inThink = false + return visible, reasoning +} + +func tagPrefixSuffixLen(text, tag string) int { + limit := len(tag) - 1 + if len(text) < limit { + limit = len(text) + } + for i := limit; i > 0; i-- { + if strings.HasPrefix(tag, text[len(text)-i:]) { + return i + } + } + return 0 } func executeAgentToolCall(ctx context.Context, call *model.ToolCall, tools map[string]agentTool, emit func(chatSSEFrame)) string { @@ -1155,6 +1328,223 @@ func executeAgentToolCall(ctx context.Context, call *model.ToolCall, tools map[s return result } +type ollamaChatRequest struct { + Model string `json:"model"` + Messages []ollamaChatMessage `json:"messages"` + Stream bool `json:"stream"` + Options map[string]int `json:"options,omitempty"` +} + +type ollamaChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` + Images []string `json:"images,omitempty"` +} + +type ollamaChatResponse struct { + Message struct { + Role string `json:"role"` + Content string `json:"content"` + Thinking string `json:"thinking"` + } `json:"message"` + Done bool `json:"done"` + PromptEvalCount int `json:"prompt_eval_count"` + EvalCount int `json:"eval_count"` + DoneReason string `json:"done_reason"` +} + +func streamOllamaChat(ctx context.Context, profile *OpenAIProfile, messages []*model.ChatCompletionMessage, promptTokens int, usage *tokenUsageTracker, emit func(chatSSEFrame), onDone func(string)) error { + requestMessages, err := buildOllamaMessages(messages) + if err != nil { + return err + } + baseURL, err := ollamaBaseURL(profile) + if err != nil { + return err + } + body, err := json.Marshal(ollamaChatRequest{ + Model: profile.Config.Model, + Messages: requestMessages, + Stream: true, + Options: map[string]int{"num_predict": 4096}, + }) + if err != nil { + return err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(baseURL, "/")+"/api/chat", bytes.NewReader(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return fmt.Errorf("Ollama 原生接口调用失败: %s %s", resp.Status, strings.TrimSpace(string(data))) + } + + emit(chatSSEFrame{Type: "trace", Tool: "model", Stage: "stream", Status: "running", Message: "Ollama 视觉模型已开始输出"}) + parseThinkTags := shouldParseThinkTags(profile) + thinkParser := &thinkTagParser{} + var full strings.Builder + completionTokens := 0 + streamStarted := time.Now() + peakTokensPerSecond := 0.0 + emitDelta := func(delta string) { + if delta == "" { + return + } + full.WriteString(delta) + completionTokens += estimateTokenCount(delta) + usage.setModel(promptTokens, completionTokens) + currentSpeed := tokensPerSecond(completionTokens, streamStarted) + if currentSpeed > peakTokensPerSecond { + peakTokensPerSecond = currentSpeed + } + stats := usage.snapshot(currentSpeed, peakTokensPerSecond) + emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats}) + } + emitContent := func(delta string) { + if delta == "" { + return + } + if !parseThinkTags { + emitDelta(delta) + return + } + visible, reasoning := thinkParser.Accept(delta) + if reasoning != "" { + emit(chatSSEFrame{Type: "reasoning", Text: reasoning}) + } + emitDelta(visible) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 10*1024*1024) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + var chunk ollamaChatResponse + if err := json.Unmarshal([]byte(line), &chunk); err != nil { + return fmt.Errorf("解析 Ollama 流失败: %w", err) + } + if chunk.Message.Thinking != "" { + emit(chatSSEFrame{Type: "reasoning", Text: chunk.Message.Thinking}) + } + emitContent(chunk.Message.Content) + if chunk.Done { + if chunk.PromptEvalCount > 0 || chunk.EvalCount > 0 { + usage.setModel(chunk.PromptEvalCount, chunk.EvalCount) + } + break + } + } + if err := scanner.Err(); err != nil { + return err + } + if parseThinkTags { + visible, reasoning := thinkParser.Flush() + if reasoning != "" { + emit(chatSSEFrame{Type: "reasoning", Text: reasoning}) + } + emitDelta(visible) + } + if onDone != nil { + onDone(full.String()) + } + finalStats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond) + emit(chatSSEFrame{Type: "stats", Stats: &finalStats}) + emit(chatSSEFrame{Type: "trace", Tool: "model", Stage: "stream", Status: "success", Message: "回答生成完成"}) + return nil +} + +func buildOllamaMessages(messages []*model.ChatCompletionMessage) ([]ollamaChatMessage, error) { + result := make([]ollamaChatMessage, 0, len(messages)) + for _, msg := range messages { + if msg == nil { + continue + } + role := string(msg.Role) + if msg.Role == model.ChatMessageRoleTool { + role = string(model.ChatMessageRoleUser) + } + item := ollamaChatMessage{Role: role} + if msg.Content == nil { + if len(msg.ToolCalls) > 0 { + continue + } + result = append(result, item) + continue + } + if msg.Content.StringValue != nil { + item.Content = *msg.Content.StringValue + if msg.Role == model.ChatMessageRoleTool { + item.Content = "工具结果:\n" + item.Content + } + result = append(result, item) + continue + } + for _, part := range msg.Content.ListValue { + if part == nil { + continue + } + switch part.Type { + case model.ChatCompletionMessageContentPartTypeText: + if part.Text != "" { + if item.Content != "" { + item.Content += "\n" + } + item.Content += part.Text + } + case model.ChatCompletionMessageContentPartTypeImageURL: + if part.ImageURL == nil { + continue + } + image, err := ollamaImagePayload(part.ImageURL.URL) + if err != nil { + return nil, err + } + item.Images = append(item.Images, image) + } + } + result = append(result, item) + } + return result, nil +} + +func ollamaImagePayload(raw string) (string, error) { + raw = strings.TrimSpace(raw) + if strings.HasPrefix(strings.ToLower(raw), "data:") { + comma := strings.Index(raw, ",") + if comma < 0 { + return "", errors.New("图片 base64 数据格式错误") + } + return strings.TrimSpace(raw[comma+1:]), nil + } + return raw, nil +} + +func ollamaBaseURL(profile *OpenAIProfile) (string, error) { + if profile == nil { + return "", errors.New("Ollama 配置为空") + } + u, err := url.Parse(strings.TrimSpace(profile.Config.BaseURL)) + if err != nil { + return "", err + } + if strings.TrimRight(u.Path, "/") == "/v1" { + u.Path = strings.TrimSuffix(strings.TrimRight(u.Path, "/"), "/v1") + } + u.RawQuery = "" + u.Fragment = "" + return strings.TrimRight(u.String(), "/"), nil +} + func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) { return completeTextWithTimeout(ctx, profile, chatMessages, maxTokens, time.Duration(profile.Config.Timeout)*time.Second) } @@ -1178,10 +1568,23 @@ func completeTextWithTimeout(ctx context.Context, profile *OpenAIProfile, chatMe promptTokens := estimateChatMessagesTokens(chatMessages) completionTokens := 0 + parseThinkTags := shouldParseThinkTags(profile) + thinkParser := &thinkTagParser{} var b strings.Builder + appendVisible := func(delta string) { + if delta == "" { + return + } + b.WriteString(delta) + completionTokens += estimateTokenCount(delta) + } for { resp, err := stream.Recv() if errors.Is(err, io.EOF) { + if parseThinkTags { + visible, _ := thinkParser.Flush() + appendVisible(visible) + } if tracker := tokenUsageFromContext(ctx); tracker != nil { tracker.addTool(promptTokens, completionTokens) } @@ -1192,8 +1595,12 @@ func completeTextWithTimeout(ctx context.Context, profile *OpenAIProfile, chatMe } if len(resp.Choices) > 0 { delta := resp.Choices[0].Delta.Content - b.WriteString(delta) - completionTokens += estimateTokenCount(delta) + if parseThinkTags { + visible, _ := thinkParser.Accept(delta) + appendVisible(visible) + } else { + appendVisible(delta) + } } } } @@ -1368,6 +1775,33 @@ func buildArkMessages(chatMessages []ChatMessage) ([]*model.ChatCompletionMessag return messages, nil } +func hasImageMessage(messages []ChatMessage) bool { + for _, msg := range messages { + if strings.TrimSpace(msg.ImageURL) != "" || strings.TrimSpace(msg.ImageURLAlias) != "" { + return true + } + } + return false +} + +func buildToolDecisionMessages(chatMessages []ChatMessage) ([]*model.ChatCompletionMessage, error) { + messages := make([]*model.ChatCompletionMessage, 0, len(chatMessages)) + for _, m := range chatMessages { + content := m.Content + if strings.TrimSpace(m.ImageURL) != "" || strings.TrimSpace(m.ImageURLAlias) != "" { + content = strings.TrimSpace(content) + placeholder := "[用户上传了一张图片。工具判断阶段不读取图片内容;如果问题主要依赖识图,应不要调用工具,交给最终多模态模型回答。]" + if content == "" { + content = placeholder + } else { + content += "\n\n" + placeholder + } + } + messages = append(messages, &model.ChatCompletionMessage{Role: m.Role, Content: stringContent(content)}) + } + return messages, nil +} + func buildArkMessage(m ChatMessage) (*model.ChatCompletionMessage, error) { msg := &model.ChatCompletionMessage{Role: m.Role} @@ -1388,11 +1822,12 @@ func buildArkMessage(m ChatMessage) (*model.ChatCompletionMessage, error) { } // 有图片时:文字内容可有可无(图片 caption 场景),均构造多模态消息 - // 若无文字,则只传图片 part;若同时有图片和文字,先图后文 - parts := []*model.ChatCompletionMessageContentPart{imagePart(imageURL)} + // 若无文字,则只传图片 part;若同时有图片和文字,先文后图 + parts := make([]*model.ChatCompletionMessageContentPart, 0, 2) if m.Content != "" { parts = append(parts, textPart(m.Content)) } + parts = append(parts, imagePart(imageURL)) msg.Content = &model.ChatCompletionMessageContent{ListValue: parts} return msg, nil } diff --git a/main_test.go b/main_test.go index cefebeb..2ee35fb 100644 --- a/main_test.go +++ b/main_test.go @@ -189,3 +189,176 @@ func TestRunAgentToolLoopMaxIterations(t *testing.T) { t.Fatalf("unexpected last message: %#v", last) } } + +func TestBuildArkMessageImageTextOrder(t *testing.T) { + msg, err := buildArkMessage(ChatMessage{Role: "user", Content: "请描述图片", ImageURL: "data:image/png;base64,aGVsbG8="}) + if err != nil { + t.Fatal(err) + } + if msg.Content == nil || len(msg.Content.ListValue) != 2 { + t.Fatalf("unexpected content: %#v", msg.Content) + } + if msg.Content.ListValue[0].Type != model.ChatCompletionMessageContentPartTypeText || msg.Content.ListValue[0].Text != "请描述图片" { + t.Fatalf("first part should be text: %#v", msg.Content.ListValue[0]) + } + if msg.Content.ListValue[1].Type != model.ChatCompletionMessageContentPartTypeImageURL || msg.Content.ListValue[1].ImageURL == nil { + t.Fatalf("second part should be image: %#v", msg.Content.ListValue[1]) + } +} + +func TestBuildArkMessageImageOnly(t *testing.T) { + msg, err := buildArkMessage(ChatMessage{Role: "user", ImageURL: "data:image/png;base64,aGVsbG8="}) + if err != nil { + t.Fatal(err) + } + if msg.Content == nil || len(msg.Content.ListValue) != 1 || msg.Content.ListValue[0].Type != model.ChatCompletionMessageContentPartTypeImageURL { + t.Fatalf("unexpected content: %#v", msg.Content) + } +} + +func TestThinkTagParserSingleChunk(t *testing.T) { + parser := &thinkTagParser{} + visible, reasoning := parser.Accept("hello abc world") + flushVisible, flushReasoning := parser.Flush() + visible += flushVisible + reasoning += flushReasoning + if visible != "hello world" || reasoning != "abc" { + t.Fatalf("visible=%q reasoning=%q", visible, reasoning) + } +} + +func TestThinkTagParserAcrossChunks(t *testing.T) { + parser := &thinkTagParser{} + var visible, reasoning string + for _, chunk := range []string{"hello abc world"} { + v, r := parser.Accept(chunk) + visible += v + reasoning += r + } + v, r := parser.Flush() + visible += v + reasoning += r + if visible != "hello world" || reasoning != "abc" { + t.Fatalf("visible=%q reasoning=%q", visible, reasoning) + } +} + +func TestThinkTagParserUnclosedThink(t *testing.T) { + parser := &thinkTagParser{} + visible, reasoning := parser.Accept("answer still thinking") + v, r := parser.Flush() + visible += v + reasoning += r + if visible != "answer " || reasoning != "still thinking" { + t.Fatalf("visible=%q reasoning=%q", visible, reasoning) + } +} + +func TestShouldParseThinkTags(t *testing.T) { + if !shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: "http://127.0.0.1:11434/v1"}}) { + t.Fatal("expected local ollama to parse think tags") + } + if shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: defaultOpenAIBaseURL}}) { + t.Fatal("expected remote profile not to parse think tags by default") + } + falseValue := false + if shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: "http://127.0.0.1:11434/v1", ParseThinkTags: &falseValue}}) { + t.Fatal("explicit false should disable think parsing") + } + trueValue := true + if !shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: defaultOpenAIBaseURL, ParseThinkTags: &trueValue}}) { + t.Fatal("explicit true should enable think parsing") + } +} + +func TestBuildToolDecisionMessagesRemovesImages(t *testing.T) { + messages, err := buildToolDecisionMessages([]ChatMessage{{Role: "user", Content: "描述这张图", ImageURL: "data:image/png;base64,aGVsbG8="}}) + if err != nil { + t.Fatal(err) + } + if len(messages) != 1 || messages[0].Content == nil || messages[0].Content.StringValue == nil { + t.Fatalf("unexpected messages: %#v", messages) + } + if !strings.Contains(*messages[0].Content.StringValue, "工具判断阶段不读取图片内容") { + t.Fatalf("missing image placeholder: %q", *messages[0].Content.StringValue) + } + if messages[0].Content.ListValue != nil { + t.Fatalf("decision message should be text-only: %#v", messages[0].Content) + } +} + +func TestRunAgentToolLoopImageUsesTextOnlyDecisionMessages(t *testing.T) { + oldRouter := toolRouterState + defer func() { toolRouterState = oldRouter }() + + toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{ + Enabled: true, + Timeout: 1, + MaxTokens: 128, + SystemPrompt: "use tools", + Tools: []ToolRouteConfig{{Name: "time", Enabled: true}}, + }} + toolRouterState.complete = func(ctx context.Context, profile *OpenAIProfile, req model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) { + for _, msg := range req.Messages { + if msg.Content != nil && len(msg.Content.ListValue) > 0 { + t.Fatalf("tool decision should not receive multimodal content: %#v", msg.Content) + } + } + joined := "" + for _, msg := range req.Messages { + if msg.Content != nil && msg.Content.StringValue != nil { + joined += *msg.Content.StringValue + } + } + if !strings.Contains(joined, "工具判断阶段不读取图片内容") { + t.Fatalf("missing placeholder in decision messages: %q", joined) + } + return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{Content: stringContent("no tool")}}}}, nil + } + + messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "chat"}}, []ChatMessage{{Role: "user", Content: "描述这张图", ImageURL: "data:image/png;base64,aGVsbG8="}}, nil) + if err != nil { + t.Fatal(err) + } + foundImage := false + for _, msg := range messages { + if msg.Content != nil && len(msg.Content.ListValue) > 0 { + foundImage = true + } + } + if !foundImage { + t.Fatalf("final messages should retain image: %#v", messages) + } +} + +func TestRunAgentToolLoopUsesConfiguredRouterProfile(t *testing.T) { + oldRouter := toolRouterState + defer func() { toolRouterState = oldRouter }() + + ai, err := NewOpenAIState([]OpenAIConfig{ + {Name: "chat", APIKey: "key", BaseURL: defaultOpenAIBaseURL, Model: "chat-model", Timeout: 1, Active: true}, + {Name: "router", APIKey: "key", BaseURL: defaultOpenAIBaseURL, Model: "router-model", Timeout: 1}, + }) + if err != nil { + t.Fatal(err) + } + toolRouterState = &ToolRouterState{ai: ai, cfg: &ToolRouterConfig{ + Enabled: true, + OpenAIName: "router", + Timeout: 1, + MaxTokens: 128, + SystemPrompt: "use tools", + Tools: []ToolRouteConfig{{Name: "time", Enabled: true}}, + }} + toolRouterState.complete = func(ctx context.Context, profile *OpenAIProfile, req model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) { + if profile.Config.Name != "router" || req.Model != "router-model" { + t.Fatalf("router profile not used: profile=%s model=%s", profile.Config.Name, req.Model) + } + return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{Content: stringContent("no tool")}}}}, nil + } + + _, err = runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Name: "chat", Model: "chat-model"}}, []ChatMessage{{Role: "user", Content: "今天"}}, nil) + if err != nil { + t.Fatal(err) + } +}