1409 lines
38 KiB
Go
1409 lines
38 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"crypto/rand"
|
||
"encoding/base64"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net"
|
||
"net/http"
|
||
"net/url"
|
||
"os"
|
||
"path/filepath"
|
||
"sort"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
ark "github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||
"gopkg.in/yaml.v3"
|
||
)
|
||
|
||
// ─── 配置 ─────────────────────────────────────────────────
|
||
|
||
const (
|
||
defaultOpenAIBaseURL = "https://ark.cn-beijing.volces.com/api/v3"
|
||
defaultOpenAITimeout = 120
|
||
defaultSearchBaseURL = "https://api.duckduckgo.com/"
|
||
defaultSearchTimeout = 10
|
||
defaultSearchCount = 5
|
||
)
|
||
|
||
type OpenAIConfig struct {
|
||
Name string `yaml:"name" json:"name"`
|
||
Active bool `yaml:"active,omitempty" json:"active"`
|
||
APIKey string `yaml:"api_key" json:"-"`
|
||
BaseURL string `yaml:"base_url" json:"base_url"`
|
||
Model string `yaml:"model" json:"model"`
|
||
Timeout int `yaml:"timeout" json:"timeout"`
|
||
}
|
||
|
||
type OpenAIConfigs []OpenAIConfig
|
||
|
||
type SearchConfig struct {
|
||
Name string `yaml:"name" json:"name"`
|
||
Active bool `yaml:"active,omitempty" json:"active"`
|
||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||
Provider string `yaml:"provider" json:"provider"`
|
||
APIKey string `yaml:"api_key" json:"-"`
|
||
BaseURL string `yaml:"base_url" json:"base_url"`
|
||
Count int `yaml:"count" json:"count"`
|
||
Timeout int `yaml:"timeout" json:"timeout"`
|
||
}
|
||
|
||
type SearchConfigs []SearchConfig
|
||
|
||
func (configs *SearchConfigs) UnmarshalYAML(value *yaml.Node) error {
|
||
switch value.Kind {
|
||
case yaml.SequenceNode:
|
||
var items []SearchConfig
|
||
if err := value.Decode(&items); err != nil {
|
||
return err
|
||
}
|
||
*configs = items
|
||
case yaml.MappingNode:
|
||
var item SearchConfig
|
||
if err := value.Decode(&item); err != nil {
|
||
return err
|
||
}
|
||
*configs = []SearchConfig{item}
|
||
case yaml.ScalarNode:
|
||
if value.Tag == "!!null" {
|
||
*configs = nil
|
||
return nil
|
||
}
|
||
return fmt.Errorf("search 配置格式无效")
|
||
default:
|
||
return fmt.Errorf("search 配置格式无效")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (configs *OpenAIConfigs) UnmarshalYAML(value *yaml.Node) error {
|
||
switch value.Kind {
|
||
case yaml.SequenceNode:
|
||
var items []OpenAIConfig
|
||
if err := value.Decode(&items); err != nil {
|
||
return err
|
||
}
|
||
*configs = items
|
||
case yaml.MappingNode:
|
||
var item OpenAIConfig
|
||
if err := value.Decode(&item); err != nil {
|
||
return err
|
||
}
|
||
*configs = []OpenAIConfig{item}
|
||
case yaml.ScalarNode:
|
||
if value.Tag == "!!null" {
|
||
*configs = nil
|
||
return nil
|
||
}
|
||
return fmt.Errorf("openai 配置格式无效")
|
||
default:
|
||
return fmt.Errorf("openai 配置格式无效")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
type Config struct {
|
||
Server struct {
|
||
Mode string `yaml:"mode"`
|
||
Address string `yaml:"address"`
|
||
} `yaml:"server"`
|
||
OpenAI OpenAIConfigs `yaml:"openai"`
|
||
Search SearchConfigs `yaml:"search"`
|
||
}
|
||
|
||
func defaultOpenAIConfig() OpenAIConfig {
|
||
return OpenAIConfig{
|
||
Name: "default",
|
||
Active: true,
|
||
BaseURL: defaultOpenAIBaseURL,
|
||
Timeout: defaultOpenAITimeout,
|
||
}
|
||
}
|
||
|
||
func defaultSearchConfig() SearchConfig {
|
||
return SearchConfig{
|
||
Name: "duckduckgo",
|
||
Active: true,
|
||
Enabled: true,
|
||
Provider: "duckduckgo",
|
||
BaseURL: defaultSearchBaseURL,
|
||
Count: defaultSearchCount,
|
||
Timeout: defaultSearchTimeout,
|
||
}
|
||
}
|
||
|
||
func defaultConfig() Config {
|
||
var cfg Config
|
||
cfg.Server.Mode = "tcp"
|
||
cfg.Server.Address = "0.0.0.0:8080"
|
||
cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()}
|
||
cfg.Search = SearchConfigs{defaultSearchConfig()}
|
||
return cfg
|
||
}
|
||
|
||
func loadConfig(path string) (*Config, error) {
|
||
if err := ensureConfigFile(path); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
data, err := os.ReadFile(path)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
||
}
|
||
var cfg Config
|
||
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
||
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||
}
|
||
if _, err := normalizeOpenAIConfigs(&cfg); err != nil {
|
||
return nil, err
|
||
}
|
||
// 环境变量优先
|
||
if key := os.Getenv("ARK_API_KEY"); key != "" {
|
||
for i := range cfg.OpenAI {
|
||
cfg.OpenAI[i].APIKey = key
|
||
}
|
||
}
|
||
if _, err := normalizeSearchConfigs(&cfg); err != nil {
|
||
return nil, err
|
||
}
|
||
if key := os.Getenv("BRAVE_SEARCH_API_KEY"); key != "" {
|
||
for i := range cfg.Search {
|
||
if strings.ToLower(cfg.Search[i].Provider) == "brave" {
|
||
cfg.Search[i].APIKey = key
|
||
}
|
||
}
|
||
}
|
||
return &cfg, nil
|
||
}
|
||
|
||
func ensureConfigFile(path string) error {
|
||
defaults := defaultConfig()
|
||
if _, err := os.Stat(path); err != nil {
|
||
if !os.IsNotExist(err) {
|
||
return fmt.Errorf("检查配置文件失败: %w", err)
|
||
}
|
||
return writeConfig(path, defaults)
|
||
}
|
||
|
||
data, err := os.ReadFile(path)
|
||
if err != nil {
|
||
return fmt.Errorf("读取配置文件失败: %w", err)
|
||
}
|
||
var cfg Config
|
||
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
||
return fmt.Errorf("解析配置文件失败: %w", err)
|
||
}
|
||
var raw map[string]any
|
||
if err = yaml.Unmarshal(data, &raw); err != nil {
|
||
return fmt.Errorf("解析配置文件失败: %w", err)
|
||
}
|
||
|
||
changed := false
|
||
server, _ := raw["server"].(map[string]any)
|
||
if server == nil {
|
||
cfg.Server = defaults.Server
|
||
changed = true
|
||
} else {
|
||
if _, ok := server["mode"]; !ok {
|
||
cfg.Server.Mode = defaults.Server.Mode
|
||
changed = true
|
||
}
|
||
if _, ok := server["address"]; !ok {
|
||
cfg.Server.Address = defaults.Server.Address
|
||
changed = true
|
||
}
|
||
}
|
||
|
||
if _, ok := raw["openai"].([]any); !ok {
|
||
changed = true
|
||
}
|
||
if normalized, err := normalizeOpenAIConfigs(&cfg); err != nil {
|
||
return err
|
||
} else if normalized {
|
||
changed = true
|
||
}
|
||
|
||
if _, ok := raw["search"].([]any); !ok {
|
||
changed = true
|
||
}
|
||
if normalized, err := normalizeSearchConfigs(&cfg); err != nil {
|
||
return err
|
||
} else if normalized {
|
||
changed = true
|
||
}
|
||
|
||
if !changed {
|
||
return nil
|
||
}
|
||
return writeConfig(path, cfg)
|
||
}
|
||
|
||
func normalizeOpenAIConfigs(cfg *Config) (bool, error) {
|
||
changed := false
|
||
if len(cfg.OpenAI) == 0 {
|
||
cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()}
|
||
changed = true
|
||
}
|
||
|
||
activeIndex := -1
|
||
seen := map[string]bool{}
|
||
for i := range cfg.OpenAI {
|
||
profile := &cfg.OpenAI[i]
|
||
name := strings.TrimSpace(profile.Name)
|
||
if name == "" {
|
||
name = strings.TrimSpace(profile.Model)
|
||
if name == "" {
|
||
name = fmt.Sprintf("openai-%d", i+1)
|
||
}
|
||
profile.Name = name
|
||
changed = true
|
||
} else if name != profile.Name {
|
||
profile.Name = name
|
||
changed = true
|
||
}
|
||
if seen[name] {
|
||
return changed, fmt.Errorf("openai 配置名称重复: %s", name)
|
||
}
|
||
seen[name] = true
|
||
|
||
if strings.TrimSpace(profile.BaseURL) == "" {
|
||
profile.BaseURL = defaultOpenAIBaseURL
|
||
changed = true
|
||
}
|
||
if profile.Timeout <= 0 {
|
||
profile.Timeout = defaultOpenAITimeout
|
||
changed = true
|
||
}
|
||
if profile.Active {
|
||
if activeIndex == -1 {
|
||
activeIndex = i
|
||
} else {
|
||
profile.Active = false
|
||
changed = true
|
||
}
|
||
}
|
||
}
|
||
if activeIndex == -1 {
|
||
cfg.OpenAI[0].Active = true
|
||
changed = true
|
||
}
|
||
return changed, nil
|
||
}
|
||
|
||
func normalizeSearchConfigs(cfg *Config) (bool, error) {
|
||
changed := false
|
||
if len(cfg.Search) == 0 {
|
||
cfg.Search = SearchConfigs{defaultSearchConfig()}
|
||
changed = true
|
||
}
|
||
|
||
activeIndex := -1
|
||
seen := map[string]bool{}
|
||
for i := range cfg.Search {
|
||
profile := &cfg.Search[i]
|
||
profile.Provider = strings.ToLower(strings.TrimSpace(profile.Provider))
|
||
if profile.Provider == "" {
|
||
profile.Provider = "duckduckgo"
|
||
changed = true
|
||
}
|
||
name := strings.TrimSpace(profile.Name)
|
||
if name == "" {
|
||
name = profile.Provider
|
||
if seen[name] {
|
||
name = fmt.Sprintf("%s-%d", profile.Provider, i+1)
|
||
}
|
||
profile.Name = name
|
||
changed = true
|
||
} else if name != profile.Name {
|
||
profile.Name = name
|
||
changed = true
|
||
}
|
||
if seen[name] {
|
||
return changed, fmt.Errorf("search 配置名称重复: %s", name)
|
||
}
|
||
seen[name] = true
|
||
|
||
if strings.TrimSpace(profile.BaseURL) == "" {
|
||
switch profile.Provider {
|
||
case "duckduckgo", "ddg":
|
||
profile.BaseURL = defaultSearchBaseURL
|
||
case "brave":
|
||
profile.BaseURL = "https://api.search.brave.com/res/v1/web/search"
|
||
default:
|
||
return changed, fmt.Errorf("暂不支持搜索服务: %s", profile.Provider)
|
||
}
|
||
changed = true
|
||
}
|
||
if profile.Count <= 0 {
|
||
profile.Count = defaultSearchCount
|
||
changed = true
|
||
}
|
||
if profile.Timeout <= 0 {
|
||
profile.Timeout = defaultSearchTimeout
|
||
changed = true
|
||
}
|
||
if profile.Active {
|
||
if activeIndex == -1 {
|
||
activeIndex = i
|
||
} else {
|
||
profile.Active = false
|
||
changed = true
|
||
}
|
||
}
|
||
}
|
||
if activeIndex == -1 {
|
||
cfg.Search[0].Active = true
|
||
changed = true
|
||
}
|
||
return changed, nil
|
||
}
|
||
|
||
func writeConfig(path string, cfg Config) error {
|
||
data, err := yaml.Marshal(&cfg)
|
||
if err != nil {
|
||
return fmt.Errorf("生成配置文件失败: %w", err)
|
||
}
|
||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||
return fmt.Errorf("写入配置文件失败: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ─── 请求结构 ─────────────────────────────────────────────
|
||
|
||
type ChatMessage struct {
|
||
Role string `json:"role"`
|
||
Content string `json:"content"`
|
||
ImageURL string `json:"image_url,omitempty"` // base64 data URI 或 http URL
|
||
ImageURLAlias string `json:"imageURL,omitempty"`
|
||
Hidden bool `json:"hidden,omitempty"`
|
||
}
|
||
|
||
type ChatRequest struct {
|
||
ConversationID string `json:"conversation_id,omitempty"`
|
||
Messages []ChatMessage `json:"messages"`
|
||
WebSearch bool `json:"web_search,omitempty"`
|
||
OpenAIName string `json:"openai_name,omitempty"`
|
||
}
|
||
|
||
type Conversation struct {
|
||
ID string `json:"id"`
|
||
Title string `json:"title"`
|
||
CreatedAt time.Time `json:"created_at"`
|
||
UpdatedAt time.Time `json:"updated_at"`
|
||
Messages []ChatMessage `json:"messages,omitempty"`
|
||
}
|
||
|
||
type ConvStore struct {
|
||
dir string
|
||
mu sync.Mutex
|
||
}
|
||
|
||
type OpenAIProfile struct {
|
||
Config OpenAIConfig
|
||
Client *ark.Client
|
||
}
|
||
|
||
type OpenAIState struct {
|
||
mu sync.RWMutex
|
||
profiles map[string]*OpenAIProfile
|
||
order []string
|
||
activeName string
|
||
}
|
||
|
||
type activeProfileRequest struct {
|
||
Name string `json:"name"`
|
||
}
|
||
|
||
type openAIListResponse struct {
|
||
Active string `json:"active"`
|
||
Profiles []OpenAIConfig `json:"profiles"`
|
||
}
|
||
|
||
type searchListResponse struct {
|
||
Active string `json:"active"`
|
||
Profiles []SearchConfig `json:"profiles"`
|
||
}
|
||
|
||
func NewOpenAIState(configs []OpenAIConfig) (*OpenAIState, error) {
|
||
state := &OpenAIState{
|
||
profiles: make(map[string]*OpenAIProfile, len(configs)),
|
||
order: make([]string, 0, len(configs)),
|
||
}
|
||
for _, config := range configs {
|
||
if strings.TrimSpace(config.Name) == "" {
|
||
return nil, errors.New("openai.name 不能为空")
|
||
}
|
||
if strings.TrimSpace(config.APIKey) == "" {
|
||
return nil, fmt.Errorf("openai.%s.api_key 未配置,也未设置环境变量 ARK_API_KEY", config.Name)
|
||
}
|
||
if strings.TrimSpace(config.Model) == "" {
|
||
return nil, fmt.Errorf("openai.%s.model 未配置", config.Name)
|
||
}
|
||
if strings.TrimSpace(config.BaseURL) == "" {
|
||
return nil, fmt.Errorf("openai.%s.base_url 未配置", config.Name)
|
||
}
|
||
if config.Timeout <= 0 {
|
||
return nil, fmt.Errorf("openai.%s.timeout 必须大于 0", config.Name)
|
||
}
|
||
if _, ok := state.profiles[config.Name]; ok {
|
||
return nil, fmt.Errorf("openai 配置名称重复: %s", config.Name)
|
||
}
|
||
state.profiles[config.Name] = &OpenAIProfile{
|
||
Config: config,
|
||
Client: ark.NewClientWithApiKey(
|
||
config.APIKey,
|
||
ark.WithBaseUrl(config.BaseURL),
|
||
ark.WithTimeout(time.Duration(config.Timeout)*time.Second),
|
||
),
|
||
}
|
||
state.order = append(state.order, config.Name)
|
||
if config.Active && state.activeName == "" {
|
||
state.activeName = config.Name
|
||
}
|
||
}
|
||
if len(state.order) == 0 {
|
||
return nil, errors.New("openai 配置不能为空")
|
||
}
|
||
if state.activeName == "" {
|
||
state.activeName = state.order[0]
|
||
}
|
||
return state, nil
|
||
}
|
||
|
||
func (s *OpenAIState) ActiveProfile() *OpenAIProfile {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
return s.profiles[s.activeName]
|
||
}
|
||
|
||
func (s *OpenAIState) GetProfile(name string) (*OpenAIProfile, error) {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
if strings.TrimSpace(name) == "" {
|
||
return s.profiles[s.activeName], nil
|
||
}
|
||
profile, ok := s.profiles[strings.TrimSpace(name)]
|
||
if !ok {
|
||
return nil, fmt.Errorf("OpenAI 配置不存在: %s", name)
|
||
}
|
||
return profile, nil
|
||
}
|
||
|
||
func (s *OpenAIState) SwitchActive(name string) (*OpenAIProfile, error) {
|
||
name = strings.TrimSpace(name)
|
||
if name == "" {
|
||
return nil, errors.New("OpenAI 配置名称不能为空")
|
||
}
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
profile, ok := s.profiles[name]
|
||
if !ok {
|
||
return nil, fmt.Errorf("OpenAI 配置不存在: %s", name)
|
||
}
|
||
s.activeName = name
|
||
return profile, nil
|
||
}
|
||
|
||
func (s *OpenAIState) ListProfiles() openAIListResponse {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
profiles := make([]OpenAIConfig, 0, len(s.order))
|
||
for _, name := range s.order {
|
||
profile := s.profiles[name]
|
||
config := profile.Config
|
||
config.APIKey = ""
|
||
config.Active = name == s.activeName
|
||
profiles = append(profiles, config)
|
||
}
|
||
return openAIListResponse{Active: s.activeName, Profiles: profiles}
|
||
}
|
||
|
||
func publicOpenAIConfig(profile *OpenAIProfile, active bool) OpenAIConfig {
|
||
config := profile.Config
|
||
config.APIKey = ""
|
||
config.Active = active
|
||
return config
|
||
}
|
||
|
||
type SearchState struct {
|
||
mu sync.RWMutex
|
||
profiles map[string]SearchConfig
|
||
order []string
|
||
activeName string
|
||
}
|
||
|
||
func NewSearchState(configs []SearchConfig) (*SearchState, error) {
|
||
state := &SearchState{
|
||
profiles: make(map[string]SearchConfig, len(configs)),
|
||
order: make([]string, 0, len(configs)),
|
||
}
|
||
for _, config := range configs {
|
||
if strings.TrimSpace(config.Name) == "" {
|
||
return nil, errors.New("search.name 不能为空")
|
||
}
|
||
if strings.TrimSpace(config.Provider) == "" {
|
||
return nil, fmt.Errorf("search.%s.provider 未配置", config.Name)
|
||
}
|
||
if strings.TrimSpace(config.BaseURL) == "" {
|
||
return nil, fmt.Errorf("search.%s.base_url 未配置", config.Name)
|
||
}
|
||
if config.Timeout <= 0 {
|
||
return nil, fmt.Errorf("search.%s.timeout 必须大于 0", config.Name)
|
||
}
|
||
if _, ok := state.profiles[config.Name]; ok {
|
||
return nil, fmt.Errorf("search 配置名称重复: %s", config.Name)
|
||
}
|
||
state.profiles[config.Name] = config
|
||
state.order = append(state.order, config.Name)
|
||
if config.Active && state.activeName == "" {
|
||
state.activeName = config.Name
|
||
}
|
||
}
|
||
if len(state.order) == 0 {
|
||
return nil, errors.New("search 配置不能为空")
|
||
}
|
||
if state.activeName == "" {
|
||
state.activeName = state.order[0]
|
||
}
|
||
return state, nil
|
||
}
|
||
|
||
func (s *SearchState) ActiveProfile() SearchConfig {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
return s.profiles[s.activeName]
|
||
}
|
||
|
||
func (s *SearchState) SwitchActive(name string) (SearchConfig, error) {
|
||
name = strings.TrimSpace(name)
|
||
if name == "" {
|
||
return SearchConfig{}, errors.New("搜索配置名称不能为空")
|
||
}
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
profile, ok := s.profiles[name]
|
||
if !ok {
|
||
return SearchConfig{}, fmt.Errorf("搜索配置不存在: %s", name)
|
||
}
|
||
s.activeName = name
|
||
return profile, nil
|
||
}
|
||
|
||
func (s *SearchState) ListProfiles() searchListResponse {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
profiles := make([]SearchConfig, 0, len(s.order))
|
||
for _, name := range s.order {
|
||
profile := s.profiles[name]
|
||
profile.APIKey = ""
|
||
profile.Active = name == s.activeName
|
||
profiles = append(profiles, profile)
|
||
}
|
||
return searchListResponse{Active: s.activeName, Profiles: profiles}
|
||
}
|
||
|
||
func publicSearchConfig(config SearchConfig, active bool) SearchConfig {
|
||
config.APIKey = ""
|
||
config.Active = active
|
||
return config
|
||
}
|
||
|
||
// ─── 全局变量 ─────────────────────────────────────────────
|
||
|
||
var (
|
||
cfg *Config
|
||
aiState *OpenAIState
|
||
searchState *SearchState
|
||
store *ConvStore
|
||
)
|
||
|
||
// ─── 路由 ─────────────────────────────────────────────────
|
||
|
||
func indexHandler(c *gin.Context) {
|
||
profile := aiState.ActiveProfile()
|
||
c.HTML(http.StatusOK, "chat.html", gin.H{
|
||
"Title": "AI 对话",
|
||
"Model": profile.Config.Model,
|
||
"OpenAIName": profile.Config.Name,
|
||
})
|
||
}
|
||
|
||
func listOpenAIHandler(c *gin.Context) {
|
||
c.JSON(http.StatusOK, aiState.ListProfiles())
|
||
}
|
||
|
||
func switchOpenAIHandler(c *gin.Context) {
|
||
var req activeProfileRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
|
||
return
|
||
}
|
||
profile, err := aiState.SwitchActive(req.Name)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"active": profile.Config.Name,
|
||
"profile": publicOpenAIConfig(profile, true),
|
||
})
|
||
}
|
||
|
||
func listSearchHandler(c *gin.Context) {
|
||
c.JSON(http.StatusOK, searchState.ListProfiles())
|
||
}
|
||
|
||
func switchSearchHandler(c *gin.Context) {
|
||
var req activeProfileRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
|
||
return
|
||
}
|
||
profile, err := searchState.SwitchActive(req.Name)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, gin.H{
|
||
"active": profile.Name,
|
||
"profile": publicSearchConfig(profile, true),
|
||
})
|
||
}
|
||
|
||
func listConversationsHandler(c *gin.Context) {
|
||
convs, err := store.List()
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, convs)
|
||
}
|
||
|
||
func createConversationHandler(c *gin.Context) {
|
||
conv, err := store.Create()
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建对话失败: " + err.Error()})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, conv)
|
||
}
|
||
|
||
func getConversationHandler(c *gin.Context) {
|
||
conv, err := store.Get(c.Param("id"))
|
||
if err != nil {
|
||
status := http.StatusInternalServerError
|
||
if err.Error() == "对话不存在" {
|
||
status = http.StatusNotFound
|
||
}
|
||
c.JSON(status, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
c.JSON(http.StatusOK, conv)
|
||
}
|
||
|
||
func deleteConversationHandler(c *gin.Context) {
|
||
if err := store.Delete(c.Param("id")); err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
c.Status(http.StatusNoContent)
|
||
}
|
||
|
||
// chatHandler 流式 SSE 对话接口
|
||
func chatHandler(c *gin.Context) {
|
||
var req ChatRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
|
||
return
|
||
}
|
||
if len(req.Messages) == 0 {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "消息不能为空"})
|
||
return
|
||
}
|
||
profile, err := aiState.GetProfile(req.OpenAIName)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
chatMessages := req.Messages
|
||
if req.WebSearch {
|
||
withSearch, err := enrichMessagesWithSearch(c.Request.Context(), req.Messages)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
chatMessages = withSearch
|
||
}
|
||
|
||
// 构建 ark 消息列表
|
||
messages, err := buildArkMessages(chatMessages)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// SSE 头
|
||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||
c.Writer.Header().Set("Connection", "keep-alive")
|
||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||
c.Writer.WriteHeader(http.StatusOK)
|
||
flusher, ok := c.Writer.(http.Flusher)
|
||
if !ok {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器不支持流式响应"})
|
||
return
|
||
}
|
||
|
||
// 超时 context
|
||
timeout := time.Duration(profile.Config.Timeout) * time.Second
|
||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||
defer cancel()
|
||
|
||
// 发起流式请求(使用 CreateChatCompletionStream)
|
||
stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
|
||
Model: profile.Config.Model,
|
||
Messages: messages,
|
||
MaxTokens: intPtr(4096),
|
||
}.WithStream(true))
|
||
if err != nil {
|
||
fmt.Fprintf(c.Writer, "data: {\"error\":%s}\n\n", toJSON(err.Error()))
|
||
flusher.Flush()
|
||
return
|
||
}
|
||
defer stream.Close()
|
||
|
||
var full strings.Builder
|
||
for {
|
||
resp, err := stream.Recv()
|
||
if errors.Is(err, io.EOF) {
|
||
if req.ConversationID != "" {
|
||
if err := saveConversationMessages(req.ConversationID, req.Messages, full.String()); err != nil {
|
||
fmt.Fprintln(os.Stderr, "保存对话失败:", err)
|
||
}
|
||
}
|
||
fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||
flusher.Flush()
|
||
return
|
||
}
|
||
if err != nil {
|
||
fmt.Fprintf(c.Writer, "data: {\"error\":%s}\n\n", toJSON(err.Error()))
|
||
flusher.Flush()
|
||
return
|
||
}
|
||
if len(resp.Choices) > 0 {
|
||
delta := resp.Choices[0].Delta.Content
|
||
if delta != "" {
|
||
full.WriteString(delta)
|
||
fmt.Fprintf(c.Writer, "data: %s\n\n", toSSE(delta))
|
||
flusher.Flush()
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// ─── 辅助函数 ─────────────────────────────────────────────
|
||
|
||
type searchResult struct {
|
||
Title string `json:"title"`
|
||
URL string `json:"url"`
|
||
Description string `json:"description"`
|
||
}
|
||
|
||
type braveSearchResponse struct {
|
||
Web struct {
|
||
Results []searchResult `json:"results"`
|
||
} `json:"web"`
|
||
}
|
||
|
||
type duckDuckGoResponse struct {
|
||
Abstract string `json:"Abstract"`
|
||
AbstractSource string `json:"AbstractSource"`
|
||
AbstractURL string `json:"AbstractURL"`
|
||
Heading string `json:"Heading"`
|
||
RelatedTopics []struct {
|
||
Text string `json:"Text"`
|
||
FirstURL string `json:"FirstURL"`
|
||
Topics []struct {
|
||
Text string `json:"Text"`
|
||
FirstURL string `json:"FirstURL"`
|
||
} `json:"Topics"`
|
||
} `json:"RelatedTopics"`
|
||
Infobox struct {
|
||
Content []struct {
|
||
Label string `json:"label"`
|
||
Value string `json:"value"`
|
||
} `json:"content"`
|
||
} `json:"Infobox"`
|
||
}
|
||
|
||
func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage) ([]ChatMessage, error) {
|
||
searchConfig := searchState.ActiveProfile()
|
||
if !searchConfig.Enabled {
|
||
return nil, errors.New("联网搜索未启用,请先在 config.yaml 中配置 search.enabled")
|
||
}
|
||
|
||
query := latestUserQuery(messages)
|
||
if query == "" {
|
||
return nil, errors.New("联网搜索需要输入文本问题")
|
||
}
|
||
|
||
results, err := webSearch(ctx, searchConfig, query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(results) == 0 {
|
||
return nil, errors.New("未搜索到相关网页结果")
|
||
}
|
||
|
||
searchContext := buildSearchContext(searchConfig, query, results)
|
||
withSearch := make([]ChatMessage, 0, len(messages)+1)
|
||
withSearch = append(withSearch, ChatMessage{Role: "system", Content: searchContext, Hidden: true})
|
||
withSearch = append(withSearch, messages...)
|
||
return withSearch, nil
|
||
}
|
||
|
||
func latestUserQuery(messages []ChatMessage) string {
|
||
for i := len(messages) - 1; i >= 0; i-- {
|
||
if messages[i].Role == "user" {
|
||
return strings.TrimSpace(messages[i].Content)
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func webSearch(ctx context.Context, config SearchConfig, query string) ([]searchResult, error) {
|
||
switch strings.ToLower(config.Provider) {
|
||
case "duckduckgo", "ddg":
|
||
return duckDuckGoSearch(ctx, config, query)
|
||
case "brave":
|
||
return braveWebSearch(ctx, config, query)
|
||
default:
|
||
return nil, fmt.Errorf("暂不支持搜索服务: %s", config.Provider)
|
||
}
|
||
}
|
||
|
||
func duckDuckGoSearch(ctx context.Context, config SearchConfig, query string) ([]searchResult, error) {
|
||
searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second)
|
||
defer cancel()
|
||
|
||
u, err := url.Parse(config.BaseURL)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("搜索服务地址无效: %w", err)
|
||
}
|
||
q := u.Query()
|
||
q.Set("q", query)
|
||
q.Set("format", "json")
|
||
q.Set("no_html", "1")
|
||
q.Set("skip_disambig", "1")
|
||
u.RawQuery = q.Encode()
|
||
|
||
req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建搜索请求失败: %w", err)
|
||
}
|
||
req.Header.Set("Accept", "application/json")
|
||
req.Header.Set("User-Agent", "aichat/1.0")
|
||
|
||
resp, err := http.DefaultClient.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("联网搜索失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取搜索响应失败: %w", err)
|
||
}
|
||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||
return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||
}
|
||
|
||
var parsed duckDuckGoResponse
|
||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||
return nil, fmt.Errorf("解析搜索响应失败: %w", err)
|
||
}
|
||
|
||
limit := config.Count
|
||
if limit <= 0 {
|
||
limit = defaultSearchCount
|
||
}
|
||
results := make([]searchResult, 0, limit)
|
||
if strings.TrimSpace(parsed.Abstract) != "" {
|
||
title := strings.TrimSpace(parsed.Heading)
|
||
if title == "" {
|
||
title = strings.TrimSpace(parsed.AbstractSource)
|
||
}
|
||
if title == "" {
|
||
title = "DuckDuckGo 摘要"
|
||
}
|
||
results = append(results, searchResult{Title: title, URL: strings.TrimSpace(parsed.AbstractURL), Description: strings.TrimSpace(parsed.Abstract)})
|
||
}
|
||
for _, item := range parsed.Infobox.Content {
|
||
if len(results) >= limit {
|
||
break
|
||
}
|
||
if strings.TrimSpace(item.Label) == "" || strings.TrimSpace(item.Value) == "" {
|
||
continue
|
||
}
|
||
results = append(results, searchResult{Title: item.Label, Description: item.Value})
|
||
}
|
||
appendRelated := func(text, firstURL string) {
|
||
if len(results) >= limit || strings.TrimSpace(text) == "" {
|
||
return
|
||
}
|
||
title, desc := splitDuckDuckGoText(text)
|
||
results = append(results, searchResult{Title: title, URL: strings.TrimSpace(firstURL), Description: desc})
|
||
}
|
||
for _, topic := range parsed.RelatedTopics {
|
||
appendRelated(topic.Text, topic.FirstURL)
|
||
for _, nested := range topic.Topics {
|
||
appendRelated(nested.Text, nested.FirstURL)
|
||
}
|
||
if len(results) >= limit {
|
||
break
|
||
}
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
func splitDuckDuckGoText(text string) (string, string) {
|
||
text = strings.TrimSpace(text)
|
||
parts := strings.SplitN(text, " - ", 2)
|
||
if len(parts) == 2 {
|
||
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
|
||
}
|
||
runes := []rune(text)
|
||
if len(runes) > 42 {
|
||
return string(runes[:42]) + "...", text
|
||
}
|
||
return text, text
|
||
}
|
||
|
||
func braveWebSearch(ctx context.Context, config SearchConfig, query string) ([]searchResult, error) {
|
||
if config.APIKey == "" {
|
||
return nil, errors.New("Brave 搜索未配置 API Key,请设置 search.api_key 或环境变量 BRAVE_SEARCH_API_KEY")
|
||
}
|
||
searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second)
|
||
defer cancel()
|
||
|
||
u, err := url.Parse(config.BaseURL)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("搜索服务地址无效: %w", err)
|
||
}
|
||
q := u.Query()
|
||
q.Set("q", query)
|
||
q.Set("count", fmt.Sprintf("%d", config.Count))
|
||
q.Set("search_lang", "zh-hans")
|
||
q.Set("country", "CN")
|
||
u.RawQuery = q.Encode()
|
||
|
||
req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建搜索请求失败: %w", err)
|
||
}
|
||
req.Header.Set("Accept", "application/json")
|
||
req.Header.Set("X-Subscription-Token", config.APIKey)
|
||
|
||
resp, err := http.DefaultClient.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("联网搜索失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取搜索响应失败: %w", err)
|
||
}
|
||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||
return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||
}
|
||
|
||
var parsed braveSearchResponse
|
||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||
return nil, fmt.Errorf("解析搜索响应失败: %w", err)
|
||
}
|
||
return parsed.Web.Results, nil
|
||
}
|
||
|
||
func buildSearchContext(config SearchConfig, query string, results []searchResult) string {
|
||
var b strings.Builder
|
||
fmt.Fprintf(&b, "用户开启了联网搜索。当前搜索源: %s(%s)。请优先根据以下搜索结果回答,并在合适位置标注来源链接。\n", config.Name, config.Provider)
|
||
if strings.ToLower(config.Provider) == "duckduckgo" || strings.ToLower(config.Provider) == "ddg" {
|
||
fmt.Fprintln(&b, "注意:DuckDuckGo 即时答案不是全量网页搜索,结果可能较少。")
|
||
}
|
||
fmt.Fprintf(&b, "搜索时间: %s\n", time.Now().Format("2006-01-02 15:04:05"))
|
||
fmt.Fprintf(&b, "搜索词: %s\n\n", query)
|
||
fmt.Fprintln(&b, "搜索结果:")
|
||
for i, r := range results {
|
||
fmt.Fprintf(&b, "%d. 标题: %s\n", i+1, strings.TrimSpace(r.Title))
|
||
if strings.TrimSpace(r.URL) != "" {
|
||
fmt.Fprintf(&b, " 链接: %s\n", strings.TrimSpace(r.URL))
|
||
}
|
||
if strings.TrimSpace(r.Description) != "" {
|
||
fmt.Fprintf(&b, " 摘要: %s\n", strings.TrimSpace(r.Description))
|
||
}
|
||
}
|
||
fmt.Fprintln(&b, "\n如果搜索结果不足以回答,请明确说明不确定,不要编造。")
|
||
return b.String()
|
||
}
|
||
|
||
func newUUID() string {
|
||
b := make([]byte, 16)
|
||
_, _ = rand.Read(b)
|
||
b[6] = (b[6] & 0x0f) | 0x40
|
||
b[8] = (b[8] & 0x3f) | 0x80
|
||
return hex.EncodeToString(b[:4]) + "-" + hex.EncodeToString(b[4:6]) + "-" +
|
||
hex.EncodeToString(b[6:8]) + "-" + hex.EncodeToString(b[8:10]) + "-" +
|
||
hex.EncodeToString(b[10:])
|
||
}
|
||
|
||
// ─── ConvStore ─────────────────────────────────────────────
|
||
|
||
func NewConvStore(dir string) *ConvStore {
|
||
os.MkdirAll(dir, 0755)
|
||
return &ConvStore{dir: dir}
|
||
}
|
||
|
||
func (s *ConvStore) path(id string) string {
|
||
return filepath.Join(s.dir, id+".json")
|
||
}
|
||
|
||
func (s *ConvStore) Create() (*Conversation, error) {
|
||
conv := &Conversation{
|
||
ID: newUUID(),
|
||
Title: "新对话",
|
||
CreatedAt: time.Now(),
|
||
UpdatedAt: time.Now(),
|
||
}
|
||
if err := s.Save(conv); err != nil {
|
||
return nil, err
|
||
}
|
||
return conv, nil
|
||
}
|
||
|
||
func (s *ConvStore) Save(conv *Conversation) error {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
conv.UpdatedAt = time.Now()
|
||
return atomicWriteJSON(s.path(conv.ID), conv)
|
||
}
|
||
|
||
func (s *ConvStore) Get(id string) (*Conversation, error) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
data, err := os.ReadFile(s.path(id))
|
||
if err != nil {
|
||
if os.IsNotExist(err) {
|
||
return nil, errors.New("对话不存在")
|
||
}
|
||
return nil, fmt.Errorf("读取对话失败: %w", err)
|
||
}
|
||
var conv Conversation
|
||
if err := json.Unmarshal(data, &conv); err != nil {
|
||
return nil, fmt.Errorf("解析对话失败: %w", err)
|
||
}
|
||
return &conv, nil
|
||
}
|
||
|
||
func (s *ConvStore) List() ([]Conversation, error) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
entries, err := os.ReadDir(s.dir)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取对话目录失败: %w", err)
|
||
}
|
||
|
||
var list []Conversation
|
||
for _, e := range entries {
|
||
if e.IsDir() || filepath.Ext(e.Name()) != ".json" {
|
||
continue
|
||
}
|
||
data, err := os.ReadFile(filepath.Join(s.dir, e.Name()))
|
||
if err != nil {
|
||
continue
|
||
}
|
||
var conv Conversation
|
||
if err := json.Unmarshal(data, &conv); err != nil {
|
||
continue
|
||
}
|
||
conv.Messages = nil // 列表不返回消息体
|
||
list = append(list, conv)
|
||
}
|
||
|
||
sort.Slice(list, func(i, j int) bool {
|
||
return list[i].UpdatedAt.After(list[j].UpdatedAt)
|
||
})
|
||
return list, nil
|
||
}
|
||
|
||
func (s *ConvStore) Delete(id string) error {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
if err := os.Remove(s.path(id)); err != nil && !os.IsNotExist(err) {
|
||
return fmt.Errorf("删除对话失败: %w", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func atomicWriteJSON(path string, v any) error {
|
||
tmp := path + ".tmp"
|
||
data, err := json.Marshal(v)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if err := os.WriteFile(tmp, data, 0644); err != nil {
|
||
return err
|
||
}
|
||
return os.Rename(tmp, path)
|
||
}
|
||
|
||
func saveConversationMessages(id string, messages []ChatMessage, assistantContent string) error {
|
||
conv, err := store.Get(id)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
conv.Messages = append([]ChatMessage(nil), messages...)
|
||
conv.Messages = append(conv.Messages, ChatMessage{Role: "assistant", Content: assistantContent})
|
||
if conv.Title == "" || conv.Title == "新对话" {
|
||
conv.Title = genConvTitle(conv.Messages)
|
||
}
|
||
return store.Save(conv)
|
||
}
|
||
|
||
func genConvTitle(messages []ChatMessage) string {
|
||
for _, m := range messages {
|
||
if m.Hidden {
|
||
continue
|
||
}
|
||
if m.Role == "user" && strings.TrimSpace(m.Content) != "" {
|
||
title := strings.TrimSpace(m.Content)
|
||
title = strings.ReplaceAll(title, "\r\n", " ")
|
||
title = strings.ReplaceAll(title, "\n", " ")
|
||
runes := []rune(title)
|
||
if len(runes) > 30 {
|
||
return string(runes[:30]) + "..."
|
||
}
|
||
return title
|
||
}
|
||
}
|
||
return "新对话"
|
||
}
|
||
|
||
const maxImageSize = 4 * 1024 * 1024
|
||
|
||
var allowedImageTypes = map[string]bool{
|
||
"image/jpeg": true,
|
||
"image/png": true,
|
||
"image/webp": true,
|
||
"image/gif": true,
|
||
}
|
||
|
||
func buildArkMessages(chatMessages []ChatMessage) ([]*model.ChatCompletionMessage, error) {
|
||
messages := make([]*model.ChatCompletionMessage, 0, len(chatMessages))
|
||
for _, m := range chatMessages {
|
||
msg, err := buildArkMessage(m)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
messages = append(messages, msg)
|
||
}
|
||
return messages, nil
|
||
}
|
||
|
||
func buildArkMessage(m ChatMessage) (*model.ChatCompletionMessage, error) {
|
||
msg := &model.ChatCompletionMessage{Role: m.Role}
|
||
|
||
if m.ImageURL == "" && m.ImageURLAlias != "" {
|
||
m.ImageURL = m.ImageURLAlias
|
||
}
|
||
|
||
if m.ImageURL == "" {
|
||
msg.Content = &model.ChatCompletionMessageContent{
|
||
StringValue: &m.Content,
|
||
}
|
||
return msg, nil
|
||
}
|
||
|
||
imageURL, err := normalizeImageURL(m.ImageURL)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 有图片时:文字内容可有可无(图片 caption 场景),均构造多模态消息
|
||
// 若无文字,则只传图片 part;若同时有图片和文字,先图后文
|
||
parts := []*model.ChatCompletionMessageContentPart{imagePart(imageURL)}
|
||
if m.Content != "" {
|
||
parts = append(parts, textPart(m.Content))
|
||
}
|
||
msg.Content = &model.ChatCompletionMessageContent{ListValue: parts}
|
||
return msg, nil
|
||
}
|
||
|
||
func imagePart(url string) *model.ChatCompletionMessageContentPart {
|
||
return &model.ChatCompletionMessageContentPart{
|
||
Type: model.ChatCompletionMessageContentPartTypeImageURL,
|
||
ImageURL: &model.ChatMessageImageURL{
|
||
URL: url,
|
||
Detail: model.ImageURLDetailAuto,
|
||
},
|
||
}
|
||
}
|
||
|
||
func textPart(text string) *model.ChatCompletionMessageContentPart {
|
||
return &model.ChatCompletionMessageContentPart{
|
||
Type: model.ChatCompletionMessageContentPartTypeText,
|
||
Text: text,
|
||
}
|
||
}
|
||
|
||
func normalizeImageURL(raw string) (string, error) {
|
||
raw = strings.TrimSpace(raw)
|
||
if raw == "" {
|
||
return "", errors.New("图片地址不能为空")
|
||
}
|
||
|
||
lower := strings.ToLower(raw)
|
||
if strings.HasPrefix(lower, "data:") {
|
||
return normalizeImageDataURI(raw)
|
||
}
|
||
|
||
u, err := url.Parse(raw)
|
||
if err != nil || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") {
|
||
return "", errors.New("图片地址无效,仅支持 http/https URL 或 base64 data URI")
|
||
}
|
||
return raw, nil
|
||
}
|
||
|
||
func normalizeImageDataURI(raw string) (string, error) {
|
||
comma := strings.Index(raw, ",")
|
||
if comma < 0 {
|
||
return "", errors.New("图片 base64 数据格式错误")
|
||
}
|
||
|
||
meta := strings.ToLower(strings.TrimSpace(raw[5:comma]))
|
||
payload := strings.TrimSpace(raw[comma+1:])
|
||
if payload == "" {
|
||
return "", errors.New("图片 base64 数据不能为空")
|
||
}
|
||
parts := strings.Split(meta, ";")
|
||
if len(parts) < 2 || !contains(parts[1:], "base64") {
|
||
return "", errors.New("图片 data URI 必须使用 base64 编码")
|
||
}
|
||
|
||
mime := parts[0]
|
||
if !allowedImageTypes[mime] {
|
||
return "", errors.New("图片格式不支持,仅支持 jpeg/png/webp/gif")
|
||
}
|
||
|
||
decoded, err := base64.StdEncoding.DecodeString(payload)
|
||
if err != nil {
|
||
return "", errors.New("图片 base64 数据无效")
|
||
}
|
||
if len(decoded) > maxImageSize {
|
||
return "", errors.New("图片过大,请选择小于 4MB 的图片")
|
||
}
|
||
|
||
return "data:" + mime + ";base64," + payload, nil
|
||
}
|
||
|
||
func contains(items []string, target string) bool {
|
||
for _, item := range items {
|
||
if strings.TrimSpace(item) == target {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
func intPtr(i int) *int { return &i }
|
||
|
||
func toJSON(s string) string {
|
||
b, _ := json.Marshal(s)
|
||
return string(b)
|
||
}
|
||
|
||
func toSSE(s string) string {
|
||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||
s = strings.ReplaceAll(s, "\n", `\n`)
|
||
s = strings.ReplaceAll(s, "\r", "")
|
||
s = strings.ReplaceAll(s, `"`, `\"`)
|
||
return fmt.Sprintf(`"%s"`, s)
|
||
}
|
||
|
||
// ─── 入口 ─────────────────────────────────────────────────
|
||
|
||
func main() {
|
||
var err error
|
||
cfg, err = loadConfig("config.yaml")
|
||
if err != nil {
|
||
fmt.Fprintln(os.Stderr, "配置加载失败:", err)
|
||
os.Exit(1)
|
||
}
|
||
|
||
// 初始化火山方舟 SDK 客户端
|
||
aiState, err = NewOpenAIState(cfg.OpenAI)
|
||
if err != nil {
|
||
fmt.Fprintln(os.Stderr, "OpenAI 配置初始化失败:", err)
|
||
os.Exit(1)
|
||
}
|
||
searchState, err = NewSearchState(cfg.Search)
|
||
if err != nil {
|
||
fmt.Fprintln(os.Stderr, "搜索配置初始化失败:", err)
|
||
os.Exit(1)
|
||
}
|
||
store = NewConvStore("conversations")
|
||
|
||
// Gin 路由
|
||
r := gin.Default()
|
||
r.LoadHTMLGlob("templates/*")
|
||
r.Static("/static", "./static")
|
||
|
||
r.GET("/", indexHandler)
|
||
r.POST("/api/chat", chatHandler)
|
||
r.GET("/api/openai", listOpenAIHandler)
|
||
r.POST("/api/openai/active", switchOpenAIHandler)
|
||
r.GET("/api/search", listSearchHandler)
|
||
r.POST("/api/search/active", switchSearchHandler)
|
||
r.GET("/api/conversations", listConversationsHandler)
|
||
r.POST("/api/conversations", createConversationHandler)
|
||
r.GET("/api/conversations/:id", getConversationHandler)
|
||
r.DELETE("/api/conversations/:id", deleteConversationHandler)
|
||
|
||
// 根据配置选择监听方式
|
||
switch strings.ToLower(cfg.Server.Mode) {
|
||
case "unix":
|
||
socketPath := cfg.Server.Address
|
||
if _, statErr := os.Stat(socketPath); statErr == nil {
|
||
os.Remove(socketPath)
|
||
}
|
||
ln, listenErr := net.Listen("unix", socketPath)
|
||
if listenErr != nil {
|
||
fmt.Fprintln(os.Stderr, "监听 Unix socket 失败:", listenErr)
|
||
os.Exit(1)
|
||
}
|
||
fmt.Println("服务已启动,监听 Unix socket:", socketPath)
|
||
if serveErr := http.Serve(ln, r); serveErr != nil {
|
||
fmt.Fprintln(os.Stderr, "服务异常退出:", serveErr)
|
||
os.Exit(1)
|
||
}
|
||
default:
|
||
fmt.Println("服务已启动,监听 TCP:", cfg.Server.Address)
|
||
if runErr := r.Run(cfg.Server.Address); runErr != nil {
|
||
fmt.Fprintln(os.Stderr, "服务异常退出:", runErr)
|
||
os.Exit(1)
|
||
}
|
||
}
|
||
}
|