This commit is contained in:
2026-06-09 13:01:01 +08:00
parent 4dcf8c351a
commit 4429efdc7c
2 changed files with 351 additions and 48 deletions
+282 -46
View File
@@ -27,17 +27,54 @@ import (
// ─── 配置 ─────────────────────────────────────────────────
const (
defaultOpenAIBaseURL = "https://ark.cn-beijing.volces.com/api/v3"
defaultOpenAITimeout = 120
)
type OpenAIConfig struct {
Name string `yaml:"name" json:"name"`
Active bool `yaml:"active,omitempty" json:"active"`
APIKey string `yaml:"api_key" json:"-"`
BaseURL string `yaml:"base_url" json:"base_url"`
Model string `yaml:"model" json:"model"`
Timeout int `yaml:"timeout" json:"timeout"`
}
type OpenAIConfigs []OpenAIConfig
func (configs *OpenAIConfigs) UnmarshalYAML(value *yaml.Node) error {
switch value.Kind {
case yaml.SequenceNode:
var items []OpenAIConfig
if err := value.Decode(&items); err != nil {
return err
}
*configs = items
case yaml.MappingNode:
var item OpenAIConfig
if err := value.Decode(&item); err != nil {
return err
}
*configs = []OpenAIConfig{item}
case yaml.ScalarNode:
if value.Tag == "!!null" {
*configs = nil
return nil
}
return fmt.Errorf("openai 配置格式无效")
default:
return fmt.Errorf("openai 配置格式无效")
}
return nil
}
type Config struct {
Server struct {
Mode string `yaml:"mode"`
Address string `yaml:"address"`
} `yaml:"server"`
OpenAI struct {
APIKey string `yaml:"api_key"`
BaseURL string `yaml:"base_url"`
Model string `yaml:"model"`
Timeout int `yaml:"timeout"`
} `yaml:"openai"`
OpenAI OpenAIConfigs `yaml:"openai"`
Search struct {
Enabled bool `yaml:"enabled"`
Provider string `yaml:"provider"`
@@ -48,12 +85,20 @@ type Config struct {
} `yaml:"search"`
}
func defaultOpenAIConfig() OpenAIConfig {
return OpenAIConfig{
Name: "default",
Active: true,
BaseURL: defaultOpenAIBaseURL,
Timeout: defaultOpenAITimeout,
}
}
func defaultConfig() Config {
var cfg Config
cfg.Server.Mode = "tcp"
cfg.Server.Address = "0.0.0.0:8080"
cfg.OpenAI.BaseURL = "https://ark.cn-beijing.volces.com/api/v3"
cfg.OpenAI.Timeout = 120
cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()}
cfg.Search.Provider = "brave"
cfg.Search.BaseURL = "https://api.search.brave.com/res/v1/web/search"
cfg.Search.Count = 5
@@ -74,9 +119,14 @@ func loadConfig(path string) (*Config, error) {
if err = yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("解析配置文件失败: %w", err)
}
if _, err := normalizeOpenAIConfigs(&cfg); err != nil {
return nil, err
}
// 环境变量优先
if key := os.Getenv("ARK_API_KEY"); key != "" {
cfg.OpenAI.APIKey = key
for i := range cfg.OpenAI {
cfg.OpenAI[i].APIKey = key
}
}
if key := os.Getenv("BRAVE_SEARCH_API_KEY"); key != "" {
cfg.Search.APIKey = key
@@ -122,27 +172,13 @@ func ensureConfigFile(path string) error {
}
}
openai, _ := raw["openai"].(map[string]any)
if openai == nil {
cfg.OpenAI = defaults.OpenAI
if _, ok := raw["openai"].([]any); !ok {
changed = true
}
if normalized, err := normalizeOpenAIConfigs(&cfg); err != nil {
return err
} else if normalized {
changed = true
} else {
if _, ok := openai["api_key"]; !ok {
cfg.OpenAI.APIKey = defaults.OpenAI.APIKey
changed = true
}
if _, ok := openai["base_url"]; !ok {
cfg.OpenAI.BaseURL = defaults.OpenAI.BaseURL
changed = true
}
if _, ok := openai["model"]; !ok {
cfg.OpenAI.Model = defaults.OpenAI.Model
changed = true
}
if _, ok := openai["timeout"]; !ok {
cfg.OpenAI.Timeout = defaults.OpenAI.Timeout
changed = true
}
}
search, _ := raw["search"].(map[string]any)
@@ -182,6 +218,58 @@ func ensureConfigFile(path string) error {
return writeConfig(path, cfg)
}
func normalizeOpenAIConfigs(cfg *Config) (bool, error) {
changed := false
if len(cfg.OpenAI) == 0 {
cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()}
changed = true
}
activeIndex := -1
seen := map[string]bool{}
for i := range cfg.OpenAI {
profile := &cfg.OpenAI[i]
name := strings.TrimSpace(profile.Name)
if name == "" {
name = strings.TrimSpace(profile.Model)
if name == "" {
name = fmt.Sprintf("openai-%d", i+1)
}
profile.Name = name
changed = true
} else if name != profile.Name {
profile.Name = name
changed = true
}
if seen[name] {
return changed, fmt.Errorf("openai 配置名称重复: %s", name)
}
seen[name] = true
if strings.TrimSpace(profile.BaseURL) == "" {
profile.BaseURL = defaultOpenAIBaseURL
changed = true
}
if profile.Timeout <= 0 {
profile.Timeout = defaultOpenAITimeout
changed = true
}
if profile.Active {
if activeIndex == -1 {
activeIndex = i
} else {
profile.Active = false
changed = true
}
}
}
if activeIndex == -1 {
cfg.OpenAI[0].Active = true
changed = true
}
return changed, nil
}
func writeConfig(path string, cfg Config) error {
data, err := yaml.Marshal(&cfg)
if err != nil {
@@ -207,6 +295,7 @@ type ChatRequest struct {
ConversationID string `json:"conversation_id,omitempty"`
Messages []ChatMessage `json:"messages"`
WebSearch bool `json:"web_search,omitempty"`
OpenAIName string `json:"openai_name,omitempty"`
}
type Conversation struct {
@@ -222,20 +311,165 @@ type ConvStore struct {
mu sync.Mutex
}
type OpenAIProfile struct {
Config OpenAIConfig
Client *ark.Client
}
type OpenAIState struct {
mu sync.RWMutex
profiles map[string]*OpenAIProfile
order []string
activeName string
}
type openAIActiveRequest struct {
Name string `json:"name"`
}
type openAIListResponse struct {
Active string `json:"active"`
Profiles []OpenAIConfig `json:"profiles"`
}
func NewOpenAIState(configs []OpenAIConfig) (*OpenAIState, error) {
state := &OpenAIState{
profiles: make(map[string]*OpenAIProfile, len(configs)),
order: make([]string, 0, len(configs)),
}
for _, config := range configs {
if strings.TrimSpace(config.Name) == "" {
return nil, errors.New("openai.name 不能为空")
}
if strings.TrimSpace(config.APIKey) == "" {
return nil, fmt.Errorf("openai.%s.api_key 未配置,也未设置环境变量 ARK_API_KEY", config.Name)
}
if strings.TrimSpace(config.Model) == "" {
return nil, fmt.Errorf("openai.%s.model 未配置", config.Name)
}
if strings.TrimSpace(config.BaseURL) == "" {
return nil, fmt.Errorf("openai.%s.base_url 未配置", config.Name)
}
if config.Timeout <= 0 {
return nil, fmt.Errorf("openai.%s.timeout 必须大于 0", config.Name)
}
if _, ok := state.profiles[config.Name]; ok {
return nil, fmt.Errorf("openai 配置名称重复: %s", config.Name)
}
state.profiles[config.Name] = &OpenAIProfile{
Config: config,
Client: ark.NewClientWithApiKey(
config.APIKey,
ark.WithBaseUrl(config.BaseURL),
ark.WithTimeout(time.Duration(config.Timeout)*time.Second),
),
}
state.order = append(state.order, config.Name)
if config.Active && state.activeName == "" {
state.activeName = config.Name
}
}
if len(state.order) == 0 {
return nil, errors.New("openai 配置不能为空")
}
if state.activeName == "" {
state.activeName = state.order[0]
}
return state, nil
}
func (s *OpenAIState) ActiveProfile() *OpenAIProfile {
s.mu.RLock()
defer s.mu.RUnlock()
return s.profiles[s.activeName]
}
func (s *OpenAIState) GetProfile(name string) (*OpenAIProfile, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if strings.TrimSpace(name) == "" {
return s.profiles[s.activeName], nil
}
profile, ok := s.profiles[strings.TrimSpace(name)]
if !ok {
return nil, fmt.Errorf("OpenAI 配置不存在: %s", name)
}
return profile, nil
}
func (s *OpenAIState) SwitchActive(name string) (*OpenAIProfile, error) {
name = strings.TrimSpace(name)
if name == "" {
return nil, errors.New("OpenAI 配置名称不能为空")
}
s.mu.Lock()
defer s.mu.Unlock()
profile, ok := s.profiles[name]
if !ok {
return nil, fmt.Errorf("OpenAI 配置不存在: %s", name)
}
s.activeName = name
return profile, nil
}
func (s *OpenAIState) ListProfiles() openAIListResponse {
s.mu.RLock()
defer s.mu.RUnlock()
profiles := make([]OpenAIConfig, 0, len(s.order))
for _, name := range s.order {
profile := s.profiles[name]
config := profile.Config
config.APIKey = ""
config.Active = name == s.activeName
profiles = append(profiles, config)
}
return openAIListResponse{Active: s.activeName, Profiles: profiles}
}
func publicOpenAIConfig(profile *OpenAIProfile, active bool) OpenAIConfig {
config := profile.Config
config.APIKey = ""
config.Active = active
return config
}
// ─── 全局变量 ─────────────────────────────────────────────
var (
cfg *Config
aiClient *ark.Client
store *ConvStore
cfg *Config
aiState *OpenAIState
store *ConvStore
)
// ─── 路由 ─────────────────────────────────────────────────
func indexHandler(c *gin.Context) {
profile := aiState.ActiveProfile()
c.HTML(http.StatusOK, "chat.html", gin.H{
"Title": "AI 对话",
"Model": cfg.OpenAI.Model,
"Title": "AI 对话",
"Model": profile.Config.Model,
"OpenAIName": profile.Config.Name,
})
}
func listOpenAIHandler(c *gin.Context) {
c.JSON(http.StatusOK, aiState.ListProfiles())
}
func switchOpenAIHandler(c *gin.Context) {
var req openAIActiveRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
return
}
profile, err := aiState.SwitchActive(req.Name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"active": profile.Config.Name,
"profile": publicOpenAIConfig(profile, true),
})
}
@@ -289,6 +523,11 @@ func chatHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "消息不能为空"})
return
}
profile, err := aiState.GetProfile(req.OpenAIName)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
chatMessages := req.Messages
if req.WebSearch {
@@ -320,13 +559,13 @@ func chatHandler(c *gin.Context) {
}
// 超时 context
timeout := time.Duration(cfg.OpenAI.Timeout) * time.Second
timeout := time.Duration(profile.Config.Timeout) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// 发起流式请求(使用 CreateChatCompletionStream
stream, err := aiClient.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
Model: cfg.OpenAI.Model,
stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
Model: profile.Config.Model,
Messages: messages,
MaxTokens: intPtr(4096),
}.WithStream(true))
@@ -774,17 +1013,12 @@ func main() {
os.Exit(1)
}
if cfg.OpenAI.APIKey == "" {
fmt.Fprintln(os.Stderr, "错误: openai.api_key 未配置,也未设置环境变量 ARK_API_KEY")
// 初始化火山方舟 SDK 客户端
aiState, err = NewOpenAIState(cfg.OpenAI)
if err != nil {
fmt.Fprintln(os.Stderr, "OpenAI 配置初始化失败:", err)
os.Exit(1)
}
// 初始化火山方舟 SDK 客户端
aiClient = ark.NewClientWithApiKey(
cfg.OpenAI.APIKey,
ark.WithBaseUrl(cfg.OpenAI.BaseURL),
ark.WithTimeout(time.Duration(cfg.OpenAI.Timeout)*time.Second),
)
store = NewConvStore("conversations")
// Gin 路由
@@ -794,6 +1028,8 @@ func main() {
r.GET("/", indexHandler)
r.POST("/api/chat", chatHandler)
r.GET("/api/openai", listOpenAIHandler)
r.POST("/api/openai/active", switchOpenAIHandler)
r.GET("/api/conversations", listConversationsHandler)
r.POST("/api/conversations", createConversationHandler)
r.GET("/api/conversations/:id", getConversationHandler)
+69 -2
View File
@@ -144,6 +144,15 @@
padding: 3px 10px;
border-radius: 20px;
}
header .model-select {
max-width: 260px;
cursor: pointer;
outline: none;
}
header .model-select:disabled {
opacity: .65;
cursor: not-allowed;
}
.header-actions {
display: flex;
align-items: center;
@@ -466,7 +475,9 @@
{{ .Title }}
</div>
<div class="header-actions">
<span class="model-badge">{{ .Model }}</span>
<select id="modelSelect" class="model-badge model-select" title="切换 OpenAI 配置">
<option value="{{ .OpenAIName }}">{{ .Model }}</option>
</select>
<button id="btnSearch" title="开启后,本轮提问会先联网搜索">联网搜索:关</button>
<button id="btnPreset" title="设置预先提示词">预设</button>
<button id="btnClear" title="开始新对话">新对话</button>
@@ -531,6 +542,8 @@ let history = []; // {role, content, image_url?}
let currentConvId = null;
let pending = false;
let webSearchEnabled = false;
let openAIProfiles = [];
let activeOpenAIName = '{{ .OpenAIName }}';
let imageB64 = ''; // 当前待发送图片的 data URI
let imageName = '';
@@ -542,6 +555,7 @@ const btnSend = document.getElementById('btnSend');
const btnClear = document.getElementById('btnClear');
const btnPreset = document.getElementById('btnPreset');
const btnSearch = document.getElementById('btnSearch');
const modelSelect = document.getElementById('modelSelect');
const btnNewChat = document.getElementById('btnNewChat');
const convList = document.getElementById('convList');
const presetModal = document.getElementById('presetModal');
@@ -615,6 +629,7 @@ function setInputDisabled(disabled) {
inputBox.disabled = disabled;
fileInput.disabled = disabled;
btnSearch.disabled = disabled;
modelSelect.disabled = disabled || openAIProfiles.length <= 1;
}
function updateSearchButton() {
@@ -622,6 +637,27 @@ function updateSearchButton() {
btnSearch.textContent = webSearchEnabled ? '联网搜索:开' : '联网搜索:关';
}
async function loadOpenAIProfiles() {
const res = await fetch('/api/openai');
if (!res.ok) {
const err = await res.json().catch(() => ({ error: '加载模型配置失败' }));
throw new Error(err.error || '加载模型配置失败');
}
const data = await res.json();
openAIProfiles = Array.isArray(data.profiles) ? data.profiles : [];
activeOpenAIName = data.active || activeOpenAIName;
modelSelect.innerHTML = '';
for (const profile of openAIProfiles) {
const opt = document.createElement('option');
opt.value = profile.name;
opt.textContent = `${profile.name} · ${profile.model}`;
opt.selected = profile.name === activeOpenAIName;
modelSelect.appendChild(opt);
}
modelSelect.disabled = pending || openAIProfiles.length <= 1;
}
// ── 对话列表 ──────────────────────────────────────────────
async function loadConversationList() {
try {
@@ -784,7 +820,12 @@ async function streamChat(messages, aiBubble, webSearch = false) {
const res = await fetch('/api/chat', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ conversation_id: currentConvId, messages, web_search: webSearch }),
body: JSON.stringify({
conversation_id: currentConvId,
messages,
web_search: webSearch,
openai_name: activeOpenAIName,
}),
});
if (!res.ok) {
@@ -997,6 +1038,31 @@ btnSearch.addEventListener('click', () => {
webSearchEnabled = !webSearchEnabled;
updateSearchButton();
});
modelSelect.addEventListener('change', async () => {
if (pending) {
modelSelect.value = activeOpenAIName;
return;
}
const nextName = modelSelect.value;
const prevName = activeOpenAIName;
try {
const res = await fetch('/api/openai/active', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ name: nextName }),
});
if (!res.ok) {
const err = await res.json().catch(() => ({ error: '切换模型失败' }));
throw new Error(err.error || '切换模型失败');
}
const data = await res.json();
activeOpenAIName = data.active;
modelSelect.value = activeOpenAIName;
} catch (e) {
modelSelect.value = prevName;
alert(e.message);
}
});
btnNewChat.addEventListener('click', newConversation);
btnClear.addEventListener('click', newConversation);
btnPreset.addEventListener('click', openPresetModal);
@@ -1016,6 +1082,7 @@ presetModal.addEventListener('click', e => {
// 自动聚焦 & 初始化
updateSearchButton();
loadOpenAIProfiles().catch(e => alert(e.message));
loadConversationList();
inputBox.focus();
</script>