diff --git a/agents/search/search.go b/agents/search/search.go new file mode 100644 index 0000000..d93f86c --- /dev/null +++ b/agents/search/search.go @@ -0,0 +1,520 @@ +package search + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "gopkg.in/yaml.v3" +) + +const ( + defaultActivationPrompt = `判断用户问题是否需要联网搜索。 +仅当问题涉及实时信息、新闻、价格、当前版本、近期事件、政策、网页资料核验,或用户明确要求“查一下/搜索/联网/最新”时调用 search。 +普通知识、闲聊、代码推理、已有上下文可回答的问题不要调用。` + defaultBaseURL = "https://api.duckduckgo.com/" + defaultTimeout = 10 + defaultCount = 5 +) + +type Config struct { + Enabled bool `yaml:"enabled" json:"enabled"` + ActivationPrompt string `yaml:"activation_prompt" json:"activation_prompt"` + Profiles ProfileConfigs `yaml:"profiles" json:"profiles"` +} + +type ProfileConfig struct { + Name string `yaml:"name" json:"name"` + Active bool `yaml:"active,omitempty" json:"active"` + Enabled bool `yaml:"enabled" json:"enabled"` + Provider string `yaml:"provider" json:"provider"` + APIKey string `yaml:"api_key" json:"-"` + BaseURL string `yaml:"base_url" json:"base_url"` + Count int `yaml:"count" json:"count"` + Timeout int `yaml:"timeout" json:"timeout"` +} + +type ProfileConfigs []ProfileConfig + +func (configs *ProfileConfigs) UnmarshalYAML(value *yaml.Node) error { + switch value.Kind { + case yaml.SequenceNode: + var items []ProfileConfig + if err := value.Decode(&items); err != nil { + return err + } + *configs = items + case yaml.MappingNode: + var item ProfileConfig + if err := value.Decode(&item); err != nil { + return err + } + *configs = []ProfileConfig{item} + case yaml.ScalarNode: + if value.Tag == "!!null" { + *configs = nil + return nil + } + return fmt.Errorf("search 配置格式无效") + default: + return fmt.Errorf("search 配置格式无效") + } + return nil +} + +type State struct { + mu sync.RWMutex + cfg *Config + profiles map[string]ProfileConfig + order []string + activeName string +} + +type Result struct { + Title string `json:"title"` + URL string `json:"url"` + Description string `json:"description"` +} + +type ListResponse struct { + Active string `json:"active"` + Profiles []ProfileConfig `json:"profiles"` +} + +type braveSearchResponse struct { + Web struct { + Results []Result `json:"results"` + } `json:"web"` +} + +type duckDuckGoResponse struct { + Abstract string `json:"Abstract"` + AbstractSource string `json:"AbstractSource"` + AbstractURL string `json:"AbstractURL"` + Heading string `json:"Heading"` + RelatedTopics []struct { + Text string `json:"Text"` + FirstURL string `json:"FirstURL"` + Topics []struct { + Text string `json:"Text"` + FirstURL string `json:"FirstURL"` + } `json:"Topics"` + } `json:"RelatedTopics"` + Infobox struct { + Content []struct { + Label string `json:"label"` + Value string `json:"value"` + } `json:"content"` + } `json:"Infobox"` +} + +func LoadConfig(path string, legacyProfiles ...[]ProfileConfig) (*Config, error) { + if _, err := os.Stat(path); err != nil { + if !os.IsNotExist(err) { + return nil, fmt.Errorf("检查联网搜索配置失败: %w", err) + } + cfg := defaultConfig() + if len(legacyProfiles) > 0 && len(legacyProfiles[0]) > 0 { + cfg.Profiles = legacyProfiles[0] + } + if err := normalizeConfig(&cfg); err != nil { + return nil, err + } + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return nil, fmt.Errorf("创建联网搜索配置目录失败: %w", err) + } + data, err := yaml.Marshal(&cfg) + if err != nil { + return nil, fmt.Errorf("生成联网搜索配置失败: %w", err) + } + if err := os.WriteFile(path, data, 0644); err != nil { + return nil, fmt.Errorf("写入联网搜索配置失败: %w", err) + } + } + + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("读取联网搜索配置失败: %w", err) + } + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("解析联网搜索配置失败: %w", err) + } + if err := normalizeConfig(&cfg); err != nil { + return nil, err + } + applyEnv(&cfg) + return &cfg, nil +} + +func NewState(cfg *Config) (*State, error) { + state := &State{cfg: cfg, profiles: map[string]ProfileConfig{}} + if cfg == nil || !cfg.Enabled { + return state, nil + } + for _, config := range cfg.Profiles { + if strings.TrimSpace(config.Name) == "" { + return nil, errors.New("search.profiles.name 不能为空") + } + if strings.TrimSpace(config.Provider) == "" { + return nil, fmt.Errorf("search.%s.provider 未配置", config.Name) + } + if strings.TrimSpace(config.BaseURL) == "" { + return nil, fmt.Errorf("search.%s.base_url 未配置", config.Name) + } + if config.Timeout <= 0 { + return nil, fmt.Errorf("search.%s.timeout 必须大于 0", config.Name) + } + if _, ok := state.profiles[config.Name]; ok { + return nil, fmt.Errorf("search 配置名称重复: %s", config.Name) + } + state.profiles[config.Name] = config + state.order = append(state.order, config.Name) + if config.Active && state.activeName == "" { + state.activeName = config.Name + } + } + if len(state.order) == 0 { + return state, nil + } + if state.activeName == "" { + state.activeName = state.order[0] + } + return state, nil +} + +func (s *State) Enabled() bool { + if s == nil || s.cfg == nil || !s.cfg.Enabled || s.activeName == "" { + return false + } + profile := s.ActiveProfile() + return profile.Enabled +} + +func (s *State) ActivationPrompt() string { + if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.ActivationPrompt) == "" { + return defaultActivationPrompt + } + return strings.TrimSpace(s.cfg.ActivationPrompt) +} + +func (s *State) ActiveProfile() ProfileConfig { + s.mu.RLock() + defer s.mu.RUnlock() + return s.profiles[s.activeName] +} + +func (s *State) SwitchActive(name string) (ProfileConfig, error) { + name = strings.TrimSpace(name) + if name == "" { + return ProfileConfig{}, errors.New("搜索配置名称不能为空") + } + s.mu.Lock() + defer s.mu.Unlock() + profile, ok := s.profiles[name] + if !ok { + return ProfileConfig{}, fmt.Errorf("搜索配置不存在: %s", name) + } + s.activeName = name + return profile, nil +} + +func (s *State) ListProfiles() ListResponse { + s.mu.RLock() + defer s.mu.RUnlock() + profiles := make([]ProfileConfig, 0, len(s.order)) + for _, name := range s.order { + profile := s.profiles[name] + profile.APIKey = "" + profile.Active = name == s.activeName + profiles = append(profiles, profile) + } + return ListResponse{Active: s.activeName, Profiles: profiles} +} + +func (s *State) Search(ctx context.Context, query string) ([]Result, ProfileConfig, error) { + if !s.Enabled() { + return nil, ProfileConfig{}, errors.New("联网搜索未启用,请先在 agents/search/config.yaml 中启用 search") + } + profile := s.ActiveProfile() + switch strings.ToLower(profile.Provider) { + case "duckduckgo", "ddg": + results, err := duckDuckGoSearch(ctx, profile, query) + return results, profile, err + case "brave": + results, err := braveWebSearch(ctx, profile, query) + return results, profile, err + default: + return nil, profile, fmt.Errorf("暂不支持搜索服务: %s", profile.Provider) + } +} + +func BuildResultContext(config ProfileConfig, query string, results []Result, routeReason string) string { + var b strings.Builder + fmt.Fprintf(&b, "工具路由调用了联网搜索。当前搜索源: %s(%s)。请优先根据以下搜索结果回答,并在合适位置标注来源链接。\n", config.Name, config.Provider) + if strings.ToLower(config.Provider) == "duckduckgo" || strings.ToLower(config.Provider) == "ddg" { + fmt.Fprintln(&b, "注意:DuckDuckGo 即时答案不是全量网页搜索,结果可能较少。") + } + fmt.Fprintf(&b, "搜索时间: %s\n", time.Now().Format("2006-01-02 15:04:05")) + fmt.Fprintf(&b, "搜索词: %s\n", query) + if strings.TrimSpace(routeReason) != "" { + fmt.Fprintf(&b, "调用原因: %s\n", strings.TrimSpace(routeReason)) + } + fmt.Fprintln(&b, "\n搜索结果:") + for i, r := range results { + fmt.Fprintf(&b, "%d. 标题: %s\n", i+1, strings.TrimSpace(r.Title)) + if strings.TrimSpace(r.URL) != "" { + fmt.Fprintf(&b, " 链接: %s\n", strings.TrimSpace(r.URL)) + } + if strings.TrimSpace(r.Description) != "" { + fmt.Fprintf(&b, " 摘要: %s\n", strings.TrimSpace(r.Description)) + } + } + fmt.Fprintln(&b, "\n如果搜索结果不足以回答,请明确说明不确定,不要编造。") + return b.String() +} + +func BuildErrorContext(query string, err error) string { + return fmt.Sprintf("工具路由尝试联网搜索但失败。用户问题:%s\n错误:%v\n请向用户说明联网搜索失败,不要编造搜索结果。", query, err) +} + +func defaultConfig() Config { + return Config{ + Enabled: true, + ActivationPrompt: defaultActivationPrompt, + Profiles: ProfileConfigs{defaultProfileConfig()}, + } +} + +func defaultProfileConfig() ProfileConfig { + return ProfileConfig{ + Name: "duckduckgo", + Active: true, + Enabled: true, + Provider: "duckduckgo", + BaseURL: defaultBaseURL, + Count: defaultCount, + Timeout: defaultTimeout, + } +} + +func normalizeConfig(cfg *Config) error { + if strings.TrimSpace(cfg.ActivationPrompt) == "" { + cfg.ActivationPrompt = defaultActivationPrompt + } else { + cfg.ActivationPrompt = strings.TrimSpace(cfg.ActivationPrompt) + } + if len(cfg.Profiles) == 0 { + cfg.Profiles = ProfileConfigs{defaultProfileConfig()} + } + activeIndex := -1 + seen := map[string]bool{} + for i := range cfg.Profiles { + profile := &cfg.Profiles[i] + profile.Provider = strings.ToLower(strings.TrimSpace(profile.Provider)) + if profile.Provider == "" { + profile.Provider = "duckduckgo" + } + name := strings.TrimSpace(profile.Name) + if name == "" { + name = profile.Provider + if seen[name] { + name = fmt.Sprintf("%s-%d", profile.Provider, i+1) + } + } + profile.Name = name + if seen[name] { + return fmt.Errorf("search 配置名称重复: %s", name) + } + seen[name] = true + if strings.TrimSpace(profile.BaseURL) == "" { + switch profile.Provider { + case "duckduckgo", "ddg": + profile.BaseURL = defaultBaseURL + case "brave": + profile.BaseURL = "https://api.search.brave.com/res/v1/web/search" + default: + return fmt.Errorf("暂不支持搜索服务: %s", profile.Provider) + } + } + if profile.Count <= 0 { + profile.Count = defaultCount + } + if profile.Timeout <= 0 { + profile.Timeout = defaultTimeout + } + if profile.Active { + if activeIndex == -1 { + activeIndex = i + } else { + profile.Active = false + } + } + } + if activeIndex == -1 { + cfg.Profiles[0].Active = true + } + return nil +} + +func applyEnv(cfg *Config) { + if key := os.Getenv("BRAVE_SEARCH_API_KEY"); key != "" { + for i := range cfg.Profiles { + if strings.ToLower(cfg.Profiles[i].Provider) == "brave" { + cfg.Profiles[i].APIKey = key + } + } + } +} + +func duckDuckGoSearch(ctx context.Context, config ProfileConfig, query string) ([]Result, error) { + searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second) + defer cancel() + + u, err := url.Parse(config.BaseURL) + if err != nil { + return nil, fmt.Errorf("搜索服务地址无效: %w", err) + } + q := u.Query() + q.Set("q", query) + q.Set("format", "json") + q.Set("no_html", "1") + q.Set("skip_disambig", "1") + u.RawQuery = q.Encode() + + req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, fmt.Errorf("创建搜索请求失败: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "aichat/1.0") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("联网搜索失败: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024)) + if err != nil { + return nil, fmt.Errorf("读取搜索响应失败: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var parsed duckDuckGoResponse + if err := json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("解析搜索响应失败: %w", err) + } + + limit := config.Count + if limit <= 0 { + limit = defaultCount + } + results := make([]Result, 0, limit) + if strings.TrimSpace(parsed.Abstract) != "" { + title := strings.TrimSpace(parsed.Heading) + if title == "" { + title = strings.TrimSpace(parsed.AbstractSource) + } + if title == "" { + title = "DuckDuckGo 摘要" + } + results = append(results, Result{Title: title, URL: strings.TrimSpace(parsed.AbstractURL), Description: strings.TrimSpace(parsed.Abstract)}) + } + for _, item := range parsed.Infobox.Content { + if len(results) >= limit { + break + } + if strings.TrimSpace(item.Label) == "" || strings.TrimSpace(item.Value) == "" { + continue + } + results = append(results, Result{Title: item.Label, Description: item.Value}) + } + appendRelated := func(text, firstURL string) { + if len(results) >= limit || strings.TrimSpace(text) == "" { + return + } + title, desc := splitDuckDuckGoText(text) + results = append(results, Result{Title: title, URL: strings.TrimSpace(firstURL), Description: desc}) + } + for _, topic := range parsed.RelatedTopics { + appendRelated(topic.Text, topic.FirstURL) + for _, nested := range topic.Topics { + appendRelated(nested.Text, nested.FirstURL) + } + if len(results) >= limit { + break + } + } + return results, nil +} + +func splitDuckDuckGoText(text string) (string, string) { + text = strings.TrimSpace(text) + parts := strings.SplitN(text, " - ", 2) + if len(parts) == 2 { + return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) + } + runes := []rune(text) + if len(runes) > 42 { + return string(runes[:42]) + "...", text + } + return text, text +} + +func braveWebSearch(ctx context.Context, config ProfileConfig, query string) ([]Result, error) { + if config.APIKey == "" { + return nil, errors.New("Brave 搜索未配置 API Key,请设置 agents/search/config.yaml 中的 api_key 或环境变量 BRAVE_SEARCH_API_KEY") + } + searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second) + defer cancel() + + u, err := url.Parse(config.BaseURL) + if err != nil { + return nil, fmt.Errorf("搜索服务地址无效: %w", err) + } + q := u.Query() + q.Set("q", query) + q.Set("count", fmt.Sprintf("%d", config.Count)) + q.Set("search_lang", "zh-hans") + q.Set("country", "CN") + u.RawQuery = q.Encode() + + req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, fmt.Errorf("创建搜索请求失败: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("X-Subscription-Token", config.APIKey) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("联网搜索失败: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024)) + if err != nil { + return nil, fmt.Errorf("读取搜索响应失败: %w", err) + } + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var parsed braveSearchResponse + if err := json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("解析搜索响应失败: %w", err) + } + return parsed.Web.Results, nil +} diff --git a/agents/search/search_test.go b/agents/search/search_test.go new file mode 100644 index 0000000..75847d1 --- /dev/null +++ b/agents/search/search_test.go @@ -0,0 +1,113 @@ +package search + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" +) + +func TestNormalizeConfigDefaults(t *testing.T) { + cfg := &Config{Enabled: true, Profiles: ProfileConfigs{{Provider: "duckduckgo"}}} + if err := normalizeConfig(cfg); err != nil { + t.Fatal(err) + } + if strings.TrimSpace(cfg.ActivationPrompt) == "" { + t.Fatal("activation prompt should be defaulted") + } + profile := cfg.Profiles[0] + if profile.Name != "duckduckgo" || !profile.Active || profile.BaseURL != defaultBaseURL || profile.Count != defaultCount || profile.Timeout != defaultTimeout { + t.Fatalf("unexpected profile: %#v", profile) + } +} + +func TestNormalizeConfigDuplicateProfile(t *testing.T) { + cfg := &Config{Enabled: true, Profiles: ProfileConfigs{{Name: "duck", Provider: "duckduckgo"}, {Name: " duck ", Provider: "duckduckgo"}}} + if err := normalizeConfig(cfg); err == nil { + t.Fatal("expected duplicate profile error") + } +} + +func TestNormalizeConfigUnsupportedProvider(t *testing.T) { + cfg := &Config{Enabled: true, Profiles: ProfileConfigs{{Provider: "unknown"}}} + if err := normalizeConfig(cfg); err == nil { + t.Fatal("expected unsupported provider error") + } +} + +func TestApplyEnvBraveAPIKey(t *testing.T) { + t.Setenv("BRAVE_SEARCH_API_KEY", "from-env") + cfg := &Config{Profiles: ProfileConfigs{{Name: "brave", Provider: "brave", APIKey: "from-config"}}} + applyEnv(cfg) + if cfg.Profiles[0].APIKey != "from-env" { + t.Fatalf("api key = %q", cfg.Profiles[0].APIKey) + } +} + +func TestListProfilesRedactsAPIKey(t *testing.T) { + state, err := NewState(&Config{Enabled: true, Profiles: ProfileConfigs{{Name: "brave", Active: true, Enabled: true, Provider: "brave", APIKey: "secret", BaseURL: "https://example.com", Count: 1, Timeout: 1}}}) + if err != nil { + t.Fatal(err) + } + list := state.ListProfiles() + if len(list.Profiles) != 1 || list.Profiles[0].APIKey != "" || !list.Profiles[0].Active { + t.Fatalf("unexpected list: %#v", list) + } +} + +func TestBuildResultContext(t *testing.T) { + text := BuildResultContext(ProfileConfig{Name: "duckduckgo", Provider: "duckduckgo"}, "最新消息", []Result{{Title: "标题", URL: "https://example.com", Description: "摘要"}}, "需要最新信息") + for _, want := range []string{"联网搜索", "最新消息", "标题", "https://example.com", "需要最新信息", "不要编造"} { + if !strings.Contains(text, want) { + t.Fatalf("context missing %q:\n%s", want, text) + } + } +} + +func TestDuckDuckGoSearchParsesResults(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("q") != "golang" { + t.Fatalf("query = %s", r.URL.RawQuery) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "Heading":"Go", + "Abstract":"Go is a language", + "AbstractURL":"https://go.dev", + "RelatedTopics":[{"Text":"Gopher - mascot","FirstURL":"https://go.dev/blog/gopher"}] + }`)) + })) + defer server.Close() + + results, err := duckDuckGoSearch(context.Background(), ProfileConfig{Provider: "duckduckgo", BaseURL: server.URL, Count: 2, Timeout: 1}, "golang") + if err != nil { + t.Fatal(err) + } + if len(results) != 2 || results[0].Title != "Go" || results[1].Title != "Gopher" { + t.Fatalf("unexpected results: %#v", results) + } +} + +func TestBraveMissingAPIKey(t *testing.T) { + _, err := braveWebSearch(context.Background(), ProfileConfig{Provider: "brave", BaseURL: "https://example.com", Timeout: 1, Count: 1}, "query") + if err == nil || !strings.Contains(err.Error(), "API Key") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestLoadConfigWritesLegacyProfiles(t *testing.T) { + dir := t.TempDir() + path := dir + string(os.PathSeparator) + "config.yaml" + cfg, err := LoadConfig(path, []ProfileConfig{{Name: "legacy", Active: true, Enabled: true, Provider: "duckduckgo", Count: 3, Timeout: 2}}) + if err != nil { + t.Fatal(err) + } + if len(cfg.Profiles) != 1 || cfg.Profiles[0].Name != "legacy" { + t.Fatalf("unexpected cfg: %#v", cfg) + } + if _, err := os.Stat(path); err != nil { + t.Fatal(err) + } +} diff --git a/agents/SQL_query/sql_query.go b/agents/sql/sql_query.go similarity index 99% rename from agents/SQL_query/sql_query.go rename to agents/sql/sql_query.go index 49d1d82..cd4aba0 100644 --- a/agents/SQL_query/sql_query.go +++ b/agents/sql/sql_query.go @@ -22,6 +22,7 @@ import ( const ( defaultActivationPrompt = `判断用户问题是否需要查询业务数据库。 仅当用户询问数据库表、记录、字段、时间、状态、内容、统计、最近/最早/某时间范围内的数据时返回 activate=true。 +当用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,必须返回 activate=true,并说明应查询 tab_calendar_events 表。 普通知识问答、代码问题、闲聊、联网搜索问题返回 activate=false。 只返回 JSON: {"activate": true/false, "reason": "..."}` defaultDatabaseName = "default" diff --git a/agents/SQL_query/sql_query_test.go b/agents/sql/sql_query_test.go similarity index 78% rename from agents/SQL_query/sql_query_test.go rename to agents/sql/sql_query_test.go index d1f4a03..e144988 100644 --- a/agents/SQL_query/sql_query_test.go +++ b/agents/sql/sql_query_test.go @@ -1,6 +1,15 @@ package sqlquery -import "testing" +import ( + "strings" + "testing" +) + +func TestDefaultActivationPromptMentionsCalendarEvents(t *testing.T) { + if !strings.Contains(defaultActivationPrompt, "tab_calendar_events") { + t.Fatal("default activation prompt should mention tab_calendar_events for calendar queries") + } +} func TestValidateReadOnlySQLAllowsSelectAndWith(t *testing.T) { queries := []string{ diff --git a/agents/time/time.go b/agents/time/time.go new file mode 100644 index 0000000..3bc7c5f --- /dev/null +++ b/agents/time/time.go @@ -0,0 +1,76 @@ +package timeagent + +import ( + "fmt" + "strings" + "time" +) + +const ActivationPrompt = "提供当前日期、时间和常用时间范围。当用户问题包含今天、明天、昨天、本周、本月、本年、最近、日程安排等相对时间表达时,应先调用此工具;如果后续还需要查数据库,可继续调用 sql。" + +type Range struct { + Start time.Time + End time.Time +} + +type Context struct { + Now time.Time + Today Range + Tomorrow Range + Yesterday Range + ThisWeek Range + ThisMonth Range + ThisYear Range +} + +func Resolve(now time.Time) Context { + loc := now.Location() + today := Range{Start: startOfDay(now), End: startOfDay(now).AddDate(0, 0, 1)} + tomorrowStart := today.End + tomorrow := Range{Start: tomorrowStart, End: tomorrowStart.AddDate(0, 0, 1)} + yesterdayEnd := today.Start + yesterday := Range{Start: yesterdayEnd.AddDate(0, 0, -1), End: yesterdayEnd} + weekdayOffset := (int(now.Weekday()) + 6) % 7 + weekStart := today.Start.AddDate(0, 0, -weekdayOffset) + monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, loc) + yearStart := time.Date(now.Year(), 1, 1, 0, 0, 0, 0, loc) + return Context{ + Now: now, + Today: today, + Tomorrow: tomorrow, + Yesterday: yesterday, + ThisWeek: Range{Start: weekStart, End: weekStart.AddDate(0, 0, 7)}, + ThisMonth: Range{Start: monthStart, End: monthStart.AddDate(0, 1, 0)}, + ThisYear: Range{Start: yearStart, End: yearStart.AddDate(1, 0, 0)}, + } +} + +func BuildContext(ctx Context, routeReason string) string { + var b strings.Builder + b.WriteString("时间工具结果。后续工具必须优先使用这里的绝对日期解释用户问题中的相对时间,不要自行猜测当前日期。\n") + fmt.Fprintf(&b, "当前本地日期时间:%s\n", ctx.Now.Format("2006-01-02 15:04:05 MST")) + fmt.Fprintf(&b, "今天:%s\n", FormatSQLRange(ctx.Today)) + fmt.Fprintf(&b, "明天:%s\n", FormatSQLRange(ctx.Tomorrow)) + fmt.Fprintf(&b, "昨天:%s\n", FormatSQLRange(ctx.Yesterday)) + fmt.Fprintf(&b, "本周:%s\n", FormatSQLRange(ctx.ThisWeek)) + fmt.Fprintf(&b, "本月:%s\n", FormatSQLRange(ctx.ThisMonth)) + fmt.Fprintf(&b, "本年:%s\n", FormatSQLRange(ctx.ThisYear)) + b.WriteString("SQL 日期过滤建议:对日程/事件类表使用半开区间,例如 event_time >= start AND event_time < end;如果实际字段名不同,必须使用 schema 中存在的时间字段。\n") + if strings.TrimSpace(routeReason) != "" { + b.WriteString("激活原因:" + strings.TrimSpace(routeReason) + "\n") + } + return b.String() +} + +func FormatDate(t time.Time) string { + return t.Format("2006-01-02") +} + +func FormatSQLRange(r Range) string { + return fmt.Sprintf("start=%s, end_exclusive=%s", r.Start.Format("2006-01-02 15:04:05"), r.End.Format("2006-01-02 15:04:05")) +} + +func startOfDay(t time.Time) time.Time { + y, m, d := t.Date() + return time.Date(y, m, d, 0, 0, 0, 0, t.Location()) +} diff --git a/agents/time/time_test.go b/agents/time/time_test.go new file mode 100644 index 0000000..45dbc9f --- /dev/null +++ b/agents/time/time_test.go @@ -0,0 +1,31 @@ +package timeagent + +import ( + "strings" + "testing" + "time" +) + +func TestResolveBuildsCalendarRanges(t *testing.T) { + loc := time.FixedZone("CST", 8*60*60) + ctx := Resolve(time.Date(2026, 6, 10, 13, 14, 15, 0, loc)) + if got := FormatSQLRange(ctx.Today); got != "start=2026-06-10 00:00:00, end_exclusive=2026-06-11 00:00:00" { + t.Fatalf("today = %s", got) + } + if got := FormatSQLRange(ctx.ThisMonth); got != "start=2026-06-01 00:00:00, end_exclusive=2026-07-01 00:00:00" { + t.Fatalf("this month = %s", got) + } + if got := FormatSQLRange(ctx.ThisWeek); got != "start=2026-06-08 00:00:00, end_exclusive=2026-06-15 00:00:00" { + t.Fatalf("this week = %s", got) + } +} + +func TestBuildContextIncludesSQLHints(t *testing.T) { + ctx := Resolve(time.Date(2026, 6, 10, 13, 14, 15, 0, time.UTC)) + text := BuildContext(ctx, "需要日期范围") + for _, want := range []string{"时间工具结果", "本月", "start=", "end_exclusive=", "半开区间", "需要日期范围"} { + if !strings.Contains(text, want) { + t.Fatalf("context missing %q:\n%s", want, text) + } + } +} diff --git a/main.go b/main.go index f9a27ee..3293219 100644 --- a/main.go +++ b/main.go @@ -18,8 +18,11 @@ import ( "strings" "sync" "time" + "unicode" - sqlquery "aichat/agents/SQL_query" + searchagent "aichat/agents/search" + sqlquery "aichat/agents/sql" + timeagent "aichat/agents/time" "github.com/gin-gonic/gin" ark "github.com/volcengine/volcengine-go-sdk/service/arkruntime" @@ -30,11 +33,19 @@ import ( // ─── 配置 ───────────────────────────────────────────────── const ( - defaultOpenAIBaseURL = "https://ark.cn-beijing.volces.com/api/v3" - defaultOpenAITimeout = 120 - defaultSearchBaseURL = "https://api.duckduckgo.com/" - defaultSearchTimeout = 10 - defaultSearchCount = 5 + defaultOpenAIBaseURL = "https://ark.cn-beijing.volces.com/api/v3" + defaultOpenAITimeout = 120 + defaultToolRouterTimeout = 30 + defaultToolRouterMaxTokens = 512 + defaultToolRouterSystemText = `你是工具路由器。根据用户最新问题和可用工具列表,判断本轮是否需要调用一个或多个工具。 +只能返回 JSON,不要使用 Markdown。 +JSON 格式:{"tools":[{"name":"工具名称","reason":"..."}],"reason":"..."} +工具名称必须来自“可用工具”列表。 +可以选择多个工具,工具会按配置顺序依次执行;后面的工具可以使用前面工具写入的上下文。 +如果用户问题包含今天、明天、昨天、本周、本月、本年、最近等相对时间,并且还需要查询数据库,请同时选择 time 和 sql。 +例如“本月有什么日程安排”应选择 time 和 sql:先获取本月绝对日期范围,再查询日程表。 +如果无需工具,返回 {"tools":[],"reason":"..."}。 +只选择确实必要的工具。` ) type OpenAIConfig struct { @@ -48,43 +59,19 @@ type OpenAIConfig struct { type OpenAIConfigs []OpenAIConfig -type SearchConfig struct { - Name string `yaml:"name" json:"name"` - Active bool `yaml:"active,omitempty" json:"active"` - Enabled bool `yaml:"enabled" json:"enabled"` - Provider string `yaml:"provider" json:"provider"` - APIKey string `yaml:"api_key" json:"-"` - BaseURL string `yaml:"base_url" json:"base_url"` - Count int `yaml:"count" json:"count"` - Timeout int `yaml:"timeout" json:"timeout"` +type ToolRouterConfig struct { + Enabled bool `yaml:"enabled" json:"enabled"` + OpenAIName string `yaml:"openai_name" json:"openai_name"` + Timeout int `yaml:"timeout" json:"timeout"` + MaxTokens int `yaml:"max_tokens" json:"max_tokens"` + SystemPrompt string `yaml:"system_prompt" json:"system_prompt"` + Tools []ToolRouteConfig `yaml:"tools" json:"tools"` } -type SearchConfigs []SearchConfig - -func (configs *SearchConfigs) UnmarshalYAML(value *yaml.Node) error { - switch value.Kind { - case yaml.SequenceNode: - var items []SearchConfig - if err := value.Decode(&items); err != nil { - return err - } - *configs = items - case yaml.MappingNode: - var item SearchConfig - if err := value.Decode(&item); err != nil { - return err - } - *configs = []SearchConfig{item} - case yaml.ScalarNode: - if value.Tag == "!!null" { - *configs = nil - return nil - } - return fmt.Errorf("search 配置格式无效") - default: - return fmt.Errorf("search 配置格式无效") - } - return nil +type ToolRouteConfig struct { + Name string `yaml:"name" json:"name"` + Enabled bool `yaml:"enabled" json:"enabled"` + Description string `yaml:"description" json:"description"` } func (configs *OpenAIConfigs) UnmarshalYAML(value *yaml.Node) error { @@ -118,8 +105,8 @@ type Config struct { Mode string `yaml:"mode"` Address string `yaml:"address"` } `yaml:"server"` - OpenAI OpenAIConfigs `yaml:"openai"` - Search SearchConfigs `yaml:"search"` + OpenAI OpenAIConfigs `yaml:"openai"` + ToolRouter ToolRouterConfig `yaml:"tool_router"` } func defaultOpenAIConfig() OpenAIConfig { @@ -131,15 +118,18 @@ func defaultOpenAIConfig() OpenAIConfig { } } -func defaultSearchConfig() SearchConfig { - return SearchConfig{ - Name: "duckduckgo", - Active: true, - Enabled: true, - Provider: "duckduckgo", - BaseURL: defaultSearchBaseURL, - Count: defaultSearchCount, - Timeout: defaultSearchTimeout, +func defaultToolRouterConfig() ToolRouterConfig { + return ToolRouterConfig{ + Enabled: true, + OpenAIName: "", + Timeout: defaultToolRouterTimeout, + MaxTokens: defaultToolRouterMaxTokens, + SystemPrompt: defaultToolRouterSystemText, + Tools: []ToolRouteConfig{ + {Name: "time", Enabled: true, Description: ""}, + {Name: "search", Enabled: true, Description: ""}, + {Name: "sql", Enabled: true, Description: ""}, + }, } } @@ -148,7 +138,7 @@ func defaultConfig() Config { cfg.Server.Mode = "tcp" cfg.Server.Address = "0.0.0.0:8080" cfg.OpenAI = OpenAIConfigs{defaultOpenAIConfig()} - cfg.Search = SearchConfigs{defaultSearchConfig()} + cfg.ToolRouter = defaultToolRouterConfig() return cfg } @@ -174,16 +164,10 @@ func loadConfig(path string) (*Config, error) { cfg.OpenAI[i].APIKey = key } } - if _, err := normalizeSearchConfigs(&cfg); err != nil { + legacySearchProfiles = readLegacySearchProfiles(data) + if _, err := normalizeToolRouterConfig(&cfg); err != nil { return nil, err } - if key := os.Getenv("BRAVE_SEARCH_API_KEY"); key != "" { - for i := range cfg.Search { - if strings.ToLower(cfg.Search[i].Provider) == "brave" { - cfg.Search[i].APIKey = key - } - } - } return &cfg, nil } @@ -234,10 +218,10 @@ func ensureConfigFile(path string) error { changed = true } - if _, ok := raw["search"].([]any); !ok { + if _, ok := raw["tool_router"]; !ok { + cfg.ToolRouter = defaults.ToolRouter changed = true - } - if normalized, err := normalizeSearchConfigs(&cfg); err != nil { + } else if normalized, err := normalizeToolRouterConfig(&cfg); err != nil { return err } else if normalized { changed = true @@ -301,74 +285,90 @@ func normalizeOpenAIConfigs(cfg *Config) (bool, error) { return changed, nil } -func normalizeSearchConfigs(cfg *Config) (bool, error) { +func normalizeToolRouterConfig(cfg *Config) (bool, error) { changed := false - if len(cfg.Search) == 0 { - cfg.Search = SearchConfigs{defaultSearchConfig()} + defaults := defaultToolRouterConfig() + cfg.ToolRouter.OpenAIName = strings.TrimSpace(cfg.ToolRouter.OpenAIName) + if cfg.ToolRouter.Timeout <= 0 { + cfg.ToolRouter.Timeout = defaultToolRouterTimeout + changed = true + } + if cfg.ToolRouter.MaxTokens <= 0 { + cfg.ToolRouter.MaxTokens = defaultToolRouterMaxTokens + changed = true + } + if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) == "" { + cfg.ToolRouter.SystemPrompt = defaultToolRouterSystemText + changed = true + } else if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) != cfg.ToolRouter.SystemPrompt { + cfg.ToolRouter.SystemPrompt = strings.TrimSpace(cfg.ToolRouter.SystemPrompt) + changed = true + } + if len(cfg.ToolRouter.Tools) == 0 { + cfg.ToolRouter.Tools = defaults.Tools changed = true } - - activeIndex := -1 seen := map[string]bool{} - for i := range cfg.Search { - profile := &cfg.Search[i] - profile.Provider = strings.ToLower(strings.TrimSpace(profile.Provider)) - if profile.Provider == "" { - profile.Provider = "duckduckgo" - changed = true - } - name := strings.TrimSpace(profile.Name) + for i := range cfg.ToolRouter.Tools { + tool := &cfg.ToolRouter.Tools[i] + name := strings.ToLower(strings.TrimSpace(tool.Name)) if name == "" { - name = profile.Provider - if seen[name] { - name = fmt.Sprintf("%s-%d", profile.Provider, i+1) - } - profile.Name = name - changed = true - } else if name != profile.Name { - profile.Name = name + name = fmt.Sprintf("tool-%d", i+1) + } + if name != tool.Name { + tool.Name = name changed = true } + tool.Description = strings.TrimSpace(tool.Description) if seen[name] { - return changed, fmt.Errorf("search 配置名称重复: %s", name) + return changed, fmt.Errorf("tool_router.tools 配置名称重复: %s", name) } seen[name] = true - - if strings.TrimSpace(profile.BaseURL) == "" { - switch profile.Provider { - case "duckduckgo", "ddg": - profile.BaseURL = defaultSearchBaseURL - case "brave": - profile.BaseURL = "https://api.search.brave.com/res/v1/web/search" - default: - return changed, fmt.Errorf("暂不支持搜索服务: %s", profile.Provider) - } + } + byName := map[string]ToolRouteConfig{} + for _, tool := range cfg.ToolRouter.Tools { + byName[tool.Name] = tool + } + merged := make([]ToolRouteConfig, 0, len(cfg.ToolRouter.Tools)+len(defaults.Tools)) + used := map[string]bool{} + for _, tool := range defaults.Tools { + if existing, ok := byName[tool.Name]; ok { + merged = append(merged, existing) + } else { + merged = append(merged, tool) changed = true } - if profile.Count <= 0 { - profile.Count = defaultSearchCount - changed = true - } - if profile.Timeout <= 0 { - profile.Timeout = defaultSearchTimeout - changed = true - } - if profile.Active { - if activeIndex == -1 { - activeIndex = i - } else { - profile.Active = false - changed = true - } + used[tool.Name] = true + } + for _, tool := range cfg.ToolRouter.Tools { + if !used[tool.Name] { + merged = append(merged, tool) } } - if activeIndex == -1 { - cfg.Search[0].Active = true + if len(merged) != len(cfg.ToolRouter.Tools) { changed = true + } else { + for i := range merged { + if merged[i].Name != cfg.ToolRouter.Tools[i].Name { + changed = true + break + } + } } + cfg.ToolRouter.Tools = merged return changed, nil } +func readLegacySearchProfiles(data []byte) []searchagent.ProfileConfig { + var legacy struct { + Search searchagent.ProfileConfigs `yaml:"search"` + } + if err := yaml.Unmarshal(data, &legacy); err != nil { + return nil + } + return []searchagent.ProfileConfig(legacy.Search) +} + func writeConfig(path string, cfg Config) error { data, err := yaml.Marshal(&cfg) if err != nil { @@ -431,9 +431,28 @@ type openAIListResponse struct { Profiles []OpenAIConfig `json:"profiles"` } -type searchListResponse struct { - Active string `json:"active"` - Profiles []SearchConfig `json:"profiles"` +type toolTextCompleter func(context.Context, *OpenAIProfile, []ChatMessage, int, time.Duration) (string, error) + +type ToolRouterState struct { + cfg *ToolRouterConfig + ai *OpenAIState + complete toolTextCompleter +} + +func NewToolRouterState(config *ToolRouterConfig, ai *OpenAIState) (*ToolRouterState, error) { + if config == nil { + cfg := defaultToolRouterConfig() + config = &cfg + } + if ai == nil { + return nil, errors.New("工具路由需要 OpenAI 状态") + } + if config.Enabled && strings.TrimSpace(config.OpenAIName) != "" { + if _, err := ai.GetProfile(config.OpenAIName); err != nil { + return nil, fmt.Errorf("tool_router.openai_name 配置无效: %w", err) + } + } + return &ToolRouterState{cfg: config, ai: ai, complete: completeTextWithTimeout}, nil } func NewOpenAIState(configs []OpenAIConfig) (*OpenAIState, error) { @@ -537,108 +556,104 @@ func publicOpenAIConfig(profile *OpenAIProfile, active bool) OpenAIConfig { return config } -type SearchState struct { - mu sync.RWMutex - profiles map[string]SearchConfig - order []string - activeName string -} - -func NewSearchState(configs []SearchConfig) (*SearchState, error) { - state := &SearchState{ - profiles: make(map[string]SearchConfig, len(configs)), - order: make([]string, 0, len(configs)), - } - for _, config := range configs { - if strings.TrimSpace(config.Name) == "" { - return nil, errors.New("search.name 不能为空") - } - if strings.TrimSpace(config.Provider) == "" { - return nil, fmt.Errorf("search.%s.provider 未配置", config.Name) - } - if strings.TrimSpace(config.BaseURL) == "" { - return nil, fmt.Errorf("search.%s.base_url 未配置", config.Name) - } - if config.Timeout <= 0 { - return nil, fmt.Errorf("search.%s.timeout 必须大于 0", config.Name) - } - if _, ok := state.profiles[config.Name]; ok { - return nil, fmt.Errorf("search 配置名称重复: %s", config.Name) - } - state.profiles[config.Name] = config - state.order = append(state.order, config.Name) - if config.Active && state.activeName == "" { - state.activeName = config.Name - } - } - if len(state.order) == 0 { - return nil, errors.New("search 配置不能为空") - } - if state.activeName == "" { - state.activeName = state.order[0] - } - return state, nil -} - -func (s *SearchState) ActiveProfile() SearchConfig { - s.mu.RLock() - defer s.mu.RUnlock() - return s.profiles[s.activeName] -} - -func (s *SearchState) SwitchActive(name string) (SearchConfig, error) { - name = strings.TrimSpace(name) - if name == "" { - return SearchConfig{}, errors.New("搜索配置名称不能为空") - } - s.mu.Lock() - defer s.mu.Unlock() - profile, ok := s.profiles[name] - if !ok { - return SearchConfig{}, fmt.Errorf("搜索配置不存在: %s", name) - } - s.activeName = name - return profile, nil -} - -func (s *SearchState) ListProfiles() searchListResponse { - s.mu.RLock() - defer s.mu.RUnlock() - profiles := make([]SearchConfig, 0, len(s.order)) - for _, name := range s.order { - profile := s.profiles[name] - profile.APIKey = "" - profile.Active = name == s.activeName - profiles = append(profiles, profile) - } - return searchListResponse{Active: s.activeName, Profiles: profiles} -} - -func publicSearchConfig(config SearchConfig, active bool) SearchConfig { - config.APIKey = "" - config.Active = active - return config -} - // ─── 全局变量 ───────────────────────────────────────────── var ( - cfg *Config - aiState *OpenAIState - searchState *SearchState - sqlState *sqlquery.State - store *ConvStore + cfg *Config + aiState *OpenAIState + searchState *searchagent.State + legacySearchProfiles []searchagent.ProfileConfig + toolRouterState *ToolRouterState + sqlState *sqlquery.State + store *ConvStore ) type chatSSEFrame struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Message string `json:"message,omitempty"` - Tool string `json:"tool,omitempty"` - Stage string `json:"stage,omitempty"` - Status string `json:"status,omitempty"` - Data map[string]any `json:"data,omitempty"` - Error string `json:"error,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + Message string `json:"message,omitempty"` + Tool string `json:"tool,omitempty"` + Stage string `json:"stage,omitempty"` + Status string `json:"status,omitempty"` + Data map[string]any `json:"data,omitempty"` + Stats *tokenUsageStats `json:"stats,omitempty"` + Error string `json:"error,omitempty"` +} + +type tokenUsageStats struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + ToolPromptTokens int `json:"tool_prompt_tokens"` + ToolCompletionTokens int `json:"tool_completion_tokens"` + TotalTokens int `json:"total_tokens"` + CompletionTokensPerSec float64 `json:"completion_tokens_per_sec"` + PeakCompletionTokensPerSec float64 `json:"peak_completion_tokens_per_sec"` + Estimated bool `json:"estimated"` +} + +type tokenUsageTracker struct { + mu sync.Mutex + promptTokens int + completionTokens int + toolPromptTokens int + toolCompletionTokens int +} + +type tokenUsageContextKey struct{} + +func newTokenUsageTracker() *tokenUsageTracker { + return &tokenUsageTracker{} +} + +func contextWithTokenUsage(ctx context.Context, tracker *tokenUsageTracker) context.Context { + if tracker == nil { + return ctx + } + return context.WithValue(ctx, tokenUsageContextKey{}, tracker) +} + +func tokenUsageFromContext(ctx context.Context) *tokenUsageTracker { + tracker, _ := ctx.Value(tokenUsageContextKey{}).(*tokenUsageTracker) + return tracker +} + +func (t *tokenUsageTracker) addTool(promptTokens, completionTokens int) { + if t == nil { + return + } + t.mu.Lock() + defer t.mu.Unlock() + t.toolPromptTokens += promptTokens + t.toolCompletionTokens += completionTokens +} + +func (t *tokenUsageTracker) setModel(promptTokens, completionTokens int) { + if t == nil { + return + } + t.mu.Lock() + defer t.mu.Unlock() + t.promptTokens = promptTokens + t.completionTokens = completionTokens +} + +func (t *tokenUsageTracker) snapshot(tokensPerSecond, peakTokensPerSecond float64) tokenUsageStats { + if t == nil { + return tokenUsageStats{Estimated: true} + } + t.mu.Lock() + defer t.mu.Unlock() + total := t.promptTokens + t.completionTokens + t.toolPromptTokens + t.toolCompletionTokens + return tokenUsageStats{ + PromptTokens: t.promptTokens, + CompletionTokens: t.completionTokens, + ToolPromptTokens: t.toolPromptTokens, + ToolCompletionTokens: t.toolCompletionTokens, + TotalTokens: total, + CompletionTokensPerSec: tokensPerSecond, + PeakCompletionTokensPerSec: peakTokensPerSecond, + Estimated: true, + } } // ─── 路由 ───────────────────────────────────────────────── @@ -688,9 +703,11 @@ func switchSearchHandler(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + profile.APIKey = "" + profile.Active = true c.JSON(http.StatusOK, gin.H{ "active": profile.Name, - "profile": publicSearchConfig(profile, true), + "profile": profile, }) } @@ -775,32 +792,23 @@ func chatHandler(c *gin.Context) { timeout := time.Duration(profile.Config.Timeout) * time.Second ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) defer cancel() + usage := newTokenUsageTracker() + ctx = contextWithTokenUsage(ctx, usage) chatMessages := req.Messages - if sqlState != nil && sqlState.Enabled() { - withSQL, err := enrichMessagesWithSQL(ctx, profile, chatMessages, emit) - if err != nil { - fmt.Fprintln(os.Stderr, "SQL 查询插件调用失败:", err) - emitTrace("sql", "error", "error", "数据库查询插件调用失败,将继续普通回答", map[string]any{"error": err.Error()}) - } else { - chatMessages = withSQL - } + withTools, err := enrichMessagesWithRoutedTools(ctx, profile, chatMessages, emit) + if err != nil { + fmt.Fprintln(os.Stderr, "工具路由调用失败:", err) + } else { + chatMessages = withTools } - if req.WebSearch { - withSearch, err := enrichMessagesWithSearch(ctx, chatMessages, emit) - if err != nil { - emitError(err) - return - } - chatMessages = withSearch - } - // 构建 ark 消息列表 messages, err := buildArkMessages(chatMessages) if err != nil { emitError(err) return } + promptTokens := estimateChatMessagesTokens(chatMessages) emitTrace("model", "request", "running", "正在调用模型生成回答", nil) stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{ @@ -816,14 +824,34 @@ func chatHandler(c *gin.Context) { emitTrace("model", "stream", "running", "模型已开始输出", nil) var full strings.Builder + completionTokens := 0 + streamStarted := time.Now() + windowStarted := streamStarted + windowTokens := 0 + peakTokensPerSecond := 0.0 for { resp, err := stream.Recv() if errors.Is(err, io.EOF) { + usage.setModel(promptTokens, completionTokens) + if windowTokens > 0 { + windowElapsed := time.Since(windowStarted).Seconds() + if windowElapsed > 0.25 { + windowSpeed := float64(windowTokens) / windowElapsed + if windowSpeed > peakTokensPerSecond { + peakTokensPerSecond = windowSpeed + } + } + } + if peakTokensPerSecond == 0 { + peakTokensPerSecond = tokensPerSecond(completionTokens, streamStarted) + } if req.ConversationID != "" { if err := saveConversationMessages(req.ConversationID, req.Messages, full.String()); err != nil { fmt.Fprintln(os.Stderr, "保存对话失败:", err) } } + finalStats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond) + emit(chatSSEFrame{Type: "stats", Stats: &finalStats}) emitTrace("model", "stream", "success", "回答生成完成", nil) fmt.Fprintf(c.Writer, "data: [DONE]\n\n") flusher.Flush() @@ -836,8 +864,25 @@ func chatHandler(c *gin.Context) { if len(resp.Choices) > 0 { delta := resp.Choices[0].Delta.Content if delta != "" { + now := time.Now() + deltaTokens := estimateTokenCount(delta) + windowTokens += deltaTokens + windowElapsed := now.Sub(windowStarted).Seconds() + if windowElapsed >= 1 { + windowSpeed := float64(windowTokens) / windowElapsed + if windowSpeed > peakTokensPerSecond { + peakTokensPerSecond = windowSpeed + } + windowStarted = now + windowTokens = 0 + } else if peakTokensPerSecond == 0 && windowElapsed > 0.25 { + peakTokensPerSecond = float64(windowTokens) / windowElapsed + } full.WriteString(delta) - emit(chatSSEFrame{Type: "delta", Text: delta}) + completionTokens += deltaTokens + usage.setModel(promptTokens, completionTokens) + stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond) + emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats}) } } } @@ -845,70 +890,6 @@ func chatHandler(c *gin.Context) { // ─── 辅助函数 ───────────────────────────────────────────── -type searchResult struct { - Title string `json:"title"` - URL string `json:"url"` - Description string `json:"description"` -} - -type braveSearchResponse struct { - Web struct { - Results []searchResult `json:"results"` - } `json:"web"` -} - -type duckDuckGoResponse struct { - Abstract string `json:"Abstract"` - AbstractSource string `json:"AbstractSource"` - AbstractURL string `json:"AbstractURL"` - Heading string `json:"Heading"` - RelatedTopics []struct { - Text string `json:"Text"` - FirstURL string `json:"FirstURL"` - Topics []struct { - Text string `json:"Text"` - FirstURL string `json:"FirstURL"` - } `json:"Topics"` - } `json:"RelatedTopics"` - Infobox struct { - Content []struct { - Label string `json:"label"` - Value string `json:"value"` - } `json:"content"` - } `json:"Infobox"` -} - -func enrichMessagesWithSearch(ctx context.Context, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) { - searchConfig := searchState.ActiveProfile() - if !searchConfig.Enabled { - return nil, errors.New("联网搜索未启用,请先在 config.yaml 中配置 search.enabled") - } - - query := latestUserQuery(messages) - if query == "" { - return nil, errors.New("联网搜索需要输入文本问题") - } - - emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "running", Message: "正在联网搜索", Data: map[string]any{"provider": searchConfig.Provider}}) - results, err := webSearch(ctx, searchConfig, query) - if err != nil { - emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索失败", Data: map[string]any{"error": err.Error()}}) - return nil, err - } - if len(results) == 0 { - err := errors.New("未搜索到相关网页结果") - emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "error", Message: err.Error()}) - return nil, err - } - - emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "success", Message: fmt.Sprintf("联网搜索完成,找到 %d 条结果", len(results)), Data: map[string]any{"provider": searchConfig.Provider, "count": len(results)}}) - searchContext := buildSearchContext(searchConfig, query, results) - withSearch := make([]ChatMessage, 0, len(messages)+1) - withSearch = append(withSearch, ChatMessage{Role: "system", Content: searchContext, Hidden: true}) - withSearch = append(withSearch, messages...) - return withSearch, nil -} - func latestUserQuery(messages []ChatMessage) string { for i := len(messages) - 1; i >= 0; i-- { if messages[i].Role == "user" { @@ -918,9 +899,124 @@ func latestUserQuery(messages []ChatMessage) string { return "" } -type sqlActivationDecision struct { - Activate bool `json:"activate"` - Reason string `json:"reason"` +func estimateChatMessagesTokens(messages []ChatMessage) int { + total := 0 + for _, msg := range messages { + total += estimateTokenCount(msg.Role) + estimateTokenCount(msg.Content) + 4 + if msg.ImageURL != "" || msg.ImageURLAlias != "" { + total += 85 + } + } + return total +} + +func estimateTokenCount(text string) int { + text = strings.TrimSpace(text) + if text == "" { + return 0 + } + tokens := 0 + asciiRunes := 0 + flushASCII := func() { + if asciiRunes > 0 { + tokens += (asciiRunes + 3) / 4 + asciiRunes = 0 + } + } + for _, r := range text { + if unicode.IsSpace(r) { + flushASCII() + continue + } + if r <= unicode.MaxASCII { + asciiRunes++ + continue + } + flushASCII() + tokens++ + } + flushASCII() + if tokens == 0 { + return 1 + } + return tokens +} + +func tokensPerSecond(tokens int, start time.Time) float64 { + elapsed := time.Since(start).Seconds() + if tokens <= 0 || elapsed <= 0 { + return 0 + } + return float64(tokens) / elapsed +} + +type ToolSelection struct { + Name string `json:"name"` + Reason string `json:"reason"` +} + +type ToolRoutingDecision struct { + Tools []ToolSelection `json:"tools"` + Reason string `json:"reason"` +} + +type ChatTool interface { + Name() string + Description() string + Enabled() bool + Enrich(context.Context, *OpenAIProfile, []ChatMessage, string, func(chatSSEFrame)) ([]ChatMessage, error) +} + +type TimeChatTool struct{} + +func (t TimeChatTool) Name() string { return "time" } + +func (t TimeChatTool) Description() string { + return timeagent.ActivationPrompt +} + +func (t TimeChatTool) Enabled() bool { return true } + +func (t TimeChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) { + return runTimeTool(ctx, messages, routeReason, emit) +} + +type SQLChatTool struct { + state *sqlquery.State +} + +func (t SQLChatTool) Name() string { return "sql" } + +func (t SQLChatTool) Description() string { + if t.state == nil { + return "" + } + return t.state.ActivationPrompt() +} + +func (t SQLChatTool) Enabled() bool { return t.state != nil && t.state.Enabled() } + +func (t SQLChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) { + return runSQLTool(ctx, t.state, profile, messages, routeReason, emit) +} + +type SearchChatTool struct { + state *searchagent.State +} + +func (t SearchChatTool) Name() string { return "search" } + +func (t SearchChatTool) Description() string { + if t.state == nil { + return "" + } + return t.state.ActivationPrompt() +} + +func (t SearchChatTool) Enabled() bool { return t.state != nil && t.state.Enabled() } + +func (t SearchChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) { + return runSearchTool(ctx, t.state, messages, routeReason, emit) } type sqlGenerationResult struct { @@ -929,29 +1025,17 @@ type sqlGenerationResult struct { Reason string `json:"reason"` } -func enrichMessagesWithSQL(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) { +func runSQLTool(ctx context.Context, state *sqlquery.State, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) { query := latestUserQuery(messages) if query == "" { return messages, nil } - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "running", Message: "正在判断是否需要查询数据库"}) - activate, reason, err := classifySQLActivation(ctx, profile, messages) - if err != nil { - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "error", Message: "数据库查询判断失败", Data: map[string]any{"error": err.Error()}}) - return messages, err - } - if !activate { - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "success", Message: "判断结果:本轮无需查询数据库", Data: map[string]any{"activate": false, "reason": reason}}) - return messages, nil - } - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "classify", Status: "success", Message: "判断结果:需要查询数据库", Data: map[string]any{"activate": true, "reason": reason}}) - emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "running", Message: "正在读取数据库结构"}) - schemaContext, err := sqlState.SchemaContext(ctx) + schemaContext, err := state.SchemaContext(ctx) if err != nil { emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "error", Message: "数据库结构读取失败", Data: map[string]any{"error": err.Error()}}) - return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil + return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil } emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "success", Message: "数据库结构读取完成"}) @@ -959,59 +1043,230 @@ func enrichMessagesWithSQL(ctx context.Context, profile *OpenAIProfile, messages generated, err := generateSQLForUserQuery(ctx, profile, query, schemaContext) if err != nil { emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "SQL 生成失败", Data: map[string]any{"error": err.Error()}}) - return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil + return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil } generated.Database = strings.TrimSpace(generated.Database) generated.SQL = strings.TrimSpace(generated.SQL) if generated.SQL == "" { err := fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason) emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "模型未生成可执行 SQL", Data: map[string]any{"reason": generated.Reason}}) - return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil + return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil } emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "running", Message: "正在执行数据库查询", Data: map[string]any{"database": generated.Database}}) - result, err := sqlState.ExecuteReadOnly(ctx, generated.Database, generated.SQL) + result, err := state.ExecuteReadOnly(ctx, generated.Database, generated.SQL) if err != nil { emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "error", Message: "数据库查询失败", Data: map[string]any{"error": err.Error()}}) - return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil + return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil } emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "success", Message: "已生成只读 SQL", Data: map[string]any{"database": generated.Database, "sql": generated.SQL, "reason": generated.Reason}}) emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "success", Message: fmt.Sprintf("数据库查询完成,返回 %d 行", len(result.Rows)), Data: map[string]any{"database": result.Database, "rows": len(result.Rows), "columns": len(result.Columns), "truncated": result.Truncated, "max_rows": result.MaxRows}}) contextText := sqlquery.BuildResultContext(query, generated.SQL, result) - if strings.TrimSpace(reason) != "" { - contextText += "\n激活原因:" + reason + if strings.TrimSpace(routeReason) != "" { + contextText += "\n激活原因:" + routeReason } - return prependSQLContext(messages, contextText), nil + return prependHiddenContext(messages, contextText), nil } -func prependSQLContext(messages []ChatMessage, content string) []ChatMessage { - withSQL := make([]ChatMessage, 0, len(messages)+1) - withSQL = append(withSQL, ChatMessage{Role: "system", Content: content, Hidden: true}) - withSQL = append(withSQL, messages...) - return withSQL +func prependHiddenContext(messages []ChatMessage, content string) []ChatMessage { + withContext := make([]ChatMessage, 0, len(messages)+1) + withContext = append(withContext, ChatMessage{Role: "system", Content: content, Hidden: true}) + withContext = append(withContext, messages...) + return withContext } -func classifySQLActivation(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage) (bool, string, error) { +func runTimeTool(ctx context.Context, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) { + _ = ctx + resolved := timeagent.Resolve(time.Now()) + emit(chatSSEFrame{Type: "trace", Tool: "time", Stage: "resolve", Status: "success", Message: "已获取当前时间上下文", Data: map[string]any{ + "today": timeagent.FormatDate(resolved.Now), + "this_month": fmt.Sprintf("%s 至 %s", timeagent.FormatDate(resolved.ThisMonth.Start), timeagent.FormatDate(resolved.ThisMonth.End.AddDate(0, 0, -1))), + }}) + return prependHiddenContext(messages, timeagent.BuildContext(resolved, routeReason)), nil +} + +func runSearchTool(ctx context.Context, state *searchagent.State, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) { query := latestUserQuery(messages) - prompt := fmt.Sprintf("%s\n\n最新用户问题:%s", sqlState.ActivationPrompt(), query) - text, err := completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, 512) + if query == "" { + return messages, nil + } + if state == nil || !state.Enabled() { + err := errors.New("联网搜索未启用") + emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索未启用", Data: map[string]any{"error": err.Error()}}) + return prependHiddenContext(messages, searchagent.BuildErrorContext(query, err)), nil + } + active := state.ActiveProfile() + emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "running", Message: "正在联网搜索", Data: map[string]any{"provider": active.Provider}}) + results, profile, err := state.Search(ctx, query) if err != nil { - return false, "", err + emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索失败", Data: map[string]any{"error": err.Error()}}) + return prependHiddenContext(messages, searchagent.BuildErrorContext(query, err)), nil } - var decision sqlActivationDecision + if len(results) == 0 { + err := errors.New("未搜索到相关网页结果") + emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "error", Message: err.Error()}) + return prependHiddenContext(messages, searchagent.BuildErrorContext(query, err)), nil + } + emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "success", Message: fmt.Sprintf("联网搜索完成,找到 %d 条结果", len(results)), Data: map[string]any{"provider": profile.Provider, "count": len(results)}}) + return prependHiddenContext(messages, searchagent.BuildResultContext(profile, query, results, routeReason)), nil +} + +func enrichMessagesWithRoutedTools(ctx context.Context, chatProfile *OpenAIProfile, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) { + if toolRouterState == nil || toolRouterState.cfg == nil || !toolRouterState.cfg.Enabled { + return messages, nil + } + if latestUserQuery(messages) == "" { + return messages, nil + } + tools := availableChatTools(toolRouterState.cfg) + if len(tools) == 0 { + return messages, nil + } + + emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "running", Message: "正在进行工具路由"}) + decision, err := routeTools(ctx, toolRouterState, chatProfile, messages, tools) + if err != nil { + emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "error", Message: "工具路由失败,将继续普通回答", Data: map[string]any{"error": err.Error()}}) + return messages, err + } + selected := filterToolSelections(decision, tools, toolRouterState.cfg.Tools) + if len(selected) == 0 { + emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "success", Message: "工具路由结果:无需调用工具", Data: map[string]any{"reason": decision.Reason}}) + return messages, nil + } + + names := make([]string, 0, len(selected)) + for _, item := range selected { + names = append(names, item.Name) + } + emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "success", Message: "工具路由结果:将调用 " + strings.Join(names, ", "), Data: map[string]any{"tools": names, "reason": decision.Reason}}) + + current := messages + for _, item := range selected { + tool := tools[item.Name] + next, err := tool.Enrich(ctx, chatProfile, current, firstNonEmpty(item.Reason, decision.Reason), emit) + if err != nil { + emit(chatSSEFrame{Type: "trace", Tool: item.Name, Stage: "error", Status: "error", Message: "工具调用失败,将继续普通回答", Data: map[string]any{"error": err.Error()}}) + continue + } + current = next + } + return current, nil +} + +func availableChatTools(config *ToolRouterConfig) map[string]ChatTool { + configured := map[string]ToolRouteConfig{} + for _, item := range config.Tools { + configured[item.Name] = item + } + registered := []ChatTool{ + TimeChatTool{}, + SearchChatTool{state: searchState}, + SQLChatTool{state: sqlState}, + } + available := map[string]ChatTool{} + for _, tool := range registered { + name := tool.Name() + item, ok := configured[name] + if !ok || !item.Enabled || !tool.Enabled() { + continue + } + available[name] = tool + } + return available +} + +func routeTools(ctx context.Context, state *ToolRouterState, chatProfile *OpenAIProfile, messages []ChatMessage, tools map[string]ChatTool) (ToolRoutingDecision, error) { + routerProfile := chatProfile + if strings.TrimSpace(state.cfg.OpenAIName) != "" { + profile, err := state.ai.GetProfile(state.cfg.OpenAIName) + if err != nil { + return ToolRoutingDecision{}, err + } + routerProfile = profile + } + prompt := buildToolRouterPrompt(state.cfg, messages, tools) + text, err := state.complete(ctx, routerProfile, []ChatMessage{{Role: "system", Content: prompt}}, state.cfg.MaxTokens, time.Duration(state.cfg.Timeout)*time.Second) + if err != nil { + return ToolRoutingDecision{}, err + } + return parseToolRoutingDecision(text) +} + +func buildToolRouterPrompt(config *ToolRouterConfig, messages []ChatMessage, tools map[string]ChatTool) string { + query := latestUserQuery(messages) + var b strings.Builder + b.WriteString(strings.TrimSpace(config.SystemPrompt)) + b.WriteString("\n\n可用工具:\n") + for _, item := range config.Tools { + tool, ok := tools[item.Name] + if !ok { + continue + } + description := strings.TrimSpace(item.Description) + if description == "" { + description = tool.Description() + } + fmt.Fprintf(&b, "- name: %s\n description: %s\n", item.Name, description) + } + fmt.Fprintf(&b, "\n最新用户问题:%s", query) + return b.String() +} + +func parseToolRoutingDecision(text string) (ToolRoutingDecision, error) { + var decision ToolRoutingDecision if err := json.Unmarshal([]byte(extractJSONObject(text)), &decision); err != nil { - return false, "", fmt.Errorf("解析 SQL 查询激活结果失败: %w", err) + return decision, fmt.Errorf("解析工具路由结果失败: %w", err) } - return decision.Activate, decision.Reason, nil + for i := range decision.Tools { + decision.Tools[i].Name = strings.ToLower(strings.TrimSpace(decision.Tools[i].Name)) + decision.Tools[i].Reason = strings.TrimSpace(decision.Tools[i].Reason) + } + decision.Reason = strings.TrimSpace(decision.Reason) + return decision, nil +} + +func filterToolSelections(decision ToolRoutingDecision, tools map[string]ChatTool, order []ToolRouteConfig) []ToolSelection { + selected := map[string]ToolSelection{} + for _, item := range decision.Tools { + if item.Name == "" { + continue + } + if _, ok := tools[item.Name]; !ok { + continue + } + if _, ok := selected[item.Name]; !ok { + selected[item.Name] = item + } + } + result := make([]ToolSelection, 0, len(selected)) + for _, item := range order { + if selection, ok := selected[item.Name]; ok { + result = append(result, selection) + } + } + return result +} + +func firstNonEmpty(items ...string) string { + for _, item := range items { + if strings.TrimSpace(item) != "" { + return strings.TrimSpace(item) + } + } + return "" } func generateSQLForUserQuery(ctx context.Context, profile *OpenAIProfile, userQuery string, schemaContext string) (*sqlGenerationResult, error) { - prompt := fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题和数据库 schema 生成一条只读 SQL。 + prompt := fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题、隐藏上下文和数据库 schema 生成一条只读 SQL。 要求: - 只能返回 JSON,不要使用 Markdown。 - JSON 格式:{"database":"数据库名称","sql":"SELECT ... LIMIT N","reason":"生成原因"} - 只能生成 SELECT 或 WITH 查询,禁止 INSERT/UPDATE/DELETE/DROP/ALTER/CREATE 等任何修改语句。 - 必须只使用 schema 中出现的数据库、表和字段。 +- 如果隐藏上下文中包含“时间工具结果”,必须使用其中的绝对日期范围解释用户问题里的今天、明天、昨天、本周、本月、本年、最近等相对时间。 +- 用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,优先查询 tab_calendar_events 表;如果 schema 中没有该表,再返回无法根据已知表结构生成查询。 +- 查询日程表时,涉及日期范围必须使用半开区间:时间字段 >= start AND 时间字段 < end_exclusive;时间字段必须从 schema 中选择真实存在的字段。 - 必须添加 LIMIT,且 LIMIT 不超过插件配置的 max_rows。 - 如果无法根据 schema 回答,返回 {"database":"","sql":"","reason":"无法根据已知表结构生成查询"}。 @@ -1030,11 +1285,14 @@ func generateSQLForUserQuery(ctx context.Context, profile *OpenAIProfile, userQu } func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) { + return completeTextWithTimeout(ctx, profile, chatMessages, maxTokens, time.Duration(profile.Config.Timeout)*time.Second) +} + +func completeTextWithTimeout(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int, timeout time.Duration) (string, error) { messages, err := buildArkMessages(chatMessages) if err != nil { return "", err } - timeout := time.Duration(profile.Config.Timeout) * time.Second completionCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() stream, err := profile.Client.CreateChatCompletionStream(completionCtx, model.CreateChatCompletionRequest{ @@ -1047,17 +1305,24 @@ func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []Ch } defer stream.Close() + promptTokens := estimateChatMessagesTokens(chatMessages) + completionTokens := 0 var b strings.Builder for { resp, err := stream.Recv() if errors.Is(err, io.EOF) { + if tracker := tokenUsageFromContext(ctx); tracker != nil { + tracker.addTool(promptTokens, completionTokens) + } return b.String(), nil } if err != nil { return "", err } if len(resp.Choices) > 0 { - b.WriteString(resp.Choices[0].Delta.Content) + delta := resp.Choices[0].Delta.Content + b.WriteString(delta) + completionTokens += estimateTokenCount(delta) } } } @@ -1072,182 +1337,6 @@ func extractJSONObject(text string) string { return text } -func webSearch(ctx context.Context, config SearchConfig, query string) ([]searchResult, error) { - switch strings.ToLower(config.Provider) { - case "duckduckgo", "ddg": - return duckDuckGoSearch(ctx, config, query) - case "brave": - return braveWebSearch(ctx, config, query) - default: - return nil, fmt.Errorf("暂不支持搜索服务: %s", config.Provider) - } -} - -func duckDuckGoSearch(ctx context.Context, config SearchConfig, query string) ([]searchResult, error) { - searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second) - defer cancel() - - u, err := url.Parse(config.BaseURL) - if err != nil { - return nil, fmt.Errorf("搜索服务地址无效: %w", err) - } - q := u.Query() - q.Set("q", query) - q.Set("format", "json") - q.Set("no_html", "1") - q.Set("skip_disambig", "1") - u.RawQuery = q.Encode() - - req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil) - if err != nil { - return nil, fmt.Errorf("创建搜索请求失败: %w", err) - } - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", "aichat/1.0") - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("联网搜索失败: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024)) - if err != nil { - return nil, fmt.Errorf("读取搜索响应失败: %w", err) - } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) - } - - var parsed duckDuckGoResponse - if err := json.Unmarshal(body, &parsed); err != nil { - return nil, fmt.Errorf("解析搜索响应失败: %w", err) - } - - limit := config.Count - if limit <= 0 { - limit = defaultSearchCount - } - results := make([]searchResult, 0, limit) - if strings.TrimSpace(parsed.Abstract) != "" { - title := strings.TrimSpace(parsed.Heading) - if title == "" { - title = strings.TrimSpace(parsed.AbstractSource) - } - if title == "" { - title = "DuckDuckGo 摘要" - } - results = append(results, searchResult{Title: title, URL: strings.TrimSpace(parsed.AbstractURL), Description: strings.TrimSpace(parsed.Abstract)}) - } - for _, item := range parsed.Infobox.Content { - if len(results) >= limit { - break - } - if strings.TrimSpace(item.Label) == "" || strings.TrimSpace(item.Value) == "" { - continue - } - results = append(results, searchResult{Title: item.Label, Description: item.Value}) - } - appendRelated := func(text, firstURL string) { - if len(results) >= limit || strings.TrimSpace(text) == "" { - return - } - title, desc := splitDuckDuckGoText(text) - results = append(results, searchResult{Title: title, URL: strings.TrimSpace(firstURL), Description: desc}) - } - for _, topic := range parsed.RelatedTopics { - appendRelated(topic.Text, topic.FirstURL) - for _, nested := range topic.Topics { - appendRelated(nested.Text, nested.FirstURL) - } - if len(results) >= limit { - break - } - } - return results, nil -} - -func splitDuckDuckGoText(text string) (string, string) { - text = strings.TrimSpace(text) - parts := strings.SplitN(text, " - ", 2) - if len(parts) == 2 { - return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) - } - runes := []rune(text) - if len(runes) > 42 { - return string(runes[:42]) + "...", text - } - return text, text -} - -func braveWebSearch(ctx context.Context, config SearchConfig, query string) ([]searchResult, error) { - if config.APIKey == "" { - return nil, errors.New("Brave 搜索未配置 API Key,请设置 search.api_key 或环境变量 BRAVE_SEARCH_API_KEY") - } - searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second) - defer cancel() - - u, err := url.Parse(config.BaseURL) - if err != nil { - return nil, fmt.Errorf("搜索服务地址无效: %w", err) - } - q := u.Query() - q.Set("q", query) - q.Set("count", fmt.Sprintf("%d", config.Count)) - q.Set("search_lang", "zh-hans") - q.Set("country", "CN") - u.RawQuery = q.Encode() - - req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil) - if err != nil { - return nil, fmt.Errorf("创建搜索请求失败: %w", err) - } - req.Header.Set("Accept", "application/json") - req.Header.Set("X-Subscription-Token", config.APIKey) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return nil, fmt.Errorf("联网搜索失败: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024)) - if err != nil { - return nil, fmt.Errorf("读取搜索响应失败: %w", err) - } - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) - } - - var parsed braveSearchResponse - if err := json.Unmarshal(body, &parsed); err != nil { - return nil, fmt.Errorf("解析搜索响应失败: %w", err) - } - return parsed.Web.Results, nil -} - -func buildSearchContext(config SearchConfig, query string, results []searchResult) string { - var b strings.Builder - fmt.Fprintf(&b, "用户开启了联网搜索。当前搜索源: %s(%s)。请优先根据以下搜索结果回答,并在合适位置标注来源链接。\n", config.Name, config.Provider) - if strings.ToLower(config.Provider) == "duckduckgo" || strings.ToLower(config.Provider) == "ddg" { - fmt.Fprintln(&b, "注意:DuckDuckGo 即时答案不是全量网页搜索,结果可能较少。") - } - fmt.Fprintf(&b, "搜索时间: %s\n", time.Now().Format("2006-01-02 15:04:05")) - fmt.Fprintf(&b, "搜索词: %s\n\n", query) - fmt.Fprintln(&b, "搜索结果:") - for i, r := range results { - fmt.Fprintf(&b, "%d. 标题: %s\n", i+1, strings.TrimSpace(r.Title)) - if strings.TrimSpace(r.URL) != "" { - fmt.Fprintf(&b, " 链接: %s\n", strings.TrimSpace(r.URL)) - } - if strings.TrimSpace(r.Description) != "" { - fmt.Fprintf(&b, " 摘要: %s\n", strings.TrimSpace(r.Description)) - } - } - fmt.Fprintln(&b, "\n如果搜索结果不足以回答,请明确说明不确定,不要编造。") - return b.String() -} - func newUUID() string { b := make([]byte, 16) _, _ = rand.Read(b) @@ -1556,12 +1645,17 @@ func main() { fmt.Fprintln(os.Stderr, "OpenAI 配置初始化失败:", err) os.Exit(1) } - searchState, err = NewSearchState(cfg.Search) + searchConfig, err := searchagent.LoadConfig("agents/search/config.yaml", legacySearchProfiles) if err != nil { - fmt.Fprintln(os.Stderr, "搜索配置初始化失败:", err) + fmt.Fprintln(os.Stderr, "联网搜索配置加载失败:", err) os.Exit(1) } - sqlConfig, err := sqlquery.LoadConfig("agents/SQL_query/config.yaml") + searchState, err = searchagent.NewState(searchConfig) + if err != nil { + fmt.Fprintln(os.Stderr, "联网搜索初始化失败:", err) + os.Exit(1) + } + sqlConfig, err := sqlquery.LoadConfig("agents/sql/config.yaml") if err != nil { fmt.Fprintln(os.Stderr, "SQL 查询插件配置加载失败:", err) os.Exit(1) @@ -1572,6 +1666,11 @@ func main() { os.Exit(1) } defer sqlState.Close() + toolRouterState, err = NewToolRouterState(&cfg.ToolRouter, aiState) + if err != nil { + fmt.Fprintln(os.Stderr, "工具路由配置初始化失败:", err) + os.Exit(1) + } store = NewConvStore("conversations") // Gin 路由 diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..6f50605 --- /dev/null +++ b/main_test.go @@ -0,0 +1,217 @@ +package main + +import ( + "context" + "errors" + "strings" + "testing" + "time" +) + +type fakeChatTool struct { + name string + description string + enabled bool +} + +func (t fakeChatTool) Name() string { return t.name } +func (t fakeChatTool) Description() string { return t.description } +func (t fakeChatTool) Enabled() bool { return t.enabled } +func (t fakeChatTool) Enrich(context.Context, *OpenAIProfile, []ChatMessage, string, func(chatSSEFrame)) ([]ChatMessage, error) { + return nil, nil +} + +func TestNormalizeToolRouterConfigDefaults(t *testing.T) { + cfg := &Config{ToolRouter: ToolRouterConfig{Enabled: true}} + changed, err := normalizeToolRouterConfig(cfg) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Fatal("expected defaults to change config") + } + if cfg.ToolRouter.Timeout != defaultToolRouterTimeout { + t.Fatalf("timeout = %d", cfg.ToolRouter.Timeout) + } + if cfg.ToolRouter.MaxTokens != defaultToolRouterMaxTokens { + t.Fatalf("max_tokens = %d", cfg.ToolRouter.MaxTokens) + } + if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) == "" { + t.Fatal("system prompt should be defaulted") + } + if len(cfg.ToolRouter.Tools) != 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[1].Name != "search" || cfg.ToolRouter.Tools[2].Name != "sql" || !cfg.ToolRouter.Tools[0].Enabled || !cfg.ToolRouter.Tools[1].Enabled || !cfg.ToolRouter.Tools[2].Enabled { + t.Fatalf("unexpected tools: %#v", cfg.ToolRouter.Tools) + } +} + +func TestNormalizeToolRouterConfigAddsTimeBeforeSQL(t *testing.T) { + cfg := &Config{ToolRouter: ToolRouterConfig{ + Enabled: true, + Timeout: 1, + MaxTokens: 1, + SystemPrompt: "route", + Tools: []ToolRouteConfig{ + {Name: "search", Enabled: true}, + {Name: "sql", Enabled: true}, + }, + }} + changed, err := normalizeToolRouterConfig(cfg) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Fatal("expected time tool to be added") + } + if len(cfg.ToolRouter.Tools) < 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[2].Name != "sql" { + t.Fatalf("unexpected tool order: %#v", cfg.ToolRouter.Tools) + } +} + +func TestNormalizeToolRouterConfigDuplicateTools(t *testing.T) { + cfg := &Config{ToolRouter: ToolRouterConfig{ + Enabled: true, + Timeout: 1, + MaxTokens: 1, + SystemPrompt: "route", + Tools: []ToolRouteConfig{ + {Name: "sql", Enabled: true}, + {Name: " SQL ", Enabled: true}, + }, + }} + _, err := normalizeToolRouterConfig(cfg) + if err == nil { + t.Fatal("expected duplicate tool error") + } +} + +func TestParseToolRoutingDecision(t *testing.T) { + decision, err := parseToolRoutingDecision("```json\n{\"tools\":[{\"name\":\" SQL \",\"reason\":\" 需要查库 \"}],\"reason\":\" 总原因 \"}\n```") + if err != nil { + t.Fatal(err) + } + if len(decision.Tools) != 1 || decision.Tools[0].Name != "sql" || decision.Tools[0].Reason != "需要查库" { + t.Fatalf("unexpected decision: %#v", decision) + } + if decision.Reason != "总原因" { + t.Fatalf("reason = %q", decision.Reason) + } + + if _, err := parseToolRoutingDecision("not json"); err == nil { + t.Fatal("expected malformed JSON error") + } +} + +func TestFilterToolSelections(t *testing.T) { + tools := map[string]ChatTool{ + "time": fakeChatTool{name: "time", enabled: true}, + "sql": fakeChatTool{name: "sql", enabled: true}, + "search": fakeChatTool{name: "search", enabled: true}, + } + decision := ToolRoutingDecision{Tools: []ToolSelection{ + {Name: "unknown", Reason: "ignore"}, + {Name: "search", Reason: "second in config"}, + {Name: "sql", Reason: "third in config"}, + {Name: "time", Reason: "first in config"}, + {Name: "sql", Reason: "duplicate"}, + }} + selected := filterToolSelections(decision, tools, []ToolRouteConfig{{Name: "time"}, {Name: "search"}, {Name: "sql"}}) + if len(selected) != 3 { + t.Fatalf("selected length = %d", len(selected)) + } + if selected[0].Name != "time" || selected[0].Reason != "first in config" { + t.Fatalf("first selection = %#v", selected[0]) + } + if selected[1].Name != "search" { + t.Fatalf("second selection = %#v", selected[1]) + } + if selected[2].Name != "sql" || selected[2].Reason != "third in config" { + t.Fatalf("third selection = %#v", selected[2]) + } +} + +func TestRunTimeToolAddsHiddenDateRanges(t *testing.T) { + messages := []ChatMessage{{Role: "user", Content: "本月有什么日程安排"}} + withTime, err := runTimeTool(context.Background(), messages, "需要日期范围", func(chatSSEFrame) {}) + if err != nil { + t.Fatal(err) + } + if len(withTime) != 2 || !withTime[0].Hidden || withTime[0].Role != "system" { + t.Fatalf("unexpected messages: %#v", withTime) + } + for _, want := range []string{"时间工具结果", "本月", "start=", "end_exclusive=", "半开区间"} { + if !strings.Contains(withTime[0].Content, want) { + t.Fatalf("time context missing %q:\n%s", want, withTime[0].Content) + } + } +} + +func TestBuildToolRouterPrompt(t *testing.T) { + cfg := &ToolRouterConfig{ + SystemPrompt: "router", + Tools: []ToolRouteConfig{ + {Name: "time", Enabled: true}, + {Name: "sql", Enabled: true, Description: "configured sql"}, + {Name: "search", Enabled: true}, + }, + } + tools := map[string]ChatTool{ + "time": fakeChatTool{name: "time", description: "fallback time", enabled: true}, + "sql": fakeChatTool{name: "sql", description: "fallback sql", enabled: true}, + "search": fakeChatTool{name: "search", description: "fallback search", enabled: true}, + } + prompt := buildToolRouterPrompt(cfg, []ChatMessage{{Role: "user", Content: "查一下订单"}}, tools) + for _, want := range []string{"router", "name: time", "fallback time", "name: sql", "configured sql", "name: search", "fallback search", "最新用户问题:查一下订单"} { + if !strings.Contains(prompt, want) { + t.Fatalf("prompt missing %q:\n%s", want, prompt) + } + } +} + +func TestRouteToolsUsesConfiguredRouterProfileAndTimeout(t *testing.T) { + ai := &OpenAIState{ + profiles: map[string]*OpenAIProfile{ + "chat": {Config: OpenAIConfig{Name: "chat"}}, + "router": {Config: OpenAIConfig{Name: "router"}}, + }, + order: []string{"chat", "router"}, + activeName: "chat", + } + state := &ToolRouterState{cfg: &ToolRouterConfig{ + OpenAIName: "router", + Timeout: 7, + MaxTokens: 123, + SystemPrompt: "router prompt", + Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}}, + }, ai: ai} + state.complete = func(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, maxTokens int, timeout time.Duration) (string, error) { + if profile.Config.Name != "router" { + t.Fatalf("profile = %s", profile.Config.Name) + } + if maxTokens != 123 { + t.Fatalf("maxTokens = %d", maxTokens) + } + if timeout != 7*time.Second { + t.Fatalf("timeout = %s", timeout) + } + return `{"tools":[],"reason":"无需工具"}`, nil + } + decision, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}}) + if err != nil { + t.Fatal(err) + } + if len(decision.Tools) != 0 || decision.Reason != "无需工具" { + t.Fatalf("unexpected decision: %#v", decision) + } +} + +func TestRouteToolsCompleterError(t *testing.T) { + ai := &OpenAIState{profiles: map[string]*OpenAIProfile{"chat": {Config: OpenAIConfig{Name: "chat"}}}, activeName: "chat"} + state := &ToolRouterState{cfg: &ToolRouterConfig{Timeout: 1, MaxTokens: 1, SystemPrompt: "router", Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}}}, ai: ai} + state.complete = func(context.Context, *OpenAIProfile, []ChatMessage, int, time.Duration) (string, error) { + return "", errors.New("boom") + } + _, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}}) + if err == nil { + t.Fatal("expected completer error") + } +} diff --git a/templates/chat.html b/templates/chat.html index 90f57a3..a750d33 100644 --- a/templates/chat.html +++ b/templates/chat.html @@ -158,7 +158,7 @@ align-items: center; gap: 8px; } - #btnClear, #btnPreset, #btnSearch { + #btnClear, #btnPreset { background: none; border: 1px solid var(--border); color: var(--text-dim); @@ -169,13 +169,7 @@ transition: all .15s; } #btnClear:hover { border-color: var(--danger); color: var(--danger); } - #btnPreset:hover, #btnSearch:hover { border-color: var(--accent); color: var(--accent); } - #btnSearch.active { - background: var(--accent-soft); - border-color: var(--accent); - color: var(--accent); - } - + #btnPreset:hover { border-color: var(--accent); color: var(--accent); } /* ── 消息区 ── */ #messages { flex: 1; @@ -292,6 +286,13 @@ overflow-x: auto; } .answer-text { display: inline; } + .token-stats { + margin-top: 8px; + color: var(--text-dim); + font-size: 0.76rem; + white-space: normal; + } + .token-stats:empty { display: none; } /* 错误消息 */ .error-msg { @@ -513,7 +514,6 @@ - @@ -554,7 +554,7 @@ -
Enter 发送 · Shift+Enter 换行 · 支持图片多模态
+Enter 发送 · Shift+Enter 换行 · 支持图片多模态 · 工具路由会自动判断是否需要联网搜索
@@ -576,7 +576,6 @@ let history = []; // {role, content, image_url?} let currentConvId = null; let pending = false; -let webSearchEnabled = false; let openAIProfiles = []; let activeOpenAIName = '{{ .OpenAIName }}'; let searchProfiles = []; @@ -591,7 +590,6 @@ const inputBox = document.getElementById('inputBox'); const btnSend = document.getElementById('btnSend'); const btnClear = document.getElementById('btnClear'); const btnPreset = document.getElementById('btnPreset'); -const btnSearch = document.getElementById('btnSearch'); const modelSelect = document.getElementById('modelSelect'); const searchSelect = document.getElementById('searchSelect'); const btnNewChat = document.getElementById('btnNewChat'); @@ -666,16 +664,10 @@ function setInputDisabled(disabled) { btnSend.disabled = disabled; inputBox.disabled = disabled; fileInput.disabled = disabled; - btnSearch.disabled = disabled; modelSelect.disabled = disabled || openAIProfiles.length <= 1; searchSelect.disabled = disabled || searchProfiles.length <= 1; } -function updateSearchButton() { - btnSearch.classList.toggle('active', webSearchEnabled); - btnSearch.textContent = webSearchEnabled ? '联网搜索:开' : '联网搜索:关'; -} - async function loadOpenAIProfiles() { const res = await fetch('/api/openai'); if (!res.ok) { @@ -869,13 +861,42 @@ function addAIBubble() { const txt = document.createElement('span'); txt.className = 'answer-text'; + const stats = document.createElement('div'); + stats.className = 'token-stats'; bub.appendChild(trace); bub.appendChild(txt); + bub.appendChild(stats); row.appendChild(av); row.appendChild(bub); msgBox.appendChild(row); scrollToBottom(); - return { bub, txt, trace }; + return { bub, txt, trace, stats }; +} + +function formatTokenStats(stats) { + if (!stats) return ''; + const avgSpeed = typeof stats.completion_tokens_per_sec === 'number' + ? stats.completion_tokens_per_sec.toFixed(1) + : '0.0'; + const peakSpeed = typeof stats.peak_completion_tokens_per_sec === 'number' + ? stats.peak_completion_tokens_per_sec.toFixed(1) + : '0.0'; + const parts = [ + `平均 ${avgSpeed} tokens/sec`, + `最高 ${peakSpeed} tokens/sec`, + `总 token ${stats.total_tokens || 0}`, + `输入 ${stats.prompt_tokens || 0}`, + `输出 ${stats.completion_tokens || 0}`, + ]; + const toolTokens = (stats.tool_prompt_tokens || 0) + (stats.tool_completion_tokens || 0); + if (toolTokens) parts.push(`工具 ${toolTokens}`); + if (stats.estimated) parts.push('本地估算'); + return parts.join(' | '); +} + +function updateTokenStats(aiBubble, stats) { + if (!aiBubble.stats) return; + aiBubble.stats.textContent = formatTokenStats(stats); } function appendTrace(aiBubble, frame) { @@ -893,6 +914,7 @@ function appendTrace(aiBubble, frame) { if (typeof data.rows === 'number') stats.push(`行数: ${data.rows}`); if (typeof data.columns === 'number') stats.push(`列数: ${data.columns}`); if (typeof data.count === 'number') stats.push(`结果数: ${data.count}`); + if (Array.isArray(data.tools) && data.tools.length) stats.push(`工具: ${data.tools.join(', ')}`); if (data.truncated) stats.push(`已截断,最多 ${data.max_rows || ''} 行`); if (data.reason) stats.push(`原因: ${data.reason}`); if (data.error) stats.push(`错误: ${data.error}`); @@ -909,7 +931,7 @@ function appendTrace(aiBubble, frame) { scrollToBottom(); } -async function streamChat(messages, aiBubble, webSearch = false) { +async function streamChat(messages, aiBubble) { const txtEl = aiBubble.txt; let full = ''; @@ -919,8 +941,7 @@ async function streamChat(messages, aiBubble, webSearch = false) { body: JSON.stringify({ conversation_id: currentConvId, messages, - web_search: webSearch, - openai_name: activeOpenAIName, + openai_name: activeOpenAIName, }), }); @@ -965,8 +986,14 @@ async function streamChat(messages, aiBubble, webSearch = false) { if (delta) { full += delta; txtEl.innerHTML = renderMarkdown(full); - scrollToBottom(); } + updateTokenStats(aiBubble, parsed.stats); + scrollToBottom(); + continue; + } + if (parsed.type === 'stats') { + updateTokenStats(aiBubble, parsed.stats); + scrollToBottom(); continue; } if (parsed.type === 'trace') { @@ -1070,7 +1097,7 @@ async function sendMessage() { clearImage(); aiBubble = addAIBubble(); - const full = await streamChat(history, aiBubble, webSearchEnabled); + const full = await streamChat(history, aiBubble); history.push({ role: 'assistant', content: full }); scrollToBottom(); await loadConversationList(); @@ -1145,11 +1172,6 @@ inputBox.addEventListener('keydown', e => { }); btnSend.addEventListener('click', sendMessage); -btnSearch.addEventListener('click', () => { - if (pending) return; - webSearchEnabled = !webSearchEnabled; - updateSearchButton(); -}); modelSelect.addEventListener('change', async () => { if (pending) { modelSelect.value = activeOpenAIName; @@ -1218,7 +1240,6 @@ presetModal.addEventListener('click', e => { }); // 自动聚焦 & 初始化 -updateSearchButton(); loadOpenAIProfiles().catch(e => alert(e.message)); loadSearchProfiles().catch(e => alert(e.message)); loadConversationList();