支持工具链

This commit is contained in:
2026-06-10 12:07:07 +08:00
parent 1e793ce814
commit fe2477dd97
9 changed files with 1632 additions and 545 deletions
+520
View File
@@ -0,0 +1,520 @@
package search
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
"gopkg.in/yaml.v3"
)
const (
defaultActivationPrompt = `判断用户问题是否需要联网搜索。
仅当问题涉及实时信息、新闻、价格、当前版本、近期事件、政策、网页资料核验,或用户明确要求“查一下/搜索/联网/最新”时调用 search。
普通知识、闲聊、代码推理、已有上下文可回答的问题不要调用。`
defaultBaseURL = "https://api.duckduckgo.com/"
defaultTimeout = 10
defaultCount = 5
)
type Config struct {
Enabled bool `yaml:"enabled" json:"enabled"`
ActivationPrompt string `yaml:"activation_prompt" json:"activation_prompt"`
Profiles ProfileConfigs `yaml:"profiles" json:"profiles"`
}
type ProfileConfig struct {
Name string `yaml:"name" json:"name"`
Active bool `yaml:"active,omitempty" json:"active"`
Enabled bool `yaml:"enabled" json:"enabled"`
Provider string `yaml:"provider" json:"provider"`
APIKey string `yaml:"api_key" json:"-"`
BaseURL string `yaml:"base_url" json:"base_url"`
Count int `yaml:"count" json:"count"`
Timeout int `yaml:"timeout" json:"timeout"`
}
type ProfileConfigs []ProfileConfig
func (configs *ProfileConfigs) UnmarshalYAML(value *yaml.Node) error {
switch value.Kind {
case yaml.SequenceNode:
var items []ProfileConfig
if err := value.Decode(&items); err != nil {
return err
}
*configs = items
case yaml.MappingNode:
var item ProfileConfig
if err := value.Decode(&item); err != nil {
return err
}
*configs = []ProfileConfig{item}
case yaml.ScalarNode:
if value.Tag == "!!null" {
*configs = nil
return nil
}
return fmt.Errorf("search 配置格式无效")
default:
return fmt.Errorf("search 配置格式无效")
}
return nil
}
type State struct {
mu sync.RWMutex
cfg *Config
profiles map[string]ProfileConfig
order []string
activeName string
}
type Result struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
}
type ListResponse struct {
Active string `json:"active"`
Profiles []ProfileConfig `json:"profiles"`
}
type braveSearchResponse struct {
Web struct {
Results []Result `json:"results"`
} `json:"web"`
}
type duckDuckGoResponse struct {
Abstract string `json:"Abstract"`
AbstractSource string `json:"AbstractSource"`
AbstractURL string `json:"AbstractURL"`
Heading string `json:"Heading"`
RelatedTopics []struct {
Text string `json:"Text"`
FirstURL string `json:"FirstURL"`
Topics []struct {
Text string `json:"Text"`
FirstURL string `json:"FirstURL"`
} `json:"Topics"`
} `json:"RelatedTopics"`
Infobox struct {
Content []struct {
Label string `json:"label"`
Value string `json:"value"`
} `json:"content"`
} `json:"Infobox"`
}
func LoadConfig(path string, legacyProfiles ...[]ProfileConfig) (*Config, error) {
if _, err := os.Stat(path); err != nil {
if !os.IsNotExist(err) {
return nil, fmt.Errorf("检查联网搜索配置失败: %w", err)
}
cfg := defaultConfig()
if len(legacyProfiles) > 0 && len(legacyProfiles[0]) > 0 {
cfg.Profiles = legacyProfiles[0]
}
if err := normalizeConfig(&cfg); err != nil {
return nil, err
}
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return nil, fmt.Errorf("创建联网搜索配置目录失败: %w", err)
}
data, err := yaml.Marshal(&cfg)
if err != nil {
return nil, fmt.Errorf("生成联网搜索配置失败: %w", err)
}
if err := os.WriteFile(path, data, 0644); err != nil {
return nil, fmt.Errorf("写入联网搜索配置失败: %w", err)
}
}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取联网搜索配置失败: %w", err)
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("解析联网搜索配置失败: %w", err)
}
if err := normalizeConfig(&cfg); err != nil {
return nil, err
}
applyEnv(&cfg)
return &cfg, nil
}
func NewState(cfg *Config) (*State, error) {
state := &State{cfg: cfg, profiles: map[string]ProfileConfig{}}
if cfg == nil || !cfg.Enabled {
return state, nil
}
for _, config := range cfg.Profiles {
if strings.TrimSpace(config.Name) == "" {
return nil, errors.New("search.profiles.name 不能为空")
}
if strings.TrimSpace(config.Provider) == "" {
return nil, fmt.Errorf("search.%s.provider 未配置", config.Name)
}
if strings.TrimSpace(config.BaseURL) == "" {
return nil, fmt.Errorf("search.%s.base_url 未配置", config.Name)
}
if config.Timeout <= 0 {
return nil, fmt.Errorf("search.%s.timeout 必须大于 0", config.Name)
}
if _, ok := state.profiles[config.Name]; ok {
return nil, fmt.Errorf("search 配置名称重复: %s", config.Name)
}
state.profiles[config.Name] = config
state.order = append(state.order, config.Name)
if config.Active && state.activeName == "" {
state.activeName = config.Name
}
}
if len(state.order) == 0 {
return state, nil
}
if state.activeName == "" {
state.activeName = state.order[0]
}
return state, nil
}
func (s *State) Enabled() bool {
if s == nil || s.cfg == nil || !s.cfg.Enabled || s.activeName == "" {
return false
}
profile := s.ActiveProfile()
return profile.Enabled
}
func (s *State) ActivationPrompt() string {
if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.ActivationPrompt) == "" {
return defaultActivationPrompt
}
return strings.TrimSpace(s.cfg.ActivationPrompt)
}
func (s *State) ActiveProfile() ProfileConfig {
s.mu.RLock()
defer s.mu.RUnlock()
return s.profiles[s.activeName]
}
func (s *State) SwitchActive(name string) (ProfileConfig, error) {
name = strings.TrimSpace(name)
if name == "" {
return ProfileConfig{}, errors.New("搜索配置名称不能为空")
}
s.mu.Lock()
defer s.mu.Unlock()
profile, ok := s.profiles[name]
if !ok {
return ProfileConfig{}, fmt.Errorf("搜索配置不存在: %s", name)
}
s.activeName = name
return profile, nil
}
func (s *State) ListProfiles() ListResponse {
s.mu.RLock()
defer s.mu.RUnlock()
profiles := make([]ProfileConfig, 0, len(s.order))
for _, name := range s.order {
profile := s.profiles[name]
profile.APIKey = ""
profile.Active = name == s.activeName
profiles = append(profiles, profile)
}
return ListResponse{Active: s.activeName, Profiles: profiles}
}
func (s *State) Search(ctx context.Context, query string) ([]Result, ProfileConfig, error) {
if !s.Enabled() {
return nil, ProfileConfig{}, errors.New("联网搜索未启用,请先在 agents/search/config.yaml 中启用 search")
}
profile := s.ActiveProfile()
switch strings.ToLower(profile.Provider) {
case "duckduckgo", "ddg":
results, err := duckDuckGoSearch(ctx, profile, query)
return results, profile, err
case "brave":
results, err := braveWebSearch(ctx, profile, query)
return results, profile, err
default:
return nil, profile, fmt.Errorf("暂不支持搜索服务: %s", profile.Provider)
}
}
func BuildResultContext(config ProfileConfig, query string, results []Result, routeReason string) string {
var b strings.Builder
fmt.Fprintf(&b, "工具路由调用了联网搜索。当前搜索源: %s(%s)。请优先根据以下搜索结果回答,并在合适位置标注来源链接。\n", config.Name, config.Provider)
if strings.ToLower(config.Provider) == "duckduckgo" || strings.ToLower(config.Provider) == "ddg" {
fmt.Fprintln(&b, "注意:DuckDuckGo 即时答案不是全量网页搜索,结果可能较少。")
}
fmt.Fprintf(&b, "搜索时间: %s\n", time.Now().Format("2006-01-02 15:04:05"))
fmt.Fprintf(&b, "搜索词: %s\n", query)
if strings.TrimSpace(routeReason) != "" {
fmt.Fprintf(&b, "调用原因: %s\n", strings.TrimSpace(routeReason))
}
fmt.Fprintln(&b, "\n搜索结果:")
for i, r := range results {
fmt.Fprintf(&b, "%d. 标题: %s\n", i+1, strings.TrimSpace(r.Title))
if strings.TrimSpace(r.URL) != "" {
fmt.Fprintf(&b, " 链接: %s\n", strings.TrimSpace(r.URL))
}
if strings.TrimSpace(r.Description) != "" {
fmt.Fprintf(&b, " 摘要: %s\n", strings.TrimSpace(r.Description))
}
}
fmt.Fprintln(&b, "\n如果搜索结果不足以回答,请明确说明不确定,不要编造。")
return b.String()
}
func BuildErrorContext(query string, err error) string {
return fmt.Sprintf("工具路由尝试联网搜索但失败。用户问题:%s\n错误:%v\n请向用户说明联网搜索失败,不要编造搜索结果。", query, err)
}
func defaultConfig() Config {
return Config{
Enabled: true,
ActivationPrompt: defaultActivationPrompt,
Profiles: ProfileConfigs{defaultProfileConfig()},
}
}
func defaultProfileConfig() ProfileConfig {
return ProfileConfig{
Name: "duckduckgo",
Active: true,
Enabled: true,
Provider: "duckduckgo",
BaseURL: defaultBaseURL,
Count: defaultCount,
Timeout: defaultTimeout,
}
}
func normalizeConfig(cfg *Config) error {
if strings.TrimSpace(cfg.ActivationPrompt) == "" {
cfg.ActivationPrompt = defaultActivationPrompt
} else {
cfg.ActivationPrompt = strings.TrimSpace(cfg.ActivationPrompt)
}
if len(cfg.Profiles) == 0 {
cfg.Profiles = ProfileConfigs{defaultProfileConfig()}
}
activeIndex := -1
seen := map[string]bool{}
for i := range cfg.Profiles {
profile := &cfg.Profiles[i]
profile.Provider = strings.ToLower(strings.TrimSpace(profile.Provider))
if profile.Provider == "" {
profile.Provider = "duckduckgo"
}
name := strings.TrimSpace(profile.Name)
if name == "" {
name = profile.Provider
if seen[name] {
name = fmt.Sprintf("%s-%d", profile.Provider, i+1)
}
}
profile.Name = name
if seen[name] {
return fmt.Errorf("search 配置名称重复: %s", name)
}
seen[name] = true
if strings.TrimSpace(profile.BaseURL) == "" {
switch profile.Provider {
case "duckduckgo", "ddg":
profile.BaseURL = defaultBaseURL
case "brave":
profile.BaseURL = "https://api.search.brave.com/res/v1/web/search"
default:
return fmt.Errorf("暂不支持搜索服务: %s", profile.Provider)
}
}
if profile.Count <= 0 {
profile.Count = defaultCount
}
if profile.Timeout <= 0 {
profile.Timeout = defaultTimeout
}
if profile.Active {
if activeIndex == -1 {
activeIndex = i
} else {
profile.Active = false
}
}
}
if activeIndex == -1 {
cfg.Profiles[0].Active = true
}
return nil
}
func applyEnv(cfg *Config) {
if key := os.Getenv("BRAVE_SEARCH_API_KEY"); key != "" {
for i := range cfg.Profiles {
if strings.ToLower(cfg.Profiles[i].Provider) == "brave" {
cfg.Profiles[i].APIKey = key
}
}
}
}
func duckDuckGoSearch(ctx context.Context, config ProfileConfig, query string) ([]Result, error) {
searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second)
defer cancel()
u, err := url.Parse(config.BaseURL)
if err != nil {
return nil, fmt.Errorf("搜索服务地址无效: %w", err)
}
q := u.Query()
q.Set("q", query)
q.Set("format", "json")
q.Set("no_html", "1")
q.Set("skip_disambig", "1")
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, fmt.Errorf("创建搜索请求失败: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "aichat/1.0")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("联网搜索失败: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
if err != nil {
return nil, fmt.Errorf("读取搜索响应失败: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var parsed duckDuckGoResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("解析搜索响应失败: %w", err)
}
limit := config.Count
if limit <= 0 {
limit = defaultCount
}
results := make([]Result, 0, limit)
if strings.TrimSpace(parsed.Abstract) != "" {
title := strings.TrimSpace(parsed.Heading)
if title == "" {
title = strings.TrimSpace(parsed.AbstractSource)
}
if title == "" {
title = "DuckDuckGo 摘要"
}
results = append(results, Result{Title: title, URL: strings.TrimSpace(parsed.AbstractURL), Description: strings.TrimSpace(parsed.Abstract)})
}
for _, item := range parsed.Infobox.Content {
if len(results) >= limit {
break
}
if strings.TrimSpace(item.Label) == "" || strings.TrimSpace(item.Value) == "" {
continue
}
results = append(results, Result{Title: item.Label, Description: item.Value})
}
appendRelated := func(text, firstURL string) {
if len(results) >= limit || strings.TrimSpace(text) == "" {
return
}
title, desc := splitDuckDuckGoText(text)
results = append(results, Result{Title: title, URL: strings.TrimSpace(firstURL), Description: desc})
}
for _, topic := range parsed.RelatedTopics {
appendRelated(topic.Text, topic.FirstURL)
for _, nested := range topic.Topics {
appendRelated(nested.Text, nested.FirstURL)
}
if len(results) >= limit {
break
}
}
return results, nil
}
func splitDuckDuckGoText(text string) (string, string) {
text = strings.TrimSpace(text)
parts := strings.SplitN(text, " - ", 2)
if len(parts) == 2 {
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
}
runes := []rune(text)
if len(runes) > 42 {
return string(runes[:42]) + "...", text
}
return text, text
}
func braveWebSearch(ctx context.Context, config ProfileConfig, query string) ([]Result, error) {
if config.APIKey == "" {
return nil, errors.New("Brave 搜索未配置 API Key,请设置 agents/search/config.yaml 中的 api_key 或环境变量 BRAVE_SEARCH_API_KEY")
}
searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second)
defer cancel()
u, err := url.Parse(config.BaseURL)
if err != nil {
return nil, fmt.Errorf("搜索服务地址无效: %w", err)
}
q := u.Query()
q.Set("q", query)
q.Set("count", fmt.Sprintf("%d", config.Count))
q.Set("search_lang", "zh-hans")
q.Set("country", "CN")
u.RawQuery = q.Encode()
req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil)
if err != nil {
return nil, fmt.Errorf("创建搜索请求失败: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", config.APIKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("联网搜索失败: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
if err != nil {
return nil, fmt.Errorf("读取搜索响应失败: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var parsed braveSearchResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return nil, fmt.Errorf("解析搜索响应失败: %w", err)
}
return parsed.Web.Results, nil
}
+113
View File
@@ -0,0 +1,113 @@
package search
import (
"context"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
)
func TestNormalizeConfigDefaults(t *testing.T) {
cfg := &Config{Enabled: true, Profiles: ProfileConfigs{{Provider: "duckduckgo"}}}
if err := normalizeConfig(cfg); err != nil {
t.Fatal(err)
}
if strings.TrimSpace(cfg.ActivationPrompt) == "" {
t.Fatal("activation prompt should be defaulted")
}
profile := cfg.Profiles[0]
if profile.Name != "duckduckgo" || !profile.Active || profile.BaseURL != defaultBaseURL || profile.Count != defaultCount || profile.Timeout != defaultTimeout {
t.Fatalf("unexpected profile: %#v", profile)
}
}
func TestNormalizeConfigDuplicateProfile(t *testing.T) {
cfg := &Config{Enabled: true, Profiles: ProfileConfigs{{Name: "duck", Provider: "duckduckgo"}, {Name: " duck ", Provider: "duckduckgo"}}}
if err := normalizeConfig(cfg); err == nil {
t.Fatal("expected duplicate profile error")
}
}
func TestNormalizeConfigUnsupportedProvider(t *testing.T) {
cfg := &Config{Enabled: true, Profiles: ProfileConfigs{{Provider: "unknown"}}}
if err := normalizeConfig(cfg); err == nil {
t.Fatal("expected unsupported provider error")
}
}
func TestApplyEnvBraveAPIKey(t *testing.T) {
t.Setenv("BRAVE_SEARCH_API_KEY", "from-env")
cfg := &Config{Profiles: ProfileConfigs{{Name: "brave", Provider: "brave", APIKey: "from-config"}}}
applyEnv(cfg)
if cfg.Profiles[0].APIKey != "from-env" {
t.Fatalf("api key = %q", cfg.Profiles[0].APIKey)
}
}
func TestListProfilesRedactsAPIKey(t *testing.T) {
state, err := NewState(&Config{Enabled: true, Profiles: ProfileConfigs{{Name: "brave", Active: true, Enabled: true, Provider: "brave", APIKey: "secret", BaseURL: "https://example.com", Count: 1, Timeout: 1}}})
if err != nil {
t.Fatal(err)
}
list := state.ListProfiles()
if len(list.Profiles) != 1 || list.Profiles[0].APIKey != "" || !list.Profiles[0].Active {
t.Fatalf("unexpected list: %#v", list)
}
}
func TestBuildResultContext(t *testing.T) {
text := BuildResultContext(ProfileConfig{Name: "duckduckgo", Provider: "duckduckgo"}, "最新消息", []Result{{Title: "标题", URL: "https://example.com", Description: "摘要"}}, "需要最新信息")
for _, want := range []string{"联网搜索", "最新消息", "标题", "https://example.com", "需要最新信息", "不要编造"} {
if !strings.Contains(text, want) {
t.Fatalf("context missing %q:\n%s", want, text)
}
}
}
func TestDuckDuckGoSearchParsesResults(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("q") != "golang" {
t.Fatalf("query = %s", r.URL.RawQuery)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"Heading":"Go",
"Abstract":"Go is a language",
"AbstractURL":"https://go.dev",
"RelatedTopics":[{"Text":"Gopher - mascot","FirstURL":"https://go.dev/blog/gopher"}]
}`))
}))
defer server.Close()
results, err := duckDuckGoSearch(context.Background(), ProfileConfig{Provider: "duckduckgo", BaseURL: server.URL, Count: 2, Timeout: 1}, "golang")
if err != nil {
t.Fatal(err)
}
if len(results) != 2 || results[0].Title != "Go" || results[1].Title != "Gopher" {
t.Fatalf("unexpected results: %#v", results)
}
}
func TestBraveMissingAPIKey(t *testing.T) {
_, err := braveWebSearch(context.Background(), ProfileConfig{Provider: "brave", BaseURL: "https://example.com", Timeout: 1, Count: 1}, "query")
if err == nil || !strings.Contains(err.Error(), "API Key") {
t.Fatalf("unexpected error: %v", err)
}
}
func TestLoadConfigWritesLegacyProfiles(t *testing.T) {
dir := t.TempDir()
path := dir + string(os.PathSeparator) + "config.yaml"
cfg, err := LoadConfig(path, []ProfileConfig{{Name: "legacy", Active: true, Enabled: true, Provider: "duckduckgo", Count: 3, Timeout: 2}})
if err != nil {
t.Fatal(err)
}
if len(cfg.Profiles) != 1 || cfg.Profiles[0].Name != "legacy" {
t.Fatalf("unexpected cfg: %#v", cfg)
}
if _, err := os.Stat(path); err != nil {
t.Fatal(err)
}
}
@@ -22,6 +22,7 @@ import (
const (
defaultActivationPrompt = `判断用户问题是否需要查询业务数据库
仅当用户询问数据库表记录字段时间状态内容统计最近/最早/某时间范围内的数据时返回 activate=true
当用户询问日程日程安排行程会议待办今天/明天/本周有什么安排时必须返回 activate=true并说明应查询 tab_calendar_events
普通知识问答代码问题闲聊联网搜索问题返回 activate=false
只返回 JSON: {"activate": true/false, "reason": "..."}`
defaultDatabaseName = "default"
@@ -1,6 +1,15 @@
package sqlquery
import "testing"
import (
"strings"
"testing"
)
func TestDefaultActivationPromptMentionsCalendarEvents(t *testing.T) {
if !strings.Contains(defaultActivationPrompt, "tab_calendar_events") {
t.Fatal("default activation prompt should mention tab_calendar_events for calendar queries")
}
}
func TestValidateReadOnlySQLAllowsSelectAndWith(t *testing.T) {
queries := []string{
+76
View File
@@ -0,0 +1,76 @@
package timeagent
import (
"fmt"
"strings"
"time"
)
const ActivationPrompt = "提供当前日期、时间和常用时间范围。当用户问题包含今天、明天、昨天、本周、本月、本年、最近、日程安排等相对时间表达时,应先调用此工具;如果后续还需要查数据库,可继续调用 sql。"
type Range struct {
Start time.Time
End time.Time
}
type Context struct {
Now time.Time
Today Range
Tomorrow Range
Yesterday Range
ThisWeek Range
ThisMonth Range
ThisYear Range
}
func Resolve(now time.Time) Context {
loc := now.Location()
today := Range{Start: startOfDay(now), End: startOfDay(now).AddDate(0, 0, 1)}
tomorrowStart := today.End
tomorrow := Range{Start: tomorrowStart, End: tomorrowStart.AddDate(0, 0, 1)}
yesterdayEnd := today.Start
yesterday := Range{Start: yesterdayEnd.AddDate(0, 0, -1), End: yesterdayEnd}
weekdayOffset := (int(now.Weekday()) + 6) % 7
weekStart := today.Start.AddDate(0, 0, -weekdayOffset)
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, loc)
yearStart := time.Date(now.Year(), 1, 1, 0, 0, 0, 0, loc)
return Context{
Now: now,
Today: today,
Tomorrow: tomorrow,
Yesterday: yesterday,
ThisWeek: Range{Start: weekStart, End: weekStart.AddDate(0, 0, 7)},
ThisMonth: Range{Start: monthStart, End: monthStart.AddDate(0, 1, 0)},
ThisYear: Range{Start: yearStart, End: yearStart.AddDate(1, 0, 0)},
}
}
func BuildContext(ctx Context, routeReason string) string {
var b strings.Builder
b.WriteString("时间工具结果。后续工具必须优先使用这里的绝对日期解释用户问题中的相对时间,不要自行猜测当前日期。\n")
fmt.Fprintf(&b, "当前本地日期时间:%s\n", ctx.Now.Format("2006-01-02 15:04:05 MST"))
fmt.Fprintf(&b, "今天:%s\n", FormatSQLRange(ctx.Today))
fmt.Fprintf(&b, "明天:%s\n", FormatSQLRange(ctx.Tomorrow))
fmt.Fprintf(&b, "昨天:%s\n", FormatSQLRange(ctx.Yesterday))
fmt.Fprintf(&b, "本周:%s\n", FormatSQLRange(ctx.ThisWeek))
fmt.Fprintf(&b, "本月:%s\n", FormatSQLRange(ctx.ThisMonth))
fmt.Fprintf(&b, "本年:%s\n", FormatSQLRange(ctx.ThisYear))
b.WriteString("SQL 日期过滤建议:对日程/事件类表使用半开区间,例如 event_time >= start AND event_time < end;如果实际字段名不同,必须使用 schema 中存在的时间字段。\n")
if strings.TrimSpace(routeReason) != "" {
b.WriteString("激活原因:" + strings.TrimSpace(routeReason) + "\n")
}
return b.String()
}
func FormatDate(t time.Time) string {
return t.Format("2006-01-02")
}
func FormatSQLRange(r Range) string {
return fmt.Sprintf("start=%s, end_exclusive=%s", r.Start.Format("2006-01-02 15:04:05"), r.End.Format("2006-01-02 15:04:05"))
}
func startOfDay(t time.Time) time.Time {
y, m, d := t.Date()
return time.Date(y, m, d, 0, 0, 0, 0, t.Location())
}
+31
View File
@@ -0,0 +1,31 @@
package timeagent
import (
"strings"
"testing"
"time"
)
func TestResolveBuildsCalendarRanges(t *testing.T) {
loc := time.FixedZone("CST", 8*60*60)
ctx := Resolve(time.Date(2026, 6, 10, 13, 14, 15, 0, loc))
if got := FormatSQLRange(ctx.Today); got != "start=2026-06-10 00:00:00, end_exclusive=2026-06-11 00:00:00" {
t.Fatalf("today = %s", got)
}
if got := FormatSQLRange(ctx.ThisMonth); got != "start=2026-06-01 00:00:00, end_exclusive=2026-07-01 00:00:00" {
t.Fatalf("this month = %s", got)
}
if got := FormatSQLRange(ctx.ThisWeek); got != "start=2026-06-08 00:00:00, end_exclusive=2026-06-15 00:00:00" {
t.Fatalf("this week = %s", got)
}
}
func TestBuildContextIncludesSQLHints(t *testing.T) {
ctx := Resolve(time.Date(2026, 6, 10, 13, 14, 15, 0, time.UTC))
text := BuildContext(ctx, "需要日期范围")
for _, want := range []string{"时间工具结果", "本月", "start=", "end_exclusive=", "半开区间", "需要日期范围"} {
if !strings.Contains(text, want) {
t.Fatalf("context missing %q:\n%s", want, text)
}
}
}
+613 -514
View File
File diff suppressed because it is too large Load Diff
+217
View File
@@ -0,0 +1,217 @@
package main
import (
"context"
"errors"
"strings"
"testing"
"time"
)
type fakeChatTool struct {
name string
description string
enabled bool
}
func (t fakeChatTool) Name() string { return t.name }
func (t fakeChatTool) Description() string { return t.description }
func (t fakeChatTool) Enabled() bool { return t.enabled }
func (t fakeChatTool) Enrich(context.Context, *OpenAIProfile, []ChatMessage, string, func(chatSSEFrame)) ([]ChatMessage, error) {
return nil, nil
}
func TestNormalizeToolRouterConfigDefaults(t *testing.T) {
cfg := &Config{ToolRouter: ToolRouterConfig{Enabled: true}}
changed, err := normalizeToolRouterConfig(cfg)
if err != nil {
t.Fatal(err)
}
if !changed {
t.Fatal("expected defaults to change config")
}
if cfg.ToolRouter.Timeout != defaultToolRouterTimeout {
t.Fatalf("timeout = %d", cfg.ToolRouter.Timeout)
}
if cfg.ToolRouter.MaxTokens != defaultToolRouterMaxTokens {
t.Fatalf("max_tokens = %d", cfg.ToolRouter.MaxTokens)
}
if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) == "" {
t.Fatal("system prompt should be defaulted")
}
if len(cfg.ToolRouter.Tools) != 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[1].Name != "search" || cfg.ToolRouter.Tools[2].Name != "sql" || !cfg.ToolRouter.Tools[0].Enabled || !cfg.ToolRouter.Tools[1].Enabled || !cfg.ToolRouter.Tools[2].Enabled {
t.Fatalf("unexpected tools: %#v", cfg.ToolRouter.Tools)
}
}
func TestNormalizeToolRouterConfigAddsTimeBeforeSQL(t *testing.T) {
cfg := &Config{ToolRouter: ToolRouterConfig{
Enabled: true,
Timeout: 1,
MaxTokens: 1,
SystemPrompt: "route",
Tools: []ToolRouteConfig{
{Name: "search", Enabled: true},
{Name: "sql", Enabled: true},
},
}}
changed, err := normalizeToolRouterConfig(cfg)
if err != nil {
t.Fatal(err)
}
if !changed {
t.Fatal("expected time tool to be added")
}
if len(cfg.ToolRouter.Tools) < 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[2].Name != "sql" {
t.Fatalf("unexpected tool order: %#v", cfg.ToolRouter.Tools)
}
}
func TestNormalizeToolRouterConfigDuplicateTools(t *testing.T) {
cfg := &Config{ToolRouter: ToolRouterConfig{
Enabled: true,
Timeout: 1,
MaxTokens: 1,
SystemPrompt: "route",
Tools: []ToolRouteConfig{
{Name: "sql", Enabled: true},
{Name: " SQL ", Enabled: true},
},
}}
_, err := normalizeToolRouterConfig(cfg)
if err == nil {
t.Fatal("expected duplicate tool error")
}
}
func TestParseToolRoutingDecision(t *testing.T) {
decision, err := parseToolRoutingDecision("```json\n{\"tools\":[{\"name\":\" SQL \",\"reason\":\" 需要查库 \"}],\"reason\":\" 总原因 \"}\n```")
if err != nil {
t.Fatal(err)
}
if len(decision.Tools) != 1 || decision.Tools[0].Name != "sql" || decision.Tools[0].Reason != "需要查库" {
t.Fatalf("unexpected decision: %#v", decision)
}
if decision.Reason != "总原因" {
t.Fatalf("reason = %q", decision.Reason)
}
if _, err := parseToolRoutingDecision("not json"); err == nil {
t.Fatal("expected malformed JSON error")
}
}
func TestFilterToolSelections(t *testing.T) {
tools := map[string]ChatTool{
"time": fakeChatTool{name: "time", enabled: true},
"sql": fakeChatTool{name: "sql", enabled: true},
"search": fakeChatTool{name: "search", enabled: true},
}
decision := ToolRoutingDecision{Tools: []ToolSelection{
{Name: "unknown", Reason: "ignore"},
{Name: "search", Reason: "second in config"},
{Name: "sql", Reason: "third in config"},
{Name: "time", Reason: "first in config"},
{Name: "sql", Reason: "duplicate"},
}}
selected := filterToolSelections(decision, tools, []ToolRouteConfig{{Name: "time"}, {Name: "search"}, {Name: "sql"}})
if len(selected) != 3 {
t.Fatalf("selected length = %d", len(selected))
}
if selected[0].Name != "time" || selected[0].Reason != "first in config" {
t.Fatalf("first selection = %#v", selected[0])
}
if selected[1].Name != "search" {
t.Fatalf("second selection = %#v", selected[1])
}
if selected[2].Name != "sql" || selected[2].Reason != "third in config" {
t.Fatalf("third selection = %#v", selected[2])
}
}
func TestRunTimeToolAddsHiddenDateRanges(t *testing.T) {
messages := []ChatMessage{{Role: "user", Content: "本月有什么日程安排"}}
withTime, err := runTimeTool(context.Background(), messages, "需要日期范围", func(chatSSEFrame) {})
if err != nil {
t.Fatal(err)
}
if len(withTime) != 2 || !withTime[0].Hidden || withTime[0].Role != "system" {
t.Fatalf("unexpected messages: %#v", withTime)
}
for _, want := range []string{"时间工具结果", "本月", "start=", "end_exclusive=", "半开区间"} {
if !strings.Contains(withTime[0].Content, want) {
t.Fatalf("time context missing %q:\n%s", want, withTime[0].Content)
}
}
}
func TestBuildToolRouterPrompt(t *testing.T) {
cfg := &ToolRouterConfig{
SystemPrompt: "router",
Tools: []ToolRouteConfig{
{Name: "time", Enabled: true},
{Name: "sql", Enabled: true, Description: "configured sql"},
{Name: "search", Enabled: true},
},
}
tools := map[string]ChatTool{
"time": fakeChatTool{name: "time", description: "fallback time", enabled: true},
"sql": fakeChatTool{name: "sql", description: "fallback sql", enabled: true},
"search": fakeChatTool{name: "search", description: "fallback search", enabled: true},
}
prompt := buildToolRouterPrompt(cfg, []ChatMessage{{Role: "user", Content: "查一下订单"}}, tools)
for _, want := range []string{"router", "name: time", "fallback time", "name: sql", "configured sql", "name: search", "fallback search", "最新用户问题:查一下订单"} {
if !strings.Contains(prompt, want) {
t.Fatalf("prompt missing %q:\n%s", want, prompt)
}
}
}
func TestRouteToolsUsesConfiguredRouterProfileAndTimeout(t *testing.T) {
ai := &OpenAIState{
profiles: map[string]*OpenAIProfile{
"chat": {Config: OpenAIConfig{Name: "chat"}},
"router": {Config: OpenAIConfig{Name: "router"}},
},
order: []string{"chat", "router"},
activeName: "chat",
}
state := &ToolRouterState{cfg: &ToolRouterConfig{
OpenAIName: "router",
Timeout: 7,
MaxTokens: 123,
SystemPrompt: "router prompt",
Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}},
}, ai: ai}
state.complete = func(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, maxTokens int, timeout time.Duration) (string, error) {
if profile.Config.Name != "router" {
t.Fatalf("profile = %s", profile.Config.Name)
}
if maxTokens != 123 {
t.Fatalf("maxTokens = %d", maxTokens)
}
if timeout != 7*time.Second {
t.Fatalf("timeout = %s", timeout)
}
return `{"tools":[],"reason":"无需工具"}`, nil
}
decision, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}})
if err != nil {
t.Fatal(err)
}
if len(decision.Tools) != 0 || decision.Reason != "无需工具" {
t.Fatalf("unexpected decision: %#v", decision)
}
}
func TestRouteToolsCompleterError(t *testing.T) {
ai := &OpenAIState{profiles: map[string]*OpenAIProfile{"chat": {Config: OpenAIConfig{Name: "chat"}}}, activeName: "chat"}
state := &ToolRouterState{cfg: &ToolRouterConfig{Timeout: 1, MaxTokens: 1, SystemPrompt: "router", Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}}}, ai: ai}
state.complete = func(context.Context, *OpenAIProfile, []ChatMessage, int, time.Duration) (string, error) {
return "", errors.New("boom")
}
_, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}})
if err == nil {
t.Fatal("expected completer error")
}
}
+51 -30
View File
@@ -158,7 +158,7 @@
align-items: center;
gap: 8px;
}
#btnClear, #btnPreset, #btnSearch {
#btnClear, #btnPreset {
background: none;
border: 1px solid var(--border);
color: var(--text-dim);
@@ -169,13 +169,7 @@
transition: all .15s;
}
#btnClear:hover { border-color: var(--danger); color: var(--danger); }
#btnPreset:hover, #btnSearch:hover { border-color: var(--accent); color: var(--accent); }
#btnSearch.active {
background: var(--accent-soft);
border-color: var(--accent);
color: var(--accent);
}
#btnPreset:hover { border-color: var(--accent); color: var(--accent); }
/* ── 消息区 ── */
#messages {
flex: 1;
@@ -292,6 +286,13 @@
overflow-x: auto;
}
.answer-text { display: inline; }
.token-stats {
margin-top: 8px;
color: var(--text-dim);
font-size: 0.76rem;
white-space: normal;
}
.token-stats:empty { display: none; }
/* 错误消息 */
.error-msg {
@@ -513,7 +514,6 @@
<option value="{{ .OpenAIName }}">{{ .Model }}</option>
</select>
<select id="searchSelect" class="model-badge profile-select" title="切换搜索源"></select>
<button id="btnSearch" title="开启后,本轮提问会先联网搜索">联网搜索:关</button>
<button id="btnPreset" title="设置预先提示词">预设</button>
<button id="btnClear" title="开始新对话">新对话</button>
</div>
@@ -554,7 +554,7 @@
</svg>
</button>
</div>
<p class="hint">Enter 发送 &nbsp;·&nbsp; Shift+Enter 换行 &nbsp;·&nbsp; 支持图片多模态</p>
<p class="hint">Enter 发送 &nbsp;·&nbsp; Shift+Enter 换行 &nbsp;·&nbsp; 支持图片多模态 &nbsp;·&nbsp; 工具路由会自动判断是否需要联网搜索</p>
</footer>
</div>
@@ -576,7 +576,6 @@
let history = []; // {role, content, image_url?}
let currentConvId = null;
let pending = false;
let webSearchEnabled = false;
let openAIProfiles = [];
let activeOpenAIName = '{{ .OpenAIName }}';
let searchProfiles = [];
@@ -591,7 +590,6 @@ const inputBox = document.getElementById('inputBox');
const btnSend = document.getElementById('btnSend');
const btnClear = document.getElementById('btnClear');
const btnPreset = document.getElementById('btnPreset');
const btnSearch = document.getElementById('btnSearch');
const modelSelect = document.getElementById('modelSelect');
const searchSelect = document.getElementById('searchSelect');
const btnNewChat = document.getElementById('btnNewChat');
@@ -666,16 +664,10 @@ function setInputDisabled(disabled) {
btnSend.disabled = disabled;
inputBox.disabled = disabled;
fileInput.disabled = disabled;
btnSearch.disabled = disabled;
modelSelect.disabled = disabled || openAIProfiles.length <= 1;
searchSelect.disabled = disabled || searchProfiles.length <= 1;
}
function updateSearchButton() {
btnSearch.classList.toggle('active', webSearchEnabled);
btnSearch.textContent = webSearchEnabled ? '联网搜索:开' : '联网搜索:关';
}
async function loadOpenAIProfiles() {
const res = await fetch('/api/openai');
if (!res.ok) {
@@ -869,13 +861,42 @@ function addAIBubble() {
const txt = document.createElement('span');
txt.className = 'answer-text';
const stats = document.createElement('div');
stats.className = 'token-stats';
bub.appendChild(trace);
bub.appendChild(txt);
bub.appendChild(stats);
row.appendChild(av);
row.appendChild(bub);
msgBox.appendChild(row);
scrollToBottom();
return { bub, txt, trace };
return { bub, txt, trace, stats };
}
function formatTokenStats(stats) {
if (!stats) return '';
const avgSpeed = typeof stats.completion_tokens_per_sec === 'number'
? stats.completion_tokens_per_sec.toFixed(1)
: '0.0';
const peakSpeed = typeof stats.peak_completion_tokens_per_sec === 'number'
? stats.peak_completion_tokens_per_sec.toFixed(1)
: '0.0';
const parts = [
`平均 ${avgSpeed} tokens/sec`,
`最高 ${peakSpeed} tokens/sec`,
`总 token ${stats.total_tokens || 0}`,
`输入 ${stats.prompt_tokens || 0}`,
`输出 ${stats.completion_tokens || 0}`,
];
const toolTokens = (stats.tool_prompt_tokens || 0) + (stats.tool_completion_tokens || 0);
if (toolTokens) parts.push(`工具 ${toolTokens}`);
if (stats.estimated) parts.push('本地估算');
return parts.join(' ');
}
function updateTokenStats(aiBubble, stats) {
if (!aiBubble.stats) return;
aiBubble.stats.textContent = formatTokenStats(stats);
}
function appendTrace(aiBubble, frame) {
@@ -893,6 +914,7 @@ function appendTrace(aiBubble, frame) {
if (typeof data.rows === 'number') stats.push(`行数: ${data.rows}`);
if (typeof data.columns === 'number') stats.push(`列数: ${data.columns}`);
if (typeof data.count === 'number') stats.push(`结果数: ${data.count}`);
if (Array.isArray(data.tools) && data.tools.length) stats.push(`工具: ${data.tools.join(', ')}`);
if (data.truncated) stats.push(`已截断,最多 ${data.max_rows || ''}`);
if (data.reason) stats.push(`原因: ${data.reason}`);
if (data.error) stats.push(`错误: ${data.error}`);
@@ -909,7 +931,7 @@ function appendTrace(aiBubble, frame) {
scrollToBottom();
}
async function streamChat(messages, aiBubble, webSearch = false) {
async function streamChat(messages, aiBubble) {
const txtEl = aiBubble.txt;
let full = '';
@@ -919,8 +941,7 @@ async function streamChat(messages, aiBubble, webSearch = false) {
body: JSON.stringify({
conversation_id: currentConvId,
messages,
web_search: webSearch,
openai_name: activeOpenAIName,
openai_name: activeOpenAIName,
}),
});
@@ -965,8 +986,14 @@ async function streamChat(messages, aiBubble, webSearch = false) {
if (delta) {
full += delta;
txtEl.innerHTML = renderMarkdown(full);
scrollToBottom();
}
updateTokenStats(aiBubble, parsed.stats);
scrollToBottom();
continue;
}
if (parsed.type === 'stats') {
updateTokenStats(aiBubble, parsed.stats);
scrollToBottom();
continue;
}
if (parsed.type === 'trace') {
@@ -1070,7 +1097,7 @@ async function sendMessage() {
clearImage();
aiBubble = addAIBubble();
const full = await streamChat(history, aiBubble, webSearchEnabled);
const full = await streamChat(history, aiBubble);
history.push({ role: 'assistant', content: full });
scrollToBottom();
await loadConversationList();
@@ -1145,11 +1172,6 @@ inputBox.addEventListener('keydown', e => {
});
btnSend.addEventListener('click', sendMessage);
btnSearch.addEventListener('click', () => {
if (pending) return;
webSearchEnabled = !webSearchEnabled;
updateSearchButton();
});
modelSelect.addEventListener('change', async () => {
if (pending) {
modelSelect.value = activeOpenAIName;
@@ -1218,7 +1240,6 @@ presetModal.addEventListener('click', e => {
});
// 自动聚焦 & 初始化
updateSearchButton();
loadOpenAIProfiles().catch(e => alert(e.message));
loadSearchProfiles().catch(e => alert(e.message));
loadConversationList();