1760 lines
54 KiB
Go
1760 lines
54 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
"unicode"
|
|
|
|
searchagent "aichat/agents/search"
|
|
sqlquery "aichat/agents/sql"
|
|
timeagent "aichat/agents/time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
ark "github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
|
"gopkg.in/yaml.v3"
|
|
)
|
|
|
|
// ─── 配置 ─────────────────────────────────────────────────
|
|
|
|
const (
|
|
defaultOpenAIBaseURL = "https://ark.cn-beijing.volces.com/api/v3"
|
|
defaultOpenAITimeout = 120
|
|
defaultToolRouterTimeout = 30
|
|
defaultToolRouterMaxTokens = 512
|
|
defaultToolRouterSystemText = `你是工具路由器。根据用户最新问题和可用工具列表,判断本轮是否需要调用一个或多个工具。
|
|
只能返回 JSON,不要使用 Markdown。
|
|
JSON 格式:{"tools":[{"name":"工具名称","reason":"..."}],"reason":"..."}
|
|
工具名称必须来自“可用工具”列表。
|
|
可以选择多个工具,工具会按配置顺序依次执行;后面的工具可以使用前面工具写入的上下文。
|
|
如果用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近等相对时间,且还需要调用 search 或 sql,必须同时选择 time,并让 time 排在这些工具之前。
|
|
例如“历史上的今天都发生了什么”应选择 time 和 search:先获取今天的绝对日期,再搜索当天历史事件;如果联网无结果,主模型会回退到自身知识库回答并说明来源。
|
|
例如“本月有什么日程安排”应选择 time 和 sql:先获取本月绝对日期范围,再查询日程表。
|
|
如果无需工具,返回 {"tools":[],"reason":"..."}。
|
|
只选择确实必要的工具。`
|
|
)
|
|
|
|
type OpenAIConfig struct {
|
|
Name string `yaml:"name" json:"name"`
|
|
Active bool `yaml:"active,omitempty" json:"active"`
|
|
APIKey string `yaml:"api_key" json:"-"`
|
|
BaseURL string `yaml:"base_url" json:"base_url"`
|
|
Model string `yaml:"model" json:"model"`
|
|
Timeout int `yaml:"timeout" json:"timeout"`
|
|
}
|
|
|
|
type OpenAIConfigs []OpenAIConfig
|
|
|
|
type ToolRouterConfig struct {
|
|
Enabled bool `yaml:"enabled" json:"enabled"`
|
|
OpenAIName string `yaml:"openai_name" json:"openai_name"`
|
|
Timeout int `yaml:"timeout" json:"timeout"`
|
|
MaxTokens int `yaml:"max_tokens" json:"max_tokens"`
|
|
SystemPrompt string `yaml:"system_prompt" json:"system_prompt"`
|
|
Tools []ToolRouteConfig `yaml:"tools" json:"tools"`
|
|
}
|
|
|
|
type ToolRouteConfig struct {
|
|
Name string `yaml:"name" json:"name"`
|
|
Enabled bool `yaml:"enabled" json:"enabled"`
|
|
Description string `yaml:"description" json:"description"`
|
|
}
|
|
|
|
func (configs *OpenAIConfigs) UnmarshalYAML(value *yaml.Node) error {
|
|
switch value.Kind {
|
|
case yaml.SequenceNode:
|
|
var items []OpenAIConfig
|
|
if err := value.Decode(&items); err != nil {
|
|
return err
|
|
}
|
|
*configs = items
|
|
case yaml.MappingNode:
|
|
var item OpenAIConfig
|
|
if err := value.Decode(&item); err != nil {
|
|
return err
|
|
}
|
|
*configs = []OpenAIConfig{item}
|
|
case yaml.ScalarNode:
|
|
if value.Tag == "!!null" {
|
|
*configs = nil
|
|
return nil
|
|
}
|
|
return fmt.Errorf("openai 配置格式无效")
|
|
default:
|
|
return fmt.Errorf("openai 配置格式无效")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type Config struct {
|
|
Server struct {
|
|
Mode string `yaml:"mode"`
|
|
Address string `yaml:"address"`
|
|
} `yaml:"server"`
|
|
OpenAI OpenAIConfigs `yaml:"openai"`
|
|
ToolRouter ToolRouterConfig `yaml:"tool_router"`
|
|
}
|
|
|
|
func defaultOpenAIConfig() OpenAIConfig {
|
|
return OpenAIConfig{
|
|
Name: "default",
|
|
Active: true,
|
|
BaseURL: defaultOpenAIBaseURL,
|
|
Timeout: defaultOpenAITimeout,
|
|
}
|
|
}
|
|
|
|
func defaultToolRouterConfig() ToolRouterConfig {
|
|
return ToolRouterConfig{
|
|
Enabled: true,
|
|
OpenAIName: "",
|
|
Timeout: defaultToolRouterTimeout,
|
|
MaxTokens: defaultToolRouterMaxTokens,
|
|
SystemPrompt: defaultToolRouterSystemText,
|
|
Tools: []ToolRouteConfig{
|
|
{Name: "time", Enabled: true, Description: ""},
|
|
{Name: "search", Enabled: true, Description: ""},
|
|
{Name: "sql", Enabled: true, Description: ""},
|
|
},
|
|
}
|
|
}
|
|
|
|
func defaultConfig() Config {
|
|
var cfg Config
|
|
cfg.Server.Mode = "tcp"
|
|
cfg.Server.Address = "0.0.0.0:8080"
|
|
cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()}
|
|
cfg.ToolRouter = defaultToolRouterConfig()
|
|
return cfg
|
|
}
|
|
|
|
func loadConfig(path string) (*Config, error) {
|
|
if err := ensureConfigFile(path); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
|
}
|
|
var cfg Config
|
|
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
|
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
|
}
|
|
if _, err := normalizeOpenAIConfigs(&cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
// 环境变量优先
|
|
if key := os.Getenv("ARK_API_KEY"); key != "" {
|
|
for i := range cfg.OpenAI {
|
|
cfg.OpenAI[i].APIKey = key
|
|
}
|
|
}
|
|
legacySearchProfiles = readLegacySearchProfiles(data)
|
|
if _, err := normalizeToolRouterConfig(&cfg); err != nil {
|
|
return nil, err
|
|
}
|
|
return &cfg, nil
|
|
}
|
|
|
|
func ensureConfigFile(path string) error {
|
|
defaults := defaultConfig()
|
|
if _, err := os.Stat(path); err != nil {
|
|
if !os.IsNotExist(err) {
|
|
return fmt.Errorf("检查配置文件失败: %w", err)
|
|
}
|
|
return writeConfig(path, defaults)
|
|
}
|
|
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return fmt.Errorf("读取配置文件失败: %w", err)
|
|
}
|
|
var cfg Config
|
|
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
|
return fmt.Errorf("解析配置文件失败: %w", err)
|
|
}
|
|
var raw map[string]any
|
|
if err = yaml.Unmarshal(data, &raw); err != nil {
|
|
return fmt.Errorf("解析配置文件失败: %w", err)
|
|
}
|
|
|
|
changed := false
|
|
server, _ := raw["server"].(map[string]any)
|
|
if server == nil {
|
|
cfg.Server = defaults.Server
|
|
changed = true
|
|
} else {
|
|
if _, ok := server["mode"]; !ok {
|
|
cfg.Server.Mode = defaults.Server.Mode
|
|
changed = true
|
|
}
|
|
if _, ok := server["address"]; !ok {
|
|
cfg.Server.Address = defaults.Server.Address
|
|
changed = true
|
|
}
|
|
}
|
|
|
|
if _, ok := raw["openai"].([]any); !ok {
|
|
changed = true
|
|
}
|
|
if normalized, err := normalizeOpenAIConfigs(&cfg); err != nil {
|
|
return err
|
|
} else if normalized {
|
|
changed = true
|
|
}
|
|
|
|
if _, ok := raw["tool_router"]; !ok {
|
|
cfg.ToolRouter = defaults.ToolRouter
|
|
changed = true
|
|
} else if normalized, err := normalizeToolRouterConfig(&cfg); err != nil {
|
|
return err
|
|
} else if normalized {
|
|
changed = true
|
|
}
|
|
|
|
if !changed {
|
|
return nil
|
|
}
|
|
return writeConfig(path, cfg)
|
|
}
|
|
|
|
func normalizeOpenAIConfigs(cfg *Config) (bool, error) {
|
|
changed := false
|
|
if len(cfg.OpenAI) == 0 {
|
|
cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()}
|
|
changed = true
|
|
}
|
|
|
|
activeIndex := -1
|
|
seen := map[string]bool{}
|
|
for i := range cfg.OpenAI {
|
|
profile := &cfg.OpenAI[i]
|
|
name := strings.TrimSpace(profile.Name)
|
|
if name == "" {
|
|
name = strings.TrimSpace(profile.Model)
|
|
if name == "" {
|
|
name = fmt.Sprintf("openai-%d", i+1)
|
|
}
|
|
profile.Name = name
|
|
changed = true
|
|
} else if name != profile.Name {
|
|
profile.Name = name
|
|
changed = true
|
|
}
|
|
if seen[name] {
|
|
return changed, fmt.Errorf("openai 配置名称重复: %s", name)
|
|
}
|
|
seen[name] = true
|
|
|
|
if strings.TrimSpace(profile.BaseURL) == "" {
|
|
profile.BaseURL = defaultOpenAIBaseURL
|
|
changed = true
|
|
}
|
|
if profile.Timeout <= 0 {
|
|
profile.Timeout = defaultOpenAITimeout
|
|
changed = true
|
|
}
|
|
if profile.Active {
|
|
if activeIndex == -1 {
|
|
activeIndex = i
|
|
} else {
|
|
profile.Active = false
|
|
changed = true
|
|
}
|
|
}
|
|
}
|
|
if activeIndex == -1 {
|
|
cfg.OpenAI[0].Active = true
|
|
changed = true
|
|
}
|
|
return changed, nil
|
|
}
|
|
|
|
func normalizeToolRouterConfig(cfg *Config) (bool, error) {
|
|
changed := false
|
|
defaults := defaultToolRouterConfig()
|
|
cfg.ToolRouter.OpenAIName = strings.TrimSpace(cfg.ToolRouter.OpenAIName)
|
|
if cfg.ToolRouter.Timeout <= 0 {
|
|
cfg.ToolRouter.Timeout = defaultToolRouterTimeout
|
|
changed = true
|
|
}
|
|
if cfg.ToolRouter.MaxTokens <= 0 {
|
|
cfg.ToolRouter.MaxTokens = defaultToolRouterMaxTokens
|
|
changed = true
|
|
}
|
|
if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) == "" {
|
|
cfg.ToolRouter.SystemPrompt = defaultToolRouterSystemText
|
|
changed = true
|
|
} else if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) != cfg.ToolRouter.SystemPrompt {
|
|
cfg.ToolRouter.SystemPrompt = strings.TrimSpace(cfg.ToolRouter.SystemPrompt)
|
|
changed = true
|
|
}
|
|
if len(cfg.ToolRouter.Tools) == 0 {
|
|
cfg.ToolRouter.Tools = defaults.Tools
|
|
changed = true
|
|
}
|
|
seen := map[string]bool{}
|
|
for i := range cfg.ToolRouter.Tools {
|
|
tool := &cfg.ToolRouter.Tools[i]
|
|
name := strings.ToLower(strings.TrimSpace(tool.Name))
|
|
if name == "" {
|
|
name = fmt.Sprintf("tool-%d", i+1)
|
|
}
|
|
if name != tool.Name {
|
|
tool.Name = name
|
|
changed = true
|
|
}
|
|
tool.Description = strings.TrimSpace(tool.Description)
|
|
if seen[name] {
|
|
return changed, fmt.Errorf("tool_router.tools 配置名称重复: %s", name)
|
|
}
|
|
seen[name] = true
|
|
}
|
|
byName := map[string]ToolRouteConfig{}
|
|
for _, tool := range cfg.ToolRouter.Tools {
|
|
byName[tool.Name] = tool
|
|
}
|
|
merged := make([]ToolRouteConfig, 0, len(cfg.ToolRouter.Tools)+len(defaults.Tools))
|
|
used := map[string]bool{}
|
|
for _, tool := range defaults.Tools {
|
|
if existing, ok := byName[tool.Name]; ok {
|
|
merged = append(merged, existing)
|
|
} else {
|
|
merged = append(merged, tool)
|
|
changed = true
|
|
}
|
|
used[tool.Name] = true
|
|
}
|
|
for _, tool := range cfg.ToolRouter.Tools {
|
|
if !used[tool.Name] {
|
|
merged = append(merged, tool)
|
|
}
|
|
}
|
|
if len(merged) != len(cfg.ToolRouter.Tools) {
|
|
changed = true
|
|
} else {
|
|
for i := range merged {
|
|
if merged[i].Name != cfg.ToolRouter.Tools[i].Name {
|
|
changed = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
cfg.ToolRouter.Tools = merged
|
|
return changed, nil
|
|
}
|
|
|
|
func readLegacySearchProfiles(data []byte) []searchagent.ProfileConfig {
|
|
var legacy struct {
|
|
Search searchagent.ProfileConfigs `yaml:"search"`
|
|
}
|
|
if err := yaml.Unmarshal(data, &legacy); err != nil {
|
|
return nil
|
|
}
|
|
return []searchagent.ProfileConfig(legacy.Search)
|
|
}
|
|
|
|
func writeConfig(path string, cfg Config) error {
|
|
data, err := yaml.Marshal(&cfg)
|
|
if err != nil {
|
|
return fmt.Errorf("生成配置文件失败: %w", err)
|
|
}
|
|
if err := os.WriteFile(path, data, 0644); err != nil {
|
|
return fmt.Errorf("写入配置文件失败: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ─── 请求结构 ─────────────────────────────────────────────
|
|
|
|
type ChatMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
ImageURL string `json:"image_url,omitempty"` // base64 data URI 或 http URL
|
|
ImageURLAlias string `json:"imageURL,omitempty"`
|
|
Hidden bool `json:"hidden,omitempty"`
|
|
}
|
|
|
|
type ChatRequest struct {
|
|
ConversationID string `json:"conversation_id,omitempty"`
|
|
Messages []ChatMessage `json:"messages"`
|
|
WebSearch bool `json:"web_search,omitempty"`
|
|
OpenAIName string `json:"openai_name,omitempty"`
|
|
}
|
|
|
|
type Conversation struct {
|
|
ID string `json:"id"`
|
|
Title string `json:"title"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
Messages []ChatMessage `json:"messages,omitempty"`
|
|
}
|
|
|
|
type ConvStore struct {
|
|
dir string
|
|
mu sync.Mutex
|
|
}
|
|
|
|
type OpenAIProfile struct {
|
|
Config OpenAIConfig
|
|
Client *ark.Client
|
|
}
|
|
|
|
type OpenAIState struct {
|
|
mu sync.RWMutex
|
|
profiles map[string]*OpenAIProfile
|
|
order []string
|
|
activeName string
|
|
}
|
|
|
|
type activeProfileRequest struct {
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
type openAIListResponse struct {
|
|
Active string `json:"active"`
|
|
Profiles []OpenAIConfig `json:"profiles"`
|
|
}
|
|
|
|
type toolTextCompleter func(context.Context, *OpenAIProfile, []ChatMessage, int, time.Duration) (string, error)
|
|
|
|
type ToolRouterState struct {
|
|
cfg *ToolRouterConfig
|
|
ai *OpenAIState
|
|
complete toolTextCompleter
|
|
}
|
|
|
|
func NewToolRouterState(config *ToolRouterConfig, ai *OpenAIState) (*ToolRouterState, error) {
|
|
if config == nil {
|
|
cfg := defaultToolRouterConfig()
|
|
config = &cfg
|
|
}
|
|
if ai == nil {
|
|
return nil, errors.New("工具路由需要 OpenAI 状态")
|
|
}
|
|
if config.Enabled && strings.TrimSpace(config.OpenAIName) != "" {
|
|
if _, err := ai.GetProfile(config.OpenAIName); err != nil {
|
|
return nil, fmt.Errorf("tool_router.openai_name 配置无效: %w", err)
|
|
}
|
|
}
|
|
return &ToolRouterState{cfg: config, ai: ai, complete: completeTextWithTimeout}, nil
|
|
}
|
|
|
|
func NewOpenAIState(configs []OpenAIConfig) (*OpenAIState, error) {
|
|
state := &OpenAIState{
|
|
profiles: make(map[string]*OpenAIProfile, len(configs)),
|
|
order: make([]string, 0, len(configs)),
|
|
}
|
|
for _, config := range configs {
|
|
if strings.TrimSpace(config.Name) == "" {
|
|
return nil, errors.New("openai.name 不能为空")
|
|
}
|
|
if strings.TrimSpace(config.APIKey) == "" {
|
|
return nil, fmt.Errorf("openai.%s.api_key 未配置,也未设置环境变量 ARK_API_KEY", config.Name)
|
|
}
|
|
if strings.TrimSpace(config.Model) == "" {
|
|
return nil, fmt.Errorf("openai.%s.model 未配置", config.Name)
|
|
}
|
|
if strings.TrimSpace(config.BaseURL) == "" {
|
|
return nil, fmt.Errorf("openai.%s.base_url 未配置", config.Name)
|
|
}
|
|
if config.Timeout <= 0 {
|
|
return nil, fmt.Errorf("openai.%s.timeout 必须大于 0", config.Name)
|
|
}
|
|
if _, ok := state.profiles[config.Name]; ok {
|
|
return nil, fmt.Errorf("openai 配置名称重复: %s", config.Name)
|
|
}
|
|
state.profiles[config.Name] = &OpenAIProfile{
|
|
Config: config,
|
|
Client: ark.NewClientWithApiKey(
|
|
config.APIKey,
|
|
ark.WithBaseUrl(config.BaseURL),
|
|
ark.WithTimeout(time.Duration(config.Timeout)*time.Second),
|
|
),
|
|
}
|
|
state.order = append(state.order, config.Name)
|
|
if config.Active && state.activeName == "" {
|
|
state.activeName = config.Name
|
|
}
|
|
}
|
|
if len(state.order) == 0 {
|
|
return nil, errors.New("openai 配置不能为空")
|
|
}
|
|
if state.activeName == "" {
|
|
state.activeName = state.order[0]
|
|
}
|
|
return state, nil
|
|
}
|
|
|
|
func (s *OpenAIState) ActiveProfile() *OpenAIProfile {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
return s.profiles[s.activeName]
|
|
}
|
|
|
|
func (s *OpenAIState) GetProfile(name string) (*OpenAIProfile, error) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
if strings.TrimSpace(name) == "" {
|
|
return s.profiles[s.activeName], nil
|
|
}
|
|
profile, ok := s.profiles[strings.TrimSpace(name)]
|
|
if !ok {
|
|
return nil, fmt.Errorf("OpenAI 配置不存在: %s", name)
|
|
}
|
|
return profile, nil
|
|
}
|
|
|
|
func (s *OpenAIState) SwitchActive(name string) (*OpenAIProfile, error) {
|
|
name = strings.TrimSpace(name)
|
|
if name == "" {
|
|
return nil, errors.New("OpenAI 配置名称不能为空")
|
|
}
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
profile, ok := s.profiles[name]
|
|
if !ok {
|
|
return nil, fmt.Errorf("OpenAI 配置不存在: %s", name)
|
|
}
|
|
s.activeName = name
|
|
return profile, nil
|
|
}
|
|
|
|
func (s *OpenAIState) ListProfiles() openAIListResponse {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
profiles := make([]OpenAIConfig, 0, len(s.order))
|
|
for _, name := range s.order {
|
|
profile := s.profiles[name]
|
|
config := profile.Config
|
|
config.APIKey = ""
|
|
config.Active = name == s.activeName
|
|
profiles = append(profiles, config)
|
|
}
|
|
return openAIListResponse{Active: s.activeName, Profiles: profiles}
|
|
}
|
|
|
|
func publicOpenAIConfig(profile *OpenAIProfile, active bool) OpenAIConfig {
|
|
config := profile.Config
|
|
config.APIKey = ""
|
|
config.Active = active
|
|
return config
|
|
}
|
|
|
|
// ─── 全局变量 ─────────────────────────────────────────────
|
|
|
|
var (
|
|
cfg *Config
|
|
aiState *OpenAIState
|
|
searchState *searchagent.State
|
|
legacySearchProfiles []searchagent.ProfileConfig
|
|
toolRouterState *ToolRouterState
|
|
sqlState *sqlquery.State
|
|
store *ConvStore
|
|
)
|
|
|
|
type chatSSEFrame struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
Message string `json:"message,omitempty"`
|
|
Tool string `json:"tool,omitempty"`
|
|
Stage string `json:"stage,omitempty"`
|
|
Status string `json:"status,omitempty"`
|
|
Data map[string]any `json:"data,omitempty"`
|
|
Stats *tokenUsageStats `json:"stats,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
|
|
type tokenUsageStats struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
ToolPromptTokens int `json:"tool_prompt_tokens"`
|
|
ToolCompletionTokens int `json:"tool_completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
CompletionTokensPerSec float64 `json:"completion_tokens_per_sec"`
|
|
PeakCompletionTokensPerSec float64 `json:"peak_completion_tokens_per_sec"`
|
|
Estimated bool `json:"estimated"`
|
|
}
|
|
|
|
type tokenUsageTracker struct {
|
|
mu sync.Mutex
|
|
promptTokens int
|
|
completionTokens int
|
|
toolPromptTokens int
|
|
toolCompletionTokens int
|
|
}
|
|
|
|
type tokenUsageContextKey struct{}
|
|
|
|
func newTokenUsageTracker() *tokenUsageTracker {
|
|
return &tokenUsageTracker{}
|
|
}
|
|
|
|
func contextWithTokenUsage(ctx context.Context, tracker *tokenUsageTracker) context.Context {
|
|
if tracker == nil {
|
|
return ctx
|
|
}
|
|
return context.WithValue(ctx, tokenUsageContextKey{}, tracker)
|
|
}
|
|
|
|
func tokenUsageFromContext(ctx context.Context) *tokenUsageTracker {
|
|
tracker, _ := ctx.Value(tokenUsageContextKey{}).(*tokenUsageTracker)
|
|
return tracker
|
|
}
|
|
|
|
func (t *tokenUsageTracker) addTool(promptTokens, completionTokens int) {
|
|
if t == nil {
|
|
return
|
|
}
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
t.toolPromptTokens += promptTokens
|
|
t.toolCompletionTokens += completionTokens
|
|
}
|
|
|
|
func (t *tokenUsageTracker) setModel(promptTokens, completionTokens int) {
|
|
if t == nil {
|
|
return
|
|
}
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
t.promptTokens = promptTokens
|
|
t.completionTokens = completionTokens
|
|
}
|
|
|
|
func (t *tokenUsageTracker) snapshot(tokensPerSecond, peakTokensPerSecond float64) tokenUsageStats {
|
|
if t == nil {
|
|
return tokenUsageStats{Estimated: true}
|
|
}
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
total := t.promptTokens + t.completionTokens + t.toolPromptTokens + t.toolCompletionTokens
|
|
return tokenUsageStats{
|
|
PromptTokens: t.promptTokens,
|
|
CompletionTokens: t.completionTokens,
|
|
ToolPromptTokens: t.toolPromptTokens,
|
|
ToolCompletionTokens: t.toolCompletionTokens,
|
|
TotalTokens: total,
|
|
CompletionTokensPerSec: tokensPerSecond,
|
|
PeakCompletionTokensPerSec: peakTokensPerSecond,
|
|
Estimated: true,
|
|
}
|
|
}
|
|
|
|
// ─── 路由 ─────────────────────────────────────────────────
|
|
|
|
func indexHandler(c *gin.Context) {
|
|
profile := aiState.ActiveProfile()
|
|
c.HTML(http.StatusOK, "chat.html", gin.H{
|
|
"Title": "AI 对话",
|
|
"Model": profile.Config.Model,
|
|
"OpenAIName": profile.Config.Name,
|
|
})
|
|
}
|
|
|
|
func listOpenAIHandler(c *gin.Context) {
|
|
c.JSON(http.StatusOK, aiState.ListProfiles())
|
|
}
|
|
|
|
func switchOpenAIHandler(c *gin.Context) {
|
|
var req activeProfileRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
|
|
return
|
|
}
|
|
profile, err := aiState.SwitchActive(req.Name)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"active": profile.Config.Name,
|
|
"profile": publicOpenAIConfig(profile, true),
|
|
})
|
|
}
|
|
|
|
func listSearchHandler(c *gin.Context) {
|
|
c.JSON(http.StatusOK, searchState.ListProfiles())
|
|
}
|
|
|
|
func switchSearchHandler(c *gin.Context) {
|
|
var req activeProfileRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
|
|
return
|
|
}
|
|
profile, err := searchState.SwitchActive(req.Name)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
profile.APIKey = ""
|
|
profile.Active = true
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"active": profile.Name,
|
|
"profile": profile,
|
|
})
|
|
}
|
|
|
|
func listConversationsHandler(c *gin.Context) {
|
|
convs, err := store.List()
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, convs)
|
|
}
|
|
|
|
func createConversationHandler(c *gin.Context) {
|
|
conv, err := store.Create()
|
|
if err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建对话失败: " + err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, conv)
|
|
}
|
|
|
|
func getConversationHandler(c *gin.Context) {
|
|
conv, err := store.Get(c.Param("id"))
|
|
if err != nil {
|
|
status := http.StatusInternalServerError
|
|
if err.Error() == "对话不存在" {
|
|
status = http.StatusNotFound
|
|
}
|
|
c.JSON(status, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.JSON(http.StatusOK, conv)
|
|
}
|
|
|
|
func deleteConversationHandler(c *gin.Context) {
|
|
if err := store.Delete(c.Param("id")); err != nil {
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
c.Status(http.StatusNoContent)
|
|
}
|
|
|
|
// chatHandler 流式 SSE 对话接口
|
|
func chatHandler(c *gin.Context) {
|
|
var req ChatRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
|
|
return
|
|
}
|
|
if len(req.Messages) == 0 {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "消息不能为空"})
|
|
return
|
|
}
|
|
profile, err := aiState.GetProfile(req.OpenAIName)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
return
|
|
}
|
|
|
|
// SSE 头先写出,后续插件/模型过程都通过 trace 事件实时展示。
|
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
|
c.Writer.Header().Set("Connection", "keep-alive")
|
|
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
flusher, ok := c.Writer.(http.Flusher)
|
|
if !ok {
|
|
return
|
|
}
|
|
emit := func(frame chatSSEFrame) {
|
|
writeSSEJSON(c.Writer, frame)
|
|
flusher.Flush()
|
|
}
|
|
emitTrace := func(tool, stage, status, message string, data map[string]any) {
|
|
emit(chatSSEFrame{Type: "trace", Tool: tool, Stage: stage, Status: status, Message: message, Data: data})
|
|
}
|
|
emitError := func(err error) {
|
|
emit(chatSSEFrame{Type: "error", Error: err.Error()})
|
|
}
|
|
|
|
// 超时 context
|
|
timeout := time.Duration(profile.Config.Timeout) * time.Second
|
|
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
|
defer cancel()
|
|
usage := newTokenUsageTracker()
|
|
ctx = contextWithTokenUsage(ctx, usage)
|
|
|
|
chatMessages := req.Messages
|
|
withTools, err := enrichMessagesWithRoutedTools(ctx, profile, chatMessages, emit)
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, "工具路由调用失败:", err)
|
|
} else {
|
|
chatMessages = withTools
|
|
}
|
|
// 构建 ark 消息列表
|
|
messages, err := buildArkMessages(chatMessages)
|
|
if err != nil {
|
|
emitError(err)
|
|
return
|
|
}
|
|
promptTokens := estimateChatMessagesTokens(chatMessages)
|
|
|
|
emitTrace("model", "request", "running", "正在调用模型生成回答", nil)
|
|
stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
|
|
Model: profile.Config.Model,
|
|
Messages: messages,
|
|
MaxTokens: intPtr(4096),
|
|
}.WithStream(true))
|
|
if err != nil {
|
|
emitError(err)
|
|
return
|
|
}
|
|
defer stream.Close()
|
|
emitTrace("model", "stream", "running", "模型已开始输出", nil)
|
|
|
|
var full strings.Builder
|
|
completionTokens := 0
|
|
streamStarted := time.Now()
|
|
windowStarted := streamStarted
|
|
windowTokens := 0
|
|
peakTokensPerSecond := 0.0
|
|
for {
|
|
resp, err := stream.Recv()
|
|
if errors.Is(err, io.EOF) {
|
|
usage.setModel(promptTokens, completionTokens)
|
|
if windowTokens > 0 {
|
|
windowElapsed := time.Since(windowStarted).Seconds()
|
|
if windowElapsed > 0.25 {
|
|
windowSpeed := float64(windowTokens) / windowElapsed
|
|
if windowSpeed > peakTokensPerSecond {
|
|
peakTokensPerSecond = windowSpeed
|
|
}
|
|
}
|
|
}
|
|
if peakTokensPerSecond == 0 {
|
|
peakTokensPerSecond = tokensPerSecond(completionTokens, streamStarted)
|
|
}
|
|
if req.ConversationID != "" {
|
|
if err := saveConversationMessages(req.ConversationID, req.Messages, full.String()); err != nil {
|
|
fmt.Fprintln(os.Stderr, "保存对话失败:", err)
|
|
}
|
|
}
|
|
finalStats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
|
|
emit(chatSSEFrame{Type: "stats", Stats: &finalStats})
|
|
emitTrace("model", "stream", "success", "回答生成完成", nil)
|
|
fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
|
flusher.Flush()
|
|
return
|
|
}
|
|
if err != nil {
|
|
emitError(err)
|
|
return
|
|
}
|
|
if len(resp.Choices) > 0 {
|
|
delta := resp.Choices[0].Delta.Content
|
|
if delta != "" {
|
|
now := time.Now()
|
|
deltaTokens := estimateTokenCount(delta)
|
|
windowTokens += deltaTokens
|
|
windowElapsed := now.Sub(windowStarted).Seconds()
|
|
if windowElapsed >= 1 {
|
|
windowSpeed := float64(windowTokens) / windowElapsed
|
|
if windowSpeed > peakTokensPerSecond {
|
|
peakTokensPerSecond = windowSpeed
|
|
}
|
|
windowStarted = now
|
|
windowTokens = 0
|
|
} else if peakTokensPerSecond == 0 && windowElapsed > 0.25 {
|
|
peakTokensPerSecond = float64(windowTokens) / windowElapsed
|
|
}
|
|
full.WriteString(delta)
|
|
completionTokens += deltaTokens
|
|
usage.setModel(promptTokens, completionTokens)
|
|
stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
|
|
emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ─── 辅助函数 ─────────────────────────────────────────────
|
|
|
|
func latestUserQuery(messages []ChatMessage) string {
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
if messages[i].Role == "user" {
|
|
return strings.TrimSpace(messages[i].Content)
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func estimateChatMessagesTokens(messages []ChatMessage) int {
|
|
total := 0
|
|
for _, msg := range messages {
|
|
total += estimateTokenCount(msg.Role) + estimateTokenCount(msg.Content) + 4
|
|
if msg.ImageURL != "" || msg.ImageURLAlias != "" {
|
|
total += 85
|
|
}
|
|
}
|
|
return total
|
|
}
|
|
|
|
func estimateTokenCount(text string) int {
|
|
text = strings.TrimSpace(text)
|
|
if text == "" {
|
|
return 0
|
|
}
|
|
tokens := 0
|
|
asciiRunes := 0
|
|
flushASCII := func() {
|
|
if asciiRunes > 0 {
|
|
tokens += (asciiRunes + 3) / 4
|
|
asciiRunes = 0
|
|
}
|
|
}
|
|
for _, r := range text {
|
|
if unicode.IsSpace(r) {
|
|
flushASCII()
|
|
continue
|
|
}
|
|
if r <= unicode.MaxASCII {
|
|
asciiRunes++
|
|
continue
|
|
}
|
|
flushASCII()
|
|
tokens++
|
|
}
|
|
flushASCII()
|
|
if tokens == 0 {
|
|
return 1
|
|
}
|
|
return tokens
|
|
}
|
|
|
|
func tokensPerSecond(tokens int, start time.Time) float64 {
|
|
elapsed := time.Since(start).Seconds()
|
|
if tokens <= 0 || elapsed <= 0 {
|
|
return 0
|
|
}
|
|
return float64(tokens) / elapsed
|
|
}
|
|
|
|
type ToolSelection struct {
|
|
Name string `json:"name"`
|
|
Reason string `json:"reason"`
|
|
}
|
|
|
|
type ToolRoutingDecision struct {
|
|
Tools []ToolSelection `json:"tools"`
|
|
Reason string `json:"reason"`
|
|
}
|
|
|
|
type ChatTool interface {
|
|
Name() string
|
|
Description() string
|
|
Enabled() bool
|
|
Enrich(context.Context, *OpenAIProfile, []ChatMessage, string, func(chatSSEFrame)) ([]ChatMessage, error)
|
|
}
|
|
|
|
type TimeChatTool struct{}
|
|
|
|
func (t TimeChatTool) Name() string { return "time" }
|
|
|
|
func (t TimeChatTool) Description() string {
|
|
return timeagent.ActivationPrompt
|
|
}
|
|
|
|
func (t TimeChatTool) Enabled() bool { return true }
|
|
|
|
func (t TimeChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
return runTimeTool(ctx, messages, routeReason, emit)
|
|
}
|
|
|
|
type SQLChatTool struct {
|
|
state *sqlquery.State
|
|
}
|
|
|
|
func (t SQLChatTool) Name() string { return "sql" }
|
|
|
|
func (t SQLChatTool) Description() string {
|
|
if t.state == nil {
|
|
return ""
|
|
}
|
|
return t.state.ActivationPrompt()
|
|
}
|
|
|
|
func (t SQLChatTool) Enabled() bool { return t.state != nil && t.state.Enabled() }
|
|
|
|
func (t SQLChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
return runSQLTool(ctx, t.state, profile, messages, routeReason, emit)
|
|
}
|
|
|
|
type SearchChatTool struct {
|
|
state *searchagent.State
|
|
}
|
|
|
|
func (t SearchChatTool) Name() string { return "search" }
|
|
|
|
func (t SearchChatTool) Description() string {
|
|
if t.state == nil {
|
|
return ""
|
|
}
|
|
return t.state.ActivationPrompt()
|
|
}
|
|
|
|
func (t SearchChatTool) Enabled() bool { return t.state != nil && t.state.Enabled() }
|
|
|
|
func (t SearchChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
return runSearchTool(ctx, t.state, messages, routeReason, emit)
|
|
}
|
|
|
|
type sqlGenerationResult struct {
|
|
Database string `json:"database"`
|
|
SQL string `json:"sql"`
|
|
Reason string `json:"reason"`
|
|
}
|
|
|
|
func runSQLTool(ctx context.Context, state *sqlquery.State, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
query := latestUserQuery(messages)
|
|
if query == "" {
|
|
return messages, nil
|
|
}
|
|
|
|
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "running", Message: "正在读取数据库结构"})
|
|
schemaContext, err := state.SchemaContext(ctx)
|
|
if err != nil {
|
|
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "error", Message: "数据库结构读取失败", Data: map[string]any{"error": err.Error()}})
|
|
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
|
}
|
|
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "success", Message: "数据库结构读取完成"})
|
|
|
|
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "running", Message: "正在生成只读 SQL"})
|
|
generated, err := generateSQLForUserQuery(ctx, profile, query, schemaContext)
|
|
if err != nil {
|
|
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "SQL 生成失败", Data: map[string]any{"error": err.Error()}})
|
|
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
|
}
|
|
generated.Database = strings.TrimSpace(generated.Database)
|
|
generated.SQL = strings.TrimSpace(generated.SQL)
|
|
if generated.SQL == "" {
|
|
err := fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason)
|
|
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "模型未生成可执行 SQL", Data: map[string]any{"reason": generated.Reason}})
|
|
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
|
}
|
|
|
|
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "running", Message: "正在执行数据库查询", Data: map[string]any{"database": generated.Database}})
|
|
result, err := state.ExecuteReadOnly(ctx, generated.Database, generated.SQL)
|
|
if err != nil {
|
|
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "error", Message: "数据库查询失败", Data: map[string]any{"error": err.Error()}})
|
|
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
|
}
|
|
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "success", Message: "已生成只读 SQL", Data: map[string]any{"database": generated.Database, "sql": generated.SQL, "reason": generated.Reason}})
|
|
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "success", Message: fmt.Sprintf("数据库查询完成,返回 %d 行", len(result.Rows)), Data: map[string]any{"database": result.Database, "rows": len(result.Rows), "columns": len(result.Columns), "truncated": result.Truncated, "max_rows": result.MaxRows}})
|
|
contextText := sqlquery.BuildResultContext(query, generated.SQL, result)
|
|
if strings.TrimSpace(routeReason) != "" {
|
|
contextText += "\n激活原因:" + routeReason
|
|
}
|
|
return prependHiddenContext(messages, contextText), nil
|
|
}
|
|
|
|
func prependHiddenContext(messages []ChatMessage, content string) []ChatMessage {
|
|
withContext := make([]ChatMessage, 0, len(messages)+1)
|
|
withContext = append(withContext, ChatMessage{Role: "system", Content: content, Hidden: true})
|
|
withContext = append(withContext, messages...)
|
|
return withContext
|
|
}
|
|
|
|
func runTimeTool(ctx context.Context, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
_ = ctx
|
|
resolved := timeagent.Resolve(time.Now())
|
|
emit(chatSSEFrame{Type: "trace", Tool: "time", Stage: "resolve", Status: "success", Message: "已获取当前时间上下文", Data: map[string]any{
|
|
"today": timeagent.FormatDate(resolved.Now),
|
|
"this_month": fmt.Sprintf("%s 至 %s", timeagent.FormatDate(resolved.ThisMonth.Start), timeagent.FormatDate(resolved.ThisMonth.End.AddDate(0, 0, -1))),
|
|
}})
|
|
return prependHiddenContext(messages, timeagent.BuildContext(resolved, routeReason)), nil
|
|
}
|
|
|
|
func runSearchTool(ctx context.Context, state *searchagent.State, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
query := latestUserQuery(messages)
|
|
if query == "" {
|
|
return messages, nil
|
|
}
|
|
if state == nil || !state.Enabled() {
|
|
err := errors.New("联网搜索未启用")
|
|
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索未启用", Data: map[string]any{"error": err.Error()}})
|
|
return prependHiddenContext(messages, searchagent.BuildErrorContext(query, err)), nil
|
|
}
|
|
active := state.ActiveProfile()
|
|
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "running", Message: "正在联网搜索", Data: map[string]any{"provider": active.Provider}})
|
|
results, profile, err := state.Search(ctx, query)
|
|
if err != nil {
|
|
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索失败", Data: map[string]any{"error": err.Error()}})
|
|
return prependHiddenContext(messages, searchagent.BuildErrorContext(query, err)), nil
|
|
}
|
|
if len(results) == 0 {
|
|
err := errors.New("未搜索到相关网页结果")
|
|
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "warning", Message: "未搜索到相关网页结果,将使用模型知识库回答"})
|
|
return prependHiddenContext(messages, searchagent.BuildFallbackContext(profile, query, routeReason, err)), nil
|
|
}
|
|
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "success", Message: fmt.Sprintf("联网搜索完成,找到 %d 条结果", len(results)), Data: map[string]any{"provider": profile.Provider, "count": len(results)}})
|
|
return prependHiddenContext(messages, searchagent.BuildResultContext(profile, query, results, routeReason)), nil
|
|
}
|
|
|
|
func enrichMessagesWithRoutedTools(ctx context.Context, chatProfile *OpenAIProfile, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
if toolRouterState == nil || toolRouterState.cfg == nil || !toolRouterState.cfg.Enabled {
|
|
return messages, nil
|
|
}
|
|
if latestUserQuery(messages) == "" {
|
|
return messages, nil
|
|
}
|
|
tools := availableChatTools(toolRouterState.cfg)
|
|
if len(tools) == 0 {
|
|
return messages, nil
|
|
}
|
|
|
|
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "running", Message: "正在进行工具路由"})
|
|
decision, err := routeTools(ctx, toolRouterState, chatProfile, messages, tools)
|
|
if err != nil {
|
|
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "error", Message: "工具路由失败,将继续普通回答", Data: map[string]any{"error": err.Error()}})
|
|
return messages, err
|
|
}
|
|
selected := filterToolSelections(decision, tools, toolRouterState.cfg.Tools)
|
|
selected = ensureTimeSelectionForRelativeQuery(selected, tools, toolRouterState.cfg.Tools, latestUserQuery(messages))
|
|
if len(selected) == 0 {
|
|
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "success", Message: "工具路由结果:无需调用工具", Data: map[string]any{"reason": decision.Reason}})
|
|
return messages, nil
|
|
}
|
|
|
|
names := make([]string, 0, len(selected))
|
|
for _, item := range selected {
|
|
names = append(names, item.Name)
|
|
}
|
|
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "success", Message: "工具路由结果:将调用 " + strings.Join(names, ", "), Data: map[string]any{"tools": names, "reason": decision.Reason}})
|
|
|
|
current := messages
|
|
for _, item := range selected {
|
|
tool := tools[item.Name]
|
|
next, err := tool.Enrich(ctx, chatProfile, current, firstNonEmpty(item.Reason, decision.Reason), emit)
|
|
if err != nil {
|
|
emit(chatSSEFrame{Type: "trace", Tool: item.Name, Stage: "error", Status: "error", Message: "工具调用失败,将继续普通回答", Data: map[string]any{"error": err.Error()}})
|
|
continue
|
|
}
|
|
current = next
|
|
}
|
|
return current, nil
|
|
}
|
|
|
|
func availableChatTools(config *ToolRouterConfig) map[string]ChatTool {
|
|
configured := map[string]ToolRouteConfig{}
|
|
for _, item := range config.Tools {
|
|
configured[item.Name] = item
|
|
}
|
|
registered := []ChatTool{
|
|
TimeChatTool{},
|
|
SearchChatTool{state: searchState},
|
|
SQLChatTool{state: sqlState},
|
|
}
|
|
available := map[string]ChatTool{}
|
|
for _, tool := range registered {
|
|
name := tool.Name()
|
|
item, ok := configured[name]
|
|
if !ok || !item.Enabled || !tool.Enabled() {
|
|
continue
|
|
}
|
|
available[name] = tool
|
|
}
|
|
return available
|
|
}
|
|
|
|
func routeTools(ctx context.Context, state *ToolRouterState, chatProfile *OpenAIProfile, messages []ChatMessage, tools map[string]ChatTool) (ToolRoutingDecision, error) {
|
|
routerProfile := chatProfile
|
|
if strings.TrimSpace(state.cfg.OpenAIName) != "" {
|
|
profile, err := state.ai.GetProfile(state.cfg.OpenAIName)
|
|
if err != nil {
|
|
return ToolRoutingDecision{}, err
|
|
}
|
|
routerProfile = profile
|
|
}
|
|
prompt := buildToolRouterPrompt(state.cfg, messages, tools)
|
|
text, err := state.complete(ctx, routerProfile, []ChatMessage{{Role: "system", Content: prompt}}, state.cfg.MaxTokens, time.Duration(state.cfg.Timeout)*time.Second)
|
|
if err != nil {
|
|
return ToolRoutingDecision{}, err
|
|
}
|
|
return parseToolRoutingDecision(text)
|
|
}
|
|
|
|
func buildToolRouterPrompt(config *ToolRouterConfig, messages []ChatMessage, tools map[string]ChatTool) string {
|
|
query := latestUserQuery(messages)
|
|
var b strings.Builder
|
|
b.WriteString(strings.TrimSpace(config.SystemPrompt))
|
|
b.WriteString("\n\n可用工具:\n")
|
|
for _, item := range config.Tools {
|
|
tool, ok := tools[item.Name]
|
|
if !ok {
|
|
continue
|
|
}
|
|
description := strings.TrimSpace(item.Description)
|
|
if description == "" {
|
|
description = tool.Description()
|
|
}
|
|
fmt.Fprintf(&b, "- name: %s\n description: %s\n", item.Name, description)
|
|
}
|
|
fmt.Fprintf(&b, "\n最新用户问题:%s", query)
|
|
return b.String()
|
|
}
|
|
|
|
func parseToolRoutingDecision(text string) (ToolRoutingDecision, error) {
|
|
var decision ToolRoutingDecision
|
|
if err := json.Unmarshal([]byte(extractJSONObject(text)), &decision); err != nil {
|
|
return decision, fmt.Errorf("解析工具路由结果失败: %w", err)
|
|
}
|
|
for i := range decision.Tools {
|
|
decision.Tools[i].Name = strings.ToLower(strings.TrimSpace(decision.Tools[i].Name))
|
|
decision.Tools[i].Reason = strings.TrimSpace(decision.Tools[i].Reason)
|
|
}
|
|
decision.Reason = strings.TrimSpace(decision.Reason)
|
|
return decision, nil
|
|
}
|
|
|
|
func filterToolSelections(decision ToolRoutingDecision, tools map[string]ChatTool, order []ToolRouteConfig) []ToolSelection {
|
|
selected := map[string]ToolSelection{}
|
|
for _, item := range decision.Tools {
|
|
if item.Name == "" {
|
|
continue
|
|
}
|
|
if _, ok := tools[item.Name]; !ok {
|
|
continue
|
|
}
|
|
if _, ok := selected[item.Name]; !ok {
|
|
selected[item.Name] = item
|
|
}
|
|
}
|
|
return orderToolSelections(selected, order)
|
|
}
|
|
|
|
func ensureTimeSelectionForRelativeQuery(selected []ToolSelection, tools map[string]ChatTool, order []ToolRouteConfig, query string) []ToolSelection {
|
|
if !containsRelativeTime(query) || hasToolSelection(selected, "time") || (!hasToolSelection(selected, "search") && !hasToolSelection(selected, "sql")) {
|
|
return selected
|
|
}
|
|
if _, ok := tools["time"]; !ok {
|
|
return selected
|
|
}
|
|
withTime := make(map[string]ToolSelection, len(selected)+1)
|
|
for _, item := range selected {
|
|
withTime[item.Name] = item
|
|
}
|
|
withTime["time"] = ToolSelection{Name: "time", Reason: "问题包含相对日期,需要先获取当前日期"}
|
|
return orderToolSelections(withTime, order)
|
|
}
|
|
|
|
func containsRelativeTime(query string) bool {
|
|
query = strings.TrimSpace(query)
|
|
if query == "" {
|
|
return false
|
|
}
|
|
for _, keyword := range []string{"今天", "今日", "明天", "昨天", "本周", "这周", "本月", "这个月", "本年", "今年", "最近", "历史上的今天"} {
|
|
if strings.Contains(query, keyword) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func hasToolSelection(selected []ToolSelection, name string) bool {
|
|
for _, item := range selected {
|
|
if item.Name == name {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func orderToolSelections(selected map[string]ToolSelection, order []ToolRouteConfig) []ToolSelection {
|
|
result := make([]ToolSelection, 0, len(selected))
|
|
for _, item := range order {
|
|
if selection, ok := selected[item.Name]; ok {
|
|
result = append(result, selection)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func firstNonEmpty(items ...string) string {
|
|
for _, item := range items {
|
|
if strings.TrimSpace(item) != "" {
|
|
return strings.TrimSpace(item)
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func generateSQLForUserQuery(ctx context.Context, profile *OpenAIProfile, userQuery string, schemaContext string) (*sqlGenerationResult, error) {
|
|
prompt := fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题、隐藏上下文和数据库 schema 生成一条只读 SQL。
|
|
要求:
|
|
- 只能返回 JSON,不要使用 Markdown。
|
|
- JSON 格式:{"database":"数据库名称","sql":"SELECT ... LIMIT N","reason":"生成原因"}
|
|
- 只能生成 SELECT 或 WITH 查询,禁止 INSERT/UPDATE/DELETE/DROP/ALTER/CREATE 等任何修改语句。
|
|
- 必须只使用 schema 中出现的数据库、表和字段。
|
|
- 如果隐藏上下文中包含“时间工具结果”,必须使用其中的绝对日期范围解释用户问题里的今天、明天、昨天、本周、本月、本年、最近等相对时间。
|
|
- 用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,优先查询 tab_calendar_events 表;如果 schema 中没有该表,再返回无法根据已知表结构生成查询。
|
|
- 查询日程表时,涉及日期范围必须使用半开区间:时间字段 >= start AND 时间字段 < end_exclusive;时间字段必须从 schema 中选择真实存在的字段。
|
|
- 必须添加 LIMIT,且 LIMIT 不超过插件配置的 max_rows。
|
|
- 如果无法根据 schema 回答,返回 {"database":"","sql":"","reason":"无法根据已知表结构生成查询"}。
|
|
|
|
%s
|
|
|
|
用户问题:%s`, schemaContext, userQuery)
|
|
text, err := completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, 1024)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var generated sqlGenerationResult
|
|
if err := json.Unmarshal([]byte(extractJSONObject(text)), &generated); err != nil {
|
|
return nil, fmt.Errorf("解析 SQL 生成结果失败: %w", err)
|
|
}
|
|
return &generated, nil
|
|
}
|
|
|
|
func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) {
|
|
return completeTextWithTimeout(ctx, profile, chatMessages, maxTokens, time.Duration(profile.Config.Timeout)*time.Second)
|
|
}
|
|
|
|
func completeTextWithTimeout(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int, timeout time.Duration) (string, error) {
|
|
messages, err := buildArkMessages(chatMessages)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
completionCtx, cancel := context.WithTimeout(ctx, timeout)
|
|
defer cancel()
|
|
stream, err := profile.Client.CreateChatCompletionStream(completionCtx, model.CreateChatCompletionRequest{
|
|
Model: profile.Config.Model,
|
|
Messages: messages,
|
|
MaxTokens: intPtr(maxTokens),
|
|
}.WithStream(true))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer stream.Close()
|
|
|
|
promptTokens := estimateChatMessagesTokens(chatMessages)
|
|
completionTokens := 0
|
|
var b strings.Builder
|
|
for {
|
|
resp, err := stream.Recv()
|
|
if errors.Is(err, io.EOF) {
|
|
if tracker := tokenUsageFromContext(ctx); tracker != nil {
|
|
tracker.addTool(promptTokens, completionTokens)
|
|
}
|
|
return b.String(), nil
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if len(resp.Choices) > 0 {
|
|
delta := resp.Choices[0].Delta.Content
|
|
b.WriteString(delta)
|
|
completionTokens += estimateTokenCount(delta)
|
|
}
|
|
}
|
|
}
|
|
|
|
func extractJSONObject(text string) string {
|
|
text = strings.TrimSpace(text)
|
|
start := strings.Index(text, "{")
|
|
end := strings.LastIndex(text, "}")
|
|
if start >= 0 && end > start {
|
|
return text[start : end+1]
|
|
}
|
|
return text
|
|
}
|
|
|
|
func newUUID() string {
|
|
b := make([]byte, 16)
|
|
_, _ = rand.Read(b)
|
|
b[6] = (b[6] & 0x0f) | 0x40
|
|
b[8] = (b[8] & 0x3f) | 0x80
|
|
return hex.EncodeToString(b[:4]) + "-" + hex.EncodeToString(b[4:6]) + "-" +
|
|
hex.EncodeToString(b[6:8]) + "-" + hex.EncodeToString(b[8:10]) + "-" +
|
|
hex.EncodeToString(b[10:])
|
|
}
|
|
|
|
// ─── ConvStore ─────────────────────────────────────────────
|
|
|
|
func NewConvStore(dir string) *ConvStore {
|
|
os.MkdirAll(dir, 0755)
|
|
return &ConvStore{dir: dir}
|
|
}
|
|
|
|
func (s *ConvStore) path(id string) string {
|
|
return filepath.Join(s.dir, id+".json")
|
|
}
|
|
|
|
func (s *ConvStore) Create() (*Conversation, error) {
|
|
conv := &Conversation{
|
|
ID: newUUID(),
|
|
Title: "新对话",
|
|
CreatedAt: time.Now(),
|
|
UpdatedAt: time.Now(),
|
|
}
|
|
if err := s.Save(conv); err != nil {
|
|
return nil, err
|
|
}
|
|
return conv, nil
|
|
}
|
|
|
|
func (s *ConvStore) Save(conv *Conversation) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
conv.UpdatedAt = time.Now()
|
|
return atomicWriteJSON(s.path(conv.ID), conv)
|
|
}
|
|
|
|
func (s *ConvStore) Get(id string) (*Conversation, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
data, err := os.ReadFile(s.path(id))
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return nil, errors.New("对话不存在")
|
|
}
|
|
return nil, fmt.Errorf("读取对话失败: %w", err)
|
|
}
|
|
var conv Conversation
|
|
if err := json.Unmarshal(data, &conv); err != nil {
|
|
return nil, fmt.Errorf("解析对话失败: %w", err)
|
|
}
|
|
return &conv, nil
|
|
}
|
|
|
|
func (s *ConvStore) List() ([]Conversation, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
entries, err := os.ReadDir(s.dir)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("读取对话目录失败: %w", err)
|
|
}
|
|
|
|
var list []Conversation
|
|
for _, e := range entries {
|
|
if e.IsDir() || filepath.Ext(e.Name()) != ".json" {
|
|
continue
|
|
}
|
|
data, err := os.ReadFile(filepath.Join(s.dir, e.Name()))
|
|
if err != nil {
|
|
continue
|
|
}
|
|
var conv Conversation
|
|
if err := json.Unmarshal(data, &conv); err != nil {
|
|
continue
|
|
}
|
|
conv.Messages = nil // 列表不返回消息体
|
|
list = append(list, conv)
|
|
}
|
|
|
|
sort.Slice(list, func(i, j int) bool {
|
|
return list[i].UpdatedAt.After(list[j].UpdatedAt)
|
|
})
|
|
return list, nil
|
|
}
|
|
|
|
func (s *ConvStore) Delete(id string) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if err := os.Remove(s.path(id)); err != nil && !os.IsNotExist(err) {
|
|
return fmt.Errorf("删除对话失败: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func atomicWriteJSON(path string, v any) error {
|
|
tmp := path + ".tmp"
|
|
data, err := json.Marshal(v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := os.WriteFile(tmp, data, 0644); err != nil {
|
|
return err
|
|
}
|
|
return os.Rename(tmp, path)
|
|
}
|
|
|
|
func saveConversationMessages(id string, messages []ChatMessage, assistantContent string) error {
|
|
conv, err := store.Get(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
conv.Messages = append([]ChatMessage(nil), messages...)
|
|
conv.Messages = append(conv.Messages, ChatMessage{Role: "assistant", Content: assistantContent})
|
|
if conv.Title == "" || conv.Title == "新对话" {
|
|
conv.Title = genConvTitle(conv.Messages)
|
|
}
|
|
return store.Save(conv)
|
|
}
|
|
|
|
func genConvTitle(messages []ChatMessage) string {
|
|
for _, m := range messages {
|
|
if m.Hidden {
|
|
continue
|
|
}
|
|
if m.Role == "user" && strings.TrimSpace(m.Content) != "" {
|
|
title := strings.TrimSpace(m.Content)
|
|
title = strings.ReplaceAll(title, "\r\n", " ")
|
|
title = strings.ReplaceAll(title, "\n", " ")
|
|
runes := []rune(title)
|
|
if len(runes) > 30 {
|
|
return string(runes[:30]) + "..."
|
|
}
|
|
return title
|
|
}
|
|
}
|
|
return "新对话"
|
|
}
|
|
|
|
const maxImageSize = 4 * 1024 * 1024
|
|
|
|
var allowedImageTypes = map[string]bool{
|
|
"image/jpeg": true,
|
|
"image/png": true,
|
|
"image/webp": true,
|
|
"image/gif": true,
|
|
}
|
|
|
|
func buildArkMessages(chatMessages []ChatMessage) ([]*model.ChatCompletionMessage, error) {
|
|
messages := make([]*model.ChatCompletionMessage, 0, len(chatMessages))
|
|
for _, m := range chatMessages {
|
|
msg, err := buildArkMessage(m)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
messages = append(messages, msg)
|
|
}
|
|
return messages, nil
|
|
}
|
|
|
|
func buildArkMessage(m ChatMessage) (*model.ChatCompletionMessage, error) {
|
|
msg := &model.ChatCompletionMessage{Role: m.Role}
|
|
|
|
if m.ImageURL == "" && m.ImageURLAlias != "" {
|
|
m.ImageURL = m.ImageURLAlias
|
|
}
|
|
|
|
if m.ImageURL == "" {
|
|
msg.Content = &model.ChatCompletionMessageContent{
|
|
StringValue: &m.Content,
|
|
}
|
|
return msg, nil
|
|
}
|
|
|
|
imageURL, err := normalizeImageURL(m.ImageURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 有图片时:文字内容可有可无(图片 caption 场景),均构造多模态消息
|
|
// 若无文字,则只传图片 part;若同时有图片和文字,先图后文
|
|
parts := []*model.ChatCompletionMessageContentPart{imagePart(imageURL)}
|
|
if m.Content != "" {
|
|
parts = append(parts, textPart(m.Content))
|
|
}
|
|
msg.Content = &model.ChatCompletionMessageContent{ListValue: parts}
|
|
return msg, nil
|
|
}
|
|
|
|
func imagePart(url string) *model.ChatCompletionMessageContentPart {
|
|
return &model.ChatCompletionMessageContentPart{
|
|
Type: model.ChatCompletionMessageContentPartTypeImageURL,
|
|
ImageURL: &model.ChatMessageImageURL{
|
|
URL: url,
|
|
Detail: model.ImageURLDetailAuto,
|
|
},
|
|
}
|
|
}
|
|
|
|
func textPart(text string) *model.ChatCompletionMessageContentPart {
|
|
return &model.ChatCompletionMessageContentPart{
|
|
Type: model.ChatCompletionMessageContentPartTypeText,
|
|
Text: text,
|
|
}
|
|
}
|
|
|
|
func normalizeImageURL(raw string) (string, error) {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return "", errors.New("图片地址不能为空")
|
|
}
|
|
|
|
lower := strings.ToLower(raw)
|
|
if strings.HasPrefix(lower, "data:") {
|
|
return normalizeImageDataURI(raw)
|
|
}
|
|
|
|
u, err := url.Parse(raw)
|
|
if err != nil || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") {
|
|
return "", errors.New("图片地址无效,仅支持 http/https URL 或 base64 data URI")
|
|
}
|
|
return raw, nil
|
|
}
|
|
|
|
func normalizeImageDataURI(raw string) (string, error) {
|
|
comma := strings.Index(raw, ",")
|
|
if comma < 0 {
|
|
return "", errors.New("图片 base64 数据格式错误")
|
|
}
|
|
|
|
meta := strings.ToLower(strings.TrimSpace(raw[5:comma]))
|
|
payload := strings.TrimSpace(raw[comma+1:])
|
|
if payload == "" {
|
|
return "", errors.New("图片 base64 数据不能为空")
|
|
}
|
|
parts := strings.Split(meta, ";")
|
|
if len(parts) < 2 || !contains(parts[1:], "base64") {
|
|
return "", errors.New("图片 data URI 必须使用 base64 编码")
|
|
}
|
|
|
|
mime := parts[0]
|
|
if !allowedImageTypes[mime] {
|
|
return "", errors.New("图片格式不支持,仅支持 jpeg/png/webp/gif")
|
|
}
|
|
|
|
decoded, err := base64.StdEncoding.DecodeString(payload)
|
|
if err != nil {
|
|
return "", errors.New("图片 base64 数据无效")
|
|
}
|
|
if len(decoded) > maxImageSize {
|
|
return "", errors.New("图片过大,请选择小于 4MB 的图片")
|
|
}
|
|
|
|
return "data:" + mime + ";base64," + payload, nil
|
|
}
|
|
|
|
func contains(items []string, target string) bool {
|
|
for _, item := range items {
|
|
if strings.TrimSpace(item) == target {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func intPtr(i int) *int { return &i }
|
|
|
|
func writeSSEJSON(w io.Writer, frame chatSSEFrame) {
|
|
data, err := json.Marshal(frame)
|
|
if err != nil {
|
|
data, _ = json.Marshal(chatSSEFrame{Type: "error", Error: "序列化流事件失败"})
|
|
}
|
|
fmt.Fprintf(w, "data: %s\n\n", data)
|
|
}
|
|
|
|
func toJSON(s string) string {
|
|
b, _ := json.Marshal(s)
|
|
return string(b)
|
|
}
|
|
|
|
func toSSE(s string) string {
|
|
s = strings.ReplaceAll(s, `\`, `\\`)
|
|
s = strings.ReplaceAll(s, "\n", `\n`)
|
|
s = strings.ReplaceAll(s, "\r", "")
|
|
s = strings.ReplaceAll(s, `"`, `\"`)
|
|
return fmt.Sprintf(`"%s"`, s)
|
|
}
|
|
|
|
// ─── 入口 ─────────────────────────────────────────────────
|
|
|
|
func main() {
|
|
var err error
|
|
cfg, err = loadConfig("config.yaml")
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, "配置加载失败:", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
// 初始化火山方舟 SDK 客户端
|
|
aiState, err = NewOpenAIState(cfg.OpenAI)
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, "OpenAI 配置初始化失败:", err)
|
|
os.Exit(1)
|
|
}
|
|
searchConfig, err := searchagent.LoadConfig("agents/search/config.yaml", legacySearchProfiles)
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, "联网搜索配置加载失败:", err)
|
|
os.Exit(1)
|
|
}
|
|
searchState, err = searchagent.NewState(searchConfig)
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, "联网搜索初始化失败:", err)
|
|
os.Exit(1)
|
|
}
|
|
sqlConfig, err := sqlquery.LoadConfig("agents/sql/config.yaml")
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, "SQL 查询插件配置加载失败:", err)
|
|
os.Exit(1)
|
|
}
|
|
sqlState, err = sqlquery.NewState(sqlConfig)
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, "SQL 查询插件初始化失败:", err)
|
|
os.Exit(1)
|
|
}
|
|
defer sqlState.Close()
|
|
toolRouterState, err = NewToolRouterState(&cfg.ToolRouter, aiState)
|
|
if err != nil {
|
|
fmt.Fprintln(os.Stderr, "工具路由配置初始化失败:", err)
|
|
os.Exit(1)
|
|
}
|
|
store = NewConvStore("conversations")
|
|
|
|
// Gin 路由
|
|
r := gin.Default()
|
|
r.LoadHTMLGlob("templates/*")
|
|
r.Static("/static", "./static")
|
|
|
|
r.GET("/", indexHandler)
|
|
r.POST("/api/chat", chatHandler)
|
|
r.GET("/api/openai", listOpenAIHandler)
|
|
r.POST("/api/openai/active", switchOpenAIHandler)
|
|
r.GET("/api/search", listSearchHandler)
|
|
r.POST("/api/search/active", switchSearchHandler)
|
|
r.GET("/api/conversations", listConversationsHandler)
|
|
r.POST("/api/conversations", createConversationHandler)
|
|
r.GET("/api/conversations/:id", getConversationHandler)
|
|
r.DELETE("/api/conversations/:id", deleteConversationHandler)
|
|
|
|
// 根据配置选择监听方式
|
|
switch strings.ToLower(cfg.Server.Mode) {
|
|
case "unix":
|
|
socketPath := cfg.Server.Address
|
|
if _, statErr := os.Stat(socketPath); statErr == nil {
|
|
os.Remove(socketPath)
|
|
}
|
|
ln, listenErr := net.Listen("unix", socketPath)
|
|
if listenErr != nil {
|
|
fmt.Fprintln(os.Stderr, "监听 Unix socket 失败:", listenErr)
|
|
os.Exit(1)
|
|
}
|
|
fmt.Println("服务已启动,监听 Unix socket:", socketPath)
|
|
if serveErr := http.Serve(ln, r); serveErr != nil {
|
|
fmt.Fprintln(os.Stderr, "服务异常退出:", serveErr)
|
|
os.Exit(1)
|
|
}
|
|
default:
|
|
fmt.Println("服务已启动,监听 TCP:", cfg.Server.Address)
|
|
if runErr := r.Run(cfg.Server.Address); runErr != nil {
|
|
fmt.Fprintln(os.Stderr, "服务异常退出:", runErr)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
}
|