From 4429efdc7c2e514b92b1a812bf63e58a48c0182d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E6=96=87=E5=B3=B0?= Date: Tue, 9 Jun 2026 13:01:01 +0800 Subject: [PATCH] up --- main.go | 328 +++++++++++++++++++++++++++++++++++++------- templates/chat.html | 71 +++++++++- 2 files changed, 351 insertions(+), 48 deletions(-) diff --git a/main.go b/main.go index cbdfcbd..02ce48b 100644 --- a/main.go +++ b/main.go @@ -27,17 +27,54 @@ import ( // ─── 配置 ───────────────────────────────────────────────── +const ( + defaultOpenAIBaseURL = "https://ark.cn-beijing.volces.com/api/v3" + defaultOpenAITimeout = 120 +) + +type OpenAIConfig struct { + Name string `yaml:"name" json:"name"` + Active bool `yaml:"active,omitempty" json:"active"` + APIKey string `yaml:"api_key" json:"-"` + BaseURL string `yaml:"base_url" json:"base_url"` + Model string `yaml:"model" json:"model"` + Timeout int `yaml:"timeout" json:"timeout"` +} + +type OpenAIConfigs []OpenAIConfig + +func (configs *OpenAIConfigs) UnmarshalYAML(value *yaml.Node) error { + switch value.Kind { + case yaml.SequenceNode: + var items []OpenAIConfig + if err := value.Decode(&items); err != nil { + return err + } + *configs = items + case yaml.MappingNode: + var item OpenAIConfig + if err := value.Decode(&item); err != nil { + return err + } + *configs = []OpenAIConfig{item} + case yaml.ScalarNode: + if value.Tag == "!!null" { + *configs = nil + return nil + } + return fmt.Errorf("openai 配置格式无效") + default: + return fmt.Errorf("openai 配置格式无效") + } + return nil +} + type Config struct { Server struct { Mode string `yaml:"mode"` Address string `yaml:"address"` } `yaml:"server"` - OpenAI struct { - APIKey string `yaml:"api_key"` - BaseURL string `yaml:"base_url"` - Model string `yaml:"model"` - Timeout int `yaml:"timeout"` - } `yaml:"openai"` + OpenAI OpenAIConfigs `yaml:"openai"` Search struct { Enabled bool `yaml:"enabled"` Provider string `yaml:"provider"` @@ -48,12 +85,20 @@ type Config struct { } `yaml:"search"` } +func defaultOpenAIConfig() OpenAIConfig { + return OpenAIConfig{ + Name: "default", + Active: true, + BaseURL: defaultOpenAIBaseURL, + Timeout: defaultOpenAITimeout, + } +} + func defaultConfig() Config { var cfg Config cfg.Server.Mode = "tcp" cfg.Server.Address = "0.0.0.0:8080" - cfg.OpenAI.BaseURL = "https://ark.cn-beijing.volces.com/api/v3" - cfg.OpenAI.Timeout = 120 + cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()} cfg.Search.Provider = "brave" cfg.Search.BaseURL = "https://api.search.brave.com/res/v1/web/search" cfg.Search.Count = 5 @@ -74,9 +119,14 @@ func loadConfig(path string) (*Config, error) { if err = yaml.Unmarshal(data, &cfg); err != nil { return nil, fmt.Errorf("解析配置文件失败: %w", err) } + if _, err := normalizeOpenAIConfigs(&cfg); err != nil { + return nil, err + } // 环境变量优先 if key := os.Getenv("ARK_API_KEY"); key != "" { - cfg.OpenAI.APIKey = key + for i := range cfg.OpenAI { + cfg.OpenAI[i].APIKey = key + } } if key := os.Getenv("BRAVE_SEARCH_API_KEY"); key != "" { cfg.Search.APIKey = key @@ -122,27 +172,13 @@ func ensureConfigFile(path string) error { } } - openai, _ := raw["openai"].(map[string]any) - if openai == nil { - cfg.OpenAI = defaults.OpenAI + if _, ok := raw["openai"].([]any); !ok { + changed = true + } + if normalized, err := normalizeOpenAIConfigs(&cfg); err != nil { + return err + } else if normalized { changed = true - } else { - if _, ok := openai["api_key"]; !ok { - cfg.OpenAI.APIKey = defaults.OpenAI.APIKey - changed = true - } - if _, ok := openai["base_url"]; !ok { - cfg.OpenAI.BaseURL = defaults.OpenAI.BaseURL - changed = true - } - if _, ok := openai["model"]; !ok { - cfg.OpenAI.Model = defaults.OpenAI.Model - changed = true - } - if _, ok := openai["timeout"]; !ok { - cfg.OpenAI.Timeout = defaults.OpenAI.Timeout - changed = true - } } search, _ := raw["search"].(map[string]any) @@ -182,6 +218,58 @@ func ensureConfigFile(path string) error { return writeConfig(path, cfg) } +func normalizeOpenAIConfigs(cfg *Config) (bool, error) { + changed := false + if len(cfg.OpenAI) == 0 { + cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()} + changed = true + } + + activeIndex := -1 + seen := map[string]bool{} + for i := range cfg.OpenAI { + profile := &cfg.OpenAI[i] + name := strings.TrimSpace(profile.Name) + if name == "" { + name = strings.TrimSpace(profile.Model) + if name == "" { + name = fmt.Sprintf("openai-%d", i+1) + } + profile.Name = name + changed = true + } else if name != profile.Name { + profile.Name = name + changed = true + } + if seen[name] { + return changed, fmt.Errorf("openai 配置名称重复: %s", name) + } + seen[name] = true + + if strings.TrimSpace(profile.BaseURL) == "" { + profile.BaseURL = defaultOpenAIBaseURL + changed = true + } + if profile.Timeout <= 0 { + profile.Timeout = defaultOpenAITimeout + changed = true + } + if profile.Active { + if activeIndex == -1 { + activeIndex = i + } else { + profile.Active = false + changed = true + } + } + } + if activeIndex == -1 { + cfg.OpenAI[0].Active = true + changed = true + } + return changed, nil +} + func writeConfig(path string, cfg Config) error { data, err := yaml.Marshal(&cfg) if err != nil { @@ -207,6 +295,7 @@ type ChatRequest struct { ConversationID string `json:"conversation_id,omitempty"` Messages []ChatMessage `json:"messages"` WebSearch bool `json:"web_search,omitempty"` + OpenAIName string `json:"openai_name,omitempty"` } type Conversation struct { @@ -222,20 +311,165 @@ type ConvStore struct { mu sync.Mutex } +type OpenAIProfile struct { + Config OpenAIConfig + Client *ark.Client +} + +type OpenAIState struct { + mu sync.RWMutex + profiles map[string]*OpenAIProfile + order []string + activeName string +} + +type openAIActiveRequest struct { + Name string `json:"name"` +} + +type openAIListResponse struct { + Active string `json:"active"` + Profiles []OpenAIConfig `json:"profiles"` +} + +func NewOpenAIState(configs []OpenAIConfig) (*OpenAIState, error) { + state := &OpenAIState{ + profiles: make(map[string]*OpenAIProfile, len(configs)), + order: make([]string, 0, len(configs)), + } + for _, config := range configs { + if strings.TrimSpace(config.Name) == "" { + return nil, errors.New("openai.name 不能为空") + } + if strings.TrimSpace(config.APIKey) == "" { + return nil, fmt.Errorf("openai.%s.api_key 未配置,也未设置环境变量 ARK_API_KEY", config.Name) + } + if strings.TrimSpace(config.Model) == "" { + return nil, fmt.Errorf("openai.%s.model 未配置", config.Name) + } + if strings.TrimSpace(config.BaseURL) == "" { + return nil, fmt.Errorf("openai.%s.base_url 未配置", config.Name) + } + if config.Timeout <= 0 { + return nil, fmt.Errorf("openai.%s.timeout 必须大于 0", config.Name) + } + if _, ok := state.profiles[config.Name]; ok { + return nil, fmt.Errorf("openai 配置名称重复: %s", config.Name) + } + state.profiles[config.Name] = &OpenAIProfile{ + Config: config, + Client: ark.NewClientWithApiKey( + config.APIKey, + ark.WithBaseUrl(config.BaseURL), + ark.WithTimeout(time.Duration(config.Timeout)*time.Second), + ), + } + state.order = append(state.order, config.Name) + if config.Active && state.activeName == "" { + state.activeName = config.Name + } + } + if len(state.order) == 0 { + return nil, errors.New("openai 配置不能为空") + } + if state.activeName == "" { + state.activeName = state.order[0] + } + return state, nil +} + +func (s *OpenAIState) ActiveProfile() *OpenAIProfile { + s.mu.RLock() + defer s.mu.RUnlock() + return s.profiles[s.activeName] +} + +func (s *OpenAIState) GetProfile(name string) (*OpenAIProfile, error) { + s.mu.RLock() + defer s.mu.RUnlock() + if strings.TrimSpace(name) == "" { + return s.profiles[s.activeName], nil + } + profile, ok := s.profiles[strings.TrimSpace(name)] + if !ok { + return nil, fmt.Errorf("OpenAI 配置不存在: %s", name) + } + return profile, nil +} + +func (s *OpenAIState) SwitchActive(name string) (*OpenAIProfile, error) { + name = strings.TrimSpace(name) + if name == "" { + return nil, errors.New("OpenAI 配置名称不能为空") + } + s.mu.Lock() + defer s.mu.Unlock() + profile, ok := s.profiles[name] + if !ok { + return nil, fmt.Errorf("OpenAI 配置不存在: %s", name) + } + s.activeName = name + return profile, nil +} + +func (s *OpenAIState) ListProfiles() openAIListResponse { + s.mu.RLock() + defer s.mu.RUnlock() + profiles := make([]OpenAIConfig, 0, len(s.order)) + for _, name := range s.order { + profile := s.profiles[name] + config := profile.Config + config.APIKey = "" + config.Active = name == s.activeName + profiles = append(profiles, config) + } + return openAIListResponse{Active: s.activeName, Profiles: profiles} +} + +func publicOpenAIConfig(profile *OpenAIProfile, active bool) OpenAIConfig { + config := profile.Config + config.APIKey = "" + config.Active = active + return config +} + // ─── 全局变量 ───────────────────────────────────────────── var ( - cfg *Config - aiClient *ark.Client - store *ConvStore + cfg *Config + aiState *OpenAIState + store *ConvStore ) // ─── 路由 ───────────────────────────────────────────────── func indexHandler(c *gin.Context) { + profile := aiState.ActiveProfile() c.HTML(http.StatusOK, "chat.html", gin.H{ - "Title": "AI 对话", - "Model": cfg.OpenAI.Model, + "Title": "AI 对话", + "Model": profile.Config.Model, + "OpenAIName": profile.Config.Name, + }) +} + +func listOpenAIHandler(c *gin.Context) { + c.JSON(http.StatusOK, aiState.ListProfiles()) +} + +func switchOpenAIHandler(c *gin.Context) { + var req openAIActiveRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()}) + return + } + profile, err := aiState.SwitchActive(req.Name) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{ + "active": profile.Config.Name, + "profile": publicOpenAIConfig(profile, true), }) } @@ -289,6 +523,11 @@ func chatHandler(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "消息不能为空"}) return } + profile, err := aiState.GetProfile(req.OpenAIName) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } chatMessages := req.Messages if req.WebSearch { @@ -320,13 +559,13 @@ func chatHandler(c *gin.Context) { } // 超时 context - timeout := time.Duration(cfg.OpenAI.Timeout) * time.Second + timeout := time.Duration(profile.Config.Timeout) * time.Second ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) defer cancel() // 发起流式请求(使用 CreateChatCompletionStream) - stream, err := aiClient.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{ - Model: cfg.OpenAI.Model, + stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{ + Model: profile.Config.Model, Messages: messages, MaxTokens: intPtr(4096), }.WithStream(true)) @@ -774,17 +1013,12 @@ func main() { os.Exit(1) } - if cfg.OpenAI.APIKey == "" { - fmt.Fprintln(os.Stderr, "错误: openai.api_key 未配置,也未设置环境变量 ARK_API_KEY") + // 初始化火山方舟 SDK 客户端 + aiState, err = NewOpenAIState(cfg.OpenAI) + if err != nil { + fmt.Fprintln(os.Stderr, "OpenAI 配置初始化失败:", err) os.Exit(1) } - - // 初始化火山方舟 SDK 客户端 - aiClient = ark.NewClientWithApiKey( - cfg.OpenAI.APIKey, - ark.WithBaseUrl(cfg.OpenAI.BaseURL), - ark.WithTimeout(time.Duration(cfg.OpenAI.Timeout)*time.Second), - ) store = NewConvStore("conversations") // Gin 路由 @@ -794,6 +1028,8 @@ func main() { r.GET("/", indexHandler) r.POST("/api/chat", chatHandler) + r.GET("/api/openai", listOpenAIHandler) + r.POST("/api/openai/active", switchOpenAIHandler) r.GET("/api/conversations", listConversationsHandler) r.POST("/api/conversations", createConversationHandler) r.GET("/api/conversations/:id", getConversationHandler) diff --git a/templates/chat.html b/templates/chat.html index 51060bb..bf1ba1b 100644 --- a/templates/chat.html +++ b/templates/chat.html @@ -144,6 +144,15 @@ padding: 3px 10px; border-radius: 20px; } + header .model-select { + max-width: 260px; + cursor: pointer; + outline: none; + } + header .model-select:disabled { + opacity: .65; + cursor: not-allowed; + } .header-actions { display: flex; align-items: center; @@ -466,7 +475,9 @@ {{ .Title }}
- {{ .Model }} + @@ -531,6 +542,8 @@ let history = []; // {role, content, image_url?} let currentConvId = null; let pending = false; let webSearchEnabled = false; +let openAIProfiles = []; +let activeOpenAIName = '{{ .OpenAIName }}'; let imageB64 = ''; // 当前待发送图片的 data URI let imageName = ''; @@ -542,6 +555,7 @@ const btnSend = document.getElementById('btnSend'); const btnClear = document.getElementById('btnClear'); const btnPreset = document.getElementById('btnPreset'); const btnSearch = document.getElementById('btnSearch'); +const modelSelect = document.getElementById('modelSelect'); const btnNewChat = document.getElementById('btnNewChat'); const convList = document.getElementById('convList'); const presetModal = document.getElementById('presetModal'); @@ -615,6 +629,7 @@ function setInputDisabled(disabled) { inputBox.disabled = disabled; fileInput.disabled = disabled; btnSearch.disabled = disabled; + modelSelect.disabled = disabled || openAIProfiles.length <= 1; } function updateSearchButton() { @@ -622,6 +637,27 @@ function updateSearchButton() { btnSearch.textContent = webSearchEnabled ? '联网搜索:开' : '联网搜索:关'; } +async function loadOpenAIProfiles() { + const res = await fetch('/api/openai'); + if (!res.ok) { + const err = await res.json().catch(() => ({ error: '加载模型配置失败' })); + throw new Error(err.error || '加载模型配置失败'); + } + const data = await res.json(); + openAIProfiles = Array.isArray(data.profiles) ? data.profiles : []; + activeOpenAIName = data.active || activeOpenAIName; + + modelSelect.innerHTML = ''; + for (const profile of openAIProfiles) { + const opt = document.createElement('option'); + opt.value = profile.name; + opt.textContent = `${profile.name} · ${profile.model}`; + opt.selected = profile.name === activeOpenAIName; + modelSelect.appendChild(opt); + } + modelSelect.disabled = pending || openAIProfiles.length <= 1; +} + // ── 对话列表 ────────────────────────────────────────────── async function loadConversationList() { try { @@ -784,7 +820,12 @@ async function streamChat(messages, aiBubble, webSearch = false) { const res = await fetch('/api/chat', { method: 'POST', headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ conversation_id: currentConvId, messages, web_search: webSearch }), + body: JSON.stringify({ + conversation_id: currentConvId, + messages, + web_search: webSearch, + openai_name: activeOpenAIName, + }), }); if (!res.ok) { @@ -997,6 +1038,31 @@ btnSearch.addEventListener('click', () => { webSearchEnabled = !webSearchEnabled; updateSearchButton(); }); +modelSelect.addEventListener('change', async () => { + if (pending) { + modelSelect.value = activeOpenAIName; + return; + } + const nextName = modelSelect.value; + const prevName = activeOpenAIName; + try { + const res = await fetch('/api/openai/active', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ name: nextName }), + }); + if (!res.ok) { + const err = await res.json().catch(() => ({ error: '切换模型失败' })); + throw new Error(err.error || '切换模型失败'); + } + const data = await res.json(); + activeOpenAIName = data.active; + modelSelect.value = activeOpenAIName; + } catch (e) { + modelSelect.value = prevName; + alert(e.message); + } +}); btnNewChat.addEventListener('click', newConversation); btnClear.addEventListener('click', newConversation); btnPreset.addEventListener('click', openPresetModal); @@ -1016,6 +1082,7 @@ presetModal.addEventListener('click', e => { // 自动聚焦 & 初始化 updateSearchButton(); +loadOpenAIProfiles().catch(e => alert(e.message)); loadConversationList(); inputBox.focus();