diff --git a/main.go b/main.go
index fe77e8a..d8e9c09 100644
--- a/main.go
+++ b/main.go
@@ -1,6 +1,8 @@
package main
import (
+ "bufio"
+ "bytes"
"context"
"crypto/rand"
"encoding/base64"
@@ -46,12 +48,13 @@ const (
)
type OpenAIConfig struct {
- Name string `yaml:"name" json:"name"`
- Active bool `yaml:"active,omitempty" json:"active"`
- APIKey string `yaml:"api_key" json:"-"`
- BaseURL string `yaml:"base_url" json:"base_url"`
- Model string `yaml:"model" json:"model"`
- Timeout int `yaml:"timeout" json:"timeout"`
+ Name string `yaml:"name" json:"name"`
+ Active bool `yaml:"active,omitempty" json:"active"`
+ APIKey string `yaml:"api_key" json:"-"`
+ BaseURL string `yaml:"base_url" json:"base_url"`
+ Model string `yaml:"model" json:"model"`
+ Timeout int `yaml:"timeout" json:"timeout"`
+ ParseThinkTags *bool `yaml:"parse_think_tags,omitempty" json:"parse_think_tags,omitempty"`
}
type OpenAIConfigs []OpenAIConfig
@@ -559,6 +562,44 @@ func publicOpenAIConfig(profile *OpenAIProfile, active bool) OpenAIConfig {
return config
}
+func (s *ToolRouterState) RouterProfile(fallback *OpenAIProfile) *OpenAIProfile {
+ if s == nil || s.cfg == nil || s.ai == nil {
+ return fallback
+ }
+ name := strings.TrimSpace(s.cfg.OpenAIName)
+ if name == "" {
+ return fallback
+ }
+ profile, err := s.ai.GetProfile(name)
+ if err != nil {
+ return fallback
+ }
+ return profile
+}
+
+func isOllamaProfile(profile *OpenAIProfile) bool {
+ if profile == nil {
+ return false
+ }
+ u, err := url.Parse(strings.TrimSpace(profile.Config.BaseURL))
+ if err != nil {
+ return strings.Contains(profile.Config.BaseURL, ":11434")
+ }
+ host := strings.ToLower(u.Hostname())
+ port := u.Port()
+ return port == "11434" && (host == "127.0.0.1" || host == "localhost" || host == "::1")
+}
+
+func shouldParseThinkTags(profile *OpenAIProfile) bool {
+ if profile == nil {
+ return false
+ }
+ if profile.Config.ParseThinkTags != nil {
+ return *profile.Config.ParseThinkTags
+ }
+ return isOllamaProfile(profile)
+}
+
// ─── 全局变量 ─────────────────────────────────────────────
var (
@@ -810,6 +851,21 @@ func chatHandler(c *gin.Context) {
}
promptTokens := estimateChatMessagesTokens(req.Messages)
+ if isOllamaProfile(profile) && hasImageMessage(req.Messages) {
+ emitTrace("model", "request", "running", "正在通过 Ollama 原生接口调用视觉模型", nil)
+ err = streamOllamaChat(ctx, profile, messages, promptTokens, usage, emit, func(content string) {
+ if req.ConversationID != "" {
+ if err := saveConversationMessages(req.ConversationID, req.Messages, content); err != nil {
+ fmt.Fprintln(os.Stderr, "保存对话失败:", err)
+ }
+ }
+ })
+ if err != nil {
+ emitError(err)
+ }
+ return
+ }
+
emitTrace("model", "request", "running", "正在调用模型生成回答", nil)
stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
Model: profile.Config.Model,
@@ -829,9 +885,56 @@ func chatHandler(c *gin.Context) {
windowStarted := streamStarted
windowTokens := 0
peakTokensPerSecond := 0.0
+ parseThinkTags := shouldParseThinkTags(profile)
+ thinkParser := &thinkTagParser{}
+ emitDelta := func(delta string) {
+ if delta == "" {
+ return
+ }
+ now := time.Now()
+ deltaTokens := estimateTokenCount(delta)
+ windowTokens += deltaTokens
+ windowElapsed := now.Sub(windowStarted).Seconds()
+ if windowElapsed >= 1 {
+ windowSpeed := float64(windowTokens) / windowElapsed
+ if windowSpeed > peakTokensPerSecond {
+ peakTokensPerSecond = windowSpeed
+ }
+ windowStarted = now
+ windowTokens = 0
+ } else if peakTokensPerSecond == 0 && windowElapsed > 0.25 {
+ peakTokensPerSecond = float64(windowTokens) / windowElapsed
+ }
+ full.WriteString(delta)
+ completionTokens += deltaTokens
+ usage.setModel(promptTokens, completionTokens)
+ stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
+ emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats})
+ }
+ emitModelContent := func(delta string) {
+ if delta == "" {
+ return
+ }
+ if !parseThinkTags {
+ emitDelta(delta)
+ return
+ }
+ visible, reasoning := thinkParser.Accept(delta)
+ if reasoning != "" {
+ emit(chatSSEFrame{Type: "reasoning", Text: reasoning})
+ }
+ emitDelta(visible)
+ }
for {
resp, err := stream.Recv()
if errors.Is(err, io.EOF) {
+ if parseThinkTags {
+ visible, reasoning := thinkParser.Flush()
+ if reasoning != "" {
+ emit(chatSSEFrame{Type: "reasoning", Text: reasoning})
+ }
+ emitDelta(visible)
+ }
usage.setModel(promptTokens, completionTokens)
if windowTokens > 0 {
windowElapsed := time.Since(windowStarted).Seconds()
@@ -862,28 +965,7 @@ func chatHandler(c *gin.Context) {
return
}
if len(resp.Choices) > 0 {
- delta := resp.Choices[0].Delta.Content
- if delta != "" {
- now := time.Now()
- deltaTokens := estimateTokenCount(delta)
- windowTokens += deltaTokens
- windowElapsed := now.Sub(windowStarted).Seconds()
- if windowElapsed >= 1 {
- windowSpeed := float64(windowTokens) / windowElapsed
- if windowSpeed > peakTokensPerSecond {
- peakTokensPerSecond = windowSpeed
- }
- windowStarted = now
- windowTokens = 0
- } else if peakTokensPerSecond == 0 && windowElapsed > 0.25 {
- peakTokensPerSecond = float64(windowTokens) / windowElapsed
- }
- full.WriteString(delta)
- completionTokens += deltaTokens
- usage.setModel(promptTokens, completionTokens)
- stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
- emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats})
- }
+ emitModelContent(resp.Choices[0].Delta.Content)
// 思考过程 reasoning_content 单独事件推送
if resp.Choices[0].Delta.ReasoningContent != nil && *resp.Choices[0].Delta.ReasoningContent != "" {
emit(chatSSEFrame{Type: "reasoning", Text: *resp.Choices[0].Delta.ReasoningContent})
@@ -1035,13 +1117,27 @@ func availableAgentTools(profile *OpenAIProfile, emit func(chatSSEFrame)) []agen
}
func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, emit func(chatSSEFrame)) ([]*model.ChatCompletionMessage, error) {
- messages, err := buildArkMessages(chatMessages)
+ finalMessages, err := buildArkMessages(chatMessages)
if err != nil {
return nil, err
}
- tools := availableAgentTools(profile, emit)
+ routerProfile := profile
+ if toolRouterState != nil {
+ routerProfile = toolRouterState.RouterProfile(profile)
+ }
+ tools := availableAgentTools(routerProfile, emit)
if len(tools) == 0 {
- return messages, nil
+ return finalMessages, nil
+ }
+ decisionMessages := append([]*model.ChatCompletionMessage(nil), finalMessages...)
+ if hasImageMessage(chatMessages) {
+ decisionMessages, err = buildToolDecisionMessages(chatMessages)
+ if err != nil {
+ return nil, err
+ }
+ if emit != nil {
+ emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "prepare", Status: "success", Message: "检测到图片输入,工具判断阶段将使用纯文本上下文"})
+ }
}
toolByName := make(map[string]agentTool, len(tools))
definitions := make([]*model.Tool, 0, len(tools))
@@ -1059,28 +1155,30 @@ func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "prepare", Status: "success", Message: "已准备可用工具", Data: map[string]any{"tools": availableNames, "tool_descriptions": toolDescriptions}})
}
if prompt := strings.TrimSpace(toolRouterState.cfg.SystemPrompt); prompt != "" {
- messages = append([]*model.ChatCompletionMessage{{Role: model.ChatMessageRoleSystem, Content: stringContent(prompt)}}, messages...)
+ systemMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent(prompt)}
+ finalMessages = append([]*model.ChatCompletionMessage{systemMessage}, finalMessages...)
+ decisionMessages = append([]*model.ChatCompletionMessage{systemMessage}, decisionMessages...)
}
for i := 0; i < maxAgentToolIterations; i++ {
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "running", Message: fmt.Sprintf("正在进行第 %d 轮工具判断", i+1), Data: map[string]any{"iteration": i + 1, "max_iterations": maxAgentToolIterations, "tools": availableNames}})
}
- resp, err := toolRouterState.complete(ctx, profile, model.CreateChatCompletionRequest{
- Model: profile.Config.Model,
- Messages: messages,
+ resp, err := toolRouterState.complete(ctx, routerProfile, model.CreateChatCompletionRequest{
+ Model: routerProfile.Config.Model,
+ Messages: decisionMessages,
MaxTokens: intPtr(toolRouterState.cfg.MaxTokens),
Tools: definitions,
ToolChoice: model.ToolChoiceStringTypeAuto,
ParallelToolCalls: boolPtr(false),
}, time.Duration(toolRouterState.cfg.Timeout)*time.Second)
if err != nil {
- return messages, err
+ return finalMessages, err
}
if tracker := tokenUsageFromContext(ctx); tracker != nil {
tracker.addTool(resp.Usage.PromptTokens, resp.Usage.CompletionTokens)
}
if len(resp.Choices) == 0 {
- return messages, nil
+ return finalMessages, nil
}
choice := resp.Choices[0]
decisionPreview := chatMessageContentString(choice.Message.Content)
@@ -1095,7 +1193,7 @@ func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "success", Message: "模型未请求工具,进入回答生成"})
}
- return messages, nil
+ return finalMessages, nil
}
callNames := make([]string, 0, len(calls))
for _, call := range calls {
@@ -1106,14 +1204,89 @@ func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "tool_calls", Status: "running", Message: fmt.Sprintf("模型请求调用 %d 个工具", len(calls)), Data: map[string]any{"tools": callNames, "iteration": i + 1}})
}
- messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleAssistant, ToolCalls: calls, Content: choice.Message.Content})
+ assistantMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleAssistant, ToolCalls: calls, Content: choice.Message.Content}
+ finalMessages = append(finalMessages, assistantMessage)
+ decisionMessages = append(decisionMessages, assistantMessage)
for _, call := range calls {
result := executeAgentToolCall(ctx, call, toolByName, emit)
- messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleTool, ToolCallID: call.ID, Content: stringContent(result)})
+ toolMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleTool, ToolCallID: call.ID, Content: stringContent(result)}
+ finalMessages = append(finalMessages, toolMessage)
+ decisionMessages = append(decisionMessages, toolMessage)
}
}
- messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent("工具调用轮数已达到上限。请基于已有工具结果回答,并说明可能未完成全部工具调用。")})
- return messages, nil
+ limitMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent("工具调用轮数已达到上限。请基于已有工具结果回答,并说明可能未完成全部工具调用。")}
+ finalMessages = append(finalMessages, limitMessage)
+ return finalMessages, nil
+}
+
+type thinkTagParser struct {
+ inThink bool
+ buffer string
+}
+
+const (
+ thinkOpenTag = ""
+ thinkCloseTag = ""
+)
+
+func (p *thinkTagParser) Accept(delta string) (visible string, reasoning string) {
+ p.buffer += delta
+ for p.buffer != "" {
+ if p.inThink {
+ idx := strings.Index(p.buffer, thinkCloseTag)
+ if idx >= 0 {
+ reasoning += p.buffer[:idx]
+ p.buffer = p.buffer[idx+len(thinkCloseTag):]
+ p.inThink = false
+ continue
+ }
+ keep := tagPrefixSuffixLen(p.buffer, thinkCloseTag)
+ if len(p.buffer) > keep {
+ reasoning += p.buffer[:len(p.buffer)-keep]
+ p.buffer = p.buffer[len(p.buffer)-keep:]
+ }
+ return visible, reasoning
+ }
+
+ idx := strings.Index(p.buffer, thinkOpenTag)
+ if idx >= 0 {
+ visible += p.buffer[:idx]
+ p.buffer = p.buffer[idx+len(thinkOpenTag):]
+ p.inThink = true
+ continue
+ }
+ keep := tagPrefixSuffixLen(p.buffer, thinkOpenTag)
+ if len(p.buffer) > keep {
+ visible += p.buffer[:len(p.buffer)-keep]
+ p.buffer = p.buffer[len(p.buffer)-keep:]
+ }
+ return visible, reasoning
+ }
+ return visible, reasoning
+}
+
+func (p *thinkTagParser) Flush() (visible string, reasoning string) {
+ if p.inThink {
+ reasoning = p.buffer
+ } else {
+ visible = p.buffer
+ }
+ p.buffer = ""
+ p.inThink = false
+ return visible, reasoning
+}
+
+func tagPrefixSuffixLen(text, tag string) int {
+ limit := len(tag) - 1
+ if len(text) < limit {
+ limit = len(text)
+ }
+ for i := limit; i > 0; i-- {
+ if strings.HasPrefix(tag, text[len(text)-i:]) {
+ return i
+ }
+ }
+ return 0
}
func executeAgentToolCall(ctx context.Context, call *model.ToolCall, tools map[string]agentTool, emit func(chatSSEFrame)) string {
@@ -1155,6 +1328,223 @@ func executeAgentToolCall(ctx context.Context, call *model.ToolCall, tools map[s
return result
}
+type ollamaChatRequest struct {
+ Model string `json:"model"`
+ Messages []ollamaChatMessage `json:"messages"`
+ Stream bool `json:"stream"`
+ Options map[string]int `json:"options,omitempty"`
+}
+
+type ollamaChatMessage struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ Images []string `json:"images,omitempty"`
+}
+
+type ollamaChatResponse struct {
+ Message struct {
+ Role string `json:"role"`
+ Content string `json:"content"`
+ Thinking string `json:"thinking"`
+ } `json:"message"`
+ Done bool `json:"done"`
+ PromptEvalCount int `json:"prompt_eval_count"`
+ EvalCount int `json:"eval_count"`
+ DoneReason string `json:"done_reason"`
+}
+
+func streamOllamaChat(ctx context.Context, profile *OpenAIProfile, messages []*model.ChatCompletionMessage, promptTokens int, usage *tokenUsageTracker, emit func(chatSSEFrame), onDone func(string)) error {
+ requestMessages, err := buildOllamaMessages(messages)
+ if err != nil {
+ return err
+ }
+ baseURL, err := ollamaBaseURL(profile)
+ if err != nil {
+ return err
+ }
+ body, err := json.Marshal(ollamaChatRequest{
+ Model: profile.Config.Model,
+ Messages: requestMessages,
+ Stream: true,
+ Options: map[string]int{"num_predict": 4096},
+ })
+ if err != nil {
+ return err
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(baseURL, "/")+"/api/chat", bytes.NewReader(body))
+ if err != nil {
+ return err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
+ return fmt.Errorf("Ollama 原生接口调用失败: %s %s", resp.Status, strings.TrimSpace(string(data)))
+ }
+
+ emit(chatSSEFrame{Type: "trace", Tool: "model", Stage: "stream", Status: "running", Message: "Ollama 视觉模型已开始输出"})
+ parseThinkTags := shouldParseThinkTags(profile)
+ thinkParser := &thinkTagParser{}
+ var full strings.Builder
+ completionTokens := 0
+ streamStarted := time.Now()
+ peakTokensPerSecond := 0.0
+ emitDelta := func(delta string) {
+ if delta == "" {
+ return
+ }
+ full.WriteString(delta)
+ completionTokens += estimateTokenCount(delta)
+ usage.setModel(promptTokens, completionTokens)
+ currentSpeed := tokensPerSecond(completionTokens, streamStarted)
+ if currentSpeed > peakTokensPerSecond {
+ peakTokensPerSecond = currentSpeed
+ }
+ stats := usage.snapshot(currentSpeed, peakTokensPerSecond)
+ emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats})
+ }
+ emitContent := func(delta string) {
+ if delta == "" {
+ return
+ }
+ if !parseThinkTags {
+ emitDelta(delta)
+ return
+ }
+ visible, reasoning := thinkParser.Accept(delta)
+ if reasoning != "" {
+ emit(chatSSEFrame{Type: "reasoning", Text: reasoning})
+ }
+ emitDelta(visible)
+ }
+
+ scanner := bufio.NewScanner(resp.Body)
+ scanner.Buffer(make([]byte, 0, 64*1024), 10*1024*1024)
+ for scanner.Scan() {
+ line := strings.TrimSpace(scanner.Text())
+ if line == "" {
+ continue
+ }
+ var chunk ollamaChatResponse
+ if err := json.Unmarshal([]byte(line), &chunk); err != nil {
+ return fmt.Errorf("解析 Ollama 流失败: %w", err)
+ }
+ if chunk.Message.Thinking != "" {
+ emit(chatSSEFrame{Type: "reasoning", Text: chunk.Message.Thinking})
+ }
+ emitContent(chunk.Message.Content)
+ if chunk.Done {
+ if chunk.PromptEvalCount > 0 || chunk.EvalCount > 0 {
+ usage.setModel(chunk.PromptEvalCount, chunk.EvalCount)
+ }
+ break
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ return err
+ }
+ if parseThinkTags {
+ visible, reasoning := thinkParser.Flush()
+ if reasoning != "" {
+ emit(chatSSEFrame{Type: "reasoning", Text: reasoning})
+ }
+ emitDelta(visible)
+ }
+ if onDone != nil {
+ onDone(full.String())
+ }
+ finalStats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
+ emit(chatSSEFrame{Type: "stats", Stats: &finalStats})
+ emit(chatSSEFrame{Type: "trace", Tool: "model", Stage: "stream", Status: "success", Message: "回答生成完成"})
+ return nil
+}
+
+func buildOllamaMessages(messages []*model.ChatCompletionMessage) ([]ollamaChatMessage, error) {
+ result := make([]ollamaChatMessage, 0, len(messages))
+ for _, msg := range messages {
+ if msg == nil {
+ continue
+ }
+ role := string(msg.Role)
+ if msg.Role == model.ChatMessageRoleTool {
+ role = string(model.ChatMessageRoleUser)
+ }
+ item := ollamaChatMessage{Role: role}
+ if msg.Content == nil {
+ if len(msg.ToolCalls) > 0 {
+ continue
+ }
+ result = append(result, item)
+ continue
+ }
+ if msg.Content.StringValue != nil {
+ item.Content = *msg.Content.StringValue
+ if msg.Role == model.ChatMessageRoleTool {
+ item.Content = "工具结果:\n" + item.Content
+ }
+ result = append(result, item)
+ continue
+ }
+ for _, part := range msg.Content.ListValue {
+ if part == nil {
+ continue
+ }
+ switch part.Type {
+ case model.ChatCompletionMessageContentPartTypeText:
+ if part.Text != "" {
+ if item.Content != "" {
+ item.Content += "\n"
+ }
+ item.Content += part.Text
+ }
+ case model.ChatCompletionMessageContentPartTypeImageURL:
+ if part.ImageURL == nil {
+ continue
+ }
+ image, err := ollamaImagePayload(part.ImageURL.URL)
+ if err != nil {
+ return nil, err
+ }
+ item.Images = append(item.Images, image)
+ }
+ }
+ result = append(result, item)
+ }
+ return result, nil
+}
+
+func ollamaImagePayload(raw string) (string, error) {
+ raw = strings.TrimSpace(raw)
+ if strings.HasPrefix(strings.ToLower(raw), "data:") {
+ comma := strings.Index(raw, ",")
+ if comma < 0 {
+ return "", errors.New("图片 base64 数据格式错误")
+ }
+ return strings.TrimSpace(raw[comma+1:]), nil
+ }
+ return raw, nil
+}
+
+func ollamaBaseURL(profile *OpenAIProfile) (string, error) {
+ if profile == nil {
+ return "", errors.New("Ollama 配置为空")
+ }
+ u, err := url.Parse(strings.TrimSpace(profile.Config.BaseURL))
+ if err != nil {
+ return "", err
+ }
+ if strings.TrimRight(u.Path, "/") == "/v1" {
+ u.Path = strings.TrimSuffix(strings.TrimRight(u.Path, "/"), "/v1")
+ }
+ u.RawQuery = ""
+ u.Fragment = ""
+ return strings.TrimRight(u.String(), "/"), nil
+}
+
func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) {
return completeTextWithTimeout(ctx, profile, chatMessages, maxTokens, time.Duration(profile.Config.Timeout)*time.Second)
}
@@ -1178,10 +1568,23 @@ func completeTextWithTimeout(ctx context.Context, profile *OpenAIProfile, chatMe
promptTokens := estimateChatMessagesTokens(chatMessages)
completionTokens := 0
+ parseThinkTags := shouldParseThinkTags(profile)
+ thinkParser := &thinkTagParser{}
var b strings.Builder
+ appendVisible := func(delta string) {
+ if delta == "" {
+ return
+ }
+ b.WriteString(delta)
+ completionTokens += estimateTokenCount(delta)
+ }
for {
resp, err := stream.Recv()
if errors.Is(err, io.EOF) {
+ if parseThinkTags {
+ visible, _ := thinkParser.Flush()
+ appendVisible(visible)
+ }
if tracker := tokenUsageFromContext(ctx); tracker != nil {
tracker.addTool(promptTokens, completionTokens)
}
@@ -1192,8 +1595,12 @@ func completeTextWithTimeout(ctx context.Context, profile *OpenAIProfile, chatMe
}
if len(resp.Choices) > 0 {
delta := resp.Choices[0].Delta.Content
- b.WriteString(delta)
- completionTokens += estimateTokenCount(delta)
+ if parseThinkTags {
+ visible, _ := thinkParser.Accept(delta)
+ appendVisible(visible)
+ } else {
+ appendVisible(delta)
+ }
}
}
}
@@ -1368,6 +1775,33 @@ func buildArkMessages(chatMessages []ChatMessage) ([]*model.ChatCompletionMessag
return messages, nil
}
+func hasImageMessage(messages []ChatMessage) bool {
+ for _, msg := range messages {
+ if strings.TrimSpace(msg.ImageURL) != "" || strings.TrimSpace(msg.ImageURLAlias) != "" {
+ return true
+ }
+ }
+ return false
+}
+
+func buildToolDecisionMessages(chatMessages []ChatMessage) ([]*model.ChatCompletionMessage, error) {
+ messages := make([]*model.ChatCompletionMessage, 0, len(chatMessages))
+ for _, m := range chatMessages {
+ content := m.Content
+ if strings.TrimSpace(m.ImageURL) != "" || strings.TrimSpace(m.ImageURLAlias) != "" {
+ content = strings.TrimSpace(content)
+ placeholder := "[用户上传了一张图片。工具判断阶段不读取图片内容;如果问题主要依赖识图,应不要调用工具,交给最终多模态模型回答。]"
+ if content == "" {
+ content = placeholder
+ } else {
+ content += "\n\n" + placeholder
+ }
+ }
+ messages = append(messages, &model.ChatCompletionMessage{Role: m.Role, Content: stringContent(content)})
+ }
+ return messages, nil
+}
+
func buildArkMessage(m ChatMessage) (*model.ChatCompletionMessage, error) {
msg := &model.ChatCompletionMessage{Role: m.Role}
@@ -1388,11 +1822,12 @@ func buildArkMessage(m ChatMessage) (*model.ChatCompletionMessage, error) {
}
// 有图片时:文字内容可有可无(图片 caption 场景),均构造多模态消息
- // 若无文字,则只传图片 part;若同时有图片和文字,先图后文
- parts := []*model.ChatCompletionMessageContentPart{imagePart(imageURL)}
+ // 若无文字,则只传图片 part;若同时有图片和文字,先文后图
+ parts := make([]*model.ChatCompletionMessageContentPart, 0, 2)
if m.Content != "" {
parts = append(parts, textPart(m.Content))
}
+ parts = append(parts, imagePart(imageURL))
msg.Content = &model.ChatCompletionMessageContent{ListValue: parts}
return msg, nil
}
diff --git a/main_test.go b/main_test.go
index cefebeb..2ee35fb 100644
--- a/main_test.go
+++ b/main_test.go
@@ -189,3 +189,176 @@ func TestRunAgentToolLoopMaxIterations(t *testing.T) {
t.Fatalf("unexpected last message: %#v", last)
}
}
+
+func TestBuildArkMessageImageTextOrder(t *testing.T) {
+ msg, err := buildArkMessage(ChatMessage{Role: "user", Content: "请描述图片", ImageURL: "data:image/png;base64,aGVsbG8="})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if msg.Content == nil || len(msg.Content.ListValue) != 2 {
+ t.Fatalf("unexpected content: %#v", msg.Content)
+ }
+ if msg.Content.ListValue[0].Type != model.ChatCompletionMessageContentPartTypeText || msg.Content.ListValue[0].Text != "请描述图片" {
+ t.Fatalf("first part should be text: %#v", msg.Content.ListValue[0])
+ }
+ if msg.Content.ListValue[1].Type != model.ChatCompletionMessageContentPartTypeImageURL || msg.Content.ListValue[1].ImageURL == nil {
+ t.Fatalf("second part should be image: %#v", msg.Content.ListValue[1])
+ }
+}
+
+func TestBuildArkMessageImageOnly(t *testing.T) {
+ msg, err := buildArkMessage(ChatMessage{Role: "user", ImageURL: "data:image/png;base64,aGVsbG8="})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if msg.Content == nil || len(msg.Content.ListValue) != 1 || msg.Content.ListValue[0].Type != model.ChatCompletionMessageContentPartTypeImageURL {
+ t.Fatalf("unexpected content: %#v", msg.Content)
+ }
+}
+
+func TestThinkTagParserSingleChunk(t *testing.T) {
+ parser := &thinkTagParser{}
+ visible, reasoning := parser.Accept("hello abc world")
+ flushVisible, flushReasoning := parser.Flush()
+ visible += flushVisible
+ reasoning += flushReasoning
+ if visible != "hello world" || reasoning != "abc" {
+ t.Fatalf("visible=%q reasoning=%q", visible, reasoning)
+ }
+}
+
+func TestThinkTagParserAcrossChunks(t *testing.T) {
+ parser := &thinkTagParser{}
+ var visible, reasoning string
+ for _, chunk := range []string{"hello abc world"} {
+ v, r := parser.Accept(chunk)
+ visible += v
+ reasoning += r
+ }
+ v, r := parser.Flush()
+ visible += v
+ reasoning += r
+ if visible != "hello world" || reasoning != "abc" {
+ t.Fatalf("visible=%q reasoning=%q", visible, reasoning)
+ }
+}
+
+func TestThinkTagParserUnclosedThink(t *testing.T) {
+ parser := &thinkTagParser{}
+ visible, reasoning := parser.Accept("answer still thinking")
+ v, r := parser.Flush()
+ visible += v
+ reasoning += r
+ if visible != "answer " || reasoning != "still thinking" {
+ t.Fatalf("visible=%q reasoning=%q", visible, reasoning)
+ }
+}
+
+func TestShouldParseThinkTags(t *testing.T) {
+ if !shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: "http://127.0.0.1:11434/v1"}}) {
+ t.Fatal("expected local ollama to parse think tags")
+ }
+ if shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: defaultOpenAIBaseURL}}) {
+ t.Fatal("expected remote profile not to parse think tags by default")
+ }
+ falseValue := false
+ if shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: "http://127.0.0.1:11434/v1", ParseThinkTags: &falseValue}}) {
+ t.Fatal("explicit false should disable think parsing")
+ }
+ trueValue := true
+ if !shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: defaultOpenAIBaseURL, ParseThinkTags: &trueValue}}) {
+ t.Fatal("explicit true should enable think parsing")
+ }
+}
+
+func TestBuildToolDecisionMessagesRemovesImages(t *testing.T) {
+ messages, err := buildToolDecisionMessages([]ChatMessage{{Role: "user", Content: "描述这张图", ImageURL: "data:image/png;base64,aGVsbG8="}})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(messages) != 1 || messages[0].Content == nil || messages[0].Content.StringValue == nil {
+ t.Fatalf("unexpected messages: %#v", messages)
+ }
+ if !strings.Contains(*messages[0].Content.StringValue, "工具判断阶段不读取图片内容") {
+ t.Fatalf("missing image placeholder: %q", *messages[0].Content.StringValue)
+ }
+ if messages[0].Content.ListValue != nil {
+ t.Fatalf("decision message should be text-only: %#v", messages[0].Content)
+ }
+}
+
+func TestRunAgentToolLoopImageUsesTextOnlyDecisionMessages(t *testing.T) {
+ oldRouter := toolRouterState
+ defer func() { toolRouterState = oldRouter }()
+
+ toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
+ Enabled: true,
+ Timeout: 1,
+ MaxTokens: 128,
+ SystemPrompt: "use tools",
+ Tools: []ToolRouteConfig{{Name: "time", Enabled: true}},
+ }}
+ toolRouterState.complete = func(ctx context.Context, profile *OpenAIProfile, req model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) {
+ for _, msg := range req.Messages {
+ if msg.Content != nil && len(msg.Content.ListValue) > 0 {
+ t.Fatalf("tool decision should not receive multimodal content: %#v", msg.Content)
+ }
+ }
+ joined := ""
+ for _, msg := range req.Messages {
+ if msg.Content != nil && msg.Content.StringValue != nil {
+ joined += *msg.Content.StringValue
+ }
+ }
+ if !strings.Contains(joined, "工具判断阶段不读取图片内容") {
+ t.Fatalf("missing placeholder in decision messages: %q", joined)
+ }
+ return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{Content: stringContent("no tool")}}}}, nil
+ }
+
+ messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "chat"}}, []ChatMessage{{Role: "user", Content: "描述这张图", ImageURL: "data:image/png;base64,aGVsbG8="}}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ foundImage := false
+ for _, msg := range messages {
+ if msg.Content != nil && len(msg.Content.ListValue) > 0 {
+ foundImage = true
+ }
+ }
+ if !foundImage {
+ t.Fatalf("final messages should retain image: %#v", messages)
+ }
+}
+
+func TestRunAgentToolLoopUsesConfiguredRouterProfile(t *testing.T) {
+ oldRouter := toolRouterState
+ defer func() { toolRouterState = oldRouter }()
+
+ ai, err := NewOpenAIState([]OpenAIConfig{
+ {Name: "chat", APIKey: "key", BaseURL: defaultOpenAIBaseURL, Model: "chat-model", Timeout: 1, Active: true},
+ {Name: "router", APIKey: "key", BaseURL: defaultOpenAIBaseURL, Model: "router-model", Timeout: 1},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ toolRouterState = &ToolRouterState{ai: ai, cfg: &ToolRouterConfig{
+ Enabled: true,
+ OpenAIName: "router",
+ Timeout: 1,
+ MaxTokens: 128,
+ SystemPrompt: "use tools",
+ Tools: []ToolRouteConfig{{Name: "time", Enabled: true}},
+ }}
+ toolRouterState.complete = func(ctx context.Context, profile *OpenAIProfile, req model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) {
+ if profile.Config.Name != "router" || req.Model != "router-model" {
+ t.Fatalf("router profile not used: profile=%s model=%s", profile.Config.Name, req.Model)
+ }
+ return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{Content: stringContent("no tool")}}}}, nil
+ }
+
+ _, err = runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Name: "chat", Model: "chat-model"}}, []ChatMessage{{Role: "user", Content: "今天"}}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+}