支持工具链
This commit is contained in:
@@ -0,0 +1,520 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultActivationPrompt = `判断用户问题是否需要联网搜索。
|
||||
仅当问题涉及实时信息、新闻、价格、当前版本、近期事件、政策、网页资料核验,或用户明确要求“查一下/搜索/联网/最新”时调用 search。
|
||||
普通知识、闲聊、代码推理、已有上下文可回答的问题不要调用。`
|
||||
defaultBaseURL = "https://api.duckduckgo.com/"
|
||||
defaultTimeout = 10
|
||||
defaultCount = 5
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
ActivationPrompt string `yaml:"activation_prompt" json:"activation_prompt"`
|
||||
Profiles ProfileConfigs `yaml:"profiles" json:"profiles"`
|
||||
}
|
||||
|
||||
type ProfileConfig struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Active bool `yaml:"active,omitempty" json:"active"`
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
Provider string `yaml:"provider" json:"provider"`
|
||||
APIKey string `yaml:"api_key" json:"-"`
|
||||
BaseURL string `yaml:"base_url" json:"base_url"`
|
||||
Count int `yaml:"count" json:"count"`
|
||||
Timeout int `yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
type ProfileConfigs []ProfileConfig
|
||||
|
||||
func (configs *ProfileConfigs) UnmarshalYAML(value *yaml.Node) error {
|
||||
switch value.Kind {
|
||||
case yaml.SequenceNode:
|
||||
var items []ProfileConfig
|
||||
if err := value.Decode(&items); err != nil {
|
||||
return err
|
||||
}
|
||||
*configs = items
|
||||
case yaml.MappingNode:
|
||||
var item ProfileConfig
|
||||
if err := value.Decode(&item); err != nil {
|
||||
return err
|
||||
}
|
||||
*configs = []ProfileConfig{item}
|
||||
case yaml.ScalarNode:
|
||||
if value.Tag == "!!null" {
|
||||
*configs = nil
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("search 配置格式无效")
|
||||
default:
|
||||
return fmt.Errorf("search 配置格式无效")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type State struct {
|
||||
mu sync.RWMutex
|
||||
cfg *Config
|
||||
profiles map[string]ProfileConfig
|
||||
order []string
|
||||
activeName string
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
type ListResponse struct {
|
||||
Active string `json:"active"`
|
||||
Profiles []ProfileConfig `json:"profiles"`
|
||||
}
|
||||
|
||||
type braveSearchResponse struct {
|
||||
Web struct {
|
||||
Results []Result `json:"results"`
|
||||
} `json:"web"`
|
||||
}
|
||||
|
||||
type duckDuckGoResponse struct {
|
||||
Abstract string `json:"Abstract"`
|
||||
AbstractSource string `json:"AbstractSource"`
|
||||
AbstractURL string `json:"AbstractURL"`
|
||||
Heading string `json:"Heading"`
|
||||
RelatedTopics []struct {
|
||||
Text string `json:"Text"`
|
||||
FirstURL string `json:"FirstURL"`
|
||||
Topics []struct {
|
||||
Text string `json:"Text"`
|
||||
FirstURL string `json:"FirstURL"`
|
||||
} `json:"Topics"`
|
||||
} `json:"RelatedTopics"`
|
||||
Infobox struct {
|
||||
Content []struct {
|
||||
Label string `json:"label"`
|
||||
Value string `json:"value"`
|
||||
} `json:"content"`
|
||||
} `json:"Infobox"`
|
||||
}
|
||||
|
||||
func LoadConfig(path string, legacyProfiles ...[]ProfileConfig) (*Config, error) {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("检查联网搜索配置失败: %w", err)
|
||||
}
|
||||
cfg := defaultConfig()
|
||||
if len(legacyProfiles) > 0 && len(legacyProfiles[0]) > 0 {
|
||||
cfg.Profiles = legacyProfiles[0]
|
||||
}
|
||||
if err := normalizeConfig(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建联网搜索配置目录失败: %w", err)
|
||||
}
|
||||
data, err := yaml.Marshal(&cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成联网搜索配置失败: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return nil, fmt.Errorf("写入联网搜索配置失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取联网搜索配置失败: %w", err)
|
||||
}
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("解析联网搜索配置失败: %w", err)
|
||||
}
|
||||
if err := normalizeConfig(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyEnv(&cfg)
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func NewState(cfg *Config) (*State, error) {
|
||||
state := &State{cfg: cfg, profiles: map[string]ProfileConfig{}}
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return state, nil
|
||||
}
|
||||
for _, config := range cfg.Profiles {
|
||||
if strings.TrimSpace(config.Name) == "" {
|
||||
return nil, errors.New("search.profiles.name 不能为空")
|
||||
}
|
||||
if strings.TrimSpace(config.Provider) == "" {
|
||||
return nil, fmt.Errorf("search.%s.provider 未配置", config.Name)
|
||||
}
|
||||
if strings.TrimSpace(config.BaseURL) == "" {
|
||||
return nil, fmt.Errorf("search.%s.base_url 未配置", config.Name)
|
||||
}
|
||||
if config.Timeout <= 0 {
|
||||
return nil, fmt.Errorf("search.%s.timeout 必须大于 0", config.Name)
|
||||
}
|
||||
if _, ok := state.profiles[config.Name]; ok {
|
||||
return nil, fmt.Errorf("search 配置名称重复: %s", config.Name)
|
||||
}
|
||||
state.profiles[config.Name] = config
|
||||
state.order = append(state.order, config.Name)
|
||||
if config.Active && state.activeName == "" {
|
||||
state.activeName = config.Name
|
||||
}
|
||||
}
|
||||
if len(state.order) == 0 {
|
||||
return state, nil
|
||||
}
|
||||
if state.activeName == "" {
|
||||
state.activeName = state.order[0]
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (s *State) Enabled() bool {
|
||||
if s == nil || s.cfg == nil || !s.cfg.Enabled || s.activeName == "" {
|
||||
return false
|
||||
}
|
||||
profile := s.ActiveProfile()
|
||||
return profile.Enabled
|
||||
}
|
||||
|
||||
func (s *State) ActivationPrompt() string {
|
||||
if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.ActivationPrompt) == "" {
|
||||
return defaultActivationPrompt
|
||||
}
|
||||
return strings.TrimSpace(s.cfg.ActivationPrompt)
|
||||
}
|
||||
|
||||
func (s *State) ActiveProfile() ProfileConfig {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.profiles[s.activeName]
|
||||
}
|
||||
|
||||
func (s *State) SwitchActive(name string) (ProfileConfig, error) {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return ProfileConfig{}, errors.New("搜索配置名称不能为空")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
profile, ok := s.profiles[name]
|
||||
if !ok {
|
||||
return ProfileConfig{}, fmt.Errorf("搜索配置不存在: %s", name)
|
||||
}
|
||||
s.activeName = name
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func (s *State) ListProfiles() ListResponse {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
profiles := make([]ProfileConfig, 0, len(s.order))
|
||||
for _, name := range s.order {
|
||||
profile := s.profiles[name]
|
||||
profile.APIKey = ""
|
||||
profile.Active = name == s.activeName
|
||||
profiles = append(profiles, profile)
|
||||
}
|
||||
return ListResponse{Active: s.activeName, Profiles: profiles}
|
||||
}
|
||||
|
||||
func (s *State) Search(ctx context.Context, query string) ([]Result, ProfileConfig, error) {
|
||||
if !s.Enabled() {
|
||||
return nil, ProfileConfig{}, errors.New("联网搜索未启用,请先在 agents/search/config.yaml 中启用 search")
|
||||
}
|
||||
profile := s.ActiveProfile()
|
||||
switch strings.ToLower(profile.Provider) {
|
||||
case "duckduckgo", "ddg":
|
||||
results, err := duckDuckGoSearch(ctx, profile, query)
|
||||
return results, profile, err
|
||||
case "brave":
|
||||
results, err := braveWebSearch(ctx, profile, query)
|
||||
return results, profile, err
|
||||
default:
|
||||
return nil, profile, fmt.Errorf("暂不支持搜索服务: %s", profile.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResultContext(config ProfileConfig, query string, results []Result, routeReason string) string {
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "工具路由调用了联网搜索。当前搜索源: %s(%s)。请优先根据以下搜索结果回答,并在合适位置标注来源链接。\n", config.Name, config.Provider)
|
||||
if strings.ToLower(config.Provider) == "duckduckgo" || strings.ToLower(config.Provider) == "ddg" {
|
||||
fmt.Fprintln(&b, "注意:DuckDuckGo 即时答案不是全量网页搜索,结果可能较少。")
|
||||
}
|
||||
fmt.Fprintf(&b, "搜索时间: %s\n", time.Now().Format("2006-01-02 15:04:05"))
|
||||
fmt.Fprintf(&b, "搜索词: %s\n", query)
|
||||
if strings.TrimSpace(routeReason) != "" {
|
||||
fmt.Fprintf(&b, "调用原因: %s\n", strings.TrimSpace(routeReason))
|
||||
}
|
||||
fmt.Fprintln(&b, "\n搜索结果:")
|
||||
for i, r := range results {
|
||||
fmt.Fprintf(&b, "%d. 标题: %s\n", i+1, strings.TrimSpace(r.Title))
|
||||
if strings.TrimSpace(r.URL) != "" {
|
||||
fmt.Fprintf(&b, " 链接: %s\n", strings.TrimSpace(r.URL))
|
||||
}
|
||||
if strings.TrimSpace(r.Description) != "" {
|
||||
fmt.Fprintf(&b, " 摘要: %s\n", strings.TrimSpace(r.Description))
|
||||
}
|
||||
}
|
||||
fmt.Fprintln(&b, "\n如果搜索结果不足以回答,请明确说明不确定,不要编造。")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func BuildErrorContext(query string, err error) string {
|
||||
return fmt.Sprintf("工具路由尝试联网搜索但失败。用户问题:%s\n错误:%v\n请向用户说明联网搜索失败,不要编造搜索结果。", query, err)
|
||||
}
|
||||
|
||||
func defaultConfig() Config {
|
||||
return Config{
|
||||
Enabled: true,
|
||||
ActivationPrompt: defaultActivationPrompt,
|
||||
Profiles: ProfileConfigs{defaultProfileConfig()},
|
||||
}
|
||||
}
|
||||
|
||||
func defaultProfileConfig() ProfileConfig {
|
||||
return ProfileConfig{
|
||||
Name: "duckduckgo",
|
||||
Active: true,
|
||||
Enabled: true,
|
||||
Provider: "duckduckgo",
|
||||
BaseURL: defaultBaseURL,
|
||||
Count: defaultCount,
|
||||
Timeout: defaultTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeConfig(cfg *Config) error {
|
||||
if strings.TrimSpace(cfg.ActivationPrompt) == "" {
|
||||
cfg.ActivationPrompt = defaultActivationPrompt
|
||||
} else {
|
||||
cfg.ActivationPrompt = strings.TrimSpace(cfg.ActivationPrompt)
|
||||
}
|
||||
if len(cfg.Profiles) == 0 {
|
||||
cfg.Profiles = ProfileConfigs{defaultProfileConfig()}
|
||||
}
|
||||
activeIndex := -1
|
||||
seen := map[string]bool{}
|
||||
for i := range cfg.Profiles {
|
||||
profile := &cfg.Profiles[i]
|
||||
profile.Provider = strings.ToLower(strings.TrimSpace(profile.Provider))
|
||||
if profile.Provider == "" {
|
||||
profile.Provider = "duckduckgo"
|
||||
}
|
||||
name := strings.TrimSpace(profile.Name)
|
||||
if name == "" {
|
||||
name = profile.Provider
|
||||
if seen[name] {
|
||||
name = fmt.Sprintf("%s-%d", profile.Provider, i+1)
|
||||
}
|
||||
}
|
||||
profile.Name = name
|
||||
if seen[name] {
|
||||
return fmt.Errorf("search 配置名称重复: %s", name)
|
||||
}
|
||||
seen[name] = true
|
||||
if strings.TrimSpace(profile.BaseURL) == "" {
|
||||
switch profile.Provider {
|
||||
case "duckduckgo", "ddg":
|
||||
profile.BaseURL = defaultBaseURL
|
||||
case "brave":
|
||||
profile.BaseURL = "https://api.search.brave.com/res/v1/web/search"
|
||||
default:
|
||||
return fmt.Errorf("暂不支持搜索服务: %s", profile.Provider)
|
||||
}
|
||||
}
|
||||
if profile.Count <= 0 {
|
||||
profile.Count = defaultCount
|
||||
}
|
||||
if profile.Timeout <= 0 {
|
||||
profile.Timeout = defaultTimeout
|
||||
}
|
||||
if profile.Active {
|
||||
if activeIndex == -1 {
|
||||
activeIndex = i
|
||||
} else {
|
||||
profile.Active = false
|
||||
}
|
||||
}
|
||||
}
|
||||
if activeIndex == -1 {
|
||||
cfg.Profiles[0].Active = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyEnv(cfg *Config) {
|
||||
if key := os.Getenv("BRAVE_SEARCH_API_KEY"); key != "" {
|
||||
for i := range cfg.Profiles {
|
||||
if strings.ToLower(cfg.Profiles[i].Provider) == "brave" {
|
||||
cfg.Profiles[i].APIKey = key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func duckDuckGoSearch(ctx context.Context, config ProfileConfig, query string) ([]Result, error) {
|
||||
searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
u, err := url.Parse(config.BaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("搜索服务地址无效: %w", err)
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("q", query)
|
||||
q.Set("format", "json")
|
||||
q.Set("no_html", "1")
|
||||
q.Set("skip_disambig", "1")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建搜索请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", "aichat/1.0")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("联网搜索失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取搜索响应失败: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var parsed duckDuckGoResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("解析搜索响应失败: %w", err)
|
||||
}
|
||||
|
||||
limit := config.Count
|
||||
if limit <= 0 {
|
||||
limit = defaultCount
|
||||
}
|
||||
results := make([]Result, 0, limit)
|
||||
if strings.TrimSpace(parsed.Abstract) != "" {
|
||||
title := strings.TrimSpace(parsed.Heading)
|
||||
if title == "" {
|
||||
title = strings.TrimSpace(parsed.AbstractSource)
|
||||
}
|
||||
if title == "" {
|
||||
title = "DuckDuckGo 摘要"
|
||||
}
|
||||
results = append(results, Result{Title: title, URL: strings.TrimSpace(parsed.AbstractURL), Description: strings.TrimSpace(parsed.Abstract)})
|
||||
}
|
||||
for _, item := range parsed.Infobox.Content {
|
||||
if len(results) >= limit {
|
||||
break
|
||||
}
|
||||
if strings.TrimSpace(item.Label) == "" || strings.TrimSpace(item.Value) == "" {
|
||||
continue
|
||||
}
|
||||
results = append(results, Result{Title: item.Label, Description: item.Value})
|
||||
}
|
||||
appendRelated := func(text, firstURL string) {
|
||||
if len(results) >= limit || strings.TrimSpace(text) == "" {
|
||||
return
|
||||
}
|
||||
title, desc := splitDuckDuckGoText(text)
|
||||
results = append(results, Result{Title: title, URL: strings.TrimSpace(firstURL), Description: desc})
|
||||
}
|
||||
for _, topic := range parsed.RelatedTopics {
|
||||
appendRelated(topic.Text, topic.FirstURL)
|
||||
for _, nested := range topic.Topics {
|
||||
appendRelated(nested.Text, nested.FirstURL)
|
||||
}
|
||||
if len(results) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func splitDuckDuckGoText(text string) (string, string) {
|
||||
text = strings.TrimSpace(text)
|
||||
parts := strings.SplitN(text, " - ", 2)
|
||||
if len(parts) == 2 {
|
||||
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
|
||||
}
|
||||
runes := []rune(text)
|
||||
if len(runes) > 42 {
|
||||
return string(runes[:42]) + "...", text
|
||||
}
|
||||
return text, text
|
||||
}
|
||||
|
||||
func braveWebSearch(ctx context.Context, config ProfileConfig, query string) ([]Result, error) {
|
||||
if config.APIKey == "" {
|
||||
return nil, errors.New("Brave 搜索未配置 API Key,请设置 agents/search/config.yaml 中的 api_key 或环境变量 BRAVE_SEARCH_API_KEY")
|
||||
}
|
||||
searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
u, err := url.Parse(config.BaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("搜索服务地址无效: %w", err)
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("q", query)
|
||||
q.Set("count", fmt.Sprintf("%d", config.Count))
|
||||
q.Set("search_lang", "zh-hans")
|
||||
q.Set("country", "CN")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建搜索请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("X-Subscription-Token", config.APIKey)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("联网搜索失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取搜索响应失败: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var parsed braveSearchResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("解析搜索响应失败: %w", err)
|
||||
}
|
||||
return parsed.Web.Results, nil
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeConfigDefaults(t *testing.T) {
|
||||
cfg := &Config{Enabled: true, Profiles: ProfileConfigs{{Provider: "duckduckgo"}}}
|
||||
if err := normalizeConfig(cfg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.TrimSpace(cfg.ActivationPrompt) == "" {
|
||||
t.Fatal("activation prompt should be defaulted")
|
||||
}
|
||||
profile := cfg.Profiles[0]
|
||||
if profile.Name != "duckduckgo" || !profile.Active || profile.BaseURL != defaultBaseURL || profile.Count != defaultCount || profile.Timeout != defaultTimeout {
|
||||
t.Fatalf("unexpected profile: %#v", profile)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeConfigDuplicateProfile(t *testing.T) {
|
||||
cfg := &Config{Enabled: true, Profiles: ProfileConfigs{{Name: "duck", Provider: "duckduckgo"}, {Name: " duck ", Provider: "duckduckgo"}}}
|
||||
if err := normalizeConfig(cfg); err == nil {
|
||||
t.Fatal("expected duplicate profile error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeConfigUnsupportedProvider(t *testing.T) {
|
||||
cfg := &Config{Enabled: true, Profiles: ProfileConfigs{{Provider: "unknown"}}}
|
||||
if err := normalizeConfig(cfg); err == nil {
|
||||
t.Fatal("expected unsupported provider error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyEnvBraveAPIKey(t *testing.T) {
|
||||
t.Setenv("BRAVE_SEARCH_API_KEY", "from-env")
|
||||
cfg := &Config{Profiles: ProfileConfigs{{Name: "brave", Provider: "brave", APIKey: "from-config"}}}
|
||||
applyEnv(cfg)
|
||||
if cfg.Profiles[0].APIKey != "from-env" {
|
||||
t.Fatalf("api key = %q", cfg.Profiles[0].APIKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListProfilesRedactsAPIKey(t *testing.T) {
|
||||
state, err := NewState(&Config{Enabled: true, Profiles: ProfileConfigs{{Name: "brave", Active: true, Enabled: true, Provider: "brave", APIKey: "secret", BaseURL: "https://example.com", Count: 1, Timeout: 1}}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
list := state.ListProfiles()
|
||||
if len(list.Profiles) != 1 || list.Profiles[0].APIKey != "" || !list.Profiles[0].Active {
|
||||
t.Fatalf("unexpected list: %#v", list)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildResultContext(t *testing.T) {
|
||||
text := BuildResultContext(ProfileConfig{Name: "duckduckgo", Provider: "duckduckgo"}, "最新消息", []Result{{Title: "标题", URL: "https://example.com", Description: "摘要"}}, "需要最新信息")
|
||||
for _, want := range []string{"联网搜索", "最新消息", "标题", "https://example.com", "需要最新信息", "不要编造"} {
|
||||
if !strings.Contains(text, want) {
|
||||
t.Fatalf("context missing %q:\n%s", want, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuckDuckGoSearchParsesResults(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("q") != "golang" {
|
||||
t.Fatalf("query = %s", r.URL.RawQuery)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"Heading":"Go",
|
||||
"Abstract":"Go is a language",
|
||||
"AbstractURL":"https://go.dev",
|
||||
"RelatedTopics":[{"Text":"Gopher - mascot","FirstURL":"https://go.dev/blog/gopher"}]
|
||||
}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
results, err := duckDuckGoSearch(context.Background(), ProfileConfig{Provider: "duckduckgo", BaseURL: server.URL, Count: 2, Timeout: 1}, "golang")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(results) != 2 || results[0].Title != "Go" || results[1].Title != "Gopher" {
|
||||
t.Fatalf("unexpected results: %#v", results)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBraveMissingAPIKey(t *testing.T) {
|
||||
_, err := braveWebSearch(context.Background(), ProfileConfig{Provider: "brave", BaseURL: "https://example.com", Timeout: 1, Count: 1}, "query")
|
||||
if err == nil || !strings.Contains(err.Error(), "API Key") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadConfigWritesLegacyProfiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := dir + string(os.PathSeparator) + "config.yaml"
|
||||
cfg, err := LoadConfig(path, []ProfileConfig{{Name: "legacy", Active: true, Enabled: true, Provider: "duckduckgo", Count: 3, Timeout: 2}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(cfg.Profiles) != 1 || cfg.Profiles[0].Name != "legacy" {
|
||||
t.Fatalf("unexpected cfg: %#v", cfg)
|
||||
}
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
const (
|
||||
defaultActivationPrompt = `判断用户问题是否需要查询业务数据库。
|
||||
仅当用户询问数据库表、记录、字段、时间、状态、内容、统计、最近/最早/某时间范围内的数据时返回 activate=true。
|
||||
当用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,必须返回 activate=true,并说明应查询 tab_calendar_events 表。
|
||||
普通知识问答、代码问题、闲聊、联网搜索问题返回 activate=false。
|
||||
只返回 JSON: {"activate": true/false, "reason": "..."}`
|
||||
defaultDatabaseName = "default"
|
||||
@@ -1,6 +1,15 @@
|
||||
package sqlquery
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultActivationPromptMentionsCalendarEvents(t *testing.T) {
|
||||
if !strings.Contains(defaultActivationPrompt, "tab_calendar_events") {
|
||||
t.Fatal("default activation prompt should mention tab_calendar_events for calendar queries")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateReadOnlySQLAllowsSelectAndWith(t *testing.T) {
|
||||
queries := []string{
|
||||
@@ -0,0 +1,76 @@
|
||||
package timeagent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const ActivationPrompt = "提供当前日期、时间和常用时间范围。当用户问题包含今天、明天、昨天、本周、本月、本年、最近、日程安排等相对时间表达时,应先调用此工具;如果后续还需要查数据库,可继续调用 sql。"
|
||||
|
||||
type Range struct {
|
||||
Start time.Time
|
||||
End time.Time
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
Now time.Time
|
||||
Today Range
|
||||
Tomorrow Range
|
||||
Yesterday Range
|
||||
ThisWeek Range
|
||||
ThisMonth Range
|
||||
ThisYear Range
|
||||
}
|
||||
|
||||
func Resolve(now time.Time) Context {
|
||||
loc := now.Location()
|
||||
today := Range{Start: startOfDay(now), End: startOfDay(now).AddDate(0, 0, 1)}
|
||||
tomorrowStart := today.End
|
||||
tomorrow := Range{Start: tomorrowStart, End: tomorrowStart.AddDate(0, 0, 1)}
|
||||
yesterdayEnd := today.Start
|
||||
yesterday := Range{Start: yesterdayEnd.AddDate(0, 0, -1), End: yesterdayEnd}
|
||||
weekdayOffset := (int(now.Weekday()) + 6) % 7
|
||||
weekStart := today.Start.AddDate(0, 0, -weekdayOffset)
|
||||
monthStart := time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, loc)
|
||||
yearStart := time.Date(now.Year(), 1, 1, 0, 0, 0, 0, loc)
|
||||
return Context{
|
||||
Now: now,
|
||||
Today: today,
|
||||
Tomorrow: tomorrow,
|
||||
Yesterday: yesterday,
|
||||
ThisWeek: Range{Start: weekStart, End: weekStart.AddDate(0, 0, 7)},
|
||||
ThisMonth: Range{Start: monthStart, End: monthStart.AddDate(0, 1, 0)},
|
||||
ThisYear: Range{Start: yearStart, End: yearStart.AddDate(1, 0, 0)},
|
||||
}
|
||||
}
|
||||
|
||||
func BuildContext(ctx Context, routeReason string) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("时间工具结果。后续工具必须优先使用这里的绝对日期解释用户问题中的相对时间,不要自行猜测当前日期。\n")
|
||||
fmt.Fprintf(&b, "当前本地日期时间:%s\n", ctx.Now.Format("2006-01-02 15:04:05 MST"))
|
||||
fmt.Fprintf(&b, "今天:%s\n", FormatSQLRange(ctx.Today))
|
||||
fmt.Fprintf(&b, "明天:%s\n", FormatSQLRange(ctx.Tomorrow))
|
||||
fmt.Fprintf(&b, "昨天:%s\n", FormatSQLRange(ctx.Yesterday))
|
||||
fmt.Fprintf(&b, "本周:%s\n", FormatSQLRange(ctx.ThisWeek))
|
||||
fmt.Fprintf(&b, "本月:%s\n", FormatSQLRange(ctx.ThisMonth))
|
||||
fmt.Fprintf(&b, "本年:%s\n", FormatSQLRange(ctx.ThisYear))
|
||||
b.WriteString("SQL 日期过滤建议:对日程/事件类表使用半开区间,例如 event_time >= start AND event_time < end;如果实际字段名不同,必须使用 schema 中存在的时间字段。\n")
|
||||
if strings.TrimSpace(routeReason) != "" {
|
||||
b.WriteString("激活原因:" + strings.TrimSpace(routeReason) + "\n")
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func FormatDate(t time.Time) string {
|
||||
return t.Format("2006-01-02")
|
||||
}
|
||||
|
||||
func FormatSQLRange(r Range) string {
|
||||
return fmt.Sprintf("start=%s, end_exclusive=%s", r.Start.Format("2006-01-02 15:04:05"), r.End.Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
|
||||
func startOfDay(t time.Time) time.Time {
|
||||
y, m, d := t.Date()
|
||||
return time.Date(y, m, d, 0, 0, 0, 0, t.Location())
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package timeagent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestResolveBuildsCalendarRanges(t *testing.T) {
|
||||
loc := time.FixedZone("CST", 8*60*60)
|
||||
ctx := Resolve(time.Date(2026, 6, 10, 13, 14, 15, 0, loc))
|
||||
if got := FormatSQLRange(ctx.Today); got != "start=2026-06-10 00:00:00, end_exclusive=2026-06-11 00:00:00" {
|
||||
t.Fatalf("today = %s", got)
|
||||
}
|
||||
if got := FormatSQLRange(ctx.ThisMonth); got != "start=2026-06-01 00:00:00, end_exclusive=2026-07-01 00:00:00" {
|
||||
t.Fatalf("this month = %s", got)
|
||||
}
|
||||
if got := FormatSQLRange(ctx.ThisWeek); got != "start=2026-06-08 00:00:00, end_exclusive=2026-06-15 00:00:00" {
|
||||
t.Fatalf("this week = %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildContextIncludesSQLHints(t *testing.T) {
|
||||
ctx := Resolve(time.Date(2026, 6, 10, 13, 14, 15, 0, time.UTC))
|
||||
text := BuildContext(ctx, "需要日期范围")
|
||||
for _, want := range []string{"时间工具结果", "本月", "start=", "end_exclusive=", "半开区间", "需要日期范围"} {
|
||||
if !strings.Contains(text, want) {
|
||||
t.Fatalf("context missing %q:\n%s", want, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
+217
@@ -0,0 +1,217 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type fakeChatTool struct {
|
||||
name string
|
||||
description string
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func (t fakeChatTool) Name() string { return t.name }
|
||||
func (t fakeChatTool) Description() string { return t.description }
|
||||
func (t fakeChatTool) Enabled() bool { return t.enabled }
|
||||
func (t fakeChatTool) Enrich(context.Context, *OpenAIProfile, []ChatMessage, string, func(chatSSEFrame)) ([]ChatMessage, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestNormalizeToolRouterConfigDefaults(t *testing.T) {
|
||||
cfg := &Config{ToolRouter: ToolRouterConfig{Enabled: true}}
|
||||
changed, err := normalizeToolRouterConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !changed {
|
||||
t.Fatal("expected defaults to change config")
|
||||
}
|
||||
if cfg.ToolRouter.Timeout != defaultToolRouterTimeout {
|
||||
t.Fatalf("timeout = %d", cfg.ToolRouter.Timeout)
|
||||
}
|
||||
if cfg.ToolRouter.MaxTokens != defaultToolRouterMaxTokens {
|
||||
t.Fatalf("max_tokens = %d", cfg.ToolRouter.MaxTokens)
|
||||
}
|
||||
if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) == "" {
|
||||
t.Fatal("system prompt should be defaulted")
|
||||
}
|
||||
if len(cfg.ToolRouter.Tools) != 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[1].Name != "search" || cfg.ToolRouter.Tools[2].Name != "sql" || !cfg.ToolRouter.Tools[0].Enabled || !cfg.ToolRouter.Tools[1].Enabled || !cfg.ToolRouter.Tools[2].Enabled {
|
||||
t.Fatalf("unexpected tools: %#v", cfg.ToolRouter.Tools)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeToolRouterConfigAddsTimeBeforeSQL(t *testing.T) {
|
||||
cfg := &Config{ToolRouter: ToolRouterConfig{
|
||||
Enabled: true,
|
||||
Timeout: 1,
|
||||
MaxTokens: 1,
|
||||
SystemPrompt: "route",
|
||||
Tools: []ToolRouteConfig{
|
||||
{Name: "search", Enabled: true},
|
||||
{Name: "sql", Enabled: true},
|
||||
},
|
||||
}}
|
||||
changed, err := normalizeToolRouterConfig(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !changed {
|
||||
t.Fatal("expected time tool to be added")
|
||||
}
|
||||
if len(cfg.ToolRouter.Tools) < 3 || cfg.ToolRouter.Tools[0].Name != "time" || cfg.ToolRouter.Tools[2].Name != "sql" {
|
||||
t.Fatalf("unexpected tool order: %#v", cfg.ToolRouter.Tools)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeToolRouterConfigDuplicateTools(t *testing.T) {
|
||||
cfg := &Config{ToolRouter: ToolRouterConfig{
|
||||
Enabled: true,
|
||||
Timeout: 1,
|
||||
MaxTokens: 1,
|
||||
SystemPrompt: "route",
|
||||
Tools: []ToolRouteConfig{
|
||||
{Name: "sql", Enabled: true},
|
||||
{Name: " SQL ", Enabled: true},
|
||||
},
|
||||
}}
|
||||
_, err := normalizeToolRouterConfig(cfg)
|
||||
if err == nil {
|
||||
t.Fatal("expected duplicate tool error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolRoutingDecision(t *testing.T) {
|
||||
decision, err := parseToolRoutingDecision("```json\n{\"tools\":[{\"name\":\" SQL \",\"reason\":\" 需要查库 \"}],\"reason\":\" 总原因 \"}\n```")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(decision.Tools) != 1 || decision.Tools[0].Name != "sql" || decision.Tools[0].Reason != "需要查库" {
|
||||
t.Fatalf("unexpected decision: %#v", decision)
|
||||
}
|
||||
if decision.Reason != "总原因" {
|
||||
t.Fatalf("reason = %q", decision.Reason)
|
||||
}
|
||||
|
||||
if _, err := parseToolRoutingDecision("not json"); err == nil {
|
||||
t.Fatal("expected malformed JSON error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterToolSelections(t *testing.T) {
|
||||
tools := map[string]ChatTool{
|
||||
"time": fakeChatTool{name: "time", enabled: true},
|
||||
"sql": fakeChatTool{name: "sql", enabled: true},
|
||||
"search": fakeChatTool{name: "search", enabled: true},
|
||||
}
|
||||
decision := ToolRoutingDecision{Tools: []ToolSelection{
|
||||
{Name: "unknown", Reason: "ignore"},
|
||||
{Name: "search", Reason: "second in config"},
|
||||
{Name: "sql", Reason: "third in config"},
|
||||
{Name: "time", Reason: "first in config"},
|
||||
{Name: "sql", Reason: "duplicate"},
|
||||
}}
|
||||
selected := filterToolSelections(decision, tools, []ToolRouteConfig{{Name: "time"}, {Name: "search"}, {Name: "sql"}})
|
||||
if len(selected) != 3 {
|
||||
t.Fatalf("selected length = %d", len(selected))
|
||||
}
|
||||
if selected[0].Name != "time" || selected[0].Reason != "first in config" {
|
||||
t.Fatalf("first selection = %#v", selected[0])
|
||||
}
|
||||
if selected[1].Name != "search" {
|
||||
t.Fatalf("second selection = %#v", selected[1])
|
||||
}
|
||||
if selected[2].Name != "sql" || selected[2].Reason != "third in config" {
|
||||
t.Fatalf("third selection = %#v", selected[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunTimeToolAddsHiddenDateRanges(t *testing.T) {
|
||||
messages := []ChatMessage{{Role: "user", Content: "本月有什么日程安排"}}
|
||||
withTime, err := runTimeTool(context.Background(), messages, "需要日期范围", func(chatSSEFrame) {})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(withTime) != 2 || !withTime[0].Hidden || withTime[0].Role != "system" {
|
||||
t.Fatalf("unexpected messages: %#v", withTime)
|
||||
}
|
||||
for _, want := range []string{"时间工具结果", "本月", "start=", "end_exclusive=", "半开区间"} {
|
||||
if !strings.Contains(withTime[0].Content, want) {
|
||||
t.Fatalf("time context missing %q:\n%s", want, withTime[0].Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildToolRouterPrompt(t *testing.T) {
|
||||
cfg := &ToolRouterConfig{
|
||||
SystemPrompt: "router",
|
||||
Tools: []ToolRouteConfig{
|
||||
{Name: "time", Enabled: true},
|
||||
{Name: "sql", Enabled: true, Description: "configured sql"},
|
||||
{Name: "search", Enabled: true},
|
||||
},
|
||||
}
|
||||
tools := map[string]ChatTool{
|
||||
"time": fakeChatTool{name: "time", description: "fallback time", enabled: true},
|
||||
"sql": fakeChatTool{name: "sql", description: "fallback sql", enabled: true},
|
||||
"search": fakeChatTool{name: "search", description: "fallback search", enabled: true},
|
||||
}
|
||||
prompt := buildToolRouterPrompt(cfg, []ChatMessage{{Role: "user", Content: "查一下订单"}}, tools)
|
||||
for _, want := range []string{"router", "name: time", "fallback time", "name: sql", "configured sql", "name: search", "fallback search", "最新用户问题:查一下订单"} {
|
||||
if !strings.Contains(prompt, want) {
|
||||
t.Fatalf("prompt missing %q:\n%s", want, prompt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteToolsUsesConfiguredRouterProfileAndTimeout(t *testing.T) {
|
||||
ai := &OpenAIState{
|
||||
profiles: map[string]*OpenAIProfile{
|
||||
"chat": {Config: OpenAIConfig{Name: "chat"}},
|
||||
"router": {Config: OpenAIConfig{Name: "router"}},
|
||||
},
|
||||
order: []string{"chat", "router"},
|
||||
activeName: "chat",
|
||||
}
|
||||
state := &ToolRouterState{cfg: &ToolRouterConfig{
|
||||
OpenAIName: "router",
|
||||
Timeout: 7,
|
||||
MaxTokens: 123,
|
||||
SystemPrompt: "router prompt",
|
||||
Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}},
|
||||
}, ai: ai}
|
||||
state.complete = func(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, maxTokens int, timeout time.Duration) (string, error) {
|
||||
if profile.Config.Name != "router" {
|
||||
t.Fatalf("profile = %s", profile.Config.Name)
|
||||
}
|
||||
if maxTokens != 123 {
|
||||
t.Fatalf("maxTokens = %d", maxTokens)
|
||||
}
|
||||
if timeout != 7*time.Second {
|
||||
t.Fatalf("timeout = %s", timeout)
|
||||
}
|
||||
return `{"tools":[],"reason":"无需工具"}`, nil
|
||||
}
|
||||
decision, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(decision.Tools) != 0 || decision.Reason != "无需工具" {
|
||||
t.Fatalf("unexpected decision: %#v", decision)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteToolsCompleterError(t *testing.T) {
|
||||
ai := &OpenAIState{profiles: map[string]*OpenAIProfile{"chat": {Config: OpenAIConfig{Name: "chat"}}}, activeName: "chat"}
|
||||
state := &ToolRouterState{cfg: &ToolRouterConfig{Timeout: 1, MaxTokens: 1, SystemPrompt: "router", Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}}}, ai: ai}
|
||||
state.complete = func(context.Context, *OpenAIProfile, []ChatMessage, int, time.Duration) (string, error) {
|
||||
return "", errors.New("boom")
|
||||
}
|
||||
_, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}})
|
||||
if err == nil {
|
||||
t.Fatal("expected completer error")
|
||||
}
|
||||
}
|
||||
+50
-29
@@ -158,7 +158,7 @@
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
#btnClear, #btnPreset, #btnSearch {
|
||||
#btnClear, #btnPreset {
|
||||
background: none;
|
||||
border: 1px solid var(--border);
|
||||
color: var(--text-dim);
|
||||
@@ -169,13 +169,7 @@
|
||||
transition: all .15s;
|
||||
}
|
||||
#btnClear:hover { border-color: var(--danger); color: var(--danger); }
|
||||
#btnPreset:hover, #btnSearch:hover { border-color: var(--accent); color: var(--accent); }
|
||||
#btnSearch.active {
|
||||
background: var(--accent-soft);
|
||||
border-color: var(--accent);
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
#btnPreset:hover { border-color: var(--accent); color: var(--accent); }
|
||||
/* ── 消息区 ── */
|
||||
#messages {
|
||||
flex: 1;
|
||||
@@ -292,6 +286,13 @@
|
||||
overflow-x: auto;
|
||||
}
|
||||
.answer-text { display: inline; }
|
||||
.token-stats {
|
||||
margin-top: 8px;
|
||||
color: var(--text-dim);
|
||||
font-size: 0.76rem;
|
||||
white-space: normal;
|
||||
}
|
||||
.token-stats:empty { display: none; }
|
||||
|
||||
/* 错误消息 */
|
||||
.error-msg {
|
||||
@@ -513,7 +514,6 @@
|
||||
<option value="{{ .OpenAIName }}">{{ .Model }}</option>
|
||||
</select>
|
||||
<select id="searchSelect" class="model-badge profile-select" title="切换搜索源"></select>
|
||||
<button id="btnSearch" title="开启后,本轮提问会先联网搜索">联网搜索:关</button>
|
||||
<button id="btnPreset" title="设置预先提示词">预设</button>
|
||||
<button id="btnClear" title="开始新对话">新对话</button>
|
||||
</div>
|
||||
@@ -554,7 +554,7 @@
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<p class="hint">Enter 发送 · Shift+Enter 换行 · 支持图片多模态</p>
|
||||
<p class="hint">Enter 发送 · Shift+Enter 换行 · 支持图片多模态 · 工具路由会自动判断是否需要联网搜索</p>
|
||||
</footer>
|
||||
</div>
|
||||
|
||||
@@ -576,7 +576,6 @@
|
||||
let history = []; // {role, content, image_url?}
|
||||
let currentConvId = null;
|
||||
let pending = false;
|
||||
let webSearchEnabled = false;
|
||||
let openAIProfiles = [];
|
||||
let activeOpenAIName = '{{ .OpenAIName }}';
|
||||
let searchProfiles = [];
|
||||
@@ -591,7 +590,6 @@ const inputBox = document.getElementById('inputBox');
|
||||
const btnSend = document.getElementById('btnSend');
|
||||
const btnClear = document.getElementById('btnClear');
|
||||
const btnPreset = document.getElementById('btnPreset');
|
||||
const btnSearch = document.getElementById('btnSearch');
|
||||
const modelSelect = document.getElementById('modelSelect');
|
||||
const searchSelect = document.getElementById('searchSelect');
|
||||
const btnNewChat = document.getElementById('btnNewChat');
|
||||
@@ -666,16 +664,10 @@ function setInputDisabled(disabled) {
|
||||
btnSend.disabled = disabled;
|
||||
inputBox.disabled = disabled;
|
||||
fileInput.disabled = disabled;
|
||||
btnSearch.disabled = disabled;
|
||||
modelSelect.disabled = disabled || openAIProfiles.length <= 1;
|
||||
searchSelect.disabled = disabled || searchProfiles.length <= 1;
|
||||
}
|
||||
|
||||
function updateSearchButton() {
|
||||
btnSearch.classList.toggle('active', webSearchEnabled);
|
||||
btnSearch.textContent = webSearchEnabled ? '联网搜索:开' : '联网搜索:关';
|
||||
}
|
||||
|
||||
async function loadOpenAIProfiles() {
|
||||
const res = await fetch('/api/openai');
|
||||
if (!res.ok) {
|
||||
@@ -869,13 +861,42 @@ function addAIBubble() {
|
||||
|
||||
const txt = document.createElement('span');
|
||||
txt.className = 'answer-text';
|
||||
const stats = document.createElement('div');
|
||||
stats.className = 'token-stats';
|
||||
bub.appendChild(trace);
|
||||
bub.appendChild(txt);
|
||||
bub.appendChild(stats);
|
||||
row.appendChild(av);
|
||||
row.appendChild(bub);
|
||||
msgBox.appendChild(row);
|
||||
scrollToBottom();
|
||||
return { bub, txt, trace };
|
||||
return { bub, txt, trace, stats };
|
||||
}
|
||||
|
||||
function formatTokenStats(stats) {
|
||||
if (!stats) return '';
|
||||
const avgSpeed = typeof stats.completion_tokens_per_sec === 'number'
|
||||
? stats.completion_tokens_per_sec.toFixed(1)
|
||||
: '0.0';
|
||||
const peakSpeed = typeof stats.peak_completion_tokens_per_sec === 'number'
|
||||
? stats.peak_completion_tokens_per_sec.toFixed(1)
|
||||
: '0.0';
|
||||
const parts = [
|
||||
`平均 ${avgSpeed} tokens/sec`,
|
||||
`最高 ${peakSpeed} tokens/sec`,
|
||||
`总 token ${stats.total_tokens || 0}`,
|
||||
`输入 ${stats.prompt_tokens || 0}`,
|
||||
`输出 ${stats.completion_tokens || 0}`,
|
||||
];
|
||||
const toolTokens = (stats.tool_prompt_tokens || 0) + (stats.tool_completion_tokens || 0);
|
||||
if (toolTokens) parts.push(`工具 ${toolTokens}`);
|
||||
if (stats.estimated) parts.push('本地估算');
|
||||
return parts.join(' | ');
|
||||
}
|
||||
|
||||
function updateTokenStats(aiBubble, stats) {
|
||||
if (!aiBubble.stats) return;
|
||||
aiBubble.stats.textContent = formatTokenStats(stats);
|
||||
}
|
||||
|
||||
function appendTrace(aiBubble, frame) {
|
||||
@@ -893,6 +914,7 @@ function appendTrace(aiBubble, frame) {
|
||||
if (typeof data.rows === 'number') stats.push(`行数: ${data.rows}`);
|
||||
if (typeof data.columns === 'number') stats.push(`列数: ${data.columns}`);
|
||||
if (typeof data.count === 'number') stats.push(`结果数: ${data.count}`);
|
||||
if (Array.isArray(data.tools) && data.tools.length) stats.push(`工具: ${data.tools.join(', ')}`);
|
||||
if (data.truncated) stats.push(`已截断,最多 ${data.max_rows || ''} 行`);
|
||||
if (data.reason) stats.push(`原因: ${data.reason}`);
|
||||
if (data.error) stats.push(`错误: ${data.error}`);
|
||||
@@ -909,7 +931,7 @@ function appendTrace(aiBubble, frame) {
|
||||
scrollToBottom();
|
||||
}
|
||||
|
||||
async function streamChat(messages, aiBubble, webSearch = false) {
|
||||
async function streamChat(messages, aiBubble) {
|
||||
const txtEl = aiBubble.txt;
|
||||
let full = '';
|
||||
|
||||
@@ -919,7 +941,6 @@ async function streamChat(messages, aiBubble, webSearch = false) {
|
||||
body: JSON.stringify({
|
||||
conversation_id: currentConvId,
|
||||
messages,
|
||||
web_search: webSearch,
|
||||
openai_name: activeOpenAIName,
|
||||
}),
|
||||
});
|
||||
@@ -965,8 +986,14 @@ async function streamChat(messages, aiBubble, webSearch = false) {
|
||||
if (delta) {
|
||||
full += delta;
|
||||
txtEl.innerHTML = renderMarkdown(full);
|
||||
scrollToBottom();
|
||||
}
|
||||
updateTokenStats(aiBubble, parsed.stats);
|
||||
scrollToBottom();
|
||||
continue;
|
||||
}
|
||||
if (parsed.type === 'stats') {
|
||||
updateTokenStats(aiBubble, parsed.stats);
|
||||
scrollToBottom();
|
||||
continue;
|
||||
}
|
||||
if (parsed.type === 'trace') {
|
||||
@@ -1070,7 +1097,7 @@ async function sendMessage() {
|
||||
clearImage();
|
||||
|
||||
aiBubble = addAIBubble();
|
||||
const full = await streamChat(history, aiBubble, webSearchEnabled);
|
||||
const full = await streamChat(history, aiBubble);
|
||||
history.push({ role: 'assistant', content: full });
|
||||
scrollToBottom();
|
||||
await loadConversationList();
|
||||
@@ -1145,11 +1172,6 @@ inputBox.addEventListener('keydown', e => {
|
||||
});
|
||||
|
||||
btnSend.addEventListener('click', sendMessage);
|
||||
btnSearch.addEventListener('click', () => {
|
||||
if (pending) return;
|
||||
webSearchEnabled = !webSearchEnabled;
|
||||
updateSearchButton();
|
||||
});
|
||||
modelSelect.addEventListener('change', async () => {
|
||||
if (pending) {
|
||||
modelSelect.value = activeOpenAIName;
|
||||
@@ -1218,7 +1240,6 @@ presetModal.addEventListener('click', e => {
|
||||
});
|
||||
|
||||
// 自动聚焦 & 初始化
|
||||
updateSearchButton();
|
||||
loadOpenAIProfiles().catch(e => alert(e.message));
|
||||
loadSearchProfiles().catch(e => alert(e.message));
|
||||
loadConversationList();
|
||||
|
||||
Reference in New Issue
Block a user