package main import ( "bufio" "bytes" "context" "crypto/rand" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net" "net/http" "net/url" "os" "path/filepath" "sort" "strings" "sync" "time" "unicode" searchagent "aichat/agents/search" sqlquery "aichat/agents/sql" timeagent "aichat/agents/time" "github.com/gin-gonic/gin" ark "github.com/volcengine/volcengine-go-sdk/service/arkruntime" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "gopkg.in/yaml.v3" ) // ─── 配置 ───────────────────────────────────────────────── const ( defaultOpenAIBaseURL = "https://ark.cn-beijing.volces.com/api/v3" defaultOpenAITimeout = 120 defaultToolRouterTimeout = 30 defaultToolRouterMaxTokens = 512 defaultToolRouterSystemText = `你可以按需直接调用可用工具来回答用户问题。 如果用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近等相对时间,且后续需要搜索或查询数据库,应先调用 time 获取绝对日期范围。 需要实时网页资料、新闻、当前版本、近期事件、网页核验或用户明确要求联网时,调用 search。 需要查询本地业务数据、日程、会议、待办、记录、统计或时间范围内数据时,调用 sql。 工具结果优先于模型内置知识;工具失败时必须如实说明,不要编造结果。 只调用确实必要的工具。` ) type OpenAIConfig struct { Name string `yaml:"name" json:"name"` Active bool `yaml:"active,omitempty" json:"active"` APIKey string `yaml:"api_key" json:"-"` BaseURL string `yaml:"base_url" json:"base_url"` Model string `yaml:"model" json:"model"` Timeout int `yaml:"timeout" json:"timeout"` ParseThinkTags *bool `yaml:"parse_think_tags,omitempty" json:"parse_think_tags,omitempty"` } type OpenAIConfigs []OpenAIConfig type ToolRouterConfig struct { Enabled bool `yaml:"enabled" json:"enabled"` OpenAIName string `yaml:"openai_name" json:"openai_name"` Timeout int `yaml:"timeout" json:"timeout"` MaxTokens int `yaml:"max_tokens" json:"max_tokens"` SystemPrompt string `yaml:"system_prompt" json:"system_prompt"` Tools []ToolRouteConfig `yaml:"tools" json:"tools"` } type ToolRouteConfig struct { Name string `yaml:"name" json:"name"` Enabled bool `yaml:"enabled" json:"enabled"` Description string `yaml:"description" json:"description"` } func (configs *OpenAIConfigs) UnmarshalYAML(value *yaml.Node) error { switch value.Kind { case yaml.SequenceNode: var items []OpenAIConfig if err := value.Decode(&items); err != nil { return err } *configs = items case yaml.MappingNode: var item OpenAIConfig if err := value.Decode(&item); err != nil { return err } *configs = []OpenAIConfig{item} case yaml.ScalarNode: if value.Tag == "!!null" { *configs = nil return nil } return fmt.Errorf("openai 配置格式无效") default: return fmt.Errorf("openai 配置格式无效") } return nil } type Config struct { Server struct { Mode string `yaml:"mode"` Address string `yaml:"address"` } `yaml:"server"` OpenAI OpenAIConfigs `yaml:"openai"` ToolRouter ToolRouterConfig `yaml:"tool_router"` } func defaultOpenAIConfig() OpenAIConfig { return OpenAIConfig{ Name: "default", Active: true, BaseURL: defaultOpenAIBaseURL, Timeout: defaultOpenAITimeout, } } func defaultToolRouterConfig() ToolRouterConfig { return ToolRouterConfig{ Enabled: true, OpenAIName: "", Timeout: defaultToolRouterTimeout, MaxTokens: defaultToolRouterMaxTokens, SystemPrompt: defaultToolRouterSystemText, Tools: []ToolRouteConfig{ {Name: "time", Enabled: true, Description: ""}, {Name: "search", Enabled: true, Description: ""}, {Name: "sql", Enabled: true, Description: ""}, }, } } func defaultConfig() Config { var cfg Config cfg.Server.Mode = "tcp" cfg.Server.Address = "0.0.0.0:8080" cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()} cfg.ToolRouter = defaultToolRouterConfig() return cfg } func loadConfig(path string) (*Config, error) { if err := ensureConfigFile(path); err != nil { return nil, err } data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("读取配置文件失败: %w", err) } var cfg Config if err = yaml.Unmarshal(data, &cfg); err != nil { return nil, fmt.Errorf("解析配置文件失败: %w", err) } if _, err := normalizeOpenAIConfigs(&cfg); err != nil { return nil, err } // 环境变量优先 if key := os.Getenv("ARK_API_KEY"); key != "" { for i := range cfg.OpenAI { cfg.OpenAI[i].APIKey = key } } legacySearchProfiles = readLegacySearchProfiles(data) if _, err := normalizeToolRouterConfig(&cfg); err != nil { return nil, err } return &cfg, nil } func ensureConfigFile(path string) error { defaults := defaultConfig() if _, err := os.Stat(path); err != nil { if !os.IsNotExist(err) { return fmt.Errorf("检查配置文件失败: %w", err) } return writeConfig(path, defaults) } data, err := os.ReadFile(path) if err != nil { return fmt.Errorf("读取配置文件失败: %w", err) } var cfg Config if err = yaml.Unmarshal(data, &cfg); err != nil { return fmt.Errorf("解析配置文件失败: %w", err) } var raw map[string]any if err = yaml.Unmarshal(data, &raw); err != nil { return fmt.Errorf("解析配置文件失败: %w", err) } changed := false server, _ := raw["server"].(map[string]any) if server == nil { cfg.Server = defaults.Server changed = true } else { if _, ok := server["mode"]; !ok { cfg.Server.Mode = defaults.Server.Mode changed = true } if _, ok := server["address"]; !ok { cfg.Server.Address = defaults.Server.Address changed = true } } if _, ok := raw["openai"].([]any); !ok { changed = true } if normalized, err := normalizeOpenAIConfigs(&cfg); err != nil { return err } else if normalized { changed = true } if _, ok := raw["tool_router"]; !ok { cfg.ToolRouter = defaults.ToolRouter changed = true } else if normalized, err := normalizeToolRouterConfig(&cfg); err != nil { return err } else if normalized { changed = true } if !changed { return nil } return writeConfig(path, cfg) } func normalizeOpenAIConfigs(cfg *Config) (bool, error) { changed := false if len(cfg.OpenAI) == 0 { cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()} changed = true } activeIndex := -1 seen := map[string]bool{} for i := range cfg.OpenAI { profile := &cfg.OpenAI[i] name := strings.TrimSpace(profile.Name) if name == "" { name = strings.TrimSpace(profile.Model) if name == "" { name = fmt.Sprintf("openai-%d", i+1) } profile.Name = name changed = true } else if name != profile.Name { profile.Name = name changed = true } if seen[name] { return changed, fmt.Errorf("openai 配置名称重复: %s", name) } seen[name] = true if strings.TrimSpace(profile.BaseURL) == "" { profile.BaseURL = defaultOpenAIBaseURL changed = true } if profile.Timeout <= 0 { profile.Timeout = defaultOpenAITimeout changed = true } if profile.Active { if activeIndex == -1 { activeIndex = i } else { profile.Active = false changed = true } } } if activeIndex == -1 { cfg.OpenAI[0].Active = true changed = true } return changed, nil } func isLegacyToolRouterPrompt(prompt string) bool { prompt = strings.TrimSpace(prompt) return strings.Contains(prompt, "工具路由器") || strings.Contains(prompt, "route_tools") || strings.Contains(prompt, `"tools":[`) } func normalizeToolRouterConfig(cfg *Config) (bool, error) { changed := false defaults := defaultToolRouterConfig() cfg.ToolRouter.OpenAIName = strings.TrimSpace(cfg.ToolRouter.OpenAIName) if cfg.ToolRouter.Timeout <= 0 { cfg.ToolRouter.Timeout = defaultToolRouterTimeout changed = true } if cfg.ToolRouter.MaxTokens <= 0 { cfg.ToolRouter.MaxTokens = defaultToolRouterMaxTokens changed = true } systemPrompt := strings.TrimSpace(cfg.ToolRouter.SystemPrompt) if systemPrompt == "" || isLegacyToolRouterPrompt(systemPrompt) { cfg.ToolRouter.SystemPrompt = defaultToolRouterSystemText changed = true } else if systemPrompt != cfg.ToolRouter.SystemPrompt { cfg.ToolRouter.SystemPrompt = systemPrompt changed = true } if len(cfg.ToolRouter.Tools) == 0 { cfg.ToolRouter.Tools = defaults.Tools changed = true } seen := map[string]bool{} for i := range cfg.ToolRouter.Tools { tool := &cfg.ToolRouter.Tools[i] name := strings.ToLower(strings.TrimSpace(tool.Name)) if name == "" { name = fmt.Sprintf("tool-%d", i+1) } if name != tool.Name { tool.Name = name changed = true } tool.Description = strings.TrimSpace(tool.Description) if seen[name] { return changed, fmt.Errorf("tool_router.tools 配置名称重复: %s", name) } seen[name] = true } byName := map[string]ToolRouteConfig{} for _, tool := range cfg.ToolRouter.Tools { byName[tool.Name] = tool } merged := make([]ToolRouteConfig, 0, len(cfg.ToolRouter.Tools)+len(defaults.Tools)) used := map[string]bool{} for _, tool := range defaults.Tools { if existing, ok := byName[tool.Name]; ok { merged = append(merged, existing) } else { merged = append(merged, tool) changed = true } used[tool.Name] = true } for _, tool := range cfg.ToolRouter.Tools { if !used[tool.Name] { merged = append(merged, tool) } } if len(merged) != len(cfg.ToolRouter.Tools) { changed = true } else { for i := range merged { if merged[i].Name != cfg.ToolRouter.Tools[i].Name { changed = true break } } } cfg.ToolRouter.Tools = merged return changed, nil } func readLegacySearchProfiles(data []byte) []searchagent.ProfileConfig { var legacy struct { Search searchagent.ProfileConfigs `yaml:"search"` } if err := yaml.Unmarshal(data, &legacy); err != nil { return nil } return []searchagent.ProfileConfig(legacy.Search) } func writeConfig(path string, cfg Config) error { data, err := yaml.Marshal(&cfg) if err != nil { return fmt.Errorf("生成配置文件失败: %w", err) } if err := os.WriteFile(path, data, 0644); err != nil { return fmt.Errorf("写入配置文件失败: %w", err) } return nil } // ─── 请求结构 ───────────────────────────────────────────── type ChatMessage struct { Role string `json:"role"` Content string `json:"content"` ImageURL string `json:"image_url,omitempty"` // base64 data URI 或 http URL ImageURLAlias string `json:"imageURL,omitempty"` Hidden bool `json:"hidden,omitempty"` } type ChatRequest struct { ConversationID string `json:"conversation_id,omitempty"` Messages []ChatMessage `json:"messages"` WebSearch bool `json:"web_search,omitempty"` OpenAIName string `json:"openai_name,omitempty"` } type Conversation struct { ID string `json:"id"` Title string `json:"title"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` Messages []ChatMessage `json:"messages,omitempty"` } type ConvStore struct { dir string mu sync.Mutex } type OpenAIProfile struct { Config OpenAIConfig Client *ark.Client } type OpenAIState struct { mu sync.RWMutex profiles map[string]*OpenAIProfile order []string activeName string } type activeProfileRequest struct { Name string `json:"name"` } type openAIListResponse struct { Active string `json:"active"` Profiles []OpenAIConfig `json:"profiles"` } type chatCompleter func(context.Context, *OpenAIProfile, model.CreateChatCompletionRequest, time.Duration) (model.ChatCompletionResponse, error) type ToolRouterState struct { cfg *ToolRouterConfig ai *OpenAIState complete chatCompleter } func NewToolRouterState(config *ToolRouterConfig, ai *OpenAIState) (*ToolRouterState, error) { if config == nil { cfg := defaultToolRouterConfig() config = &cfg } if ai == nil { return nil, errors.New("工具路由需要 OpenAI 状态") } if config.Enabled && strings.TrimSpace(config.OpenAIName) != "" { if _, err := ai.GetProfile(config.OpenAIName); err != nil { return nil, fmt.Errorf("tool_router.openai_name 配置无效: %w", err) } } return &ToolRouterState{cfg: config, ai: ai, complete: completeChatWithTimeout}, nil } func NewOpenAIState(configs []OpenAIConfig) (*OpenAIState, error) { state := &OpenAIState{ profiles: make(map[string]*OpenAIProfile, len(configs)), order: make([]string, 0, len(configs)), } for _, config := range configs { if strings.TrimSpace(config.Name) == "" { return nil, errors.New("openai.name 不能为空") } if strings.TrimSpace(config.APIKey) == "" { return nil, fmt.Errorf("openai.%s.api_key 未配置,也未设置环境变量 ARK_API_KEY", config.Name) } if strings.TrimSpace(config.Model) == "" { return nil, fmt.Errorf("openai.%s.model 未配置", config.Name) } if strings.TrimSpace(config.BaseURL) == "" { return nil, fmt.Errorf("openai.%s.base_url 未配置", config.Name) } if config.Timeout <= 0 { return nil, fmt.Errorf("openai.%s.timeout 必须大于 0", config.Name) } if _, ok := state.profiles[config.Name]; ok { return nil, fmt.Errorf("openai 配置名称重复: %s", config.Name) } state.profiles[config.Name] = &OpenAIProfile{ Config: config, Client: ark.NewClientWithApiKey( config.APIKey, ark.WithBaseUrl(config.BaseURL), ark.WithTimeout(time.Duration(config.Timeout)*time.Second), ), } state.order = append(state.order, config.Name) if config.Active && state.activeName == "" { state.activeName = config.Name } } if len(state.order) == 0 { return nil, errors.New("openai 配置不能为空") } if state.activeName == "" { state.activeName = state.order[0] } return state, nil } func (s *OpenAIState) ActiveProfile() *OpenAIProfile { s.mu.RLock() defer s.mu.RUnlock() return s.profiles[s.activeName] } func (s *OpenAIState) GetProfile(name string) (*OpenAIProfile, error) { s.mu.RLock() defer s.mu.RUnlock() if strings.TrimSpace(name) == "" { return s.profiles[s.activeName], nil } profile, ok := s.profiles[strings.TrimSpace(name)] if !ok { return nil, fmt.Errorf("OpenAI 配置不存在: %s", name) } return profile, nil } func (s *OpenAIState) SwitchActive(name string) (*OpenAIProfile, error) { name = strings.TrimSpace(name) if name == "" { return nil, errors.New("OpenAI 配置名称不能为空") } s.mu.Lock() defer s.mu.Unlock() profile, ok := s.profiles[name] if !ok { return nil, fmt.Errorf("OpenAI 配置不存在: %s", name) } s.activeName = name return profile, nil } func (s *OpenAIState) ListProfiles() openAIListResponse { s.mu.RLock() defer s.mu.RUnlock() profiles := make([]OpenAIConfig, 0, len(s.order)) for _, name := range s.order { profile := s.profiles[name] config := profile.Config config.APIKey = "" config.Active = name == s.activeName profiles = append(profiles, config) } return openAIListResponse{Active: s.activeName, Profiles: profiles} } func publicOpenAIConfig(profile *OpenAIProfile, active bool) OpenAIConfig { config := profile.Config config.APIKey = "" config.Active = active return config } func (s *ToolRouterState) RouterProfile(fallback *OpenAIProfile) *OpenAIProfile { if s == nil || s.cfg == nil || s.ai == nil { return fallback } name := strings.TrimSpace(s.cfg.OpenAIName) if name == "" { return fallback } profile, err := s.ai.GetProfile(name) if err != nil { return fallback } return profile } func isOllamaProfile(profile *OpenAIProfile) bool { if profile == nil { return false } u, err := url.Parse(strings.TrimSpace(profile.Config.BaseURL)) if err != nil { return strings.Contains(profile.Config.BaseURL, ":11434") } host := strings.ToLower(u.Hostname()) port := u.Port() return port == "11434" && (host == "127.0.0.1" || host == "localhost" || host == "::1") } func shouldParseThinkTags(profile *OpenAIProfile) bool { if profile == nil { return false } if profile.Config.ParseThinkTags != nil { return *profile.Config.ParseThinkTags } return isOllamaProfile(profile) } // ─── 全局变量 ───────────────────────────────────────────── var ( cfg *Config aiState *OpenAIState searchState *searchagent.State legacySearchProfiles []searchagent.ProfileConfig toolRouterState *ToolRouterState sqlState *sqlquery.State store *ConvStore ) type chatSSEFrame struct { Type string `json:"type"` Text string `json:"text,omitempty"` Message string `json:"message,omitempty"` Tool string `json:"tool,omitempty"` Stage string `json:"stage,omitempty"` Status string `json:"status,omitempty"` Data map[string]any `json:"data,omitempty"` Stats *tokenUsageStats `json:"stats,omitempty"` Error string `json:"error,omitempty"` } type tokenUsageStats struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` ToolPromptTokens int `json:"tool_prompt_tokens"` ToolCompletionTokens int `json:"tool_completion_tokens"` TotalTokens int `json:"total_tokens"` CompletionTokensPerSec float64 `json:"completion_tokens_per_sec"` PeakCompletionTokensPerSec float64 `json:"peak_completion_tokens_per_sec"` Estimated bool `json:"estimated"` } type tokenUsageTracker struct { mu sync.Mutex promptTokens int completionTokens int toolPromptTokens int toolCompletionTokens int } type tokenUsageContextKey struct{} func newTokenUsageTracker() *tokenUsageTracker { return &tokenUsageTracker{} } func contextWithTokenUsage(ctx context.Context, tracker *tokenUsageTracker) context.Context { if tracker == nil { return ctx } return context.WithValue(ctx, tokenUsageContextKey{}, tracker) } func tokenUsageFromContext(ctx context.Context) *tokenUsageTracker { tracker, _ := ctx.Value(tokenUsageContextKey{}).(*tokenUsageTracker) return tracker } func (t *tokenUsageTracker) addTool(promptTokens, completionTokens int) { if t == nil { return } t.mu.Lock() defer t.mu.Unlock() t.toolPromptTokens += promptTokens t.toolCompletionTokens += completionTokens } func (t *tokenUsageTracker) setModel(promptTokens, completionTokens int) { if t == nil { return } t.mu.Lock() defer t.mu.Unlock() t.promptTokens = promptTokens t.completionTokens = completionTokens } func (t *tokenUsageTracker) snapshot(tokensPerSecond, peakTokensPerSecond float64) tokenUsageStats { if t == nil { return tokenUsageStats{Estimated: true} } t.mu.Lock() defer t.mu.Unlock() total := t.promptTokens + t.completionTokens + t.toolPromptTokens + t.toolCompletionTokens return tokenUsageStats{ PromptTokens: t.promptTokens, CompletionTokens: t.completionTokens, ToolPromptTokens: t.toolPromptTokens, ToolCompletionTokens: t.toolCompletionTokens, TotalTokens: total, CompletionTokensPerSec: tokensPerSecond, PeakCompletionTokensPerSec: peakTokensPerSecond, Estimated: true, } } // ─── 路由 ───────────────────────────────────────────────── func indexHandler(c *gin.Context) { profile := aiState.ActiveProfile() c.HTML(http.StatusOK, "chat.html", gin.H{ "Title": "AI 对话", "Model": profile.Config.Model, "OpenAIName": profile.Config.Name, }) } func listOpenAIHandler(c *gin.Context) { c.JSON(http.StatusOK, aiState.ListProfiles()) } func switchOpenAIHandler(c *gin.Context) { var req activeProfileRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()}) return } profile, err := aiState.SwitchActive(req.Name) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{ "active": profile.Config.Name, "profile": publicOpenAIConfig(profile, true), }) } func listSearchHandler(c *gin.Context) { c.JSON(http.StatusOK, searchState.ListProfiles()) } func switchSearchHandler(c *gin.Context) { var req activeProfileRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()}) return } profile, err := searchState.SwitchActive(req.Name) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } profile.APIKey = "" profile.Active = true c.JSON(http.StatusOK, gin.H{ "active": profile.Name, "profile": profile, }) } func listConversationsHandler(c *gin.Context) { convs, err := store.List() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, convs) } func createConversationHandler(c *gin.Context) { conv, err := store.Create() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "创建对话失败: " + err.Error()}) return } c.JSON(http.StatusOK, conv) } func getConversationHandler(c *gin.Context) { conv, err := store.Get(c.Param("id")) if err != nil { status := http.StatusInternalServerError if err.Error() == "对话不存在" { status = http.StatusNotFound } c.JSON(status, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, conv) } func deleteConversationHandler(c *gin.Context) { if err := store.Delete(c.Param("id")); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.Status(http.StatusNoContent) } // chatHandler 流式 SSE 对话接口 func chatHandler(c *gin.Context) { var req ChatRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()}) return } if len(req.Messages) == 0 { c.JSON(http.StatusBadRequest, gin.H{"error": "消息不能为空"}) return } profile, err := aiState.GetProfile(req.OpenAIName) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // SSE 头先写出,后续插件/模型过程都通过 trace 事件实时展示。 c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") c.Writer.Header().Set("X-Accel-Buffering", "no") c.Writer.WriteHeader(http.StatusOK) flusher, ok := c.Writer.(http.Flusher) if !ok { return } emit := func(frame chatSSEFrame) { writeSSEJSON(c.Writer, frame) flusher.Flush() } emitTrace := func(tool, stage, status, message string, data map[string]any) { emit(chatSSEFrame{Type: "trace", Tool: tool, Stage: stage, Status: status, Message: message, Data: data}) } emitError := func(err error) { emit(chatSSEFrame{Type: "error", Error: err.Error()}) } // 超时 context timeout := time.Duration(profile.Config.Timeout) * time.Second ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) defer cancel() usage := newTokenUsageTracker() ctx = contextWithTokenUsage(ctx, usage) // 用 Function Calling 工具循环替代旧的路由+隐藏上下文机制 messages, err := runAgentToolLoop(ctx, profile, req.Messages, emit) if err != nil { fmt.Fprintln(os.Stderr, "Agent 工具循环失败:", err) messages, err = buildArkMessages(req.Messages) if err != nil { emitError(err) return } } promptTokens := estimateChatMessagesTokens(req.Messages) if isOllamaProfile(profile) && hasImageMessage(req.Messages) { emitTrace("model", "request", "running", "正在通过 Ollama 原生接口调用视觉模型", nil) err = streamOllamaChat(ctx, profile, messages, promptTokens, usage, emit, func(content string) { if req.ConversationID != "" { if err := saveConversationMessages(req.ConversationID, req.Messages, content); err != nil { fmt.Fprintln(os.Stderr, "保存对话失败:", err) } } }) if err != nil { emitError(err) } return } emitTrace("model", "request", "running", "正在调用模型生成回答", nil) stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{ Model: profile.Config.Model, Messages: messages, MaxTokens: intPtr(4096), }.WithStream(true)) if err != nil { emitError(err) return } defer stream.Close() emitTrace("model", "stream", "running", "模型已开始输出", nil) var full strings.Builder completionTokens := 0 streamStarted := time.Now() windowStarted := streamStarted windowTokens := 0 peakTokensPerSecond := 0.0 parseThinkTags := shouldParseThinkTags(profile) thinkParser := &thinkTagParser{} emitDelta := func(delta string) { if delta == "" { return } now := time.Now() deltaTokens := estimateTokenCount(delta) windowTokens += deltaTokens windowElapsed := now.Sub(windowStarted).Seconds() if windowElapsed >= 1 { windowSpeed := float64(windowTokens) / windowElapsed if windowSpeed > peakTokensPerSecond { peakTokensPerSecond = windowSpeed } windowStarted = now windowTokens = 0 } else if peakTokensPerSecond == 0 && windowElapsed > 0.25 { peakTokensPerSecond = float64(windowTokens) / windowElapsed } full.WriteString(delta) completionTokens += deltaTokens usage.setModel(promptTokens, completionTokens) stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond) emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats}) } emitModelContent := func(delta string) { if delta == "" { return } if !parseThinkTags { emitDelta(delta) return } visible, reasoning := thinkParser.Accept(delta) if reasoning != "" { emit(chatSSEFrame{Type: "reasoning", Text: reasoning}) } emitDelta(visible) } for { resp, err := stream.Recv() if errors.Is(err, io.EOF) { if parseThinkTags { visible, reasoning := thinkParser.Flush() if reasoning != "" { emit(chatSSEFrame{Type: "reasoning", Text: reasoning}) } emitDelta(visible) } usage.setModel(promptTokens, completionTokens) if windowTokens > 0 { windowElapsed := time.Since(windowStarted).Seconds() if windowElapsed > 0.25 { windowSpeed := float64(windowTokens) / windowElapsed if windowSpeed > peakTokensPerSecond { peakTokensPerSecond = windowSpeed } } } if peakTokensPerSecond == 0 { peakTokensPerSecond = tokensPerSecond(completionTokens, streamStarted) } if req.ConversationID != "" { if err := saveConversationMessages(req.ConversationID, req.Messages, full.String()); err != nil { fmt.Fprintln(os.Stderr, "保存对话失败:", err) } } finalStats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond) emit(chatSSEFrame{Type: "stats", Stats: &finalStats}) emitTrace("model", "stream", "success", "回答生成完成", nil) fmt.Fprintf(c.Writer, "data: [DONE]\n\n") flusher.Flush() return } if err != nil { emitError(err) return } if len(resp.Choices) > 0 { emitModelContent(resp.Choices[0].Delta.Content) // 思考过程 reasoning_content 单独事件推送 if resp.Choices[0].Delta.ReasoningContent != nil && *resp.Choices[0].Delta.ReasoningContent != "" { emit(chatSSEFrame{Type: "reasoning", Text: *resp.Choices[0].Delta.ReasoningContent}) } } } } // ─── 辅助函数 ───────────────────────────────────────────── func estimateChatMessagesTokens(messages []ChatMessage) int { total := 0 for _, msg := range messages { total += estimateTokenCount(msg.Role) + estimateTokenCount(msg.Content) + 4 if msg.ImageURL != "" || msg.ImageURLAlias != "" { total += 85 } } return total } func estimateTokenCount(text string) int { text = strings.TrimSpace(text) if text == "" { return 0 } tokens := 0 asciiRunes := 0 flushASCII := func() { if asciiRunes > 0 { tokens += (asciiRunes + 3) / 4 asciiRunes = 0 } } for _, r := range text { if unicode.IsSpace(r) { flushASCII() continue } if r <= unicode.MaxASCII { asciiRunes++ continue } flushASCII() tokens++ } flushASCII() if tokens == 0 { return 1 } return tokens } func tokensPerSecond(tokens int, start time.Time) float64 { elapsed := time.Since(start).Seconds() if tokens <= 0 || elapsed <= 0 { return 0 } return float64(tokens) / elapsed } type agentTool struct { name string definition *model.Tool execute func(context.Context, string) (string, error) } func (t agentTool) Name() string { return t.name } const maxAgentToolIterations = 6 func availableAgentTools(profile *OpenAIProfile, emit func(chatSSEFrame)) []agentTool { if toolRouterState == nil || toolRouterState.cfg == nil || !toolRouterState.cfg.Enabled { return nil } tools := make([]agentTool, 0, len(toolRouterState.cfg.Tools)) for _, item := range toolRouterState.cfg.Tools { if !item.Enabled { continue } description := strings.TrimSpace(item.Description) switch item.Name { case timeagent.ToolName: tools = append(tools, agentTool{ name: timeagent.ToolName, definition: timeagent.ToolDefinition(description), execute: func(ctx context.Context, args string) (string, error) { result, err := timeagent.ExecuteTool(args, time.Now()) if err == nil && emit != nil { emit(chatSSEFrame{Type: "trace", Tool: timeagent.ToolName, Stage: "resolve", Status: "success", Message: "已获取当前时间上下文"}) } return result, err }, }) case searchagent.ToolName: if searchState == nil || !searchState.Enabled() { continue } tools = append(tools, agentTool{ name: searchagent.ToolName, definition: searchState.ToolDefinition(description), execute: func(ctx context.Context, args string) (string, error) { if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: searchagent.ToolName, Stage: "request", Status: "running", Message: "正在联网搜索"}) } result, err := searchState.ExecuteTool(ctx, args) if emit != nil { status := "success" message := "联网搜索完成" if err != nil { status = "error" message = "联网搜索失败" } emit(chatSSEFrame{Type: "trace", Tool: searchagent.ToolName, Stage: "results", Status: status, Message: message}) } return result, err }, }) case sqlquery.ToolName: if sqlState == nil || !sqlState.Enabled() { continue } tools = append(tools, agentTool{ name: sqlquery.ToolName, definition: sqlState.ToolDefinition(description), execute: func(ctx context.Context, args string) (string, error) { if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: sqlquery.ToolName, Stage: "execute", Status: "running", Message: "正在查询数据库"}) } generator := func(ctx context.Context, prompt string, maxTokens int) (string, error) { return completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, maxTokens) } result, err := sqlState.ExecuteTool(ctx, args, generator) if emit != nil { status := "success" message := "数据库查询完成" if err != nil { status = "error" message = "数据库查询失败" } emit(chatSSEFrame{Type: "trace", Tool: sqlquery.ToolName, Stage: "execute", Status: status, Message: message}) } return result, err }, }) } } return tools } func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, emit func(chatSSEFrame)) ([]*model.ChatCompletionMessage, error) { finalMessages, err := buildArkMessages(chatMessages) if err != nil { return nil, err } routerProfile := profile if toolRouterState != nil { routerProfile = toolRouterState.RouterProfile(profile) } tools := availableAgentTools(routerProfile, emit) if len(tools) == 0 { return finalMessages, nil } decisionMessages := append([]*model.ChatCompletionMessage(nil), finalMessages...) if hasImageMessage(chatMessages) { decisionMessages, err = buildToolDecisionMessages(chatMessages) if err != nil { return nil, err } if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "prepare", Status: "success", Message: "检测到图片输入,工具判断阶段将使用纯文本上下文"}) } } toolByName := make(map[string]agentTool, len(tools)) definitions := make([]*model.Tool, 0, len(tools)) availableNames := make([]string, 0, len(tools)) toolDescriptions := make([]string, 0, len(tools)) for _, tool := range tools { toolByName[tool.name] = tool definitions = append(definitions, tool.definition) availableNames = append(availableNames, tool.name) if tool.definition != nil && tool.definition.Function != nil { toolDescriptions = append(toolDescriptions, fmt.Sprintf("%s: %s", tool.name, tool.definition.Function.Description)) } } if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "prepare", Status: "success", Message: "已准备可用工具", Data: map[string]any{"tools": availableNames, "tool_descriptions": toolDescriptions}}) } if prompt := strings.TrimSpace(toolRouterState.cfg.SystemPrompt); prompt != "" { systemMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent(prompt)} finalMessages = append([]*model.ChatCompletionMessage{systemMessage}, finalMessages...) decisionMessages = append([]*model.ChatCompletionMessage{systemMessage}, decisionMessages...) } for i := 0; i < maxAgentToolIterations; i++ { if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "running", Message: fmt.Sprintf("正在进行第 %d 轮工具判断", i+1), Data: map[string]any{"iteration": i + 1, "max_iterations": maxAgentToolIterations, "tools": availableNames}}) } resp, err := toolRouterState.complete(ctx, routerProfile, model.CreateChatCompletionRequest{ Model: routerProfile.Config.Model, Messages: decisionMessages, MaxTokens: intPtr(toolRouterState.cfg.MaxTokens), Tools: definitions, ToolChoice: model.ToolChoiceStringTypeAuto, ParallelToolCalls: boolPtr(false), }, time.Duration(toolRouterState.cfg.Timeout)*time.Second) if err != nil { return finalMessages, err } if tracker := tokenUsageFromContext(ctx); tracker != nil { tracker.addTool(resp.Usage.PromptTokens, resp.Usage.CompletionTokens) } if len(resp.Choices) == 0 { return finalMessages, nil } choice := resp.Choices[0] decisionPreview := chatMessageContentString(choice.Message.Content) if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "decision", Status: "success", Message: "工具判断响应已返回", Data: map[string]any{"iteration": i + 1, "finish_reason": string(choice.FinishReason), "content_preview": truncateString(decisionPreview, 800)}}) } calls := choice.Message.ToolCalls if len(calls) == 0 && choice.Message.FunctionCall != nil { calls = []*model.ToolCall{{ID: "legacy_function_call", Type: model.ToolTypeFunction, Function: *choice.Message.FunctionCall}} } if len(calls) == 0 { if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "success", Message: "模型未请求工具,进入回答生成"}) } return finalMessages, nil } callNames := make([]string, 0, len(calls)) for _, call := range calls { if call != nil { callNames = append(callNames, call.Function.Name) } } if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "tool_calls", Status: "running", Message: fmt.Sprintf("模型请求调用 %d 个工具", len(calls)), Data: map[string]any{"tools": callNames, "iteration": i + 1}}) } assistantMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleAssistant, ToolCalls: calls, Content: choice.Message.Content} finalMessages = append(finalMessages, assistantMessage) decisionMessages = append(decisionMessages, assistantMessage) for _, call := range calls { result := executeAgentToolCall(ctx, call, toolByName, emit) toolMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleTool, ToolCallID: call.ID, Content: stringContent(result)} finalMessages = append(finalMessages, toolMessage) decisionMessages = append(decisionMessages, toolMessage) } } limitMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent("工具调用轮数已达到上限。请基于已有工具结果回答,并说明可能未完成全部工具调用。")} finalMessages = append(finalMessages, limitMessage) return finalMessages, nil } type thinkTagParser struct { inThink bool buffer string } const ( thinkOpenTag = "" thinkCloseTag = "" ) func (p *thinkTagParser) Accept(delta string) (visible string, reasoning string) { p.buffer += delta for p.buffer != "" { if p.inThink { idx := strings.Index(p.buffer, thinkCloseTag) if idx >= 0 { reasoning += p.buffer[:idx] p.buffer = p.buffer[idx+len(thinkCloseTag):] p.inThink = false continue } keep := tagPrefixSuffixLen(p.buffer, thinkCloseTag) if len(p.buffer) > keep { reasoning += p.buffer[:len(p.buffer)-keep] p.buffer = p.buffer[len(p.buffer)-keep:] } return visible, reasoning } idx := strings.Index(p.buffer, thinkOpenTag) if idx >= 0 { visible += p.buffer[:idx] p.buffer = p.buffer[idx+len(thinkOpenTag):] p.inThink = true continue } keep := tagPrefixSuffixLen(p.buffer, thinkOpenTag) if len(p.buffer) > keep { visible += p.buffer[:len(p.buffer)-keep] p.buffer = p.buffer[len(p.buffer)-keep:] } return visible, reasoning } return visible, reasoning } func (p *thinkTagParser) Flush() (visible string, reasoning string) { if p.inThink { reasoning = p.buffer } else { visible = p.buffer } p.buffer = "" p.inThink = false return visible, reasoning } func tagPrefixSuffixLen(text, tag string) int { limit := len(tag) - 1 if len(text) < limit { limit = len(text) } for i := limit; i > 0; i-- { if strings.HasPrefix(tag, text[len(text)-i:]) { return i } } return 0 } func executeAgentToolCall(ctx context.Context, call *model.ToolCall, tools map[string]agentTool, emit func(chatSSEFrame)) string { if call == nil || call.Type != model.ToolTypeFunction { result := "工具调用无效:仅支持 function 类型工具。" if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "execute", Status: "error", Message: result}) } return result } toolName := call.Function.Name if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: toolName, Stage: "arguments", Status: "running", Message: "准备执行工具", Data: map[string]any{"tool_call_id": call.ID, "arguments": call.Function.Arguments}}) } tool, ok := tools[toolName] if !ok { result := fmt.Sprintf("工具调用失败:未知工具 %s。", toolName) if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: toolName, Stage: "execute", Status: "error", Message: result}) } return result } started := time.Now() result, err := tool.execute(ctx, call.Function.Arguments) durationMs := time.Since(started).Milliseconds() if err != nil { message := fmt.Sprintf("工具 %s 执行失败:%v", tool.name, err) if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: tool.name, Stage: "execute", Status: "error", Message: "工具执行失败", Data: map[string]any{"tool_call_id": call.ID, "duration_ms": durationMs, "error": err.Error()}}) } return message } if strings.TrimSpace(result) == "" { result = fmt.Sprintf("工具 %s 执行完成,但没有返回内容。", tool.name) } if emit != nil { emit(chatSSEFrame{Type: "trace", Tool: tool.name, Stage: "result", Status: "success", Message: "工具执行完成", Data: map[string]any{"tool_call_id": call.ID, "duration_ms": durationMs, "result_preview": truncateString(result, 1200)}}) } return result } type ollamaChatRequest struct { Model string `json:"model"` Messages []ollamaChatMessage `json:"messages"` Stream bool `json:"stream"` Options map[string]int `json:"options,omitempty"` } type ollamaChatMessage struct { Role string `json:"role"` Content string `json:"content"` Images []string `json:"images,omitempty"` } type ollamaChatResponse struct { Message struct { Role string `json:"role"` Content string `json:"content"` Thinking string `json:"thinking"` } `json:"message"` Done bool `json:"done"` PromptEvalCount int `json:"prompt_eval_count"` EvalCount int `json:"eval_count"` DoneReason string `json:"done_reason"` } func streamOllamaChat(ctx context.Context, profile *OpenAIProfile, messages []*model.ChatCompletionMessage, promptTokens int, usage *tokenUsageTracker, emit func(chatSSEFrame), onDone func(string)) error { requestMessages, err := buildOllamaMessages(messages) if err != nil { return err } baseURL, err := ollamaBaseURL(profile) if err != nil { return err } body, err := json.Marshal(ollamaChatRequest{ Model: profile.Config.Model, Messages: requestMessages, Stream: true, Options: map[string]int{"num_predict": 4096}, }) if err != nil { return err } req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(baseURL, "/")+"/api/chat", bytes.NewReader(body)) if err != nil { return err } req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) return fmt.Errorf("Ollama 原生接口调用失败: %s %s", resp.Status, strings.TrimSpace(string(data))) } emit(chatSSEFrame{Type: "trace", Tool: "model", Stage: "stream", Status: "running", Message: "Ollama 视觉模型已开始输出"}) parseThinkTags := shouldParseThinkTags(profile) thinkParser := &thinkTagParser{} var full strings.Builder completionTokens := 0 streamStarted := time.Now() peakTokensPerSecond := 0.0 emitDelta := func(delta string) { if delta == "" { return } full.WriteString(delta) completionTokens += estimateTokenCount(delta) usage.setModel(promptTokens, completionTokens) currentSpeed := tokensPerSecond(completionTokens, streamStarted) if currentSpeed > peakTokensPerSecond { peakTokensPerSecond = currentSpeed } stats := usage.snapshot(currentSpeed, peakTokensPerSecond) emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats}) } emitContent := func(delta string) { if delta == "" { return } if !parseThinkTags { emitDelta(delta) return } visible, reasoning := thinkParser.Accept(delta) if reasoning != "" { emit(chatSSEFrame{Type: "reasoning", Text: reasoning}) } emitDelta(visible) } scanner := bufio.NewScanner(resp.Body) scanner.Buffer(make([]byte, 0, 64*1024), 10*1024*1024) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line == "" { continue } var chunk ollamaChatResponse if err := json.Unmarshal([]byte(line), &chunk); err != nil { return fmt.Errorf("解析 Ollama 流失败: %w", err) } if chunk.Message.Thinking != "" { emit(chatSSEFrame{Type: "reasoning", Text: chunk.Message.Thinking}) } emitContent(chunk.Message.Content) if chunk.Done { if chunk.PromptEvalCount > 0 || chunk.EvalCount > 0 { usage.setModel(chunk.PromptEvalCount, chunk.EvalCount) } break } } if err := scanner.Err(); err != nil { return err } if parseThinkTags { visible, reasoning := thinkParser.Flush() if reasoning != "" { emit(chatSSEFrame{Type: "reasoning", Text: reasoning}) } emitDelta(visible) } if onDone != nil { onDone(full.String()) } finalStats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond) emit(chatSSEFrame{Type: "stats", Stats: &finalStats}) emit(chatSSEFrame{Type: "trace", Tool: "model", Stage: "stream", Status: "success", Message: "回答生成完成"}) return nil } func buildOllamaMessages(messages []*model.ChatCompletionMessage) ([]ollamaChatMessage, error) { result := make([]ollamaChatMessage, 0, len(messages)) for _, msg := range messages { if msg == nil { continue } role := string(msg.Role) if msg.Role == model.ChatMessageRoleTool { role = string(model.ChatMessageRoleUser) } item := ollamaChatMessage{Role: role} if msg.Content == nil { if len(msg.ToolCalls) > 0 { continue } result = append(result, item) continue } if msg.Content.StringValue != nil { item.Content = *msg.Content.StringValue if msg.Role == model.ChatMessageRoleTool { item.Content = "工具结果:\n" + item.Content } result = append(result, item) continue } for _, part := range msg.Content.ListValue { if part == nil { continue } switch part.Type { case model.ChatCompletionMessageContentPartTypeText: if part.Text != "" { if item.Content != "" { item.Content += "\n" } item.Content += part.Text } case model.ChatCompletionMessageContentPartTypeImageURL: if part.ImageURL == nil { continue } image, err := ollamaImagePayload(part.ImageURL.URL) if err != nil { return nil, err } item.Images = append(item.Images, image) } } result = append(result, item) } return result, nil } func ollamaImagePayload(raw string) (string, error) { raw = strings.TrimSpace(raw) if strings.HasPrefix(strings.ToLower(raw), "data:") { comma := strings.Index(raw, ",") if comma < 0 { return "", errors.New("图片 base64 数据格式错误") } return strings.TrimSpace(raw[comma+1:]), nil } return raw, nil } func ollamaBaseURL(profile *OpenAIProfile) (string, error) { if profile == nil { return "", errors.New("Ollama 配置为空") } u, err := url.Parse(strings.TrimSpace(profile.Config.BaseURL)) if err != nil { return "", err } if strings.TrimRight(u.Path, "/") == "/v1" { u.Path = strings.TrimSuffix(strings.TrimRight(u.Path, "/"), "/v1") } u.RawQuery = "" u.Fragment = "" return strings.TrimRight(u.String(), "/"), nil } func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) { return completeTextWithTimeout(ctx, profile, chatMessages, maxTokens, time.Duration(profile.Config.Timeout)*time.Second) } func completeTextWithTimeout(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int, timeout time.Duration) (string, error) { messages, err := buildArkMessages(chatMessages) if err != nil { return "", err } completionCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() stream, err := profile.Client.CreateChatCompletionStream(completionCtx, model.CreateChatCompletionRequest{ Model: profile.Config.Model, Messages: messages, MaxTokens: intPtr(maxTokens), }.WithStream(true)) if err != nil { return "", err } defer stream.Close() promptTokens := estimateChatMessagesTokens(chatMessages) completionTokens := 0 parseThinkTags := shouldParseThinkTags(profile) thinkParser := &thinkTagParser{} var b strings.Builder appendVisible := func(delta string) { if delta == "" { return } b.WriteString(delta) completionTokens += estimateTokenCount(delta) } for { resp, err := stream.Recv() if errors.Is(err, io.EOF) { if parseThinkTags { visible, _ := thinkParser.Flush() appendVisible(visible) } if tracker := tokenUsageFromContext(ctx); tracker != nil { tracker.addTool(promptTokens, completionTokens) } return b.String(), nil } if err != nil { return "", err } if len(resp.Choices) > 0 { delta := resp.Choices[0].Delta.Content if parseThinkTags { visible, _ := thinkParser.Accept(delta) appendVisible(visible) } else { appendVisible(delta) } } } } func completeChatWithTimeout(ctx context.Context, profile *OpenAIProfile, request model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) { completionCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() return profile.Client.CreateChatCompletion(completionCtx, request.WithStream(false)) } func newUUID() string { b := make([]byte, 16) _, _ = rand.Read(b) b[6] = (b[6] & 0x0f) | 0x40 b[8] = (b[8] & 0x3f) | 0x80 return hex.EncodeToString(b[:4]) + "-" + hex.EncodeToString(b[4:6]) + "-" + hex.EncodeToString(b[6:8]) + "-" + hex.EncodeToString(b[8:10]) + "-" + hex.EncodeToString(b[10:]) } // ─── ConvStore ───────────────────────────────────────────── func NewConvStore(dir string) *ConvStore { os.MkdirAll(dir, 0755) return &ConvStore{dir: dir} } func (s *ConvStore) path(id string) string { return filepath.Join(s.dir, id+".json") } func (s *ConvStore) Create() (*Conversation, error) { conv := &Conversation{ ID: newUUID(), Title: "新对话", CreatedAt: time.Now(), UpdatedAt: time.Now(), } if err := s.Save(conv); err != nil { return nil, err } return conv, nil } func (s *ConvStore) Save(conv *Conversation) error { s.mu.Lock() defer s.mu.Unlock() conv.UpdatedAt = time.Now() return atomicWriteJSON(s.path(conv.ID), conv) } func (s *ConvStore) Get(id string) (*Conversation, error) { s.mu.Lock() defer s.mu.Unlock() data, err := os.ReadFile(s.path(id)) if err != nil { if os.IsNotExist(err) { return nil, errors.New("对话不存在") } return nil, fmt.Errorf("读取对话失败: %w", err) } var conv Conversation if err := json.Unmarshal(data, &conv); err != nil { return nil, fmt.Errorf("解析对话失败: %w", err) } return &conv, nil } func (s *ConvStore) List() ([]Conversation, error) { s.mu.Lock() defer s.mu.Unlock() entries, err := os.ReadDir(s.dir) if err != nil { return nil, fmt.Errorf("读取对话目录失败: %w", err) } var list []Conversation for _, e := range entries { if e.IsDir() || filepath.Ext(e.Name()) != ".json" { continue } data, err := os.ReadFile(filepath.Join(s.dir, e.Name())) if err != nil { continue } var conv Conversation if err := json.Unmarshal(data, &conv); err != nil { continue } conv.Messages = nil // 列表不返回消息体 list = append(list, conv) } sort.Slice(list, func(i, j int) bool { return list[i].UpdatedAt.After(list[j].UpdatedAt) }) return list, nil } func (s *ConvStore) Delete(id string) error { s.mu.Lock() defer s.mu.Unlock() if err := os.Remove(s.path(id)); err != nil && !os.IsNotExist(err) { return fmt.Errorf("删除对话失败: %w", err) } return nil } func atomicWriteJSON(path string, v any) error { tmp := path + ".tmp" data, err := json.Marshal(v) if err != nil { return err } if err := os.WriteFile(tmp, data, 0644); err != nil { return err } return os.Rename(tmp, path) } func saveConversationMessages(id string, messages []ChatMessage, assistantContent string) error { conv, err := store.Get(id) if err != nil { return err } conv.Messages = append([]ChatMessage(nil), messages...) conv.Messages = append(conv.Messages, ChatMessage{Role: "assistant", Content: assistantContent}) if conv.Title == "" || conv.Title == "新对话" { conv.Title = genConvTitle(conv.Messages) } return store.Save(conv) } func genConvTitle(messages []ChatMessage) string { for _, m := range messages { if m.Hidden { continue } if m.Role == "user" && strings.TrimSpace(m.Content) != "" { title := strings.TrimSpace(m.Content) title = strings.ReplaceAll(title, "\r\n", " ") title = strings.ReplaceAll(title, "\n", " ") runes := []rune(title) if len(runes) > 30 { return string(runes[:30]) + "..." } return title } } return "新对话" } const maxImageSize = 4 * 1024 * 1024 var allowedImageTypes = map[string]bool{ "image/jpeg": true, "image/png": true, "image/webp": true, "image/gif": true, } func buildArkMessages(chatMessages []ChatMessage) ([]*model.ChatCompletionMessage, error) { messages := make([]*model.ChatCompletionMessage, 0, len(chatMessages)) for _, m := range chatMessages { msg, err := buildArkMessage(m) if err != nil { return nil, err } messages = append(messages, msg) } return messages, nil } func hasImageMessage(messages []ChatMessage) bool { for _, msg := range messages { if strings.TrimSpace(msg.ImageURL) != "" || strings.TrimSpace(msg.ImageURLAlias) != "" { return true } } return false } func buildToolDecisionMessages(chatMessages []ChatMessage) ([]*model.ChatCompletionMessage, error) { messages := make([]*model.ChatCompletionMessage, 0, len(chatMessages)) for _, m := range chatMessages { content := m.Content if strings.TrimSpace(m.ImageURL) != "" || strings.TrimSpace(m.ImageURLAlias) != "" { content = strings.TrimSpace(content) placeholder := "[用户上传了一张图片。工具判断阶段不读取图片内容;如果问题主要依赖识图,应不要调用工具,交给最终多模态模型回答。]" if content == "" { content = placeholder } else { content += "\n\n" + placeholder } } messages = append(messages, &model.ChatCompletionMessage{Role: m.Role, Content: stringContent(content)}) } return messages, nil } func buildArkMessage(m ChatMessage) (*model.ChatCompletionMessage, error) { msg := &model.ChatCompletionMessage{Role: m.Role} if m.ImageURL == "" && m.ImageURLAlias != "" { m.ImageURL = m.ImageURLAlias } if m.ImageURL == "" { msg.Content = &model.ChatCompletionMessageContent{ StringValue: &m.Content, } return msg, nil } imageURL, err := normalizeImageURL(m.ImageURL) if err != nil { return nil, err } // 有图片时:文字内容可有可无(图片 caption 场景),均构造多模态消息 // 若无文字,则只传图片 part;若同时有图片和文字,先文后图 parts := make([]*model.ChatCompletionMessageContentPart, 0, 2) if m.Content != "" { parts = append(parts, textPart(m.Content)) } parts = append(parts, imagePart(imageURL)) msg.Content = &model.ChatCompletionMessageContent{ListValue: parts} return msg, nil } func imagePart(url string) *model.ChatCompletionMessageContentPart { return &model.ChatCompletionMessageContentPart{ Type: model.ChatCompletionMessageContentPartTypeImageURL, ImageURL: &model.ChatMessageImageURL{ URL: url, Detail: model.ImageURLDetailAuto, }, } } func textPart(text string) *model.ChatCompletionMessageContentPart { return &model.ChatCompletionMessageContentPart{ Type: model.ChatCompletionMessageContentPartTypeText, Text: text, } } func stringContent(text string) *model.ChatCompletionMessageContent { return &model.ChatCompletionMessageContent{StringValue: &text} } func chatMessageContentString(content *model.ChatCompletionMessageContent) string { if content == nil || content.StringValue == nil { return "" } return *content.StringValue } func normalizeImageURL(raw string) (string, error) { raw = strings.TrimSpace(raw) if raw == "" { return "", errors.New("图片地址不能为空") } lower := strings.ToLower(raw) if strings.HasPrefix(lower, "data:") { return normalizeImageDataURI(raw) } u, err := url.Parse(raw) if err != nil || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") { return "", errors.New("图片地址无效,仅支持 http/https URL 或 base64 data URI") } return raw, nil } func normalizeImageDataURI(raw string) (string, error) { comma := strings.Index(raw, ",") if comma < 0 { return "", errors.New("图片 base64 数据格式错误") } meta := strings.ToLower(strings.TrimSpace(raw[5:comma])) payload := strings.TrimSpace(raw[comma+1:]) if payload == "" { return "", errors.New("图片 base64 数据不能为空") } parts := strings.Split(meta, ";") if len(parts) < 2 || !contains(parts[1:], "base64") { return "", errors.New("图片 data URI 必须使用 base64 编码") } mime := parts[0] if !allowedImageTypes[mime] { return "", errors.New("图片格式不支持,仅支持 jpeg/png/webp/gif") } decoded, err := base64.StdEncoding.DecodeString(payload) if err != nil { return "", errors.New("图片 base64 数据无效") } if len(decoded) > maxImageSize { return "", errors.New("图片过大,请选择小于 4MB 的图片") } return "data:" + mime + ";base64," + payload, nil } func contains(items []string, target string) bool { for _, item := range items { if strings.TrimSpace(item) == target { return true } } return false } func intPtr(i int) *int { return &i } func boolPtr(v bool) *bool { return &v } func truncateString(text string, maxRunes int) string { runes := []rune(strings.TrimSpace(text)) if maxRunes <= 0 || len(runes) <= maxRunes { return string(runes) } return string(runes[:maxRunes]) + "..." } func writeSSEJSON(w io.Writer, frame chatSSEFrame) { data, err := json.Marshal(frame) if err != nil { data, _ = json.Marshal(chatSSEFrame{Type: "error", Error: "序列化流事件失败"}) } fmt.Fprintf(w, "data: %s\n\n", data) } func toJSON(s string) string { b, _ := json.Marshal(s) return string(b) } func toSSE(s string) string { s = strings.ReplaceAll(s, `\`, `\\`) s = strings.ReplaceAll(s, "\n", `\n`) s = strings.ReplaceAll(s, "\r", "") s = strings.ReplaceAll(s, `"`, `\"`) return fmt.Sprintf(`"%s"`, s) } // ─── 入口 ───────────────────────────────────────────────── func main() { var err error cfg, err = loadConfig("config.yaml") if err != nil { fmt.Fprintln(os.Stderr, "配置加载失败:", err) os.Exit(1) } // 初始化火山方舟 SDK 客户端 aiState, err = NewOpenAIState(cfg.OpenAI) if err != nil { fmt.Fprintln(os.Stderr, "OpenAI 配置初始化失败:", err) os.Exit(1) } searchConfig, err := searchagent.LoadConfig("agents/search/config.yaml", legacySearchProfiles) if err != nil { fmt.Fprintln(os.Stderr, "联网搜索配置加载失败:", err) os.Exit(1) } searchState, err = searchagent.NewState(searchConfig) if err != nil { fmt.Fprintln(os.Stderr, "联网搜索初始化失败:", err) os.Exit(1) } sqlConfig, err := sqlquery.LoadConfig("agents/sql/config.yaml") if err != nil { fmt.Fprintln(os.Stderr, "SQL 查询插件配置加载失败:", err) os.Exit(1) } sqlState, err = sqlquery.NewState(sqlConfig) if err != nil { fmt.Fprintln(os.Stderr, "SQL 查询插件初始化失败:", err) os.Exit(1) } defer sqlState.Close() toolRouterState, err = NewToolRouterState(&cfg.ToolRouter, aiState) if err != nil { fmt.Fprintln(os.Stderr, "工具路由配置初始化失败:", err) os.Exit(1) } store = NewConvStore("conversations") // Gin 路由 r := gin.Default() r.LoadHTMLGlob("templates/*") r.Static("/static", "./static") r.GET("/", indexHandler) r.POST("/api/chat", chatHandler) r.GET("/api/openai", listOpenAIHandler) r.POST("/api/openai/active", switchOpenAIHandler) r.GET("/api/search", listSearchHandler) r.POST("/api/search/active", switchSearchHandler) r.GET("/api/conversations", listConversationsHandler) r.POST("/api/conversations", createConversationHandler) r.GET("/api/conversations/:id", getConversationHandler) r.DELETE("/api/conversations/:id", deleteConversationHandler) // 根据配置选择监听方式 switch strings.ToLower(cfg.Server.Mode) { case "unix": socketPath := cfg.Server.Address if _, statErr := os.Stat(socketPath); statErr == nil { os.Remove(socketPath) } ln, listenErr := net.Listen("unix", socketPath) if listenErr != nil { fmt.Fprintln(os.Stderr, "监听 Unix socket 失败:", listenErr) os.Exit(1) } fmt.Println("服务已启动,监听 Unix socket:", socketPath) if serveErr := http.Serve(ln, r); serveErr != nil { fmt.Fprintln(os.Stderr, "服务异常退出:", serveErr) os.Exit(1) } default: fmt.Println("服务已启动,监听 TCP:", cfg.Server.Address) if runErr := r.Run(cfg.Server.Address); runErr != nil { fmt.Fprintln(os.Stderr, "服务异常退出:", runErr) os.Exit(1) } } }