up
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
@@ -46,12 +48,13 @@ const (
|
||||
)
|
||||
|
||||
type OpenAIConfig struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Active bool `yaml:"active,omitempty" json:"active"`
|
||||
APIKey string `yaml:"api_key" json:"-"`
|
||||
BaseURL string `yaml:"base_url" json:"base_url"`
|
||||
Model string `yaml:"model" json:"model"`
|
||||
Timeout int `yaml:"timeout" json:"timeout"`
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Active bool `yaml:"active,omitempty" json:"active"`
|
||||
APIKey string `yaml:"api_key" json:"-"`
|
||||
BaseURL string `yaml:"base_url" json:"base_url"`
|
||||
Model string `yaml:"model" json:"model"`
|
||||
Timeout int `yaml:"timeout" json:"timeout"`
|
||||
ParseThinkTags *bool `yaml:"parse_think_tags,omitempty" json:"parse_think_tags,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIConfigs []OpenAIConfig
|
||||
@@ -559,6 +562,44 @@ func publicOpenAIConfig(profile *OpenAIProfile, active bool) OpenAIConfig {
|
||||
return config
|
||||
}
|
||||
|
||||
func (s *ToolRouterState) RouterProfile(fallback *OpenAIProfile) *OpenAIProfile {
|
||||
if s == nil || s.cfg == nil || s.ai == nil {
|
||||
return fallback
|
||||
}
|
||||
name := strings.TrimSpace(s.cfg.OpenAIName)
|
||||
if name == "" {
|
||||
return fallback
|
||||
}
|
||||
profile, err := s.ai.GetProfile(name)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
return profile
|
||||
}
|
||||
|
||||
func isOllamaProfile(profile *OpenAIProfile) bool {
|
||||
if profile == nil {
|
||||
return false
|
||||
}
|
||||
u, err := url.Parse(strings.TrimSpace(profile.Config.BaseURL))
|
||||
if err != nil {
|
||||
return strings.Contains(profile.Config.BaseURL, ":11434")
|
||||
}
|
||||
host := strings.ToLower(u.Hostname())
|
||||
port := u.Port()
|
||||
return port == "11434" && (host == "127.0.0.1" || host == "localhost" || host == "::1")
|
||||
}
|
||||
|
||||
func shouldParseThinkTags(profile *OpenAIProfile) bool {
|
||||
if profile == nil {
|
||||
return false
|
||||
}
|
||||
if profile.Config.ParseThinkTags != nil {
|
||||
return *profile.Config.ParseThinkTags
|
||||
}
|
||||
return isOllamaProfile(profile)
|
||||
}
|
||||
|
||||
// ─── 全局变量 ─────────────────────────────────────────────
|
||||
|
||||
var (
|
||||
@@ -810,6 +851,21 @@ func chatHandler(c *gin.Context) {
|
||||
}
|
||||
promptTokens := estimateChatMessagesTokens(req.Messages)
|
||||
|
||||
if isOllamaProfile(profile) && hasImageMessage(req.Messages) {
|
||||
emitTrace("model", "request", "running", "正在通过 Ollama 原生接口调用视觉模型", nil)
|
||||
err = streamOllamaChat(ctx, profile, messages, promptTokens, usage, emit, func(content string) {
|
||||
if req.ConversationID != "" {
|
||||
if err := saveConversationMessages(req.ConversationID, req.Messages, content); err != nil {
|
||||
fmt.Fprintln(os.Stderr, "保存对话失败:", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
emitError(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
emitTrace("model", "request", "running", "正在调用模型生成回答", nil)
|
||||
stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
|
||||
Model: profile.Config.Model,
|
||||
@@ -829,9 +885,56 @@ func chatHandler(c *gin.Context) {
|
||||
windowStarted := streamStarted
|
||||
windowTokens := 0
|
||||
peakTokensPerSecond := 0.0
|
||||
parseThinkTags := shouldParseThinkTags(profile)
|
||||
thinkParser := &thinkTagParser{}
|
||||
emitDelta := func(delta string) {
|
||||
if delta == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
deltaTokens := estimateTokenCount(delta)
|
||||
windowTokens += deltaTokens
|
||||
windowElapsed := now.Sub(windowStarted).Seconds()
|
||||
if windowElapsed >= 1 {
|
||||
windowSpeed := float64(windowTokens) / windowElapsed
|
||||
if windowSpeed > peakTokensPerSecond {
|
||||
peakTokensPerSecond = windowSpeed
|
||||
}
|
||||
windowStarted = now
|
||||
windowTokens = 0
|
||||
} else if peakTokensPerSecond == 0 && windowElapsed > 0.25 {
|
||||
peakTokensPerSecond = float64(windowTokens) / windowElapsed
|
||||
}
|
||||
full.WriteString(delta)
|
||||
completionTokens += deltaTokens
|
||||
usage.setModel(promptTokens, completionTokens)
|
||||
stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
|
||||
emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats})
|
||||
}
|
||||
emitModelContent := func(delta string) {
|
||||
if delta == "" {
|
||||
return
|
||||
}
|
||||
if !parseThinkTags {
|
||||
emitDelta(delta)
|
||||
return
|
||||
}
|
||||
visible, reasoning := thinkParser.Accept(delta)
|
||||
if reasoning != "" {
|
||||
emit(chatSSEFrame{Type: "reasoning", Text: reasoning})
|
||||
}
|
||||
emitDelta(visible)
|
||||
}
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
if parseThinkTags {
|
||||
visible, reasoning := thinkParser.Flush()
|
||||
if reasoning != "" {
|
||||
emit(chatSSEFrame{Type: "reasoning", Text: reasoning})
|
||||
}
|
||||
emitDelta(visible)
|
||||
}
|
||||
usage.setModel(promptTokens, completionTokens)
|
||||
if windowTokens > 0 {
|
||||
windowElapsed := time.Since(windowStarted).Seconds()
|
||||
@@ -862,28 +965,7 @@ func chatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
if len(resp.Choices) > 0 {
|
||||
delta := resp.Choices[0].Delta.Content
|
||||
if delta != "" {
|
||||
now := time.Now()
|
||||
deltaTokens := estimateTokenCount(delta)
|
||||
windowTokens += deltaTokens
|
||||
windowElapsed := now.Sub(windowStarted).Seconds()
|
||||
if windowElapsed >= 1 {
|
||||
windowSpeed := float64(windowTokens) / windowElapsed
|
||||
if windowSpeed > peakTokensPerSecond {
|
||||
peakTokensPerSecond = windowSpeed
|
||||
}
|
||||
windowStarted = now
|
||||
windowTokens = 0
|
||||
} else if peakTokensPerSecond == 0 && windowElapsed > 0.25 {
|
||||
peakTokensPerSecond = float64(windowTokens) / windowElapsed
|
||||
}
|
||||
full.WriteString(delta)
|
||||
completionTokens += deltaTokens
|
||||
usage.setModel(promptTokens, completionTokens)
|
||||
stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
|
||||
emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats})
|
||||
}
|
||||
emitModelContent(resp.Choices[0].Delta.Content)
|
||||
// 思考过程 reasoning_content 单独事件推送
|
||||
if resp.Choices[0].Delta.ReasoningContent != nil && *resp.Choices[0].Delta.ReasoningContent != "" {
|
||||
emit(chatSSEFrame{Type: "reasoning", Text: *resp.Choices[0].Delta.ReasoningContent})
|
||||
@@ -1035,13 +1117,27 @@ func availableAgentTools(profile *OpenAIProfile, emit func(chatSSEFrame)) []agen
|
||||
}
|
||||
|
||||
func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, emit func(chatSSEFrame)) ([]*model.ChatCompletionMessage, error) {
|
||||
messages, err := buildArkMessages(chatMessages)
|
||||
finalMessages, err := buildArkMessages(chatMessages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tools := availableAgentTools(profile, emit)
|
||||
routerProfile := profile
|
||||
if toolRouterState != nil {
|
||||
routerProfile = toolRouterState.RouterProfile(profile)
|
||||
}
|
||||
tools := availableAgentTools(routerProfile, emit)
|
||||
if len(tools) == 0 {
|
||||
return messages, nil
|
||||
return finalMessages, nil
|
||||
}
|
||||
decisionMessages := append([]*model.ChatCompletionMessage(nil), finalMessages...)
|
||||
if hasImageMessage(chatMessages) {
|
||||
decisionMessages, err = buildToolDecisionMessages(chatMessages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "prepare", Status: "success", Message: "检测到图片输入,工具判断阶段将使用纯文本上下文"})
|
||||
}
|
||||
}
|
||||
toolByName := make(map[string]agentTool, len(tools))
|
||||
definitions := make([]*model.Tool, 0, len(tools))
|
||||
@@ -1059,28 +1155,30 @@ func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "prepare", Status: "success", Message: "已准备可用工具", Data: map[string]any{"tools": availableNames, "tool_descriptions": toolDescriptions}})
|
||||
}
|
||||
if prompt := strings.TrimSpace(toolRouterState.cfg.SystemPrompt); prompt != "" {
|
||||
messages = append([]*model.ChatCompletionMessage{{Role: model.ChatMessageRoleSystem, Content: stringContent(prompt)}}, messages...)
|
||||
systemMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent(prompt)}
|
||||
finalMessages = append([]*model.ChatCompletionMessage{systemMessage}, finalMessages...)
|
||||
decisionMessages = append([]*model.ChatCompletionMessage{systemMessage}, decisionMessages...)
|
||||
}
|
||||
for i := 0; i < maxAgentToolIterations; i++ {
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "running", Message: fmt.Sprintf("正在进行第 %d 轮工具判断", i+1), Data: map[string]any{"iteration": i + 1, "max_iterations": maxAgentToolIterations, "tools": availableNames}})
|
||||
}
|
||||
resp, err := toolRouterState.complete(ctx, profile, model.CreateChatCompletionRequest{
|
||||
Model: profile.Config.Model,
|
||||
Messages: messages,
|
||||
resp, err := toolRouterState.complete(ctx, routerProfile, model.CreateChatCompletionRequest{
|
||||
Model: routerProfile.Config.Model,
|
||||
Messages: decisionMessages,
|
||||
MaxTokens: intPtr(toolRouterState.cfg.MaxTokens),
|
||||
Tools: definitions,
|
||||
ToolChoice: model.ToolChoiceStringTypeAuto,
|
||||
ParallelToolCalls: boolPtr(false),
|
||||
}, time.Duration(toolRouterState.cfg.Timeout)*time.Second)
|
||||
if err != nil {
|
||||
return messages, err
|
||||
return finalMessages, err
|
||||
}
|
||||
if tracker := tokenUsageFromContext(ctx); tracker != nil {
|
||||
tracker.addTool(resp.Usage.PromptTokens, resp.Usage.CompletionTokens)
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return messages, nil
|
||||
return finalMessages, nil
|
||||
}
|
||||
choice := resp.Choices[0]
|
||||
decisionPreview := chatMessageContentString(choice.Message.Content)
|
||||
@@ -1095,7 +1193,7 @@ func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "success", Message: "模型未请求工具,进入回答生成"})
|
||||
}
|
||||
return messages, nil
|
||||
return finalMessages, nil
|
||||
}
|
||||
callNames := make([]string, 0, len(calls))
|
||||
for _, call := range calls {
|
||||
@@ -1106,14 +1204,89 @@ func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "tool_calls", Status: "running", Message: fmt.Sprintf("模型请求调用 %d 个工具", len(calls)), Data: map[string]any{"tools": callNames, "iteration": i + 1}})
|
||||
}
|
||||
messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleAssistant, ToolCalls: calls, Content: choice.Message.Content})
|
||||
assistantMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleAssistant, ToolCalls: calls, Content: choice.Message.Content}
|
||||
finalMessages = append(finalMessages, assistantMessage)
|
||||
decisionMessages = append(decisionMessages, assistantMessage)
|
||||
for _, call := range calls {
|
||||
result := executeAgentToolCall(ctx, call, toolByName, emit)
|
||||
messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleTool, ToolCallID: call.ID, Content: stringContent(result)})
|
||||
toolMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleTool, ToolCallID: call.ID, Content: stringContent(result)}
|
||||
finalMessages = append(finalMessages, toolMessage)
|
||||
decisionMessages = append(decisionMessages, toolMessage)
|
||||
}
|
||||
}
|
||||
messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent("工具调用轮数已达到上限。请基于已有工具结果回答,并说明可能未完成全部工具调用。")})
|
||||
return messages, nil
|
||||
limitMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent("工具调用轮数已达到上限。请基于已有工具结果回答,并说明可能未完成全部工具调用。")}
|
||||
finalMessages = append(finalMessages, limitMessage)
|
||||
return finalMessages, nil
|
||||
}
|
||||
|
||||
type thinkTagParser struct {
|
||||
inThink bool
|
||||
buffer string
|
||||
}
|
||||
|
||||
const (
|
||||
thinkOpenTag = "<think>"
|
||||
thinkCloseTag = "</think>"
|
||||
)
|
||||
|
||||
func (p *thinkTagParser) Accept(delta string) (visible string, reasoning string) {
|
||||
p.buffer += delta
|
||||
for p.buffer != "" {
|
||||
if p.inThink {
|
||||
idx := strings.Index(p.buffer, thinkCloseTag)
|
||||
if idx >= 0 {
|
||||
reasoning += p.buffer[:idx]
|
||||
p.buffer = p.buffer[idx+len(thinkCloseTag):]
|
||||
p.inThink = false
|
||||
continue
|
||||
}
|
||||
keep := tagPrefixSuffixLen(p.buffer, thinkCloseTag)
|
||||
if len(p.buffer) > keep {
|
||||
reasoning += p.buffer[:len(p.buffer)-keep]
|
||||
p.buffer = p.buffer[len(p.buffer)-keep:]
|
||||
}
|
||||
return visible, reasoning
|
||||
}
|
||||
|
||||
idx := strings.Index(p.buffer, thinkOpenTag)
|
||||
if idx >= 0 {
|
||||
visible += p.buffer[:idx]
|
||||
p.buffer = p.buffer[idx+len(thinkOpenTag):]
|
||||
p.inThink = true
|
||||
continue
|
||||
}
|
||||
keep := tagPrefixSuffixLen(p.buffer, thinkOpenTag)
|
||||
if len(p.buffer) > keep {
|
||||
visible += p.buffer[:len(p.buffer)-keep]
|
||||
p.buffer = p.buffer[len(p.buffer)-keep:]
|
||||
}
|
||||
return visible, reasoning
|
||||
}
|
||||
return visible, reasoning
|
||||
}
|
||||
|
||||
func (p *thinkTagParser) Flush() (visible string, reasoning string) {
|
||||
if p.inThink {
|
||||
reasoning = p.buffer
|
||||
} else {
|
||||
visible = p.buffer
|
||||
}
|
||||
p.buffer = ""
|
||||
p.inThink = false
|
||||
return visible, reasoning
|
||||
}
|
||||
|
||||
func tagPrefixSuffixLen(text, tag string) int {
|
||||
limit := len(tag) - 1
|
||||
if len(text) < limit {
|
||||
limit = len(text)
|
||||
}
|
||||
for i := limit; i > 0; i-- {
|
||||
if strings.HasPrefix(tag, text[len(text)-i:]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func executeAgentToolCall(ctx context.Context, call *model.ToolCall, tools map[string]agentTool, emit func(chatSSEFrame)) string {
|
||||
@@ -1155,6 +1328,223 @@ func executeAgentToolCall(ctx context.Context, call *model.ToolCall, tools map[s
|
||||
return result
|
||||
}
|
||||
|
||||
type ollamaChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ollamaChatMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
Options map[string]int `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type ollamaChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Images []string `json:"images,omitempty"`
|
||||
}
|
||||
|
||||
type ollamaChatResponse struct {
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Thinking string `json:"thinking"`
|
||||
} `json:"message"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
DoneReason string `json:"done_reason"`
|
||||
}
|
||||
|
||||
func streamOllamaChat(ctx context.Context, profile *OpenAIProfile, messages []*model.ChatCompletionMessage, promptTokens int, usage *tokenUsageTracker, emit func(chatSSEFrame), onDone func(string)) error {
|
||||
requestMessages, err := buildOllamaMessages(messages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
baseURL, err := ollamaBaseURL(profile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
body, err := json.Marshal(ollamaChatRequest{
|
||||
Model: profile.Config.Model,
|
||||
Messages: requestMessages,
|
||||
Stream: true,
|
||||
Options: map[string]int{"num_predict": 4096},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(baseURL, "/")+"/api/chat", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return fmt.Errorf("Ollama 原生接口调用失败: %s %s", resp.Status, strings.TrimSpace(string(data)))
|
||||
}
|
||||
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "model", Stage: "stream", Status: "running", Message: "Ollama 视觉模型已开始输出"})
|
||||
parseThinkTags := shouldParseThinkTags(profile)
|
||||
thinkParser := &thinkTagParser{}
|
||||
var full strings.Builder
|
||||
completionTokens := 0
|
||||
streamStarted := time.Now()
|
||||
peakTokensPerSecond := 0.0
|
||||
emitDelta := func(delta string) {
|
||||
if delta == "" {
|
||||
return
|
||||
}
|
||||
full.WriteString(delta)
|
||||
completionTokens += estimateTokenCount(delta)
|
||||
usage.setModel(promptTokens, completionTokens)
|
||||
currentSpeed := tokensPerSecond(completionTokens, streamStarted)
|
||||
if currentSpeed > peakTokensPerSecond {
|
||||
peakTokensPerSecond = currentSpeed
|
||||
}
|
||||
stats := usage.snapshot(currentSpeed, peakTokensPerSecond)
|
||||
emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats})
|
||||
}
|
||||
emitContent := func(delta string) {
|
||||
if delta == "" {
|
||||
return
|
||||
}
|
||||
if !parseThinkTags {
|
||||
emitDelta(delta)
|
||||
return
|
||||
}
|
||||
visible, reasoning := thinkParser.Accept(delta)
|
||||
if reasoning != "" {
|
||||
emit(chatSSEFrame{Type: "reasoning", Text: reasoning})
|
||||
}
|
||||
emitDelta(visible)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(make([]byte, 0, 64*1024), 10*1024*1024)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
var chunk ollamaChatResponse
|
||||
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
|
||||
return fmt.Errorf("解析 Ollama 流失败: %w", err)
|
||||
}
|
||||
if chunk.Message.Thinking != "" {
|
||||
emit(chatSSEFrame{Type: "reasoning", Text: chunk.Message.Thinking})
|
||||
}
|
||||
emitContent(chunk.Message.Content)
|
||||
if chunk.Done {
|
||||
if chunk.PromptEvalCount > 0 || chunk.EvalCount > 0 {
|
||||
usage.setModel(chunk.PromptEvalCount, chunk.EvalCount)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
if parseThinkTags {
|
||||
visible, reasoning := thinkParser.Flush()
|
||||
if reasoning != "" {
|
||||
emit(chatSSEFrame{Type: "reasoning", Text: reasoning})
|
||||
}
|
||||
emitDelta(visible)
|
||||
}
|
||||
if onDone != nil {
|
||||
onDone(full.String())
|
||||
}
|
||||
finalStats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
|
||||
emit(chatSSEFrame{Type: "stats", Stats: &finalStats})
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "model", Stage: "stream", Status: "success", Message: "回答生成完成"})
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildOllamaMessages(messages []*model.ChatCompletionMessage) ([]ollamaChatMessage, error) {
|
||||
result := make([]ollamaChatMessage, 0, len(messages))
|
||||
for _, msg := range messages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
role := string(msg.Role)
|
||||
if msg.Role == model.ChatMessageRoleTool {
|
||||
role = string(model.ChatMessageRoleUser)
|
||||
}
|
||||
item := ollamaChatMessage{Role: role}
|
||||
if msg.Content == nil {
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
continue
|
||||
}
|
||||
result = append(result, item)
|
||||
continue
|
||||
}
|
||||
if msg.Content.StringValue != nil {
|
||||
item.Content = *msg.Content.StringValue
|
||||
if msg.Role == model.ChatMessageRoleTool {
|
||||
item.Content = "工具结果:\n" + item.Content
|
||||
}
|
||||
result = append(result, item)
|
||||
continue
|
||||
}
|
||||
for _, part := range msg.Content.ListValue {
|
||||
if part == nil {
|
||||
continue
|
||||
}
|
||||
switch part.Type {
|
||||
case model.ChatCompletionMessageContentPartTypeText:
|
||||
if part.Text != "" {
|
||||
if item.Content != "" {
|
||||
item.Content += "\n"
|
||||
}
|
||||
item.Content += part.Text
|
||||
}
|
||||
case model.ChatCompletionMessageContentPartTypeImageURL:
|
||||
if part.ImageURL == nil {
|
||||
continue
|
||||
}
|
||||
image, err := ollamaImagePayload(part.ImageURL.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
item.Images = append(item.Images, image)
|
||||
}
|
||||
}
|
||||
result = append(result, item)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func ollamaImagePayload(raw string) (string, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if strings.HasPrefix(strings.ToLower(raw), "data:") {
|
||||
comma := strings.Index(raw, ",")
|
||||
if comma < 0 {
|
||||
return "", errors.New("图片 base64 数据格式错误")
|
||||
}
|
||||
return strings.TrimSpace(raw[comma+1:]), nil
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
func ollamaBaseURL(profile *OpenAIProfile) (string, error) {
|
||||
if profile == nil {
|
||||
return "", errors.New("Ollama 配置为空")
|
||||
}
|
||||
u, err := url.Parse(strings.TrimSpace(profile.Config.BaseURL))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if strings.TrimRight(u.Path, "/") == "/v1" {
|
||||
u.Path = strings.TrimSuffix(strings.TrimRight(u.Path, "/"), "/v1")
|
||||
}
|
||||
u.RawQuery = ""
|
||||
u.Fragment = ""
|
||||
return strings.TrimRight(u.String(), "/"), nil
|
||||
}
|
||||
|
||||
func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) {
|
||||
return completeTextWithTimeout(ctx, profile, chatMessages, maxTokens, time.Duration(profile.Config.Timeout)*time.Second)
|
||||
}
|
||||
@@ -1178,10 +1568,23 @@ func completeTextWithTimeout(ctx context.Context, profile *OpenAIProfile, chatMe
|
||||
|
||||
promptTokens := estimateChatMessagesTokens(chatMessages)
|
||||
completionTokens := 0
|
||||
parseThinkTags := shouldParseThinkTags(profile)
|
||||
thinkParser := &thinkTagParser{}
|
||||
var b strings.Builder
|
||||
appendVisible := func(delta string) {
|
||||
if delta == "" {
|
||||
return
|
||||
}
|
||||
b.WriteString(delta)
|
||||
completionTokens += estimateTokenCount(delta)
|
||||
}
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
if parseThinkTags {
|
||||
visible, _ := thinkParser.Flush()
|
||||
appendVisible(visible)
|
||||
}
|
||||
if tracker := tokenUsageFromContext(ctx); tracker != nil {
|
||||
tracker.addTool(promptTokens, completionTokens)
|
||||
}
|
||||
@@ -1192,8 +1595,12 @@ func completeTextWithTimeout(ctx context.Context, profile *OpenAIProfile, chatMe
|
||||
}
|
||||
if len(resp.Choices) > 0 {
|
||||
delta := resp.Choices[0].Delta.Content
|
||||
b.WriteString(delta)
|
||||
completionTokens += estimateTokenCount(delta)
|
||||
if parseThinkTags {
|
||||
visible, _ := thinkParser.Accept(delta)
|
||||
appendVisible(visible)
|
||||
} else {
|
||||
appendVisible(delta)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1368,6 +1775,33 @@ func buildArkMessages(chatMessages []ChatMessage) ([]*model.ChatCompletionMessag
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func hasImageMessage(messages []ChatMessage) bool {
|
||||
for _, msg := range messages {
|
||||
if strings.TrimSpace(msg.ImageURL) != "" || strings.TrimSpace(msg.ImageURLAlias) != "" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildToolDecisionMessages(chatMessages []ChatMessage) ([]*model.ChatCompletionMessage, error) {
|
||||
messages := make([]*model.ChatCompletionMessage, 0, len(chatMessages))
|
||||
for _, m := range chatMessages {
|
||||
content := m.Content
|
||||
if strings.TrimSpace(m.ImageURL) != "" || strings.TrimSpace(m.ImageURLAlias) != "" {
|
||||
content = strings.TrimSpace(content)
|
||||
placeholder := "[用户上传了一张图片。工具判断阶段不读取图片内容;如果问题主要依赖识图,应不要调用工具,交给最终多模态模型回答。]"
|
||||
if content == "" {
|
||||
content = placeholder
|
||||
} else {
|
||||
content += "\n\n" + placeholder
|
||||
}
|
||||
}
|
||||
messages = append(messages, &model.ChatCompletionMessage{Role: m.Role, Content: stringContent(content)})
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func buildArkMessage(m ChatMessage) (*model.ChatCompletionMessage, error) {
|
||||
msg := &model.ChatCompletionMessage{Role: m.Role}
|
||||
|
||||
@@ -1388,11 +1822,12 @@ func buildArkMessage(m ChatMessage) (*model.ChatCompletionMessage, error) {
|
||||
}
|
||||
|
||||
// 有图片时:文字内容可有可无(图片 caption 场景),均构造多模态消息
|
||||
// 若无文字,则只传图片 part;若同时有图片和文字,先图后文
|
||||
parts := []*model.ChatCompletionMessageContentPart{imagePart(imageURL)}
|
||||
// 若无文字,则只传图片 part;若同时有图片和文字,先文后图
|
||||
parts := make([]*model.ChatCompletionMessageContentPart, 0, 2)
|
||||
if m.Content != "" {
|
||||
parts = append(parts, textPart(m.Content))
|
||||
}
|
||||
parts = append(parts, imagePart(imageURL))
|
||||
msg.Content = &model.ChatCompletionMessageContent{ListValue: parts}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
+173
@@ -189,3 +189,176 @@ func TestRunAgentToolLoopMaxIterations(t *testing.T) {
|
||||
t.Fatalf("unexpected last message: %#v", last)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildArkMessageImageTextOrder(t *testing.T) {
|
||||
msg, err := buildArkMessage(ChatMessage{Role: "user", Content: "请描述图片", ImageURL: "data:image/png;base64,aGVsbG8="})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg.Content == nil || len(msg.Content.ListValue) != 2 {
|
||||
t.Fatalf("unexpected content: %#v", msg.Content)
|
||||
}
|
||||
if msg.Content.ListValue[0].Type != model.ChatCompletionMessageContentPartTypeText || msg.Content.ListValue[0].Text != "请描述图片" {
|
||||
t.Fatalf("first part should be text: %#v", msg.Content.ListValue[0])
|
||||
}
|
||||
if msg.Content.ListValue[1].Type != model.ChatCompletionMessageContentPartTypeImageURL || msg.Content.ListValue[1].ImageURL == nil {
|
||||
t.Fatalf("second part should be image: %#v", msg.Content.ListValue[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildArkMessageImageOnly(t *testing.T) {
|
||||
msg, err := buildArkMessage(ChatMessage{Role: "user", ImageURL: "data:image/png;base64,aGVsbG8="})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if msg.Content == nil || len(msg.Content.ListValue) != 1 || msg.Content.ListValue[0].Type != model.ChatCompletionMessageContentPartTypeImageURL {
|
||||
t.Fatalf("unexpected content: %#v", msg.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkTagParserSingleChunk(t *testing.T) {
|
||||
parser := &thinkTagParser{}
|
||||
visible, reasoning := parser.Accept("hello <think>abc</think> world")
|
||||
flushVisible, flushReasoning := parser.Flush()
|
||||
visible += flushVisible
|
||||
reasoning += flushReasoning
|
||||
if visible != "hello world" || reasoning != "abc" {
|
||||
t.Fatalf("visible=%q reasoning=%q", visible, reasoning)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkTagParserAcrossChunks(t *testing.T) {
|
||||
parser := &thinkTagParser{}
|
||||
var visible, reasoning string
|
||||
for _, chunk := range []string{"hello <thi", "nk>abc</thi", "nk> world"} {
|
||||
v, r := parser.Accept(chunk)
|
||||
visible += v
|
||||
reasoning += r
|
||||
}
|
||||
v, r := parser.Flush()
|
||||
visible += v
|
||||
reasoning += r
|
||||
if visible != "hello world" || reasoning != "abc" {
|
||||
t.Fatalf("visible=%q reasoning=%q", visible, reasoning)
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkTagParserUnclosedThink(t *testing.T) {
|
||||
parser := &thinkTagParser{}
|
||||
visible, reasoning := parser.Accept("answer <think>still thinking")
|
||||
v, r := parser.Flush()
|
||||
visible += v
|
||||
reasoning += r
|
||||
if visible != "answer " || reasoning != "still thinking" {
|
||||
t.Fatalf("visible=%q reasoning=%q", visible, reasoning)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldParseThinkTags(t *testing.T) {
|
||||
if !shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: "http://127.0.0.1:11434/v1"}}) {
|
||||
t.Fatal("expected local ollama to parse think tags")
|
||||
}
|
||||
if shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: defaultOpenAIBaseURL}}) {
|
||||
t.Fatal("expected remote profile not to parse think tags by default")
|
||||
}
|
||||
falseValue := false
|
||||
if shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: "http://127.0.0.1:11434/v1", ParseThinkTags: &falseValue}}) {
|
||||
t.Fatal("explicit false should disable think parsing")
|
||||
}
|
||||
trueValue := true
|
||||
if !shouldParseThinkTags(&OpenAIProfile{Config: OpenAIConfig{BaseURL: defaultOpenAIBaseURL, ParseThinkTags: &trueValue}}) {
|
||||
t.Fatal("explicit true should enable think parsing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolDecisionMessagesRemovesImages(t *testing.T) {
|
||||
messages, err := buildToolDecisionMessages([]ChatMessage{{Role: "user", Content: "描述这张图", ImageURL: "data:image/png;base64,aGVsbG8="}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(messages) != 1 || messages[0].Content == nil || messages[0].Content.StringValue == nil {
|
||||
t.Fatalf("unexpected messages: %#v", messages)
|
||||
}
|
||||
if !strings.Contains(*messages[0].Content.StringValue, "工具判断阶段不读取图片内容") {
|
||||
t.Fatalf("missing image placeholder: %q", *messages[0].Content.StringValue)
|
||||
}
|
||||
if messages[0].Content.ListValue != nil {
|
||||
t.Fatalf("decision message should be text-only: %#v", messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunAgentToolLoopImageUsesTextOnlyDecisionMessages(t *testing.T) {
|
||||
oldRouter := toolRouterState
|
||||
defer func() { toolRouterState = oldRouter }()
|
||||
|
||||
toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
|
||||
Enabled: true,
|
||||
Timeout: 1,
|
||||
MaxTokens: 128,
|
||||
SystemPrompt: "use tools",
|
||||
Tools: []ToolRouteConfig{{Name: "time", Enabled: true}},
|
||||
}}
|
||||
toolRouterState.complete = func(ctx context.Context, profile *OpenAIProfile, req model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) {
|
||||
for _, msg := range req.Messages {
|
||||
if msg.Content != nil && len(msg.Content.ListValue) > 0 {
|
||||
t.Fatalf("tool decision should not receive multimodal content: %#v", msg.Content)
|
||||
}
|
||||
}
|
||||
joined := ""
|
||||
for _, msg := range req.Messages {
|
||||
if msg.Content != nil && msg.Content.StringValue != nil {
|
||||
joined += *msg.Content.StringValue
|
||||
}
|
||||
}
|
||||
if !strings.Contains(joined, "工具判断阶段不读取图片内容") {
|
||||
t.Fatalf("missing placeholder in decision messages: %q", joined)
|
||||
}
|
||||
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{Content: stringContent("no tool")}}}}, nil
|
||||
}
|
||||
|
||||
messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "chat"}}, []ChatMessage{{Role: "user", Content: "描述这张图", ImageURL: "data:image/png;base64,aGVsbG8="}}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
foundImage := false
|
||||
for _, msg := range messages {
|
||||
if msg.Content != nil && len(msg.Content.ListValue) > 0 {
|
||||
foundImage = true
|
||||
}
|
||||
}
|
||||
if !foundImage {
|
||||
t.Fatalf("final messages should retain image: %#v", messages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunAgentToolLoopUsesConfiguredRouterProfile(t *testing.T) {
|
||||
oldRouter := toolRouterState
|
||||
defer func() { toolRouterState = oldRouter }()
|
||||
|
||||
ai, err := NewOpenAIState([]OpenAIConfig{
|
||||
{Name: "chat", APIKey: "key", BaseURL: defaultOpenAIBaseURL, Model: "chat-model", Timeout: 1, Active: true},
|
||||
{Name: "router", APIKey: "key", BaseURL: defaultOpenAIBaseURL, Model: "router-model", Timeout: 1},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
toolRouterState = &ToolRouterState{ai: ai, cfg: &ToolRouterConfig{
|
||||
Enabled: true,
|
||||
OpenAIName: "router",
|
||||
Timeout: 1,
|
||||
MaxTokens: 128,
|
||||
SystemPrompt: "use tools",
|
||||
Tools: []ToolRouteConfig{{Name: "time", Enabled: true}},
|
||||
}}
|
||||
toolRouterState.complete = func(ctx context.Context, profile *OpenAIProfile, req model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) {
|
||||
if profile.Config.Name != "router" || req.Model != "router-model" {
|
||||
t.Fatalf("router profile not used: profile=%s model=%s", profile.Config.Name, req.Model)
|
||||
}
|
||||
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{Content: stringContent("no tool")}}}}, nil
|
||||
}
|
||||
|
||||
_, err = runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Name: "chat", Model: "chat-model"}}, []ChatMessage{{Role: "user", Content: "今天"}}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user