package main import ( "context" "crypto/rand" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net" "net/http" "net/url" "os" "path/filepath" "sort" "strings" "sync" "time" sqlquery "aichat/agents/SQL_query" "github.com/gin-gonic/gin" ark "github.com/volcengine/volcengine-go-sdk/service/arkruntime" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" "gopkg.in/yaml.v3" ) // ─── 配置 ───────────────────────────────────────────────── const ( defaultOpenAIBaseURL = "https://ark.cn-beijing.volces.com/api/v3" defaultOpenAITimeout = 120 defaultSearchBaseURL = "https://api.duckduckgo.com/" defaultSearchTimeout = 10 defaultSearchCount = 5 ) type OpenAIConfig struct { Name string `yaml:"name" json:"name"` Active bool `yaml:"active,omitempty" json:"active"` APIKey string `yaml:"api_key" json:"-"` BaseURL string `yaml:"base_url" json:"base_url"` Model string `yaml:"model" json:"model"` Timeout int `yaml:"timeout" json:"timeout"` } type OpenAIConfigs []OpenAIConfig type SearchConfig struct { Name string `yaml:"name" json:"name"` Active bool `yaml:"active,omitempty" json:"active"` Enabled bool `yaml:"enabled" json:"enabled"` Provider string `yaml:"provider" json:"provider"` APIKey string `yaml:"api_key" json:"-"` BaseURL string `yaml:"base_url" json:"base_url"` Count int `yaml:"count" json:"count"` Timeout int `yaml:"timeout" json:"timeout"` } type SearchConfigs []SearchConfig func (configs *SearchConfigs) UnmarshalYAML(value *yaml.Node) error { switch value.Kind { case yaml.SequenceNode: var items []SearchConfig if err := value.Decode(&items); err != nil { return err } *configs = items case yaml.MappingNode: var item SearchConfig if err := value.Decode(&item); err != nil { return err } *configs = []SearchConfig{item} case yaml.ScalarNode: if value.Tag == "!!null" { *configs = nil return nil } return fmt.Errorf("search 配置格式无效") default: return fmt.Errorf("search 配置格式无效") } return nil } func (configs *OpenAIConfigs) UnmarshalYAML(value *yaml.Node) error { switch value.Kind { case yaml.SequenceNode: var items []OpenAIConfig if err := value.Decode(&items); err != nil { return err } *configs = items case yaml.MappingNode: var item OpenAIConfig if err := value.Decode(&item); err != nil { return err } *configs = []OpenAIConfig{item} case yaml.ScalarNode: if value.Tag == "!!null" { *configs = nil return nil } return fmt.Errorf("openai 配置格式无效") default: return fmt.Errorf("openai 配置格式无效") } return nil } type Config struct { Server struct { Mode string `yaml:"mode"` Address string `yaml:"address"` } `yaml:"server"` OpenAI OpenAIConfigs `yaml:"openai"` Search SearchConfigs `yaml:"search"` } func defaultOpenAIConfig() OpenAIConfig { return OpenAIConfig{ Name: "default", Active: true, BaseURL: defaultOpenAIBaseURL, Timeout: defaultOpenAITimeout, } } func defaultSearchConfig() SearchConfig { return SearchConfig{ Name: "duckduckgo", Active: true, Enabled: true, Provider: "duckduckgo", BaseURL: defaultSearchBaseURL, Count: defaultSearchCount, Timeout: defaultSearchTimeout, } } func defaultConfig() Config { var cfg Config cfg.Server.Mode = "tcp" cfg.Server.Address = "0.0.0.0:8080" cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()} cfg.Search = SearchConfigs{defaultSearchConfig()} return cfg } func loadConfig(path string) (*Config, error) { if err := ensureConfigFile(path); err != nil { return nil, err } data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("读取配置文件失败: %w", err) } var cfg Config if err = yaml.Unmarshal(data, &cfg); err != nil { return nil, fmt.Errorf("解析配置文件失败: %w", err) } if _, err := normalizeOpenAIConfigs(&cfg); err != nil { return nil, err } // 环境变量优先 if key := os.Getenv("ARK_API_KEY"); key != "" { for i := range cfg.OpenAI { cfg.OpenAI[i].APIKey = key } } if _, err := normalizeSearchConfigs(&cfg); err != nil { return nil, err } if key := os.Getenv("BRAVE_SEARCH_API_KEY"); key != "" { for i := range cfg.Search { if strings.ToLower(cfg.Search[i].Provider) == "brave" { cfg.Search[i].APIKey = key } } } return &cfg, nil } func ensureConfigFile(path string) error { defaults := defaultConfig() if _, err := os.Stat(path); err != nil { if !os.IsNotExist(err) { return fmt.Errorf("检查配置文件失败: %w", err) } return writeConfig(path, defaults) } data, err := os.ReadFile(path) if err != nil { return fmt.Errorf("读取配置文件失败: %w", err) } var cfg Config if err = yaml.Unmarshal(data, &cfg); err != nil { return fmt.Errorf("解析配置文件失败: %w", err) } var raw map[string]any if err = yaml.Unmarshal(data, &raw); err != nil { return fmt.Errorf("解析配置文件失败: %w", err) } changed := false server, _ := raw["server"].(map[string]any) if server == nil { cfg.Server = defaults.Server changed = true } else { if _, ok := server["mode"]; !ok { cfg.Server.Mode = defaults.Server.Mode changed = true } if _, ok := server["address"]; !ok { cfg.Server.Address = defaults.Server.Address changed = true } } if _, ok := raw["openai"].([]any); !ok { changed = true } if normalized, err := normalizeOpenAIConfigs(&cfg); err != nil { return err } else if normalized { changed = true } if _, ok := raw["search"].([]any); !ok { changed = true } if normalized, err := normalizeSearchConfigs(&cfg); err != nil { return err } else if normalized { changed = true } if !changed { return nil } return writeConfig(path, cfg) } func normalizeOpenAIConfigs(cfg *Config) (bool, error) { changed := false if len(cfg.OpenAI) == 0 { cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()} changed = true } activeIndex := -1 seen := map[string]bool{} for i := range cfg.OpenAI { profile := &cfg.OpenAI[i] name := strings.TrimSpace(profile.Name) if name == "" { name = strings.TrimSpace(profile.Model) if name == "" { name = fmt.Sprintf("openai-%d", i+1) } profile.Name = name changed = true } else if name != profile.Name { profile.Name = name changed = true } if seen[name] { return changed, fmt.Errorf("openai 配置名称重复: %s", name) } seen[name] = true if strings.TrimSpace(profile.BaseURL) == "" { profile.BaseURL = defaultOpenAIBaseURL changed = true } if profile.Timeout <= 0 { profile.Timeout = defaultOpenAITimeout changed = true } if profile.Active { if activeIndex == -1 { activeIndex = i } else { profile.Active = false changed = true } } } if activeIndex == -1 { cfg.OpenAI[0].Active = true changed = true } return changed, nil } func normalizeSearchConfigs(cfg *Config) (bool, error) { changed := false if len(cfg.Search) == 0 { cfg.Search = SearchConfigs{defaultSearchConfig()} changed = true } activeIndex := -1 seen := map[string]bool{} for i := range cfg.Search { profile := &cfg.Search[i] profile.Provider = strings.ToLower(strings.TrimSpace(profile.Provider)) if profile.Provider == "" { profile.Provider = "duckduckgo" changed = true } name := strings.TrimSpace(profile.Name) if name == "" { name = profile.Provider if seen[name] { name = fmt.Sprintf("%s-%d", profile.Provider, i+1) } profile.Name = name changed = true } else if name != profile.Name { profile.Name = name changed = true } if seen[name] { return changed, fmt.Errorf("search 配置名称重复: %s", name) } seen[name] = true if strings.TrimSpace(profile.BaseURL) == "" { switch profile.Provider { case "duckduckgo", "ddg": profile.BaseURL = defaultSearchBaseURL case "brave": profile.BaseURL = "https://api.search.brave.com/res/v1/web/search" default: return changed, fmt.Errorf("暂不支持搜索服务: %s", profile.Provider) } changed = true } if profile.Count <= 0 { profile.Count = defaultSearchCount changed = true } if profile.Timeout <= 0 { profile.Timeout = defaultSearchTimeout changed = true } if profile.Active { if activeIndex == -1 { activeIndex = i } else { profile.Active = false changed = true } } } if activeIndex == -1 { cfg.Search[0].Active = true changed = true } return changed, nil } func writeConfig(path string, cfg Config) error { data, err := yaml.Marshal(&cfg) if err != nil { return fmt.Errorf("生成配置文件失败: %w", err) } if err := os.WriteFile(path, data, 0644); err != nil { return fmt.Errorf("写入配置文件失败: %w", err) } return nil } // ─── 请求结构 ───────────────────────────────────────────── type ChatMessage struct { Role string `json:"role"` Content string `json:"content"` ImageURL string `json:"image_url,omitempty"` // base64 data URI 或 http URL ImageURLAlias string `json:"imageURL,omitempty"` Hidden bool `json:"hidden,omitempty"` } type ChatRequest struct { ConversationID string `json:"conversation_id,omitempty"` Messages []ChatMessage `json:"messages"` WebSearch bool `json:"web_search,omitempty"` OpenAIName string `json:"openai_name,omitempty"` } type Conversation struct { ID string `json:"id"` Title string `json:"title"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` Messages []ChatMessage `json:"messages,omitempty"` } type ConvStore struct { dir string mu sync.Mutex } type OpenAIProfile struct { Config OpenAIConfig Client *ark.Client } type OpenAIState struct { mu sync.RWMutex profiles map[string]*OpenAIProfile order []string activeName string } type activeProfileRequest struct { Name string `json:"name"` } type openAIListResponse struct { Active string `json:"active"` Profiles []OpenAIConfig `json:"profiles"` } type searchListResponse struct { Active string `json:"active"` Profiles []SearchConfig `json:"profiles"` } func NewOpenAIState(configs []OpenAIConfig) (*OpenAIState, error) { state := &OpenAIState{ profiles: make(map[string]*OpenAIProfile, len(configs)), order: make([]string, 0, len(configs)), } for _, config := range configs { if strings.TrimSpace(config.Name) == "" { return nil, errors.New("openai.name 不能为空") } if strings.TrimSpace(config.APIKey) == "" { return nil, fmt.Errorf("openai.%s.api_key 未配置,也未设置环境变量 ARK_API_KEY", config.Name) } if strings.TrimSpace(config.Model) == "" { return nil, fmt.Errorf("openai.%s.model 未配置", config.Name) } if strings.TrimSpace(config.BaseURL) == "" { return nil, fmt.Errorf("openai.%s.base_url 未配置", config.Name) } if config.Timeout <= 0 { return nil, fmt.Errorf("openai.%s.timeout 必须大于 0", config.Name) } if _, ok := state.profiles[config.Name]; ok { return nil, fmt.Errorf("openai 配置名称重复: %s", config.Name) } state.profiles[config.Name] = &OpenAIProfile{ Config: config, Client: ark.NewClientWithApiKey( config.APIKey, ark.WithBaseUrl(config.BaseURL), ark.WithTimeout(time.Duration(config.Timeout)*time.Second), ), } state.order = append(state.order, config.Name) if config.Active && state.activeName == "" { state.activeName = config.Name } } if len(state.order) == 0 { return nil, errors.New("openai 配置不能为空") } if state.activeName == "" { state.activeName = state.order[0] } return state, nil } func (s *OpenAIState) ActiveProfile() *OpenAIProfile { s.mu.RLock() defer s.mu.RUnlock() return s.profiles[s.activeName] } func (s *OpenAIState) GetProfile(name string) (*OpenAIProfile, error) { s.mu.RLock() defer s.mu.RUnlock() if strings.TrimSpace(name) == "" { return s.profiles[s.activeName], nil } profile, ok := s.profiles[strings.TrimSpace(name)] if !ok { return nil, fmt.Errorf("OpenAI 配置不存在: %s", name) } return profile, nil } func (s *OpenAIState) SwitchActive(name string) (*OpenAIProfile, error) { name = strings.TrimSpace(name) if name == "" { return nil, errors.New("OpenAI 配置名称不能为空") } s.mu.Lock() defer s.mu.Unlock() profile, ok := s.profiles[name] if !ok { return nil, fmt.Errorf("OpenAI 配置不存在: %s", name) } s.activeName = name return profile, nil } func (s *OpenAIState) ListProfiles() openAIListResponse { s.mu.RLock() defer s.mu.RUnlock() profiles := make([]OpenAIConfig, 0, len(s.order)) for _, name := range s.order { profile := s.profiles[name] config := profile.Config config.APIKey = "" config.Active = name == s.activeName profiles = append(profiles, config) } return openAIListResponse{Active: s.activeName, Profiles: profiles} } func publicOpenAIConfig(profile *OpenAIProfile, active bool) OpenAIConfig { config := profile.Config config.APIKey = "" config.Active = active return config } type SearchState struct { mu sync.RWMutex profiles map[string]SearchConfig order []string activeName string } func NewSearchState(configs []SearchConfig) (*SearchState, error) { state := &SearchState{ profiles: make(map[string]SearchConfig, len(configs)), order: make([]string, 0, len(configs)), } for _, config := range configs { if strings.TrimSpace(config.Name) == "" { return nil, errors.New("search.name 不能为空") } if strings.TrimSpace(config.Provider) == "" { return nil, fmt.Errorf("search.%s.provider 未配置", config.Name) } if strings.TrimSpace(config.BaseURL) == "" { return nil, fmt.Errorf("search.%s.base_url 未配置", config.Name) } if config.Timeout <= 0 { return nil, fmt.Errorf("search.%s.timeout 必须大于 0", config.Name) } if _, ok := state.profiles[config.Name]; ok { return nil, fmt.Errorf("search 配置名称重复: %s", config.Name) } state.profiles[config.Name] = config state.order = append(state.order, config.Name) if config.Active && state.activeName == "" { state.activeName = config.Name } } if len(state.order) == 0 { return nil, errors.New("search 配置不能为空") } if state.activeName == "" { state.activeName = state.order[0] } return state, nil } func (s *SearchState) ActiveProfile() SearchConfig { s.mu.RLock() defer s.mu.RUnlock() return s.profiles[s.activeName] } func (s *SearchState) SwitchActive(name string) (SearchConfig, error) { name = strings.TrimSpace(name) if name == "" { return SearchConfig{}, errors.New("搜索配置名称不能为空") } s.mu.Lock() defer s.mu.Unlock() profile, ok := s.profiles[name] if !ok { return SearchConfig{}, fmt.Errorf("搜索配置不存在: %s", name) } s.activeName = name return profile, nil } func (s *SearchState) ListProfiles() searchListResponse { s.mu.RLock() defer s.mu.RUnlock() profiles := make([]SearchConfig, 0, len(s.order)) for _, name := range s.order { profile := s.profiles[name] profile.APIKey = "" profile.Active = name == s.activeName profiles = append(profiles, profile) } return searchListResponse{Active: s.activeName, Profiles: profiles} } func publicSearchConfig(config SearchConfig, active bool) SearchConfig { config.APIKey = "" config.Active = active return config } // ─── 全局变量 ───────────────────────────────────────────── var ( cfg *Config aiState *OpenAIState searchState *SearchState sqlState *sqlquery.State store *ConvStore ) type chatSSEFrame struct { Type string `json:"type"` Text string `json:"text,omitempty"` Message string `json:"message,omitempty"` Tool string `json:"tool,omitempty"` Stage string `json:"stage,omitempty"` Status string `json:"status,omitempty"` Data map[string]any `json:"data,omitempty"` Error string `json:"error,omitempty"` } // ─── 路由 ───────────────────────────────────────────────── func indexHandler(c *gin.Context) { profile := aiState.ActiveProfile() c.HTML(http.StatusOK, "chat.html", gin.H{ "Title": "AI 对话", "Model": profile.Config.Model, "OpenAIName": profile.Config.Name, }) } func listOpenAIHandler(c *gin.Context) { c.JSON(http.StatusOK, aiState.ListProfiles()) } func switchOpenAIHandler(c *gin.Context) { var req activeProfileRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()}) return } profile, err := aiState.SwitchActive(req.Name) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{ "active": profile.Config.Name, "profile": publicOpenAIConfig(profile, true), }) } func listSearchHandler(c *gin.Context) { c.JSON(http.StatusOK, searchState.ListProfiles()) } func switchSearchHandler(c *gin.Context) { var req activeProfileRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()}) return } profile, err := searchState.SwitchActive(req.Name) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, gin.H{ "active": profile.Name, "profile": publicSearchConfig(profile, true), }) } func listConversationsHandler(c *gin.Context) { convs, err := store.List() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, convs) } func createConversationHandler(c *gin.Context) { conv, err := store.Create() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "创建对话失败: " + err.Error()}) return } c.JSON(http.StatusOK, conv) } func getConversationHandler(c *gin.Context) { conv, err := store.Get(c.Param("id")) if err != nil { status := http.StatusInternalServerError if err.Error() == "对话不存在" { status = http.StatusNotFound } c.JSON(status, gin.H{"error": err.Error()}) return } c.JSON(http.StatusOK, conv) } func deleteConversationHandler(c *gin.Context) { if err := store.Delete(c.Param("id")); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } c.Status(http.StatusNoContent) } // chatHandler 流式 SSE 对话接口 func chatHandler(c *gin.Context) { var req ChatRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()}) return } if len(req.Messages) == 0 { c.JSON(http.StatusBadRequest, gin.H{"error": "消息不能为空"}) return } profile, err := aiState.GetProfile(req.OpenAIName) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } // SSE 头先写出,后续插件/模型过程都通过 trace 事件实时展示。 c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") c.Writer.Header().Set("X-Accel-Buffering", "no") c.Writer.WriteHeader(http.StatusOK) flusher, ok := c.Writer.(http.Flusher) if !ok { return } emit := func(frame chatSSEFrame) { writeSSEJSON(c.Writer, frame) flusher.Flush() } emitTrace := func(tool, stage, status, message string, data map[string]any) { emit(chatSSEFrame{Type: "trace", Tool: tool, Stage: stage, Status: status, Message: message, Data: data}) } emitError := func(err error) { emit(chatSSEFrame{Type: "error", Error: err.Error()}) } // 超时 context timeout := time.Duration(profile.Config.Timeout) * time.Second ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) defer cancel() chatMessages := req.Messages if sqlState != nil && sqlState.Enabled() { withSQL, err := enrichMessagesWithSQL(ctx, profile, chatMessages, emit) if err != nil { fmt.Fprintln(os.Stderr, "SQL 查询插件调用失败:", err) emitTrace("sql", "error", "error", "数据库查询插件调用失败,将继续普通回答", map[string]any{"error": err.Error()}) } else { chatMessages = withSQL } } if req.WebSearch { withSearch, err := enrichMessagesWithSearch(ctx, chatMessages, emit) if err != nil { emitError(err) return } chatMessages = withSearch } // 构建 ark 消息列表 messages, err := buildArkMessages(chatMessages) if err != nil { emitError(err) return } emitTrace("model", "request", "running", "正在调用模型生成回答", nil) stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{ Model: profile.Config.Model, Messages: messages, MaxTokens: intPtr(4096), }.WithStream(true)) if err != nil { emitError(err) return } defer stream.Close() emitTrace("model", "stream", "running", "模型已开始输出", nil) var full strings.Builder for { resp, err := stream.Recv() if errors.Is(err, io.EOF) { if req.ConversationID != "" { if err := saveConversationMessages(req.ConversationID, req.Messages, full.String()); err != nil { fmt.Fprintln(os.Stderr, "保存对话失败:", err) } } emitTrace("model", "stream", "success", "回答生成完成", nil) fmt.Fprintf(c.Writer, "data: [DONE]\n\n") flusher.Flush() return } if err != nil { emitError(err) return } if len(resp.Choices) > 0 { delta := resp.Choices[0].Delta.Content if delta != "" { full.WriteString(delta) emit(chatSSEFrame{Type: "delta", Text: delta}) } } } } // ─── 辅助函数 ───────────────────────────────────────────── type searchResult struct { Title string `json:"title"` URL string `json:"url"` Description string `json:"description"` } type braveSearchResponse struct { Web struct { Results []searchResult `json:"results"` } `json:"web"` } type duckDuckGoResponse struct { Abstract string `json:"Abstract"` AbstractSource string `json:"AbstractSource"` AbstractURL string `json:"AbstractURL"` Heading string `json:"Heading"` RelatedTopics []struct { Text string `json:"Text"` FirstURL string `json:"FirstURL"` Topics []struct { Text string `json:"Text"` FirstURL string `json:"FirstURL"` } `json:"Topics"` } `json:"RelatedTopics"` Infobox struct { Content []struct { Label string `json:"label"` Value string `json:"value"` } `json:"content"` } `json:"Infobox"` } func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) { searchConfig := searchState.ActiveProfile() if !searchConfig.Enabled { return nil, errors.New("联网搜索未启用,请先在 config.yaml 中配置 search.enabled") } query := latestUserQuery(messages) if query == "" { return nil, errors.New("联网搜索需要输入文本问题") } emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "running", Message: "正在联网搜索", Data: map[string]any{"provider": searchConfig.Provider}}) results, err := webSearch(ctx, searchConfig, query) if err != nil { emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索失败", Data: map[string]any{"error": err.Error()}}) return nil, err } if len(results) == 0 { err := errors.New("未搜索到相关网页结果") emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "error", Message: err.Error()}) return nil, err } emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "success", Message: fmt.Sprintf("联网搜索完成,找到 %d 条结果", len(results)), Data: map[string]any{"provider": searchConfig.Provider, "count": len(results)}}) searchContext := buildSearchContext(searchConfig, query, results) withSearch := make([]ChatMessage, 0, len(messages)+1) withSearch = append(withSearch, ChatMessage{Role: "system", Content: searchContext, Hidden: true}) withSearch = append(withSearch, messages...) return withSearch, nil } func latestUserQuery(messages []ChatMessage) string { for i := len(messages) - 1; i >= 0; i-- { if messages[i].Role == "user" { return strings.TrimSpace(messages[i].Content) } } return "" } type sqlActivationDecision struct { Activate bool `json:"activate"` Reason string `json:"reason"` } type sqlGenerationResult struct { Database string `json:"database"` SQL string `json:"sql"` Reason string `json:"reason"` } func enrichMessagesWithSQL(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) { query := latestUserQuery(messages) if query == "" { return messages, nil } emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "running", Message: "正在判断是否需要查询数据库"}) activate, reason, err := classifySQLActivation(ctx, profile, messages) if err != nil { emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "error", Message: "数据库查询判断失败", Data: map[string]any{"error": err.Error()}}) return messages, err } if !activate { emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "success", Message: "判断结果:本轮无需查询数据库", Data: map[string]any{"activate": false, "reason": reason}}) return messages, nil } emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "success", Message: "判断结果:需要查询数据库", Data: map[string]any{"activate": true, "reason": reason}}) emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "running", Message: "正在读取数据库结构"}) schemaContext, err := sqlState.SchemaContext(ctx) if err != nil { emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "error", Message: "数据库结构读取失败", Data: map[string]any{"error": err.Error()}}) return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil } emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "success", Message: "数据库结构读取完成"}) emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "running", Message: "正在生成只读 SQL"}) generated, err := generateSQLForUserQuery(ctx, profile, query, schemaContext) if err != nil { emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "SQL 生成失败", Data: map[string]any{"error": err.Error()}}) return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil } generated.Database = strings.TrimSpace(generated.Database) generated.SQL = strings.TrimSpace(generated.SQL) if generated.SQL == "" { err := fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason) emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "模型未生成可执行 SQL", Data: map[string]any{"reason": generated.Reason}}) return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil } emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "running", Message: "正在执行数据库查询", Data: map[string]any{"database": generated.Database}}) result, err := sqlState.ExecuteReadOnly(ctx, generated.Database, generated.SQL) if err != nil { emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "error", Message: "数据库查询失败", Data: map[string]any{"error": err.Error()}}) return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil } emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "success", Message: "已生成只读 SQL", Data: map[string]any{"database": generated.Database, "sql": generated.SQL, "reason": generated.Reason}}) emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "success", Message: fmt.Sprintf("数据库查询完成,返回 %d 行", len(result.Rows)), Data: map[string]any{"database": result.Database, "rows": len(result.Rows), "columns": len(result.Columns), "truncated": result.Truncated, "max_rows": result.MaxRows}}) contextText := sqlquery.BuildResultContext(query, generated.SQL, result) if strings.TrimSpace(reason) != "" { contextText += "\n激活原因:" + reason } return prependSQLContext(messages, contextText), nil } func prependSQLContext(messages []ChatMessage, content string) []ChatMessage { withSQL := make([]ChatMessage, 0, len(messages)+1) withSQL = append(withSQL, ChatMessage{Role: "system", Content: content, Hidden: true}) withSQL = append(withSQL, messages...) return withSQL } func classifySQLActivation(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage) (bool, string, error) { query := latestUserQuery(messages) prompt := fmt.Sprintf("%s\n\n最新用户问题:%s", sqlState.ActivationPrompt(), query) text, err := completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, 512) if err != nil { return false, "", err } var decision sqlActivationDecision if err := json.Unmarshal([]byte(extractJSONObject(text)), &decision); err != nil { return false, "", fmt.Errorf("解析 SQL 查询激活结果失败: %w", err) } return decision.Activate, decision.Reason, nil } func generateSQLForUserQuery(ctx context.Context, profile *OpenAIProfile, userQuery string, schemaContext string) (*sqlGenerationResult, error) { prompt := fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题和数据库 schema 生成一条只读 SQL。 要求: - 只能返回 JSON,不要使用 Markdown。 - JSON 格式:{"database":"数据库名称","sql":"SELECT ... LIMIT N","reason":"生成原因"} - 只能生成 SELECT 或 WITH 查询,禁止 INSERT/UPDATE/DELETE/DROP/ALTER/CREATE 等任何修改语句。 - 必须只使用 schema 中出现的数据库、表和字段。 - 必须添加 LIMIT,且 LIMIT 不超过插件配置的 max_rows。 - 如果无法根据 schema 回答,返回 {"database":"","sql":"","reason":"无法根据已知表结构生成查询"}。 %s 用户问题:%s`, schemaContext, userQuery) text, err := completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, 1024) if err != nil { return nil, err } var generated sqlGenerationResult if err := json.Unmarshal([]byte(extractJSONObject(text)), &generated); err != nil { return nil, fmt.Errorf("解析 SQL 生成结果失败: %w", err) } return &generated, nil } func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) { messages, err := buildArkMessages(chatMessages) if err != nil { return "", err } timeout := time.Duration(profile.Config.Timeout) * time.Second completionCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() stream, err := profile.Client.CreateChatCompletionStream(completionCtx, model.CreateChatCompletionRequest{ Model: profile.Config.Model, Messages: messages, MaxTokens: intPtr(maxTokens), }.WithStream(true)) if err != nil { return "", err } defer stream.Close() var b strings.Builder for { resp, err := stream.Recv() if errors.Is(err, io.EOF) { return b.String(), nil } if err != nil { return "", err } if len(resp.Choices) > 0 { b.WriteString(resp.Choices[0].Delta.Content) } } } func extractJSONObject(text string) string { text = strings.TrimSpace(text) start := strings.Index(text, "{") end := strings.LastIndex(text, "}") if start >= 0 && end > start { return text[start : end+1] } return text } func webSearch(ctx context.Context, config SearchConfig, query string) ([]searchResult, error) { switch strings.ToLower(config.Provider) { case "duckduckgo", "ddg": return duckDuckGoSearch(ctx, config, query) case "brave": return braveWebSearch(ctx, config, query) default: return nil, fmt.Errorf("暂不支持搜索服务: %s", config.Provider) } } func duckDuckGoSearch(ctx context.Context, config SearchConfig, query string) ([]searchResult, error) { searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second) defer cancel() u, err := url.Parse(config.BaseURL) if err != nil { return nil, fmt.Errorf("搜索服务地址无效: %w", err) } q := u.Query() q.Set("q", query) q.Set("format", "json") q.Set("no_html", "1") q.Set("skip_disambig", "1") u.RawQuery = q.Encode() req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil) if err != nil { return nil, fmt.Errorf("创建搜索请求失败: %w", err) } req.Header.Set("Accept", "application/json") req.Header.Set("User-Agent", "aichat/1.0") resp, err := http.DefaultClient.Do(req) if err != nil { return nil, fmt.Errorf("联网搜索失败: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024)) if err != nil { return nil, fmt.Errorf("读取搜索响应失败: %w", err) } if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) } var parsed duckDuckGoResponse if err := json.Unmarshal(body, &parsed); err != nil { return nil, fmt.Errorf("解析搜索响应失败: %w", err) } limit := config.Count if limit <= 0 { limit = defaultSearchCount } results := make([]searchResult, 0, limit) if strings.TrimSpace(parsed.Abstract) != "" { title := strings.TrimSpace(parsed.Heading) if title == "" { title = strings.TrimSpace(parsed.AbstractSource) } if title == "" { title = "DuckDuckGo 摘要" } results = append(results, searchResult{Title: title, URL: strings.TrimSpace(parsed.AbstractURL), Description: strings.TrimSpace(parsed.Abstract)}) } for _, item := range parsed.Infobox.Content { if len(results) >= limit { break } if strings.TrimSpace(item.Label) == "" || strings.TrimSpace(item.Value) == "" { continue } results = append(results, searchResult{Title: item.Label, Description: item.Value}) } appendRelated := func(text, firstURL string) { if len(results) >= limit || strings.TrimSpace(text) == "" { return } title, desc := splitDuckDuckGoText(text) results = append(results, searchResult{Title: title, URL: strings.TrimSpace(firstURL), Description: desc}) } for _, topic := range parsed.RelatedTopics { appendRelated(topic.Text, topic.FirstURL) for _, nested := range topic.Topics { appendRelated(nested.Text, nested.FirstURL) } if len(results) >= limit { break } } return results, nil } func splitDuckDuckGoText(text string) (string, string) { text = strings.TrimSpace(text) parts := strings.SplitN(text, " - ", 2) if len(parts) == 2 { return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) } runes := []rune(text) if len(runes) > 42 { return string(runes[:42]) + "...", text } return text, text } func braveWebSearch(ctx context.Context, config SearchConfig, query string) ([]searchResult, error) { if config.APIKey == "" { return nil, errors.New("Brave 搜索未配置 API Key,请设置 search.api_key 或环境变量 BRAVE_SEARCH_API_KEY") } searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second) defer cancel() u, err := url.Parse(config.BaseURL) if err != nil { return nil, fmt.Errorf("搜索服务地址无效: %w", err) } q := u.Query() q.Set("q", query) q.Set("count", fmt.Sprintf("%d", config.Count)) q.Set("search_lang", "zh-hans") q.Set("country", "CN") u.RawQuery = q.Encode() req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil) if err != nil { return nil, fmt.Errorf("创建搜索请求失败: %w", err) } req.Header.Set("Accept", "application/json") req.Header.Set("X-Subscription-Token", config.APIKey) resp, err := http.DefaultClient.Do(req) if err != nil { return nil, fmt.Errorf("联网搜索失败: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024)) if err != nil { return nil, fmt.Errorf("读取搜索响应失败: %w", err) } if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) } var parsed braveSearchResponse if err := json.Unmarshal(body, &parsed); err != nil { return nil, fmt.Errorf("解析搜索响应失败: %w", err) } return parsed.Web.Results, nil } func buildSearchContext(config SearchConfig, query string, results []searchResult) string { var b strings.Builder fmt.Fprintf(&b, "用户开启了联网搜索。当前搜索源: %s(%s)。请优先根据以下搜索结果回答,并在合适位置标注来源链接。\n", config.Name, config.Provider) if strings.ToLower(config.Provider) == "duckduckgo" || strings.ToLower(config.Provider) == "ddg" { fmt.Fprintln(&b, "注意:DuckDuckGo 即时答案不是全量网页搜索,结果可能较少。") } fmt.Fprintf(&b, "搜索时间: %s\n", time.Now().Format("2006-01-02 15:04:05")) fmt.Fprintf(&b, "搜索词: %s\n\n", query) fmt.Fprintln(&b, "搜索结果:") for i, r := range results { fmt.Fprintf(&b, "%d. 标题: %s\n", i+1, strings.TrimSpace(r.Title)) if strings.TrimSpace(r.URL) != "" { fmt.Fprintf(&b, " 链接: %s\n", strings.TrimSpace(r.URL)) } if strings.TrimSpace(r.Description) != "" { fmt.Fprintf(&b, " 摘要: %s\n", strings.TrimSpace(r.Description)) } } fmt.Fprintln(&b, "\n如果搜索结果不足以回答,请明确说明不确定,不要编造。") return b.String() } func newUUID() string { b := make([]byte, 16) _, _ = rand.Read(b) b[6] = (b[6] & 0x0f) | 0x40 b[8] = (b[8] & 0x3f) | 0x80 return hex.EncodeToString(b[:4]) + "-" + hex.EncodeToString(b[4:6]) + "-" + hex.EncodeToString(b[6:8]) + "-" + hex.EncodeToString(b[8:10]) + "-" + hex.EncodeToString(b[10:]) } // ─── ConvStore ───────────────────────────────────────────── func NewConvStore(dir string) *ConvStore { os.MkdirAll(dir, 0755) return &ConvStore{dir: dir} } func (s *ConvStore) path(id string) string { return filepath.Join(s.dir, id+".json") } func (s *ConvStore) Create() (*Conversation, error) { conv := &Conversation{ ID: newUUID(), Title: "新对话", CreatedAt: time.Now(), UpdatedAt: time.Now(), } if err := s.Save(conv); err != nil { return nil, err } return conv, nil } func (s *ConvStore) Save(conv *Conversation) error { s.mu.Lock() defer s.mu.Unlock() conv.UpdatedAt = time.Now() return atomicWriteJSON(s.path(conv.ID), conv) } func (s *ConvStore) Get(id string) (*Conversation, error) { s.mu.Lock() defer s.mu.Unlock() data, err := os.ReadFile(s.path(id)) if err != nil { if os.IsNotExist(err) { return nil, errors.New("对话不存在") } return nil, fmt.Errorf("读取对话失败: %w", err) } var conv Conversation if err := json.Unmarshal(data, &conv); err != nil { return nil, fmt.Errorf("解析对话失败: %w", err) } return &conv, nil } func (s *ConvStore) List() ([]Conversation, error) { s.mu.Lock() defer s.mu.Unlock() entries, err := os.ReadDir(s.dir) if err != nil { return nil, fmt.Errorf("读取对话目录失败: %w", err) } var list []Conversation for _, e := range entries { if e.IsDir() || filepath.Ext(e.Name()) != ".json" { continue } data, err := os.ReadFile(filepath.Join(s.dir, e.Name())) if err != nil { continue } var conv Conversation if err := json.Unmarshal(data, &conv); err != nil { continue } conv.Messages = nil // 列表不返回消息体 list = append(list, conv) } sort.Slice(list, func(i, j int) bool { return list[i].UpdatedAt.After(list[j].UpdatedAt) }) return list, nil } func (s *ConvStore) Delete(id string) error { s.mu.Lock() defer s.mu.Unlock() if err := os.Remove(s.path(id)); err != nil && !os.IsNotExist(err) { return fmt.Errorf("删除对话失败: %w", err) } return nil } func atomicWriteJSON(path string, v any) error { tmp := path + ".tmp" data, err := json.Marshal(v) if err != nil { return err } if err := os.WriteFile(tmp, data, 0644); err != nil { return err } return os.Rename(tmp, path) } func saveConversationMessages(id string, messages []ChatMessage, assistantContent string) error { conv, err := store.Get(id) if err != nil { return err } conv.Messages = append([]ChatMessage(nil), messages...) conv.Messages = append(conv.Messages, ChatMessage{Role: "assistant", Content: assistantContent}) if conv.Title == "" || conv.Title == "新对话" { conv.Title = genConvTitle(conv.Messages) } return store.Save(conv) } func genConvTitle(messages []ChatMessage) string { for _, m := range messages { if m.Hidden { continue } if m.Role == "user" && strings.TrimSpace(m.Content) != "" { title := strings.TrimSpace(m.Content) title = strings.ReplaceAll(title, "\r\n", " ") title = strings.ReplaceAll(title, "\n", " ") runes := []rune(title) if len(runes) > 30 { return string(runes[:30]) + "..." } return title } } return "新对话" } const maxImageSize = 4 * 1024 * 1024 var allowedImageTypes = map[string]bool{ "image/jpeg": true, "image/png": true, "image/webp": true, "image/gif": true, } func buildArkMessages(chatMessages []ChatMessage) ([]*model.ChatCompletionMessage, error) { messages := make([]*model.ChatCompletionMessage, 0, len(chatMessages)) for _, m := range chatMessages { msg, err := buildArkMessage(m) if err != nil { return nil, err } messages = append(messages, msg) } return messages, nil } func buildArkMessage(m ChatMessage) (*model.ChatCompletionMessage, error) { msg := &model.ChatCompletionMessage{Role: m.Role} if m.ImageURL == "" && m.ImageURLAlias != "" { m.ImageURL = m.ImageURLAlias } if m.ImageURL == "" { msg.Content = &model.ChatCompletionMessageContent{ StringValue: &m.Content, } return msg, nil } imageURL, err := normalizeImageURL(m.ImageURL) if err != nil { return nil, err } // 有图片时:文字内容可有可无(图片 caption 场景),均构造多模态消息 // 若无文字,则只传图片 part;若同时有图片和文字,先图后文 parts := []*model.ChatCompletionMessageContentPart{imagePart(imageURL)} if m.Content != "" { parts = append(parts, textPart(m.Content)) } msg.Content = &model.ChatCompletionMessageContent{ListValue: parts} return msg, nil } func imagePart(url string) *model.ChatCompletionMessageContentPart { return &model.ChatCompletionMessageContentPart{ Type: model.ChatCompletionMessageContentPartTypeImageURL, ImageURL: &model.ChatMessageImageURL{ URL: url, Detail: model.ImageURLDetailAuto, }, } } func textPart(text string) *model.ChatCompletionMessageContentPart { return &model.ChatCompletionMessageContentPart{ Type: model.ChatCompletionMessageContentPartTypeText, Text: text, } } func normalizeImageURL(raw string) (string, error) { raw = strings.TrimSpace(raw) if raw == "" { return "", errors.New("图片地址不能为空") } lower := strings.ToLower(raw) if strings.HasPrefix(lower, "data:") { return normalizeImageDataURI(raw) } u, err := url.Parse(raw) if err != nil || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") { return "", errors.New("图片地址无效,仅支持 http/https URL 或 base64 data URI") } return raw, nil } func normalizeImageDataURI(raw string) (string, error) { comma := strings.Index(raw, ",") if comma < 0 { return "", errors.New("图片 base64 数据格式错误") } meta := strings.ToLower(strings.TrimSpace(raw[5:comma])) payload := strings.TrimSpace(raw[comma+1:]) if payload == "" { return "", errors.New("图片 base64 数据不能为空") } parts := strings.Split(meta, ";") if len(parts) < 2 || !contains(parts[1:], "base64") { return "", errors.New("图片 data URI 必须使用 base64 编码") } mime := parts[0] if !allowedImageTypes[mime] { return "", errors.New("图片格式不支持,仅支持 jpeg/png/webp/gif") } decoded, err := base64.StdEncoding.DecodeString(payload) if err != nil { return "", errors.New("图片 base64 数据无效") } if len(decoded) > maxImageSize { return "", errors.New("图片过大,请选择小于 4MB 的图片") } return "data:" + mime + ";base64," + payload, nil } func contains(items []string, target string) bool { for _, item := range items { if strings.TrimSpace(item) == target { return true } } return false } func intPtr(i int) *int { return &i } func writeSSEJSON(w io.Writer, frame chatSSEFrame) { data, err := json.Marshal(frame) if err != nil { data, _ = json.Marshal(chatSSEFrame{Type: "error", Error: "序列化流事件失败"}) } fmt.Fprintf(w, "data: %s\n\n", data) } func toJSON(s string) string { b, _ := json.Marshal(s) return string(b) } func toSSE(s string) string { s = strings.ReplaceAll(s, `\`, `\\`) s = strings.ReplaceAll(s, "\n", `\n`) s = strings.ReplaceAll(s, "\r", "") s = strings.ReplaceAll(s, `"`, `\"`) return fmt.Sprintf(`"%s"`, s) } // ─── 入口 ───────────────────────────────────────────────── func main() { var err error cfg, err = loadConfig("config.yaml") if err != nil { fmt.Fprintln(os.Stderr, "配置加载失败:", err) os.Exit(1) } // 初始化火山方舟 SDK 客户端 aiState, err = NewOpenAIState(cfg.OpenAI) if err != nil { fmt.Fprintln(os.Stderr, "OpenAI 配置初始化失败:", err) os.Exit(1) } searchState, err = NewSearchState(cfg.Search) if err != nil { fmt.Fprintln(os.Stderr, "搜索配置初始化失败:", err) os.Exit(1) } sqlConfig, err := sqlquery.LoadConfig("agents/SQL_query/config.yaml") if err != nil { fmt.Fprintln(os.Stderr, "SQL 查询插件配置加载失败:", err) os.Exit(1) } sqlState, err = sqlquery.NewState(sqlConfig) if err != nil { fmt.Fprintln(os.Stderr, "SQL 查询插件初始化失败:", err) os.Exit(1) } defer sqlState.Close() store = NewConvStore("conversations") // Gin 路由 r := gin.Default() r.LoadHTMLGlob("templates/*") r.Static("/static", "./static") r.GET("/", indexHandler) r.POST("/api/chat", chatHandler) r.GET("/api/openai", listOpenAIHandler) r.POST("/api/openai/active", switchOpenAIHandler) r.GET("/api/search", listSearchHandler) r.POST("/api/search/active", switchSearchHandler) r.GET("/api/conversations", listConversationsHandler) r.POST("/api/conversations", createConversationHandler) r.GET("/api/conversations/:id", getConversationHandler) r.DELETE("/api/conversations/:id", deleteConversationHandler) // 根据配置选择监听方式 switch strings.ToLower(cfg.Server.Mode) { case "unix": socketPath := cfg.Server.Address if _, statErr := os.Stat(socketPath); statErr == nil { os.Remove(socketPath) } ln, listenErr := net.Listen("unix", socketPath) if listenErr != nil { fmt.Fprintln(os.Stderr, "监听 Unix socket 失败:", listenErr) os.Exit(1) } fmt.Println("服务已启动,监听 Unix socket:", socketPath) if serveErr := http.Serve(ln, r); serveErr != nil { fmt.Fprintln(os.Stderr, "服务异常退出:", serveErr) os.Exit(1) } default: fmt.Println("服务已启动,监听 TCP:", cfg.Server.Address) if runErr := r.Run(cfg.Server.Address); runErr != nil { fmt.Fprintln(os.Stderr, "服务异常退出:", runErr) os.Exit(1) } } }