Files
aichat/main.go
T
2026-06-09 13:48:14 +08:00

1409 lines
38 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
ark "github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"gopkg.in/yaml.v3"
)
// ─── 配置 ─────────────────────────────────────────────────
const (
defaultOpenAIBaseURL = "https://ark.cn-beijing.volces.com/api/v3"
defaultOpenAITimeout = 120
defaultSearchBaseURL = "https://api.duckduckgo.com/"
defaultSearchTimeout = 10
defaultSearchCount = 5
)
type OpenAIConfig struct {
Name string `yaml:"name" json:"name"`
Active bool `yaml:"active,omitempty" json:"active"`
APIKey string `yaml:"api_key" json:"-"`
BaseURL string `yaml:"base_url" json:"base_url"`
Model string `yaml:"model" json:"model"`
Timeout int `yaml:"timeout" json:"timeout"`
}
type OpenAIConfigs []OpenAIConfig
type SearchConfig struct {
Name string `yaml:"name" json:"name"`
Active bool `yaml:"active,omitempty" json:"active"`
Enabled bool `yaml:"enabled" json:"enabled"`
Provider string `yaml:"provider" json:"provider"`
APIKey string `yaml:"api_key" json:"-"`
BaseURL string `yaml:"base_url" json:"base_url"`
Count int `yaml:"count" json:"count"`
Timeout int `yaml:"timeout" json:"timeout"`
}
type SearchConfigs []SearchConfig
func (configs *SearchConfigs) UnmarshalYAML(value *yaml.Node) error {
switch value.Kind {
case yaml.SequenceNode:
var items []SearchConfig
if err := value.Decode(&items); err != nil {
return err
}
*configs = items
case yaml.MappingNode:
var item SearchConfig
if err := value.Decode(&item); err != nil {
return err
}
*configs = []SearchConfig{item}
case yaml.ScalarNode:
if value.Tag == "!!null" {
*configs = nil
return nil
}
return fmt.Errorf("search 配置格式无效")
default:
return fmt.Errorf("search 配置格式无效")
}
return nil
}
func (configs *OpenAIConfigs) UnmarshalYAML(value *yaml.Node) error {
switch value.Kind {
case yaml.SequenceNode:
var items []OpenAIConfig
if err := value.Decode(&items); err != nil {
return err
}
*configs = items
case yaml.MappingNode:
var item OpenAIConfig
if err := value.Decode(&item); err != nil {
return err
}
*configs = []OpenAIConfig{item}
case yaml.ScalarNode:
if value.Tag == "!!null" {
*configs = nil
return nil
}
return fmt.Errorf("openai 配置格式无效")
default:
return fmt.Errorf("openai 配置格式无效")
}
return nil
}
type Config struct {
Server struct {
Mode string `yaml:"mode"`
Address string `yaml:"address"`
} `yaml:"server"`
OpenAI OpenAIConfigs `yaml:"openai"`
Search SearchConfigs `yaml:"search"`
}
func defaultOpenAIConfig() OpenAIConfig {
return OpenAIConfig{
Name: "default",
Active: true,
BaseURL: defaultOpenAIBaseURL,
Timeout: defaultOpenAITimeout,
}
}
func defaultSearchConfig() SearchConfig {
return SearchConfig{
Name: "duckduckgo",
Active: true,
Enabled: true,
Provider: "duckduckgo",
BaseURL: defaultSearchBaseURL,
Count: defaultSearchCount,
Timeout: defaultSearchTimeout,
}
}
func defaultConfig() Config {
var cfg Config
cfg.Server.Mode = "tcp"
cfg.Server.Address = "0.0.0.0:8080"
cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()}
cfg.Search = SearchConfigs{defaultSearchConfig()}
return cfg
}
func loadConfig(path string) (*Config, error) {
if err := ensureConfigFile(path); err != nil {
return nil, err
}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取配置文件失败: %w", err)
}
var cfg Config
if err = yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("解析配置文件失败: %w", err)
}
if _, err := normalizeOpenAIConfigs(&cfg); err != nil {
return nil, err
}
// 环境变量优先
if key := os.Getenv("ARK_API_KEY"); key != "" {
for i := range cfg.OpenAI {
cfg.OpenAI[i].APIKey = key
}
}
if _, err := normalizeSearchConfigs(&cfg); err != nil {
return nil, err
}
if key := os.Getenv("BRAVE_SEARCH_API_KEY"); key != "" {
for i := range cfg.Search {
if strings.ToLower(cfg.Search[i].Provider) == "brave" {
cfg.Search[i].APIKey = key
}
}
}
return &cfg, nil
}
func ensureConfigFile(path string) error {
defaults := defaultConfig()
if _, err := os.Stat(path); err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("检查配置文件失败: %w", err)
}
return writeConfig(path, defaults)
}
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("读取配置文件失败: %w", err)
}
var cfg Config
if err = yaml.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("解析配置文件失败: %w", err)
}
var raw map[string]any
if err = yaml.Unmarshal(data, &raw); err != nil {
return fmt.Errorf("解析配置文件失败: %w", err)
}
changed := false
server, _ := raw["server"].(map[string]any)
if server == nil {
cfg.Server = defaults.Server
changed = true
} else {
if _, ok := server["mode"]; !ok {
cfg.Server.Mode = defaults.Server.Mode
changed = true
}
if _, ok := server["address"]; !ok {
cfg.Server.Address = defaults.Server.Address
changed = true
}
}
if _, ok := raw["openai"].([]any); !ok {
changed = true
}
if normalized, err := normalizeOpenAIConfigs(&cfg); err != nil {
return err
} else if normalized {
changed = true
}
if _, ok := raw["search"].([]any); !ok {
changed = true
}
if normalized, err := normalizeSearchConfigs(&cfg); err != nil {
return err
} else if normalized {
changed = true
}
if !changed {
return nil
}
return writeConfig(path, cfg)
}
func normalizeOpenAIConfigs(cfg *Config) (bool, error) {
changed := false
if len(cfg.OpenAI) == 0 {
cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()}
changed = true
}
activeIndex := -1
seen := map[string]bool{}
for i := range cfg.OpenAI {
profile := &cfg.OpenAI[i]
name := strings.TrimSpace(profile.Name)
if name == "" {
name = strings.TrimSpace(profile.Model)
if name == "" {
name = fmt.Sprintf("openai-%d", i+1)
}
profile.Name = name
changed = true
} else if name != profile.Name {
profile.Name = name
changed = true
}
if seen[name] {
return changed, fmt.Errorf("openai 配置名称重复: %s", name)
}
seen[name] = true
if strings.TrimSpace(profile.BaseURL) == "" {
profile.BaseURL = defaultOpenAIBaseURL
changed = true
}
if profile.Timeout <= 0 {
profile.Timeout = defaultOpenAITimeout
changed = true
}
if profile.Active {
if activeIndex == -1 {
activeIndex = i
} else {
profile.Active = false
changed = true
}
}
}
if activeIndex == -1 {
cfg.OpenAI[0].Active = true
changed = true
}
return changed, nil
}
func normalizeSearchConfigs(cfg *Config) (bool, error) {
changed := false
if len(cfg.Search) == 0 {
cfg.Search = SearchConfigs{defaultSearchConfig()}
changed = true
}
activeIndex := -1
seen := map[string]bool{}
for i := range cfg.Search {
profile := &cfg.Search[i]
profile.Provider = strings.ToLower(strings.TrimSpace(profile.Provider))
if profile.Provider == "" {
profile.Provider = "duckduckgo"
changed = true
}
name := strings.TrimSpace(profile.Name)
if name == "" {
name = profile.Provider
if seen[name] {
name = fmt.Sprintf("%s-%d", profile.Provider, i+1)
}
profile.Name = name
changed = true
} else if name != profile.Name {
profile.Name = name
changed = true
}
if seen[name] {
return changed, fmt.Errorf("search 配置名称重复: %s", name)
}
seen[name] = true
if strings.TrimSpace(profile.BaseURL) == "" {
switch profile.Provider {
case "duckduckgo", "ddg":
profile.BaseURL = defaultSearchBaseURL
case "brave":
profile.BaseURL = "https://api.search.brave.com/res/v1/web/search"
default:
return changed, fmt.Errorf("暂不支持搜索服务: %s", profile.Provider)
}
changed = true
}
if profile.Count <= 0 {
profile.Count = defaultSearchCount
changed = true
}
if profile.Timeout <= 0 {
profile.Timeout = defaultSearchTimeout
changed = true
}
if profile.Active {
if activeIndex == -1 {
activeIndex = i
} else {
profile.Active = false
changed = true
}
}
}
if activeIndex == -1 {
cfg.Search[0].Active = true
changed = true
}
return changed, nil
}
func writeConfig(path string, cfg Config) error {
data, err := yaml.Marshal(&cfg)
if err != nil {
return fmt.Errorf("生成配置文件失败: %w", err)
}
if err := os.WriteFile(path, data, 0644); err != nil {
return fmt.Errorf("写入配置文件失败: %w", err)
}
return nil
}
// ─── 请求结构 ─────────────────────────────────────────────
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ImageURL string `json:"image_url,omitempty"` // base64 data URI 或 http URL
ImageURLAlias string `json:"imageURL,omitempty"`
Hidden bool `json:"hidden,omitempty"`
}
type ChatRequest struct {
ConversationID string `json:"conversation_id,omitempty"`
Messages []ChatMessage `json:"messages"`
WebSearch bool `json:"web_search,omitempty"`
OpenAIName string `json:"openai_name,omitempty"`
}
type Conversation struct {
ID string `json:"id"`
Title string `json:"title"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Messages []ChatMessage `json:"messages,omitempty"`
}
type ConvStore struct {
dir string
mu sync.Mutex
}
type OpenAIProfile struct {
Config OpenAIConfig
Client *ark.Client
}
type OpenAIState struct {
mu sync.RWMutex
profiles map[string]*OpenAIProfile
order []string
activeName string
}
type activeProfileRequest struct {
Name string `json:"name"`
}
type openAIListResponse struct {
Active string `json:"active"`
Profiles []OpenAIConfig `json:"profiles"`
}
type searchListResponse struct {
Active string `json:"active"`
Profiles []SearchConfig `json:"profiles"`
}
func NewOpenAIState(configs []OpenAIConfig) (*OpenAIState, error) {
state := &OpenAIState{
profiles: make(map[string]*OpenAIProfile, len(configs)),
order: make([]string, 0, len(configs)),
}
for _, config := range configs {
if strings.TrimSpace(config.Name) == "" {
return nil, errors.New("openai.name 不能为空")
}
if strings.TrimSpace(config.APIKey) == "" {
return nil, fmt.Errorf("openai.%s.api_key 未配置,也未设置环境变量 ARK_API_KEY", config.Name)
}
if strings.TrimSpace(config.Model) == "" {
return nil, fmt.Errorf("openai.%s.model 未配置", config.Name)
}
if strings.TrimSpace(config.BaseURL) == "" {
return nil, fmt.Errorf("openai.%s.base_url 未配置", config.Name)
}
if config.Timeout <= 0 {
return nil, fmt.Errorf("openai.%s.timeout 必须大于 0", config.Name)
}
if _, ok := state.profiles[config.Name]; ok {
return nil, fmt.Errorf("openai 配置名称重复: %s", config.Name)
}
state.profiles[config.Name] = &OpenAIProfile{
Config: config,
Client: ark.NewClientWithApiKey(
config.APIKey,
ark.WithBaseUrl(config.BaseURL),
ark.WithTimeout(time.Duration(config.Timeout)*time.Second),
),
}
state.order = append(state.order, config.Name)
if config.Active && state.activeName == "" {
state.activeName = config.Name
}
}
if len(state.order) == 0 {
return nil, errors.New("openai 配置不能为空")
}
if state.activeName == "" {
state.activeName = state.order[0]
}
return state, nil
}
func (s *OpenAIState) ActiveProfile() *OpenAIProfile {
s.mu.RLock()
defer s.mu.RUnlock()
return s.profiles[s.activeName]
}
func (s *OpenAIState) GetProfile(name string) (*OpenAIProfile, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if strings.TrimSpace(name) == "" {
return s.profiles[s.activeName], nil
}
profile, ok := s.profiles[strings.TrimSpace(name)]
if !ok {
return nil, fmt.Errorf("OpenAI 配置不存在: %s", name)
}
return profile, nil
}
func (s *OpenAIState) SwitchActive(name string) (*OpenAIProfile, error) {
name = strings.TrimSpace(name)
if name == "" {
return nil, errors.New("OpenAI 配置名称不能为空")
}
s.mu.Lock()
defer s.mu.Unlock()
profile, ok := s.profiles[name]
if !ok {
return nil, fmt.Errorf("OpenAI 配置不存在: %s", name)
}
s.activeName = name
return profile, nil
}
func (s *OpenAIState) ListProfiles() openAIListResponse {
s.mu.RLock()
defer s.mu.RUnlock()
profiles := make([]OpenAIConfig, 0, len(s.order))
for _, name := range s.order {
profile := s.profiles[name]
config := profile.Config
config.APIKey = ""
config.Active = name == s.activeName
profiles = append(profiles, config)
}
return openAIListResponse{Active: s.activeName, Profiles: profiles}
}
func publicOpenAIConfig(profile *OpenAIProfile, active bool) OpenAIConfig {
config := profile.Config
config.APIKey = ""
config.Active = active
return config
}
type SearchState struct {
mu sync.RWMutex
profiles map[string]SearchConfig
order []string
activeName string
}
func NewSearchState(configs []SearchConfig) (*SearchState, error) {
state := &SearchState{
profiles: make(map[string]SearchConfig, len(configs)),
order: make([]string, 0, len(configs)),
}
for _, config := range configs {
if strings.TrimSpace(config.Name) == "" {
return nil, errors.New("search.name 不能为空")
}
if strings.TrimSpace(config.Provider) == "" {
return nil, fmt.Errorf("search.%s.provider 未配置", config.Name)
}
if strings.TrimSpace(config.BaseURL) == "" {
return nil, fmt.Errorf("search.%s.base_url 未配置", config.Name)
}
if config.Timeout <= 0 {
return nil, fmt.Errorf("search.%s.timeout 必须大于 0", config.Name)
}
if _, ok := state.profiles[config.Name]; ok {
return nil, fmt.Errorf("search 配置名称重复: %s", config.Name)
}
state.profiles[config.Name] = config
state.order = append(state.order, config.Name)
if config.Active && state.activeName == "" {
state.activeName = config.Name
}
}
if len(state.order) == 0 {
return nil, errors.New("search 配置不能为空")
}
if state.activeName == "" {
state.activeName = state.order[0]
}
return state, nil
}
func (s *SearchState) ActiveProfile() SearchConfig {
s.mu.RLock()
defer s.mu.RUnlock()
return s.profiles[s.activeName]
}
func (s *SearchState) SwitchActive(name string) (SearchConfig, error) {
name = strings.TrimSpace(name)
if name == "" {
return SearchConfig{}, errors.New("搜索配置名称不能为空")
}
s.mu.Lock()
defer s.mu.Unlock()
profile, ok := s.profiles[name]
if !ok {
return SearchConfig{}, fmt.Errorf("搜索配置不存在: %s", name)
}
s.activeName = name
return profile, nil
}
func (s *SearchState) ListProfiles() searchListResponse {
s.mu.RLock()
defer s.mu.RUnlock()
profiles := make([]SearchConfig, 0, len(s.order))
for _, name := range s.order {
profile := s.profiles[name]
profile.APIKey = ""
profile.Active = name == s.activeName
profiles = append(profiles, profile)
}
return searchListResponse{Active: s.activeName, Profiles: profiles}
}
func publicSearchConfig(config SearchConfig, active bool) SearchConfig {
config.APIKey = ""
config.Active = active
return config
}
// ─── 全局变量 ─────────────────────────────────────────────
var (
cfg *Config
aiState *OpenAIState
searchState *SearchState
store *ConvStore
)
// ─── 路由 ─────────────────────────────────────────────────
func indexHandler(c *gin.Context) {
profile := aiState.ActiveProfile()
c.HTML(http.StatusOK, "chat.html", gin.H{
"Title": "AI 对话",
"Model": profile.Config.Model,
"OpenAIName": profile.Config.Name,
})
}
func listOpenAIHandler(c *gin.Context) {
c.JSON(http.StatusOK, aiState.ListProfiles())
}
func switchOpenAIHandler(c *gin.Context) {
var req activeProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
return
}
profile, err := aiState.SwitchActive(req.Name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"active": profile.Config.Name,
"profile": publicOpenAIConfig(profile, true),
})
}
func listSearchHandler(c *gin.Context) {
c.JSON(http.StatusOK, searchState.ListProfiles())
}
func switchSearchHandler(c *gin.Context) {
var req activeProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
return
}
profile, err := searchState.SwitchActive(req.Name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"active": profile.Name,
"profile": publicSearchConfig(profile, true),
})
}
func listConversationsHandler(c *gin.Context) {
convs, err := store.List()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, convs)
}
func createConversationHandler(c *gin.Context) {
conv, err := store.Create()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建对话失败: " + err.Error()})
return
}
c.JSON(http.StatusOK, conv)
}
func getConversationHandler(c *gin.Context) {
conv, err := store.Get(c.Param("id"))
if err != nil {
status := http.StatusInternalServerError
if err.Error() == "对话不存在" {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, conv)
}
func deleteConversationHandler(c *gin.Context) {
if err := store.Delete(c.Param("id")); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Status(http.StatusNoContent)
}
// chatHandler 流式 SSE 对话接口
func chatHandler(c *gin.Context) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
return
}
if len(req.Messages) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "消息不能为空"})
return
}
profile, err := aiState.GetProfile(req.OpenAIName)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
chatMessages := req.Messages
if req.WebSearch {
withSearch, err := enrichMessagesWithSearch(c.Request.Context(), req.Messages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
chatMessages = withSearch
}
// 构建 ark 消息列表
messages, err := buildArkMessages(chatMessages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// SSE 头
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": "服务器不支持流式响应"})
return
}
// 超时 context
timeout := time.Duration(profile.Config.Timeout) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
// 发起流式请求(使用 CreateChatCompletionStream
stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
Model: profile.Config.Model,
Messages: messages,
MaxTokens: intPtr(4096),
}.WithStream(true))
if err != nil {
fmt.Fprintf(c.Writer, "data: {\"error\":%s}\n\n", toJSON(err.Error()))
flusher.Flush()
return
}
defer stream.Close()
var full strings.Builder
for {
resp, err := stream.Recv()
if errors.Is(err, io.EOF) {
if req.ConversationID != "" {
if err := saveConversationMessages(req.ConversationID, req.Messages, full.String()); err != nil {
fmt.Fprintln(os.Stderr, "保存对话失败:", err)
}
}
fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
return
}
if err != nil {
fmt.Fprintf(c.Writer, "data: {\"error\":%s}\n\n", toJSON(err.Error()))
flusher.Flush()
return
}
if len(resp.Choices) > 0 {
delta := resp.Choices[0].Delta.Content
if delta != "" {
full.WriteString(delta)
fmt.Fprintf(c.Writer, "data: %s\n\n", toSSE(delta))
flusher.Flush()
}
}
}
}
// ─── 辅助函数 ─────────────────────────────────────────────
type searchResult struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
}
type braveSearchResponse struct {
Web struct {
Results []searchResult `json:"results"`
} `json:"web"`
}
type duckDuckGoResponse struct {
Abstract string `json:"Abstract"`
AbstractSource string `json:"AbstractSource"`
AbstractURL string `json:"AbstractURL"`
Heading string `json:"Heading"`
RelatedTopics []struct {
Text string `json:"Text"`
FirstURL string `json:"FirstURL"`
Topics []struct {
Text string `json:"Text"`
FirstURL string `json:"FirstURL"`
} `json:"Topics"`
} `json:"RelatedTopics"`
Infobox struct {
Content []struct {
Label string `json:"label"`
Value string `json:"value"`
} `json:"content"`
} `json:"Infobox"`
}
func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage) ([]ChatMessage, error) {
searchConfig := searchState.ActiveProfile()
if !searchConfig.Enabled {
return nil, errors.New("联网搜索未启用,请先在 config.yaml 中配置 search.enabled")
}
query := latestUserQuery(messages)
if query == "" {
return nil, errors.New("联网搜索需要输入文本问题")
}
results, err := webSearch(ctx, searchConfig, query)
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, errors.New("未搜索到相关网页结果")
}
searchContext := buildSearchContext(searchConfig, query, results)
withSearch := make([]ChatMessage, 0, len(messages)+1)
withSearch = append(withSearch, ChatMessage{Role: "system", Content: searchContext, Hidden: true})
withSearch = append(withSearch, messages...)
return withSearch, nil
}
func latestUserQuery(messages []ChatMessage) string {
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
return strings.TrimSpace(messages[i].Content)
}
}
return ""
}
func webSearch(ctx context.Context, config SearchConfig, query string) ([]searchResult, error) {
switch strings.ToLower(config.Provider) {
case "duckduckgo", "ddg":
return duckDuckGoSearch(ctx, config, query)
case "brave":
return braveWebSearch(ctx, config, query)
default:
return nil, fmt.Errorf("暂不支持搜索服务: %s", config.Provider)
}
}
func duckDuckGoSearch(ctx context.Context, config SearchConfig, query string) ([]searchResult, error) {
searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second)
defer cancel()
u, err := url.Parse(config.BaseURL)
if err != nil {
return nil, fmt.Errorf("搜索服务地址无效: %w", err)
}
q := u.Query()
q.Set("q", query)
q.Set("format", "json")
q.Set("no_html", "1")
q.Set("skip_disambig", "1")
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, fmt.Errorf("创建搜索请求失败: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "aichat/1.0")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("联网搜索失败: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
if err != nil {
return nil, fmt.Errorf("读取搜索响应失败: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var parsed duckDuckGoResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("解析搜索响应失败: %w", err)
}
limit := config.Count
if limit <= 0 {
limit = defaultSearchCount
}
results := make([]searchResult, 0, limit)
if strings.TrimSpace(parsed.Abstract) != "" {
title := strings.TrimSpace(parsed.Heading)
if title == "" {
title = strings.TrimSpace(parsed.AbstractSource)
}
if title == "" {
title = "DuckDuckGo 摘要"
}
results = append(results, searchResult{Title: title, URL: strings.TrimSpace(parsed.AbstractURL), Description: strings.TrimSpace(parsed.Abstract)})
}
for _, item := range parsed.Infobox.Content {
if len(results) >= limit {
break
}
if strings.TrimSpace(item.Label) == "" || strings.TrimSpace(item.Value) == "" {
continue
}
results = append(results, searchResult{Title: item.Label, Description: item.Value})
}
appendRelated := func(text, firstURL string) {
if len(results) >= limit || strings.TrimSpace(text) == "" {
return
}
title, desc := splitDuckDuckGoText(text)
results = append(results, searchResult{Title: title, URL: strings.TrimSpace(firstURL), Description: desc})
}
for _, topic := range parsed.RelatedTopics {
appendRelated(topic.Text, topic.FirstURL)
for _, nested := range topic.Topics {
appendRelated(nested.Text, nested.FirstURL)
}
if len(results) >= limit {
break
}
}
return results, nil
}
func splitDuckDuckGoText(text string) (string, string) {
text = strings.TrimSpace(text)
parts := strings.SplitN(text, " - ", 2)
if len(parts) == 2 {
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
}
runes := []rune(text)
if len(runes) > 42 {
return string(runes[:42]) + "...", text
}
return text, text
}
func braveWebSearch(ctx context.Context, config SearchConfig, query string) ([]searchResult, error) {
if config.APIKey == "" {
return nil, errors.New("Brave 搜索未配置 API Key,请设置 search.api_key 或环境变量 BRAVE_SEARCH_API_KEY")
}
searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second)
defer cancel()
u, err := url.Parse(config.BaseURL)
if err != nil {
return nil, fmt.Errorf("搜索服务地址无效: %w", err)
}
q := u.Query()
q.Set("q", query)
q.Set("count", fmt.Sprintf("%d", config.Count))
q.Set("search_lang", "zh-hans")
q.Set("country", "CN")
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, fmt.Errorf("创建搜索请求失败: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", config.APIKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("联网搜索失败: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
if err != nil {
return nil, fmt.Errorf("读取搜索响应失败: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var parsed braveSearchResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("解析搜索响应失败: %w", err)
}
return parsed.Web.Results, nil
}
func buildSearchContext(config SearchConfig, query string, results []searchResult) string {
var b strings.Builder
fmt.Fprintf(&b, "用户开启了联网搜索。当前搜索源: %s(%s)。请优先根据以下搜索结果回答,并在合适位置标注来源链接。\n", config.Name, config.Provider)
if strings.ToLower(config.Provider) == "duckduckgo" || strings.ToLower(config.Provider) == "ddg" {
fmt.Fprintln(&b, "注意:DuckDuckGo 即时答案不是全量网页搜索,结果可能较少。")
}
fmt.Fprintf(&b, "搜索时间: %s\n", time.Now().Format("2006-01-02 15:04:05"))
fmt.Fprintf(&b, "搜索词: %s\n\n", query)
fmt.Fprintln(&b, "搜索结果:")
for i, r := range results {
fmt.Fprintf(&b, "%d. 标题: %s\n", i+1, strings.TrimSpace(r.Title))
if strings.TrimSpace(r.URL) != "" {
fmt.Fprintf(&b, " 链接: %s\n", strings.TrimSpace(r.URL))
}
if strings.TrimSpace(r.Description) != "" {
fmt.Fprintf(&b, " 摘要: %s\n", strings.TrimSpace(r.Description))
}
}
fmt.Fprintln(&b, "\n如果搜索结果不足以回答,请明确说明不确定,不要编造。")
return b.String()
}
func newUUID() string {
b := make([]byte, 16)
_, _ = rand.Read(b)
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return hex.EncodeToString(b[:4]) + "-" + hex.EncodeToString(b[4:6]) + "-" +
hex.EncodeToString(b[6:8]) + "-" + hex.EncodeToString(b[8:10]) + "-" +
hex.EncodeToString(b[10:])
}
// ─── ConvStore ─────────────────────────────────────────────
func NewConvStore(dir string) *ConvStore {
os.MkdirAll(dir, 0755)
return &ConvStore{dir: dir}
}
func (s *ConvStore) path(id string) string {
return filepath.Join(s.dir, id+".json")
}
func (s *ConvStore) Create() (*Conversation, error) {
conv := &Conversation{
ID: newUUID(),
Title: "新对话",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.Save(conv); err != nil {
return nil, err
}
return conv, nil
}
func (s *ConvStore) Save(conv *Conversation) error {
s.mu.Lock()
defer s.mu.Unlock()
conv.UpdatedAt = time.Now()
return atomicWriteJSON(s.path(conv.ID), conv)
}
func (s *ConvStore) Get(id string) (*Conversation, error) {
s.mu.Lock()
defer s.mu.Unlock()
data, err := os.ReadFile(s.path(id))
if err != nil {
if os.IsNotExist(err) {
return nil, errors.New("对话不存在")
}
return nil, fmt.Errorf("读取对话失败: %w", err)
}
var conv Conversation
if err := json.Unmarshal(data, &conv); err != nil {
return nil, fmt.Errorf("解析对话失败: %w", err)
}
return &conv, nil
}
func (s *ConvStore) List() ([]Conversation, error) {
s.mu.Lock()
defer s.mu.Unlock()
entries, err := os.ReadDir(s.dir)
if err != nil {
return nil, fmt.Errorf("读取对话目录失败: %w", err)
}
var list []Conversation
for _, e := range entries {
if e.IsDir() || filepath.Ext(e.Name()) != ".json" {
continue
}
data, err := os.ReadFile(filepath.Join(s.dir, e.Name()))
if err != nil {
continue
}
var conv Conversation
if err := json.Unmarshal(data, &conv); err != nil {
continue
}
conv.Messages = nil // 列表不返回消息体
list = append(list, conv)
}
sort.Slice(list, func(i, j int) bool {
return list[i].UpdatedAt.After(list[j].UpdatedAt)
})
return list, nil
}
func (s *ConvStore) Delete(id string) error {
s.mu.Lock()
defer s.mu.Unlock()
if err := os.Remove(s.path(id)); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("删除对话失败: %w", err)
}
return nil
}
func atomicWriteJSON(path string, v any) error {
tmp := path + ".tmp"
data, err := json.Marshal(v)
if err != nil {
return err
}
if err := os.WriteFile(tmp, data, 0644); err != nil {
return err
}
return os.Rename(tmp, path)
}
func saveConversationMessages(id string, messages []ChatMessage, assistantContent string) error {
conv, err := store.Get(id)
if err != nil {
return err
}
conv.Messages = append([]ChatMessage(nil), messages...)
conv.Messages = append(conv.Messages, ChatMessage{Role: "assistant", Content: assistantContent})
if conv.Title == "" || conv.Title == "新对话" {
conv.Title = genConvTitle(conv.Messages)
}
return store.Save(conv)
}
func genConvTitle(messages []ChatMessage) string {
for _, m := range messages {
if m.Hidden {
continue
}
if m.Role == "user" && strings.TrimSpace(m.Content) != "" {
title := strings.TrimSpace(m.Content)
title = strings.ReplaceAll(title, "\r\n", " ")
title = strings.ReplaceAll(title, "\n", " ")
runes := []rune(title)
if len(runes) > 30 {
return string(runes[:30]) + "..."
}
return title
}
}
return "新对话"
}
const maxImageSize = 4 * 1024 * 1024
var allowedImageTypes = map[string]bool{
"image/jpeg": true,
"image/png": true,
"image/webp": true,
"image/gif": true,
}
func buildArkMessages(chatMessages []ChatMessage) ([]*model.ChatCompletionMessage, error) {
messages := make([]*model.ChatCompletionMessage, 0, len(chatMessages))
for _, m := range chatMessages {
msg, err := buildArkMessage(m)
if err != nil {
return nil, err
}
messages = append(messages, msg)
}
return messages, nil
}
func buildArkMessage(m ChatMessage) (*model.ChatCompletionMessage, error) {
msg := &model.ChatCompletionMessage{Role: m.Role}
if m.ImageURL == "" && m.ImageURLAlias != "" {
m.ImageURL = m.ImageURLAlias
}
if m.ImageURL == "" {
msg.Content = &model.ChatCompletionMessageContent{
StringValue: &m.Content,
}
return msg, nil
}
imageURL, err := normalizeImageURL(m.ImageURL)
if err != nil {
return nil, err
}
// 有图片时:文字内容可有可无(图片 caption 场景),均构造多模态消息
// 若无文字,则只传图片 part;若同时有图片和文字,先图后文
parts := []*model.ChatCompletionMessageContentPart{imagePart(imageURL)}
if m.Content != "" {
parts = append(parts, textPart(m.Content))
}
msg.Content = &model.ChatCompletionMessageContent{ListValue: parts}
return msg, nil
}
func imagePart(url string) *model.ChatCompletionMessageContentPart {
return &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeImageURL,
ImageURL: &model.ChatMessageImageURL{
URL: url,
Detail: model.ImageURLDetailAuto,
},
}
}
func textPart(text string) *model.ChatCompletionMessageContentPart {
return &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeText,
Text: text,
}
}
func normalizeImageURL(raw string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", errors.New("图片地址不能为空")
}
lower := strings.ToLower(raw)
if strings.HasPrefix(lower, "data:") {
return normalizeImageDataURI(raw)
}
u, err := url.Parse(raw)
if err != nil || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") {
return "", errors.New("图片地址无效,仅支持 http/https URL 或 base64 data URI")
}
return raw, nil
}
func normalizeImageDataURI(raw string) (string, error) {
comma := strings.Index(raw, ",")
if comma < 0 {
return "", errors.New("图片 base64 数据格式错误")
}
meta := strings.ToLower(strings.TrimSpace(raw[5:comma]))
payload := strings.TrimSpace(raw[comma+1:])
if payload == "" {
return "", errors.New("图片 base64 数据不能为空")
}
parts := strings.Split(meta, ";")
if len(parts) < 2 || !contains(parts[1:], "base64") {
return "", errors.New("图片 data URI 必须使用 base64 编码")
}
mime := parts[0]
if !allowedImageTypes[mime] {
return "", errors.New("图片格式不支持,仅支持 jpeg/png/webp/gif")
}
decoded, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
return "", errors.New("图片 base64 数据无效")
}
if len(decoded) > maxImageSize {
return "", errors.New("图片过大,请选择小于 4MB 的图片")
}
return "data:" + mime + ";base64," + payload, nil
}
func contains(items []string, target string) bool {
for _, item := range items {
if strings.TrimSpace(item) == target {
return true
}
}
return false
}
func intPtr(i int) *int { return &i }
func toJSON(s string) string {
b, _ := json.Marshal(s)
return string(b)
}
func toSSE(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, "\n", `\n`)
s = strings.ReplaceAll(s, "\r", "")
s = strings.ReplaceAll(s, `"`, `\"`)
return fmt.Sprintf(`"%s"`, s)
}
// ─── 入口 ─────────────────────────────────────────────────
func main() {
var err error
cfg, err = loadConfig("config.yaml")
if err != nil {
fmt.Fprintln(os.Stderr, "配置加载失败:", err)
os.Exit(1)
}
// 初始化火山方舟 SDK 客户端
aiState, err = NewOpenAIState(cfg.OpenAI)
if err != nil {
fmt.Fprintln(os.Stderr, "OpenAI 配置初始化失败:", err)
os.Exit(1)
}
searchState, err = NewSearchState(cfg.Search)
if err != nil {
fmt.Fprintln(os.Stderr, "搜索配置初始化失败:", err)
os.Exit(1)
}
store = NewConvStore("conversations")
// Gin 路由
r := gin.Default()
r.LoadHTMLGlob("templates/*")
r.Static("/static", "./static")
r.GET("/", indexHandler)
r.POST("/api/chat", chatHandler)
r.GET("/api/openai", listOpenAIHandler)
r.POST("/api/openai/active", switchOpenAIHandler)
r.GET("/api/search", listSearchHandler)
r.POST("/api/search/active", switchSearchHandler)
r.GET("/api/conversations", listConversationsHandler)
r.POST("/api/conversations", createConversationHandler)
r.GET("/api/conversations/:id", getConversationHandler)
r.DELETE("/api/conversations/:id", deleteConversationHandler)
// 根据配置选择监听方式
switch strings.ToLower(cfg.Server.Mode) {
case "unix":
socketPath := cfg.Server.Address
if _, statErr := os.Stat(socketPath); statErr == nil {
os.Remove(socketPath)
}
ln, listenErr := net.Listen("unix", socketPath)
if listenErr != nil {
fmt.Fprintln(os.Stderr, "监听 Unix socket 失败:", listenErr)
os.Exit(1)
}
fmt.Println("服务已启动,监听 Unix socket:", socketPath)
if serveErr := http.Serve(ln, r); serveErr != nil {
fmt.Fprintln(os.Stderr, "服务异常退出:", serveErr)
os.Exit(1)
}
default:
fmt.Println("服务已启动,监听 TCP:", cfg.Server.Address)
if runErr := r.Run(cfg.Server.Address); runErr != nil {
fmt.Fprintln(os.Stderr, "服务异常退出:", runErr)
os.Exit(1)
}
}
}