Files
aichat/main.go
T
2026-06-11 20:37:18 +08:00

2040 lines
60 KiB
Go

package main
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"unicode"
searchagent "aichat/agents/search"
sqlquery "aichat/agents/sql"
timeagent "aichat/agents/time"
"github.com/gin-gonic/gin"
ark "github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"gopkg.in/yaml.v3"
)
// ─── 配置 ─────────────────────────────────────────────────
const (
defaultOpenAIBaseURL = "https://ark.cn-beijing.volces.com/api/v3"
defaultOpenAITimeout = 120
defaultToolRouterTimeout = 30
defaultToolRouterMaxTokens = 512
defaultToolRouterSystemText = `你可以按需直接调用可用工具来回答用户问题。
如果用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近等相对时间,且后续需要搜索或查询数据库,应先调用 time 获取绝对日期范围。
需要实时网页资料、新闻、当前版本、近期事件、网页核验或用户明确要求联网时,调用 search。
需要查询本地业务数据、日程、会议、待办、记录、统计或时间范围内数据时,调用 sql。
工具结果优先于模型内置知识;工具失败时必须如实说明,不要编造结果。
只调用确实必要的工具。`
)
type OpenAIConfig struct {
Name string `yaml:"name" json:"name"`
Active bool `yaml:"active,omitempty" json:"active"`
APIKey string `yaml:"api_key" json:"-"`
BaseURL string `yaml:"base_url" json:"base_url"`
Model string `yaml:"model" json:"model"`
Timeout int `yaml:"timeout" json:"timeout"`
ParseThinkTags *bool `yaml:"parse_think_tags,omitempty" json:"parse_think_tags,omitempty"`
}
type OpenAIConfigs []OpenAIConfig
type ToolRouterConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
OpenAIName string `yaml:"openai_name" json:"openai_name"`
Timeout int `yaml:"timeout" json:"timeout"`
MaxTokens int `yaml:"max_tokens" json:"max_tokens"`
SystemPrompt string `yaml:"system_prompt" json:"system_prompt"`
Tools []ToolRouteConfig `yaml:"tools" json:"tools"`
}
type ToolRouteConfig struct {
Name string `yaml:"name" json:"name"`
Enabled bool `yaml:"enabled" json:"enabled"`
Description string `yaml:"description" json:"description"`
}
func (configs *OpenAIConfigs) UnmarshalYAML(value *yaml.Node) error {
switch value.Kind {
case yaml.SequenceNode:
var items []OpenAIConfig
if err := value.Decode(&items); err != nil {
return err
}
*configs = items
case yaml.MappingNode:
var item OpenAIConfig
if err := value.Decode(&item); err != nil {
return err
}
*configs = []OpenAIConfig{item}
case yaml.ScalarNode:
if value.Tag == "!!null" {
*configs = nil
return nil
}
return fmt.Errorf("openai 配置格式无效")
default:
return fmt.Errorf("openai 配置格式无效")
}
return nil
}
type Config struct {
Server struct {
Mode string `yaml:"mode"`
Address string `yaml:"address"`
} `yaml:"server"`
OpenAI OpenAIConfigs `yaml:"openai"`
ToolRouter ToolRouterConfig `yaml:"tool_router"`
}
func defaultOpenAIConfig() OpenAIConfig {
return OpenAIConfig{
Name: "default",
Active: true,
BaseURL: defaultOpenAIBaseURL,
Timeout: defaultOpenAITimeout,
}
}
func defaultToolRouterConfig() ToolRouterConfig {
return ToolRouterConfig{
Enabled: true,
OpenAIName: "",
Timeout: defaultToolRouterTimeout,
MaxTokens: defaultToolRouterMaxTokens,
SystemPrompt: defaultToolRouterSystemText,
Tools: []ToolRouteConfig{
{Name: "time", Enabled: true, Description: ""},
{Name: "search", Enabled: true, Description: ""},
{Name: "sql", Enabled: true, Description: ""},
},
}
}
func defaultConfig() Config {
var cfg Config
cfg.Server.Mode = "tcp"
cfg.Server.Address = "0.0.0.0:8080"
cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()}
cfg.ToolRouter = defaultToolRouterConfig()
return cfg
}
func loadConfig(path string) (*Config, error) {
if err := ensureConfigFile(path); err != nil {
return nil, err
}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取配置文件失败: %w", err)
}
var cfg Config
if err = yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("解析配置文件失败: %w", err)
}
if _, err := normalizeOpenAIConfigs(&cfg); err != nil {
return nil, err
}
// 环境变量优先
if key := os.Getenv("ARK_API_KEY"); key != "" {
for i := range cfg.OpenAI {
cfg.OpenAI[i].APIKey = key
}
}
legacySearchProfiles = readLegacySearchProfiles(data)
if _, err := normalizeToolRouterConfig(&cfg); err != nil {
return nil, err
}
return &cfg, nil
}
func ensureConfigFile(path string) error {
defaults := defaultConfig()
if _, err := os.Stat(path); err != nil {
if !os.IsNotExist(err) {
return fmt.Errorf("检查配置文件失败: %w", err)
}
return writeConfig(path, defaults)
}
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("读取配置文件失败: %w", err)
}
var cfg Config
if err = yaml.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("解析配置文件失败: %w", err)
}
var raw map[string]any
if err = yaml.Unmarshal(data, &raw); err != nil {
return fmt.Errorf("解析配置文件失败: %w", err)
}
changed := false
server, _ := raw["server"].(map[string]any)
if server == nil {
cfg.Server = defaults.Server
changed = true
} else {
if _, ok := server["mode"]; !ok {
cfg.Server.Mode = defaults.Server.Mode
changed = true
}
if _, ok := server["address"]; !ok {
cfg.Server.Address = defaults.Server.Address
changed = true
}
}
if _, ok := raw["openai"].([]any); !ok {
changed = true
}
if normalized, err := normalizeOpenAIConfigs(&cfg); err != nil {
return err
} else if normalized {
changed = true
}
if _, ok := raw["tool_router"]; !ok {
cfg.ToolRouter = defaults.ToolRouter
changed = true
} else if normalized, err := normalizeToolRouterConfig(&cfg); err != nil {
return err
} else if normalized {
changed = true
}
if !changed {
return nil
}
return writeConfig(path, cfg)
}
func normalizeOpenAIConfigs(cfg *Config) (bool, error) {
changed := false
if len(cfg.OpenAI) == 0 {
cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()}
changed = true
}
activeIndex := -1
seen := map[string]bool{}
for i := range cfg.OpenAI {
profile := &cfg.OpenAI[i]
name := strings.TrimSpace(profile.Name)
if name == "" {
name = strings.TrimSpace(profile.Model)
if name == "" {
name = fmt.Sprintf("openai-%d", i+1)
}
profile.Name = name
changed = true
} else if name != profile.Name {
profile.Name = name
changed = true
}
if seen[name] {
return changed, fmt.Errorf("openai 配置名称重复: %s", name)
}
seen[name] = true
if strings.TrimSpace(profile.BaseURL) == "" {
profile.BaseURL = defaultOpenAIBaseURL
changed = true
}
if profile.Timeout <= 0 {
profile.Timeout = defaultOpenAITimeout
changed = true
}
if profile.Active {
if activeIndex == -1 {
activeIndex = i
} else {
profile.Active = false
changed = true
}
}
}
if activeIndex == -1 {
cfg.OpenAI[0].Active = true
changed = true
}
return changed, nil
}
func isLegacyToolRouterPrompt(prompt string) bool {
prompt = strings.TrimSpace(prompt)
return strings.Contains(prompt, "工具路由器") || strings.Contains(prompt, "route_tools") || strings.Contains(prompt, `"tools":[`)
}
func normalizeToolRouterConfig(cfg *Config) (bool, error) {
changed := false
defaults := defaultToolRouterConfig()
cfg.ToolRouter.OpenAIName = strings.TrimSpace(cfg.ToolRouter.OpenAIName)
if cfg.ToolRouter.Timeout <= 0 {
cfg.ToolRouter.Timeout = defaultToolRouterTimeout
changed = true
}
if cfg.ToolRouter.MaxTokens <= 0 {
cfg.ToolRouter.MaxTokens = defaultToolRouterMaxTokens
changed = true
}
systemPrompt := strings.TrimSpace(cfg.ToolRouter.SystemPrompt)
if systemPrompt == "" || isLegacyToolRouterPrompt(systemPrompt) {
cfg.ToolRouter.SystemPrompt = defaultToolRouterSystemText
changed = true
} else if systemPrompt != cfg.ToolRouter.SystemPrompt {
cfg.ToolRouter.SystemPrompt = systemPrompt
changed = true
}
if len(cfg.ToolRouter.Tools) == 0 {
cfg.ToolRouter.Tools = defaults.Tools
changed = true
}
seen := map[string]bool{}
for i := range cfg.ToolRouter.Tools {
tool := &cfg.ToolRouter.Tools[i]
name := strings.ToLower(strings.TrimSpace(tool.Name))
if name == "" {
name = fmt.Sprintf("tool-%d", i+1)
}
if name != tool.Name {
tool.Name = name
changed = true
}
tool.Description = strings.TrimSpace(tool.Description)
if seen[name] {
return changed, fmt.Errorf("tool_router.tools 配置名称重复: %s", name)
}
seen[name] = true
}
byName := map[string]ToolRouteConfig{}
for _, tool := range cfg.ToolRouter.Tools {
byName[tool.Name] = tool
}
merged := make([]ToolRouteConfig, 0, len(cfg.ToolRouter.Tools)+len(defaults.Tools))
used := map[string]bool{}
for _, tool := range defaults.Tools {
if existing, ok := byName[tool.Name]; ok {
merged = append(merged, existing)
} else {
merged = append(merged, tool)
changed = true
}
used[tool.Name] = true
}
for _, tool := range cfg.ToolRouter.Tools {
if !used[tool.Name] {
merged = append(merged, tool)
}
}
if len(merged) != len(cfg.ToolRouter.Tools) {
changed = true
} else {
for i := range merged {
if merged[i].Name != cfg.ToolRouter.Tools[i].Name {
changed = true
break
}
}
}
cfg.ToolRouter.Tools = merged
return changed, nil
}
func readLegacySearchProfiles(data []byte) []searchagent.ProfileConfig {
var legacy struct {
Search searchagent.ProfileConfigs `yaml:"search"`
}
if err := yaml.Unmarshal(data, &legacy); err != nil {
return nil
}
return []searchagent.ProfileConfig(legacy.Search)
}
func writeConfig(path string, cfg Config) error {
data, err := yaml.Marshal(&cfg)
if err != nil {
return fmt.Errorf("生成配置文件失败: %w", err)
}
if err := os.WriteFile(path, data, 0644); err != nil {
return fmt.Errorf("写入配置文件失败: %w", err)
}
return nil
}
// ─── 请求结构 ─────────────────────────────────────────────
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ImageURL string `json:"image_url,omitempty"` // base64 data URI 或 http URL
ImageURLAlias string `json:"imageURL,omitempty"`
Hidden bool `json:"hidden,omitempty"`
}
type ChatRequest struct {
ConversationID string `json:"conversation_id,omitempty"`
Messages []ChatMessage `json:"messages"`
WebSearch bool `json:"web_search,omitempty"`
OpenAIName string `json:"openai_name,omitempty"`
}
type Conversation struct {
ID string `json:"id"`
Title string `json:"title"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Messages []ChatMessage `json:"messages,omitempty"`
}
type ConvStore struct {
dir string
mu sync.Mutex
}
type OpenAIProfile struct {
Config OpenAIConfig
Client *ark.Client
}
type OpenAIState struct {
mu sync.RWMutex
profiles map[string]*OpenAIProfile
order []string
activeName string
}
type activeProfileRequest struct {
Name string `json:"name"`
}
type openAIListResponse struct {
Active string `json:"active"`
Profiles []OpenAIConfig `json:"profiles"`
}
type chatCompleter func(context.Context, *OpenAIProfile, model.CreateChatCompletionRequest, time.Duration) (model.ChatCompletionResponse, error)
type ToolRouterState struct {
cfg *ToolRouterConfig
ai *OpenAIState
complete chatCompleter
}
func NewToolRouterState(config *ToolRouterConfig, ai *OpenAIState) (*ToolRouterState, error) {
if config == nil {
cfg := defaultToolRouterConfig()
config = &cfg
}
if ai == nil {
return nil, errors.New("工具路由需要 OpenAI 状态")
}
if config.Enabled && strings.TrimSpace(config.OpenAIName) != "" {
if _, err := ai.GetProfile(config.OpenAIName); err != nil {
return nil, fmt.Errorf("tool_router.openai_name 配置无效: %w", err)
}
}
return &ToolRouterState{cfg: config, ai: ai, complete: completeChatWithTimeout}, nil
}
func NewOpenAIState(configs []OpenAIConfig) (*OpenAIState, error) {
state := &OpenAIState{
profiles: make(map[string]*OpenAIProfile, len(configs)),
order: make([]string, 0, len(configs)),
}
for _, config := range configs {
if strings.TrimSpace(config.Name) == "" {
return nil, errors.New("openai.name 不能为空")
}
if strings.TrimSpace(config.APIKey) == "" {
return nil, fmt.Errorf("openai.%s.api_key 未配置,也未设置环境变量 ARK_API_KEY", config.Name)
}
if strings.TrimSpace(config.Model) == "" {
return nil, fmt.Errorf("openai.%s.model 未配置", config.Name)
}
if strings.TrimSpace(config.BaseURL) == "" {
return nil, fmt.Errorf("openai.%s.base_url 未配置", config.Name)
}
if config.Timeout <= 0 {
return nil, fmt.Errorf("openai.%s.timeout 必须大于 0", config.Name)
}
if _, ok := state.profiles[config.Name]; ok {
return nil, fmt.Errorf("openai 配置名称重复: %s", config.Name)
}
state.profiles[config.Name] = &OpenAIProfile{
Config: config,
Client: ark.NewClientWithApiKey(
config.APIKey,
ark.WithBaseUrl(config.BaseURL),
ark.WithTimeout(time.Duration(config.Timeout)*time.Second),
),
}
state.order = append(state.order, config.Name)
if config.Active && state.activeName == "" {
state.activeName = config.Name
}
}
if len(state.order) == 0 {
return nil, errors.New("openai 配置不能为空")
}
if state.activeName == "" {
state.activeName = state.order[0]
}
return state, nil
}
func (s *OpenAIState) ActiveProfile() *OpenAIProfile {
s.mu.RLock()
defer s.mu.RUnlock()
return s.profiles[s.activeName]
}
func (s *OpenAIState) GetProfile(name string) (*OpenAIProfile, error) {
s.mu.RLock()
defer s.mu.RUnlock()
if strings.TrimSpace(name) == "" {
return s.profiles[s.activeName], nil
}
profile, ok := s.profiles[strings.TrimSpace(name)]
if !ok {
return nil, fmt.Errorf("OpenAI 配置不存在: %s", name)
}
return profile, nil
}
func (s *OpenAIState) SwitchActive(name string) (*OpenAIProfile, error) {
name = strings.TrimSpace(name)
if name == "" {
return nil, errors.New("OpenAI 配置名称不能为空")
}
s.mu.Lock()
defer s.mu.Unlock()
profile, ok := s.profiles[name]
if !ok {
return nil, fmt.Errorf("OpenAI 配置不存在: %s", name)
}
s.activeName = name
return profile, nil
}
func (s *OpenAIState) ListProfiles() openAIListResponse {
s.mu.RLock()
defer s.mu.RUnlock()
profiles := make([]OpenAIConfig, 0, len(s.order))
for _, name := range s.order {
profile := s.profiles[name]
config := profile.Config
config.APIKey = ""
config.Active = name == s.activeName
profiles = append(profiles, config)
}
return openAIListResponse{Active: s.activeName, Profiles: profiles}
}
func publicOpenAIConfig(profile *OpenAIProfile, active bool) OpenAIConfig {
config := profile.Config
config.APIKey = ""
config.Active = active
return config
}
func (s *ToolRouterState) RouterProfile(fallback *OpenAIProfile) *OpenAIProfile {
if s == nil || s.cfg == nil || s.ai == nil {
return fallback
}
name := strings.TrimSpace(s.cfg.OpenAIName)
if name == "" {
return fallback
}
profile, err := s.ai.GetProfile(name)
if err != nil {
return fallback
}
return profile
}
func isOllamaProfile(profile *OpenAIProfile) bool {
if profile == nil {
return false
}
u, err := url.Parse(strings.TrimSpace(profile.Config.BaseURL))
if err != nil {
return strings.Contains(profile.Config.BaseURL, ":11434")
}
host := strings.ToLower(u.Hostname())
port := u.Port()
return port == "11434" && (host == "127.0.0.1" || host == "localhost" || host == "::1")
}
func shouldParseThinkTags(profile *OpenAIProfile) bool {
if profile == nil {
return false
}
if profile.Config.ParseThinkTags != nil {
return *profile.Config.ParseThinkTags
}
return isOllamaProfile(profile)
}
// ─── 全局变量 ─────────────────────────────────────────────
var (
cfg *Config
aiState *OpenAIState
searchState *searchagent.State
legacySearchProfiles []searchagent.ProfileConfig
toolRouterState *ToolRouterState
sqlState *sqlquery.State
store *ConvStore
)
type chatSSEFrame struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Message string `json:"message,omitempty"`
Tool string `json:"tool,omitempty"`
Stage string `json:"stage,omitempty"`
Status string `json:"status,omitempty"`
Data map[string]any `json:"data,omitempty"`
Stats *tokenUsageStats `json:"stats,omitempty"`
Error string `json:"error,omitempty"`
}
type tokenUsageStats struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
ToolPromptTokens int `json:"tool_prompt_tokens"`
ToolCompletionTokens int `json:"tool_completion_tokens"`
TotalTokens int `json:"total_tokens"`
CompletionTokensPerSec float64 `json:"completion_tokens_per_sec"`
PeakCompletionTokensPerSec float64 `json:"peak_completion_tokens_per_sec"`
Estimated bool `json:"estimated"`
}
type tokenUsageTracker struct {
mu sync.Mutex
promptTokens int
completionTokens int
toolPromptTokens int
toolCompletionTokens int
}
type tokenUsageContextKey struct{}
func newTokenUsageTracker() *tokenUsageTracker {
return &tokenUsageTracker{}
}
func contextWithTokenUsage(ctx context.Context, tracker *tokenUsageTracker) context.Context {
if tracker == nil {
return ctx
}
return context.WithValue(ctx, tokenUsageContextKey{}, tracker)
}
func tokenUsageFromContext(ctx context.Context) *tokenUsageTracker {
tracker, _ := ctx.Value(tokenUsageContextKey{}).(*tokenUsageTracker)
return tracker
}
func (t *tokenUsageTracker) addTool(promptTokens, completionTokens int) {
if t == nil {
return
}
t.mu.Lock()
defer t.mu.Unlock()
t.toolPromptTokens += promptTokens
t.toolCompletionTokens += completionTokens
}
func (t *tokenUsageTracker) setModel(promptTokens, completionTokens int) {
if t == nil {
return
}
t.mu.Lock()
defer t.mu.Unlock()
t.promptTokens = promptTokens
t.completionTokens = completionTokens
}
func (t *tokenUsageTracker) snapshot(tokensPerSecond, peakTokensPerSecond float64) tokenUsageStats {
if t == nil {
return tokenUsageStats{Estimated: true}
}
t.mu.Lock()
defer t.mu.Unlock()
total := t.promptTokens + t.completionTokens + t.toolPromptTokens + t.toolCompletionTokens
return tokenUsageStats{
PromptTokens: t.promptTokens,
CompletionTokens: t.completionTokens,
ToolPromptTokens: t.toolPromptTokens,
ToolCompletionTokens: t.toolCompletionTokens,
TotalTokens: total,
CompletionTokensPerSec: tokensPerSecond,
PeakCompletionTokensPerSec: peakTokensPerSecond,
Estimated: true,
}
}
// ─── 路由 ─────────────────────────────────────────────────
func indexHandler(c *gin.Context) {
profile := aiState.ActiveProfile()
c.HTML(http.StatusOK, "chat.html", gin.H{
"Title": "AI 对话",
"Model": profile.Config.Model,
"OpenAIName": profile.Config.Name,
})
}
func listOpenAIHandler(c *gin.Context) {
c.JSON(http.StatusOK, aiState.ListProfiles())
}
func switchOpenAIHandler(c *gin.Context) {
var req activeProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
return
}
profile, err := aiState.SwitchActive(req.Name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"active": profile.Config.Name,
"profile": publicOpenAIConfig(profile, true),
})
}
func listSearchHandler(c *gin.Context) {
c.JSON(http.StatusOK, searchState.ListProfiles())
}
func switchSearchHandler(c *gin.Context) {
var req activeProfileRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
return
}
profile, err := searchState.SwitchActive(req.Name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
profile.APIKey = ""
profile.Active = true
c.JSON(http.StatusOK, gin.H{
"active": profile.Name,
"profile": profile,
})
}
func listConversationsHandler(c *gin.Context) {
convs, err := store.List()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, convs)
}
func createConversationHandler(c *gin.Context) {
conv, err := store.Create()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建对话失败: " + err.Error()})
return
}
c.JSON(http.StatusOK, conv)
}
func getConversationHandler(c *gin.Context) {
conv, err := store.Get(c.Param("id"))
if err != nil {
status := http.StatusInternalServerError
if err.Error() == "对话不存在" {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, conv)
}
func deleteConversationHandler(c *gin.Context) {
if err := store.Delete(c.Param("id")); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.Status(http.StatusNoContent)
}
// chatHandler 流式 SSE 对话接口
func chatHandler(c *gin.Context) {
var req ChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求格式错误: " + err.Error()})
return
}
if len(req.Messages) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "消息不能为空"})
return
}
profile, err := aiState.GetProfile(req.OpenAIName)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// SSE 头先写出,后续插件/模型过程都通过 trace 事件实时展示。
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Writer.WriteHeader(http.StatusOK)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return
}
emit := func(frame chatSSEFrame) {
writeSSEJSON(c.Writer, frame)
flusher.Flush()
}
emitTrace := func(tool, stage, status, message string, data map[string]any) {
emit(chatSSEFrame{Type: "trace", Tool: tool, Stage: stage, Status: status, Message: message, Data: data})
}
emitError := func(err error) {
emit(chatSSEFrame{Type: "error", Error: err.Error()})
}
// 超时 context
timeout := time.Duration(profile.Config.Timeout) * time.Second
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
defer cancel()
usage := newTokenUsageTracker()
ctx = contextWithTokenUsage(ctx, usage)
// 用 Function Calling 工具循环替代旧的路由+隐藏上下文机制
messages, err := runAgentToolLoop(ctx, profile, req.Messages, emit)
if err != nil {
fmt.Fprintln(os.Stderr, "Agent 工具循环失败:", err)
messages, err = buildArkMessages(req.Messages)
if err != nil {
emitError(err)
return
}
}
promptTokens := estimateChatMessagesTokens(req.Messages)
if isOllamaProfile(profile) && hasImageMessage(req.Messages) {
emitTrace("model", "request", "running", "正在通过 Ollama 原生接口调用视觉模型", nil)
err = streamOllamaChat(ctx, profile, messages, promptTokens, usage, emit, func(content string) {
if req.ConversationID != "" {
if err := saveConversationMessages(req.ConversationID, req.Messages, content); err != nil {
fmt.Fprintln(os.Stderr, "保存对话失败:", err)
}
}
})
if err != nil {
emitError(err)
}
return
}
emitTrace("model", "request", "running", "正在调用模型生成回答", nil)
stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
Model: profile.Config.Model,
Messages: messages,
MaxTokens: intPtr(4096),
}.WithStream(true))
if err != nil {
emitError(err)
return
}
defer stream.Close()
emitTrace("model", "stream", "running", "模型已开始输出", nil)
var full strings.Builder
completionTokens := 0
streamStarted := time.Now()
windowStarted := streamStarted
windowTokens := 0
peakTokensPerSecond := 0.0
parseThinkTags := shouldParseThinkTags(profile)
thinkParser := &thinkTagParser{}
emitDelta := func(delta string) {
if delta == "" {
return
}
now := time.Now()
deltaTokens := estimateTokenCount(delta)
windowTokens += deltaTokens
windowElapsed := now.Sub(windowStarted).Seconds()
if windowElapsed >= 1 {
windowSpeed := float64(windowTokens) / windowElapsed
if windowSpeed > peakTokensPerSecond {
peakTokensPerSecond = windowSpeed
}
windowStarted = now
windowTokens = 0
} else if peakTokensPerSecond == 0 && windowElapsed > 0.25 {
peakTokensPerSecond = float64(windowTokens) / windowElapsed
}
full.WriteString(delta)
completionTokens += deltaTokens
usage.setModel(promptTokens, completionTokens)
stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats})
}
emitModelContent := func(delta string) {
if delta == "" {
return
}
if !parseThinkTags {
emitDelta(delta)
return
}
visible, reasoning := thinkParser.Accept(delta)
if reasoning != "" {
emit(chatSSEFrame{Type: "reasoning", Text: reasoning})
}
emitDelta(visible)
}
for {
resp, err := stream.Recv()
if errors.Is(err, io.EOF) {
if parseThinkTags {
visible, reasoning := thinkParser.Flush()
if reasoning != "" {
emit(chatSSEFrame{Type: "reasoning", Text: reasoning})
}
emitDelta(visible)
}
usage.setModel(promptTokens, completionTokens)
if windowTokens > 0 {
windowElapsed := time.Since(windowStarted).Seconds()
if windowElapsed > 0.25 {
windowSpeed := float64(windowTokens) / windowElapsed
if windowSpeed > peakTokensPerSecond {
peakTokensPerSecond = windowSpeed
}
}
}
if peakTokensPerSecond == 0 {
peakTokensPerSecond = tokensPerSecond(completionTokens, streamStarted)
}
if req.ConversationID != "" {
if err := saveConversationMessages(req.ConversationID, req.Messages, full.String()); err != nil {
fmt.Fprintln(os.Stderr, "保存对话失败:", err)
}
}
finalStats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
emit(chatSSEFrame{Type: "stats", Stats: &finalStats})
emitTrace("model", "stream", "success", "回答生成完成", nil)
fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
return
}
if err != nil {
emitError(err)
return
}
if len(resp.Choices) > 0 {
emitModelContent(resp.Choices[0].Delta.Content)
// 思考过程 reasoning_content 单独事件推送
if resp.Choices[0].Delta.ReasoningContent != nil && *resp.Choices[0].Delta.ReasoningContent != "" {
emit(chatSSEFrame{Type: "reasoning", Text: *resp.Choices[0].Delta.ReasoningContent})
}
}
}
}
// ─── 辅助函数 ─────────────────────────────────────────────
func estimateChatMessagesTokens(messages []ChatMessage) int {
total := 0
for _, msg := range messages {
total += estimateTokenCount(msg.Role) + estimateTokenCount(msg.Content) + 4
if msg.ImageURL != "" || msg.ImageURLAlias != "" {
total += 85
}
}
return total
}
func estimateTokenCount(text string) int {
text = strings.TrimSpace(text)
if text == "" {
return 0
}
tokens := 0
asciiRunes := 0
flushASCII := func() {
if asciiRunes > 0 {
tokens += (asciiRunes + 3) / 4
asciiRunes = 0
}
}
for _, r := range text {
if unicode.IsSpace(r) {
flushASCII()
continue
}
if r <= unicode.MaxASCII {
asciiRunes++
continue
}
flushASCII()
tokens++
}
flushASCII()
if tokens == 0 {
return 1
}
return tokens
}
func tokensPerSecond(tokens int, start time.Time) float64 {
elapsed := time.Since(start).Seconds()
if tokens <= 0 || elapsed <= 0 {
return 0
}
return float64(tokens) / elapsed
}
type agentTool struct {
name string
definition *model.Tool
execute func(context.Context, string) (string, error)
}
func (t agentTool) Name() string { return t.name }
const maxAgentToolIterations = 6
func availableAgentTools(profile *OpenAIProfile, emit func(chatSSEFrame)) []agentTool {
if toolRouterState == nil || toolRouterState.cfg == nil || !toolRouterState.cfg.Enabled {
return nil
}
tools := make([]agentTool, 0, len(toolRouterState.cfg.Tools))
for _, item := range toolRouterState.cfg.Tools {
if !item.Enabled {
continue
}
description := strings.TrimSpace(item.Description)
switch item.Name {
case timeagent.ToolName:
tools = append(tools, agentTool{
name: timeagent.ToolName,
definition: timeagent.ToolDefinition(description),
execute: func(ctx context.Context, args string) (string, error) {
result, err := timeagent.ExecuteTool(args, time.Now())
if err == nil && emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: timeagent.ToolName, Stage: "resolve", Status: "success", Message: "已获取当前时间上下文"})
}
return result, err
},
})
case searchagent.ToolName:
if searchState == nil || !searchState.Enabled() {
continue
}
tools = append(tools, agentTool{
name: searchagent.ToolName,
definition: searchState.ToolDefinition(description),
execute: func(ctx context.Context, args string) (string, error) {
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: searchagent.ToolName, Stage: "request", Status: "running", Message: "正在联网搜索"})
}
result, err := searchState.ExecuteTool(ctx, args)
if emit != nil {
status := "success"
message := "联网搜索完成"
if err != nil {
status = "error"
message = "联网搜索失败"
}
emit(chatSSEFrame{Type: "trace", Tool: searchagent.ToolName, Stage: "results", Status: status, Message: message})
}
return result, err
},
})
case sqlquery.ToolName:
if sqlState == nil || !sqlState.Enabled() {
continue
}
tools = append(tools, agentTool{
name: sqlquery.ToolName,
definition: sqlState.ToolDefinition(description),
execute: func(ctx context.Context, args string) (string, error) {
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: sqlquery.ToolName, Stage: "execute", Status: "running", Message: "正在查询数据库"})
}
generator := func(ctx context.Context, prompt string, maxTokens int) (string, error) {
return completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, maxTokens)
}
result, err := sqlState.ExecuteTool(ctx, args, generator)
if emit != nil {
status := "success"
message := "数据库查询完成"
if err != nil {
status = "error"
message = "数据库查询失败"
}
emit(chatSSEFrame{Type: "trace", Tool: sqlquery.ToolName, Stage: "execute", Status: status, Message: message})
}
return result, err
},
})
}
}
return tools
}
func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, emit func(chatSSEFrame)) ([]*model.ChatCompletionMessage, error) {
finalMessages, err := buildArkMessages(chatMessages)
if err != nil {
return nil, err
}
routerProfile := profile
if toolRouterState != nil {
routerProfile = toolRouterState.RouterProfile(profile)
}
tools := availableAgentTools(routerProfile, emit)
if len(tools) == 0 {
return finalMessages, nil
}
decisionMessages := append([]*model.ChatCompletionMessage(nil), finalMessages...)
if hasImageMessage(chatMessages) {
decisionMessages, err = buildToolDecisionMessages(chatMessages)
if err != nil {
return nil, err
}
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "prepare", Status: "success", Message: "检测到图片输入,工具判断阶段将使用纯文本上下文"})
}
}
toolByName := make(map[string]agentTool, len(tools))
definitions := make([]*model.Tool, 0, len(tools))
availableNames := make([]string, 0, len(tools))
toolDescriptions := make([]string, 0, len(tools))
for _, tool := range tools {
toolByName[tool.name] = tool
definitions = append(definitions, tool.definition)
availableNames = append(availableNames, tool.name)
if tool.definition != nil && tool.definition.Function != nil {
toolDescriptions = append(toolDescriptions, fmt.Sprintf("%s: %s", tool.name, tool.definition.Function.Description))
}
}
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "prepare", Status: "success", Message: "已准备可用工具", Data: map[string]any{"tools": availableNames, "tool_descriptions": toolDescriptions}})
}
if prompt := strings.TrimSpace(toolRouterState.cfg.SystemPrompt); prompt != "" {
systemMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent(prompt)}
finalMessages = append([]*model.ChatCompletionMessage{systemMessage}, finalMessages...)
decisionMessages = append([]*model.ChatCompletionMessage{systemMessage}, decisionMessages...)
}
for i := 0; i < maxAgentToolIterations; i++ {
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "running", Message: fmt.Sprintf("正在进行第 %d 轮工具判断", i+1), Data: map[string]any{"iteration": i + 1, "max_iterations": maxAgentToolIterations, "tools": availableNames}})
}
resp, err := toolRouterState.complete(ctx, routerProfile, model.CreateChatCompletionRequest{
Model: routerProfile.Config.Model,
Messages: decisionMessages,
MaxTokens: intPtr(toolRouterState.cfg.MaxTokens),
Tools: definitions,
ToolChoice: model.ToolChoiceStringTypeAuto,
ParallelToolCalls: boolPtr(false),
}, time.Duration(toolRouterState.cfg.Timeout)*time.Second)
if err != nil {
return finalMessages, err
}
if tracker := tokenUsageFromContext(ctx); tracker != nil {
tracker.addTool(resp.Usage.PromptTokens, resp.Usage.CompletionTokens)
}
if len(resp.Choices) == 0 {
return finalMessages, nil
}
choice := resp.Choices[0]
decisionPreview := chatMessageContentString(choice.Message.Content)
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "decision", Status: "success", Message: "工具判断响应已返回", Data: map[string]any{"iteration": i + 1, "finish_reason": string(choice.FinishReason), "content_preview": truncateString(decisionPreview, 800)}})
}
calls := choice.Message.ToolCalls
if len(calls) == 0 && choice.Message.FunctionCall != nil {
calls = []*model.ToolCall{{ID: "legacy_function_call", Type: model.ToolTypeFunction, Function: *choice.Message.FunctionCall}}
}
if len(calls) == 0 {
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "success", Message: "模型未请求工具,进入回答生成"})
}
return finalMessages, nil
}
callNames := make([]string, 0, len(calls))
for _, call := range calls {
if call != nil {
callNames = append(callNames, call.Function.Name)
}
}
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "tool_calls", Status: "running", Message: fmt.Sprintf("模型请求调用 %d 个工具", len(calls)), Data: map[string]any{"tools": callNames, "iteration": i + 1}})
}
assistantMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleAssistant, ToolCalls: calls, Content: choice.Message.Content}
finalMessages = append(finalMessages, assistantMessage)
decisionMessages = append(decisionMessages, assistantMessage)
for _, call := range calls {
result := executeAgentToolCall(ctx, call, toolByName, emit)
toolMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleTool, ToolCallID: call.ID, Content: stringContent(result)}
finalMessages = append(finalMessages, toolMessage)
decisionMessages = append(decisionMessages, toolMessage)
}
}
limitMessage := &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent("工具调用轮数已达到上限。请基于已有工具结果回答,并说明可能未完成全部工具调用。")}
finalMessages = append(finalMessages, limitMessage)
return finalMessages, nil
}
type thinkTagParser struct {
inThink bool
buffer string
}
const (
thinkOpenTag = "<think>"
thinkCloseTag = "</think>"
)
func (p *thinkTagParser) Accept(delta string) (visible string, reasoning string) {
p.buffer += delta
for p.buffer != "" {
if p.inThink {
idx := strings.Index(p.buffer, thinkCloseTag)
if idx >= 0 {
reasoning += p.buffer[:idx]
p.buffer = p.buffer[idx+len(thinkCloseTag):]
p.inThink = false
continue
}
keep := tagPrefixSuffixLen(p.buffer, thinkCloseTag)
if len(p.buffer) > keep {
reasoning += p.buffer[:len(p.buffer)-keep]
p.buffer = p.buffer[len(p.buffer)-keep:]
}
return visible, reasoning
}
idx := strings.Index(p.buffer, thinkOpenTag)
if idx >= 0 {
visible += p.buffer[:idx]
p.buffer = p.buffer[idx+len(thinkOpenTag):]
p.inThink = true
continue
}
keep := tagPrefixSuffixLen(p.buffer, thinkOpenTag)
if len(p.buffer) > keep {
visible += p.buffer[:len(p.buffer)-keep]
p.buffer = p.buffer[len(p.buffer)-keep:]
}
return visible, reasoning
}
return visible, reasoning
}
func (p *thinkTagParser) Flush() (visible string, reasoning string) {
if p.inThink {
reasoning = p.buffer
} else {
visible = p.buffer
}
p.buffer = ""
p.inThink = false
return visible, reasoning
}
func tagPrefixSuffixLen(text, tag string) int {
limit := len(tag) - 1
if len(text) < limit {
limit = len(text)
}
for i := limit; i > 0; i-- {
if strings.HasPrefix(tag, text[len(text)-i:]) {
return i
}
}
return 0
}
func executeAgentToolCall(ctx context.Context, call *model.ToolCall, tools map[string]agentTool, emit func(chatSSEFrame)) string {
if call == nil || call.Type != model.ToolTypeFunction {
result := "工具调用无效:仅支持 function 类型工具。"
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "execute", Status: "error", Message: result})
}
return result
}
toolName := call.Function.Name
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: toolName, Stage: "arguments", Status: "running", Message: "准备执行工具", Data: map[string]any{"tool_call_id": call.ID, "arguments": call.Function.Arguments}})
}
tool, ok := tools[toolName]
if !ok {
result := fmt.Sprintf("工具调用失败:未知工具 %s。", toolName)
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: toolName, Stage: "execute", Status: "error", Message: result})
}
return result
}
started := time.Now()
result, err := tool.execute(ctx, call.Function.Arguments)
durationMs := time.Since(started).Milliseconds()
if err != nil {
message := fmt.Sprintf("工具 %s 执行失败:%v", tool.name, err)
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: tool.name, Stage: "execute", Status: "error", Message: "工具执行失败", Data: map[string]any{"tool_call_id": call.ID, "duration_ms": durationMs, "error": err.Error()}})
}
return message
}
if strings.TrimSpace(result) == "" {
result = fmt.Sprintf("工具 %s 执行完成,但没有返回内容。", tool.name)
}
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: tool.name, Stage: "result", Status: "success", Message: "工具执行完成", Data: map[string]any{"tool_call_id": call.ID, "duration_ms": durationMs, "result_preview": truncateString(result, 1200)}})
}
return result
}
type ollamaChatRequest struct {
Model string `json:"model"`
Messages []ollamaChatMessage `json:"messages"`
Stream bool `json:"stream"`
Options map[string]int `json:"options,omitempty"`
}
type ollamaChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Images []string `json:"images,omitempty"`
}
type ollamaChatResponse struct {
Message struct {
Role string `json:"role"`
Content string `json:"content"`
Thinking string `json:"thinking"`
} `json:"message"`
Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"`
EvalCount int `json:"eval_count"`
DoneReason string `json:"done_reason"`
}
func streamOllamaChat(ctx context.Context, profile *OpenAIProfile, messages []*model.ChatCompletionMessage, promptTokens int, usage *tokenUsageTracker, emit func(chatSSEFrame), onDone func(string)) error {
requestMessages, err := buildOllamaMessages(messages)
if err != nil {
return err
}
baseURL, err := ollamaBaseURL(profile)
if err != nil {
return err
}
body, err := json.Marshal(ollamaChatRequest{
Model: profile.Config.Model,
Messages: requestMessages,
Stream: true,
Options: map[string]int{"num_predict": 4096},
})
if err != nil {
return err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(baseURL, "/")+"/api/chat", bytes.NewReader(body))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
data, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return fmt.Errorf("Ollama 原生接口调用失败: %s %s", resp.Status, strings.TrimSpace(string(data)))
}
emit(chatSSEFrame{Type: "trace", Tool: "model", Stage: "stream", Status: "running", Message: "Ollama 视觉模型已开始输出"})
parseThinkTags := shouldParseThinkTags(profile)
thinkParser := &thinkTagParser{}
var full strings.Builder
completionTokens := 0
streamStarted := time.Now()
peakTokensPerSecond := 0.0
emitDelta := func(delta string) {
if delta == "" {
return
}
full.WriteString(delta)
completionTokens += estimateTokenCount(delta)
usage.setModel(promptTokens, completionTokens)
currentSpeed := tokensPerSecond(completionTokens, streamStarted)
if currentSpeed > peakTokensPerSecond {
peakTokensPerSecond = currentSpeed
}
stats := usage.snapshot(currentSpeed, peakTokensPerSecond)
emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats})
}
emitContent := func(delta string) {
if delta == "" {
return
}
if !parseThinkTags {
emitDelta(delta)
return
}
visible, reasoning := thinkParser.Accept(delta)
if reasoning != "" {
emit(chatSSEFrame{Type: "reasoning", Text: reasoning})
}
emitDelta(visible)
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 10*1024*1024)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
var chunk ollamaChatResponse
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
return fmt.Errorf("解析 Ollama 流失败: %w", err)
}
if chunk.Message.Thinking != "" {
emit(chatSSEFrame{Type: "reasoning", Text: chunk.Message.Thinking})
}
emitContent(chunk.Message.Content)
if chunk.Done {
if chunk.PromptEvalCount > 0 || chunk.EvalCount > 0 {
usage.setModel(chunk.PromptEvalCount, chunk.EvalCount)
}
break
}
}
if err := scanner.Err(); err != nil {
return err
}
if parseThinkTags {
visible, reasoning := thinkParser.Flush()
if reasoning != "" {
emit(chatSSEFrame{Type: "reasoning", Text: reasoning})
}
emitDelta(visible)
}
if onDone != nil {
onDone(full.String())
}
finalStats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
emit(chatSSEFrame{Type: "stats", Stats: &finalStats})
emit(chatSSEFrame{Type: "trace", Tool: "model", Stage: "stream", Status: "success", Message: "回答生成完成"})
return nil
}
func buildOllamaMessages(messages []*model.ChatCompletionMessage) ([]ollamaChatMessage, error) {
result := make([]ollamaChatMessage, 0, len(messages))
for _, msg := range messages {
if msg == nil {
continue
}
role := string(msg.Role)
if msg.Role == model.ChatMessageRoleTool {
role = string(model.ChatMessageRoleUser)
}
item := ollamaChatMessage{Role: role}
if msg.Content == nil {
if len(msg.ToolCalls) > 0 {
continue
}
result = append(result, item)
continue
}
if msg.Content.StringValue != nil {
item.Content = *msg.Content.StringValue
if msg.Role == model.ChatMessageRoleTool {
item.Content = "工具结果:\n" + item.Content
}
result = append(result, item)
continue
}
for _, part := range msg.Content.ListValue {
if part == nil {
continue
}
switch part.Type {
case model.ChatCompletionMessageContentPartTypeText:
if part.Text != "" {
if item.Content != "" {
item.Content += "\n"
}
item.Content += part.Text
}
case model.ChatCompletionMessageContentPartTypeImageURL:
if part.ImageURL == nil {
continue
}
image, err := ollamaImagePayload(part.ImageURL.URL)
if err != nil {
return nil, err
}
item.Images = append(item.Images, image)
}
}
result = append(result, item)
}
return result, nil
}
func ollamaImagePayload(raw string) (string, error) {
raw = strings.TrimSpace(raw)
if strings.HasPrefix(strings.ToLower(raw), "data:") {
comma := strings.Index(raw, ",")
if comma < 0 {
return "", errors.New("图片 base64 数据格式错误")
}
return strings.TrimSpace(raw[comma+1:]), nil
}
return raw, nil
}
func ollamaBaseURL(profile *OpenAIProfile) (string, error) {
if profile == nil {
return "", errors.New("Ollama 配置为空")
}
u, err := url.Parse(strings.TrimSpace(profile.Config.BaseURL))
if err != nil {
return "", err
}
if strings.TrimRight(u.Path, "/") == "/v1" {
u.Path = strings.TrimSuffix(strings.TrimRight(u.Path, "/"), "/v1")
}
u.RawQuery = ""
u.Fragment = ""
return strings.TrimRight(u.String(), "/"), nil
}
func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) {
return completeTextWithTimeout(ctx, profile, chatMessages, maxTokens, time.Duration(profile.Config.Timeout)*time.Second)
}
func completeTextWithTimeout(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int, timeout time.Duration) (string, error) {
messages, err := buildArkMessages(chatMessages)
if err != nil {
return "", err
}
completionCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
stream, err := profile.Client.CreateChatCompletionStream(completionCtx, model.CreateChatCompletionRequest{
Model: profile.Config.Model,
Messages: messages,
MaxTokens: intPtr(maxTokens),
}.WithStream(true))
if err != nil {
return "", err
}
defer stream.Close()
promptTokens := estimateChatMessagesTokens(chatMessages)
completionTokens := 0
parseThinkTags := shouldParseThinkTags(profile)
thinkParser := &thinkTagParser{}
var b strings.Builder
appendVisible := func(delta string) {
if delta == "" {
return
}
b.WriteString(delta)
completionTokens += estimateTokenCount(delta)
}
for {
resp, err := stream.Recv()
if errors.Is(err, io.EOF) {
if parseThinkTags {
visible, _ := thinkParser.Flush()
appendVisible(visible)
}
if tracker := tokenUsageFromContext(ctx); tracker != nil {
tracker.addTool(promptTokens, completionTokens)
}
return b.String(), nil
}
if err != nil {
return "", err
}
if len(resp.Choices) > 0 {
delta := resp.Choices[0].Delta.Content
if parseThinkTags {
visible, _ := thinkParser.Accept(delta)
appendVisible(visible)
} else {
appendVisible(delta)
}
}
}
}
func completeChatWithTimeout(ctx context.Context, profile *OpenAIProfile, request model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) {
completionCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return profile.Client.CreateChatCompletion(completionCtx, request.WithStream(false))
}
func newUUID() string {
b := make([]byte, 16)
_, _ = rand.Read(b)
b[6] = (b[6] & 0x0f) | 0x40
b[8] = (b[8] & 0x3f) | 0x80
return hex.EncodeToString(b[:4]) + "-" + hex.EncodeToString(b[4:6]) + "-" +
hex.EncodeToString(b[6:8]) + "-" + hex.EncodeToString(b[8:10]) + "-" +
hex.EncodeToString(b[10:])
}
// ─── ConvStore ─────────────────────────────────────────────
func NewConvStore(dir string) *ConvStore {
os.MkdirAll(dir, 0755)
return &ConvStore{dir: dir}
}
func (s *ConvStore) path(id string) string {
return filepath.Join(s.dir, id+".json")
}
func (s *ConvStore) Create() (*Conversation, error) {
conv := &Conversation{
ID: newUUID(),
Title: "新对话",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.Save(conv); err != nil {
return nil, err
}
return conv, nil
}
func (s *ConvStore) Save(conv *Conversation) error {
s.mu.Lock()
defer s.mu.Unlock()
conv.UpdatedAt = time.Now()
return atomicWriteJSON(s.path(conv.ID), conv)
}
func (s *ConvStore) Get(id string) (*Conversation, error) {
s.mu.Lock()
defer s.mu.Unlock()
data, err := os.ReadFile(s.path(id))
if err != nil {
if os.IsNotExist(err) {
return nil, errors.New("对话不存在")
}
return nil, fmt.Errorf("读取对话失败: %w", err)
}
var conv Conversation
if err := json.Unmarshal(data, &conv); err != nil {
return nil, fmt.Errorf("解析对话失败: %w", err)
}
return &conv, nil
}
func (s *ConvStore) List() ([]Conversation, error) {
s.mu.Lock()
defer s.mu.Unlock()
entries, err := os.ReadDir(s.dir)
if err != nil {
return nil, fmt.Errorf("读取对话目录失败: %w", err)
}
var list []Conversation
for _, e := range entries {
if e.IsDir() || filepath.Ext(e.Name()) != ".json" {
continue
}
data, err := os.ReadFile(filepath.Join(s.dir, e.Name()))
if err != nil {
continue
}
var conv Conversation
if err := json.Unmarshal(data, &conv); err != nil {
continue
}
conv.Messages = nil // 列表不返回消息体
list = append(list, conv)
}
sort.Slice(list, func(i, j int) bool {
return list[i].UpdatedAt.After(list[j].UpdatedAt)
})
return list, nil
}
func (s *ConvStore) Delete(id string) error {
s.mu.Lock()
defer s.mu.Unlock()
if err := os.Remove(s.path(id)); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("删除对话失败: %w", err)
}
return nil
}
func atomicWriteJSON(path string, v any) error {
tmp := path + ".tmp"
data, err := json.Marshal(v)
if err != nil {
return err
}
if err := os.WriteFile(tmp, data, 0644); err != nil {
return err
}
return os.Rename(tmp, path)
}
func saveConversationMessages(id string, messages []ChatMessage, assistantContent string) error {
conv, err := store.Get(id)
if err != nil {
return err
}
conv.Messages = append([]ChatMessage(nil), messages...)
conv.Messages = append(conv.Messages, ChatMessage{Role: "assistant", Content: assistantContent})
if conv.Title == "" || conv.Title == "新对话" {
conv.Title = genConvTitle(conv.Messages)
}
return store.Save(conv)
}
func genConvTitle(messages []ChatMessage) string {
for _, m := range messages {
if m.Hidden {
continue
}
if m.Role == "user" && strings.TrimSpace(m.Content) != "" {
title := strings.TrimSpace(m.Content)
title = strings.ReplaceAll(title, "\r\n", " ")
title = strings.ReplaceAll(title, "\n", " ")
runes := []rune(title)
if len(runes) > 30 {
return string(runes[:30]) + "..."
}
return title
}
}
return "新对话"
}
const maxImageSize = 4 * 1024 * 1024
var allowedImageTypes = map[string]bool{
"image/jpeg": true,
"image/png": true,
"image/webp": true,
"image/gif": true,
}
func buildArkMessages(chatMessages []ChatMessage) ([]*model.ChatCompletionMessage, error) {
messages := make([]*model.ChatCompletionMessage, 0, len(chatMessages))
for _, m := range chatMessages {
msg, err := buildArkMessage(m)
if err != nil {
return nil, err
}
messages = append(messages, msg)
}
return messages, nil
}
func hasImageMessage(messages []ChatMessage) bool {
for _, msg := range messages {
if strings.TrimSpace(msg.ImageURL) != "" || strings.TrimSpace(msg.ImageURLAlias) != "" {
return true
}
}
return false
}
func buildToolDecisionMessages(chatMessages []ChatMessage) ([]*model.ChatCompletionMessage, error) {
messages := make([]*model.ChatCompletionMessage, 0, len(chatMessages))
for _, m := range chatMessages {
content := m.Content
if strings.TrimSpace(m.ImageURL) != "" || strings.TrimSpace(m.ImageURLAlias) != "" {
content = strings.TrimSpace(content)
placeholder := "[用户上传了一张图片。工具判断阶段不读取图片内容;如果问题主要依赖识图,应不要调用工具,交给最终多模态模型回答。]"
if content == "" {
content = placeholder
} else {
content += "\n\n" + placeholder
}
}
messages = append(messages, &model.ChatCompletionMessage{Role: m.Role, Content: stringContent(content)})
}
return messages, nil
}
func buildArkMessage(m ChatMessage) (*model.ChatCompletionMessage, error) {
msg := &model.ChatCompletionMessage{Role: m.Role}
if m.ImageURL == "" && m.ImageURLAlias != "" {
m.ImageURL = m.ImageURLAlias
}
if m.ImageURL == "" {
msg.Content = &model.ChatCompletionMessageContent{
StringValue: &m.Content,
}
return msg, nil
}
imageURL, err := normalizeImageURL(m.ImageURL)
if err != nil {
return nil, err
}
// 有图片时:文字内容可有可无(图片 caption 场景),均构造多模态消息
// 若无文字,则只传图片 part;若同时有图片和文字,先文后图
parts := make([]*model.ChatCompletionMessageContentPart, 0, 2)
if m.Content != "" {
parts = append(parts, textPart(m.Content))
}
parts = append(parts, imagePart(imageURL))
msg.Content = &model.ChatCompletionMessageContent{ListValue: parts}
return msg, nil
}
func imagePart(url string) *model.ChatCompletionMessageContentPart {
return &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeImageURL,
ImageURL: &model.ChatMessageImageURL{
URL: url,
Detail: model.ImageURLDetailAuto,
},
}
}
func textPart(text string) *model.ChatCompletionMessageContentPart {
return &model.ChatCompletionMessageContentPart{
Type: model.ChatCompletionMessageContentPartTypeText,
Text: text,
}
}
func stringContent(text string) *model.ChatCompletionMessageContent {
return &model.ChatCompletionMessageContent{StringValue: &text}
}
func chatMessageContentString(content *model.ChatCompletionMessageContent) string {
if content == nil || content.StringValue == nil {
return ""
}
return *content.StringValue
}
func normalizeImageURL(raw string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return "", errors.New("图片地址不能为空")
}
lower := strings.ToLower(raw)
if strings.HasPrefix(lower, "data:") {
return normalizeImageDataURI(raw)
}
u, err := url.Parse(raw)
if err != nil || u.Host == "" || (u.Scheme != "http" && u.Scheme != "https") {
return "", errors.New("图片地址无效,仅支持 http/https URL 或 base64 data URI")
}
return raw, nil
}
func normalizeImageDataURI(raw string) (string, error) {
comma := strings.Index(raw, ",")
if comma < 0 {
return "", errors.New("图片 base64 数据格式错误")
}
meta := strings.ToLower(strings.TrimSpace(raw[5:comma]))
payload := strings.TrimSpace(raw[comma+1:])
if payload == "" {
return "", errors.New("图片 base64 数据不能为空")
}
parts := strings.Split(meta, ";")
if len(parts) < 2 || !contains(parts[1:], "base64") {
return "", errors.New("图片 data URI 必须使用 base64 编码")
}
mime := parts[0]
if !allowedImageTypes[mime] {
return "", errors.New("图片格式不支持,仅支持 jpeg/png/webp/gif")
}
decoded, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
return "", errors.New("图片 base64 数据无效")
}
if len(decoded) > maxImageSize {
return "", errors.New("图片过大,请选择小于 4MB 的图片")
}
return "data:" + mime + ";base64," + payload, nil
}
func contains(items []string, target string) bool {
for _, item := range items {
if strings.TrimSpace(item) == target {
return true
}
}
return false
}
func intPtr(i int) *int { return &i }
func boolPtr(v bool) *bool { return &v }
func truncateString(text string, maxRunes int) string {
runes := []rune(strings.TrimSpace(text))
if maxRunes <= 0 || len(runes) <= maxRunes {
return string(runes)
}
return string(runes[:maxRunes]) + "..."
}
func writeSSEJSON(w io.Writer, frame chatSSEFrame) {
data, err := json.Marshal(frame)
if err != nil {
data, _ = json.Marshal(chatSSEFrame{Type: "error", Error: "序列化流事件失败"})
}
fmt.Fprintf(w, "data: %s\n\n", data)
}
func toJSON(s string) string {
b, _ := json.Marshal(s)
return string(b)
}
func toSSE(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, "\n", `\n`)
s = strings.ReplaceAll(s, "\r", "")
s = strings.ReplaceAll(s, `"`, `\"`)
return fmt.Sprintf(`"%s"`, s)
}
// ─── 入口 ─────────────────────────────────────────────────
func main() {
var err error
cfg, err = loadConfig("config.yaml")
if err != nil {
fmt.Fprintln(os.Stderr, "配置加载失败:", err)
os.Exit(1)
}
// 初始化火山方舟 SDK 客户端
aiState, err = NewOpenAIState(cfg.OpenAI)
if err != nil {
fmt.Fprintln(os.Stderr, "OpenAI 配置初始化失败:", err)
os.Exit(1)
}
searchConfig, err := searchagent.LoadConfig("agents/search/config.yaml", legacySearchProfiles)
if err != nil {
fmt.Fprintln(os.Stderr, "联网搜索配置加载失败:", err)
os.Exit(1)
}
searchState, err = searchagent.NewState(searchConfig)
if err != nil {
fmt.Fprintln(os.Stderr, "联网搜索初始化失败:", err)
os.Exit(1)
}
sqlConfig, err := sqlquery.LoadConfig("agents/sql/config.yaml")
if err != nil {
fmt.Fprintln(os.Stderr, "SQL 查询插件配置加载失败:", err)
os.Exit(1)
}
sqlState, err = sqlquery.NewState(sqlConfig)
if err != nil {
fmt.Fprintln(os.Stderr, "SQL 查询插件初始化失败:", err)
os.Exit(1)
}
defer sqlState.Close()
toolRouterState, err = NewToolRouterState(&cfg.ToolRouter, aiState)
if err != nil {
fmt.Fprintln(os.Stderr, "工具路由配置初始化失败:", err)
os.Exit(1)
}
store = NewConvStore("conversations")
// Gin 路由
r := gin.Default()
r.LoadHTMLGlob("templates/*")
r.Static("/static", "./static")
r.GET("/", indexHandler)
r.POST("/api/chat", chatHandler)
r.GET("/api/openai", listOpenAIHandler)
r.POST("/api/openai/active", switchOpenAIHandler)
r.GET("/api/search", listSearchHandler)
r.POST("/api/search/active", switchSearchHandler)
r.GET("/api/conversations", listConversationsHandler)
r.POST("/api/conversations", createConversationHandler)
r.GET("/api/conversations/:id", getConversationHandler)
r.DELETE("/api/conversations/:id", deleteConversationHandler)
// 根据配置选择监听方式
switch strings.ToLower(cfg.Server.Mode) {
case "unix":
socketPath := cfg.Server.Address
if _, statErr := os.Stat(socketPath); statErr == nil {
os.Remove(socketPath)
}
ln, listenErr := net.Listen("unix", socketPath)
if listenErr != nil {
fmt.Fprintln(os.Stderr, "监听 Unix socket 失败:", listenErr)
os.Exit(1)
}
fmt.Println("服务已启动,监听 Unix socket:", socketPath)
if serveErr := http.Serve(ln, r); serveErr != nil {
fmt.Fprintln(os.Stderr, "服务异常退出:", serveErr)
os.Exit(1)
}
default:
fmt.Println("服务已启动,监听 TCP:", cfg.Server.Address)
if runErr := r.Run(cfg.Server.Address); runErr != nil {
fmt.Fprintln(os.Stderr, "服务异常退出:", runErr)
os.Exit(1)
}
}
}