支持工具链
This commit is contained in:
@@ -0,0 +1,520 @@
|
||||
package search
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultActivationPrompt = `判断用户问题是否需要联网搜索。
|
||||
仅当问题涉及实时信息、新闻、价格、当前版本、近期事件、政策、网页资料核验,或用户明确要求“查一下/搜索/联网/最新”时调用 search。
|
||||
普通知识、闲聊、代码推理、已有上下文可回答的问题不要调用。`
|
||||
defaultBaseURL = "https://api.duckduckgo.com/"
|
||||
defaultTimeout = 10
|
||||
defaultCount = 5
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
ActivationPrompt string `yaml:"activation_prompt" json:"activation_prompt"`
|
||||
Profiles ProfileConfigs `yaml:"profiles" json:"profiles"`
|
||||
}
|
||||
|
||||
type ProfileConfig struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Active bool `yaml:"active,omitempty" json:"active"`
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
Provider string `yaml:"provider" json:"provider"`
|
||||
APIKey string `yaml:"api_key" json:"-"`
|
||||
BaseURL string `yaml:"base_url" json:"base_url"`
|
||||
Count int `yaml:"count" json:"count"`
|
||||
Timeout int `yaml:"timeout" json:"timeout"`
|
||||
}
|
||||
|
||||
type ProfileConfigs []ProfileConfig
|
||||
|
||||
func (configs *ProfileConfigs) UnmarshalYAML(value *yaml.Node) error {
|
||||
switch value.Kind {
|
||||
case yaml.SequenceNode:
|
||||
var items []ProfileConfig
|
||||
if err := value.Decode(&items); err != nil {
|
||||
return err
|
||||
}
|
||||
*configs = items
|
||||
case yaml.MappingNode:
|
||||
var item ProfileConfig
|
||||
if err := value.Decode(&item); err != nil {
|
||||
return err
|
||||
}
|
||||
*configs = []ProfileConfig{item}
|
||||
case yaml.ScalarNode:
|
||||
if value.Tag == "!!null" {
|
||||
*configs = nil
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("search 配置格式无效")
|
||||
default:
|
||||
return fmt.Errorf("search 配置格式无效")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type State struct {
|
||||
mu sync.RWMutex
|
||||
cfg *Config
|
||||
profiles map[string]ProfileConfig
|
||||
order []string
|
||||
activeName string
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
type ListResponse struct {
|
||||
Active string `json:"active"`
|
||||
Profiles []ProfileConfig `json:"profiles"`
|
||||
}
|
||||
|
||||
type braveSearchResponse struct {
|
||||
Web struct {
|
||||
Results []Result `json:"results"`
|
||||
} `json:"web"`
|
||||
}
|
||||
|
||||
type duckDuckGoResponse struct {
|
||||
Abstract string `json:"Abstract"`
|
||||
AbstractSource string `json:"AbstractSource"`
|
||||
AbstractURL string `json:"AbstractURL"`
|
||||
Heading string `json:"Heading"`
|
||||
RelatedTopics []struct {
|
||||
Text string `json:"Text"`
|
||||
FirstURL string `json:"FirstURL"`
|
||||
Topics []struct {
|
||||
Text string `json:"Text"`
|
||||
FirstURL string `json:"FirstURL"`
|
||||
} `json:"Topics"`
|
||||
} `json:"RelatedTopics"`
|
||||
Infobox struct {
|
||||
Content []struct {
|
||||
Label string `json:"label"`
|
||||
Value string `json:"value"`
|
||||
} `json:"content"`
|
||||
} `json:"Infobox"`
|
||||
}
|
||||
|
||||
func LoadConfig(path string, legacyProfiles ...[]ProfileConfig) (*Config, error) {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("检查联网搜索配置失败: %w", err)
|
||||
}
|
||||
cfg := defaultConfig()
|
||||
if len(legacyProfiles) > 0 && len(legacyProfiles[0]) > 0 {
|
||||
cfg.Profiles = legacyProfiles[0]
|
||||
}
|
||||
if err := normalizeConfig(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建联网搜索配置目录失败: %w", err)
|
||||
}
|
||||
data, err := yaml.Marshal(&cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成联网搜索配置失败: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return nil, fmt.Errorf("写入联网搜索配置失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取联网搜索配置失败: %w", err)
|
||||
}
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("解析联网搜索配置失败: %w", err)
|
||||
}
|
||||
if err := normalizeConfig(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyEnv(&cfg)
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func NewState(cfg *Config) (*State, error) {
|
||||
state := &State{cfg: cfg, profiles: map[string]ProfileConfig{}}
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return state, nil
|
||||
}
|
||||
for _, config := range cfg.Profiles {
|
||||
if strings.TrimSpace(config.Name) == "" {
|
||||
return nil, errors.New("search.profiles.name 不能为空")
|
||||
}
|
||||
if strings.TrimSpace(config.Provider) == "" {
|
||||
return nil, fmt.Errorf("search.%s.provider 未配置", config.Name)
|
||||
}
|
||||
if strings.TrimSpace(config.BaseURL) == "" {
|
||||
return nil, fmt.Errorf("search.%s.base_url 未配置", config.Name)
|
||||
}
|
||||
if config.Timeout <= 0 {
|
||||
return nil, fmt.Errorf("search.%s.timeout 必须大于 0", config.Name)
|
||||
}
|
||||
if _, ok := state.profiles[config.Name]; ok {
|
||||
return nil, fmt.Errorf("search 配置名称重复: %s", config.Name)
|
||||
}
|
||||
state.profiles[config.Name] = config
|
||||
state.order = append(state.order, config.Name)
|
||||
if config.Active && state.activeName == "" {
|
||||
state.activeName = config.Name
|
||||
}
|
||||
}
|
||||
if len(state.order) == 0 {
|
||||
return state, nil
|
||||
}
|
||||
if state.activeName == "" {
|
||||
state.activeName = state.order[0]
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (s *State) Enabled() bool {
|
||||
if s == nil || s.cfg == nil || !s.cfg.Enabled || s.activeName == "" {
|
||||
return false
|
||||
}
|
||||
profile := s.ActiveProfile()
|
||||
return profile.Enabled
|
||||
}
|
||||
|
||||
func (s *State) ActivationPrompt() string {
|
||||
if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.ActivationPrompt) == "" {
|
||||
return defaultActivationPrompt
|
||||
}
|
||||
return strings.TrimSpace(s.cfg.ActivationPrompt)
|
||||
}
|
||||
|
||||
func (s *State) ActiveProfile() ProfileConfig {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.profiles[s.activeName]
|
||||
}
|
||||
|
||||
func (s *State) SwitchActive(name string) (ProfileConfig, error) {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return ProfileConfig{}, errors.New("搜索配置名称不能为空")
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
profile, ok := s.profiles[name]
|
||||
if !ok {
|
||||
return ProfileConfig{}, fmt.Errorf("搜索配置不存在: %s", name)
|
||||
}
|
||||
s.activeName = name
|
||||
return profile, nil
|
||||
}
|
||||
|
||||
func (s *State) ListProfiles() ListResponse {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
profiles := make([]ProfileConfig, 0, len(s.order))
|
||||
for _, name := range s.order {
|
||||
profile := s.profiles[name]
|
||||
profile.APIKey = ""
|
||||
profile.Active = name == s.activeName
|
||||
profiles = append(profiles, profile)
|
||||
}
|
||||
return ListResponse{Active: s.activeName, Profiles: profiles}
|
||||
}
|
||||
|
||||
func (s *State) Search(ctx context.Context, query string) ([]Result, ProfileConfig, error) {
|
||||
if !s.Enabled() {
|
||||
return nil, ProfileConfig{}, errors.New("联网搜索未启用,请先在 agents/search/config.yaml 中启用 search")
|
||||
}
|
||||
profile := s.ActiveProfile()
|
||||
switch strings.ToLower(profile.Provider) {
|
||||
case "duckduckgo", "ddg":
|
||||
results, err := duckDuckGoSearch(ctx, profile, query)
|
||||
return results, profile, err
|
||||
case "brave":
|
||||
results, err := braveWebSearch(ctx, profile, query)
|
||||
return results, profile, err
|
||||
default:
|
||||
return nil, profile, fmt.Errorf("暂不支持搜索服务: %s", profile.Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func BuildResultContext(config ProfileConfig, query string, results []Result, routeReason string) string {
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "工具路由调用了联网搜索。当前搜索源: %s(%s)。请优先根据以下搜索结果回答,并在合适位置标注来源链接。\n", config.Name, config.Provider)
|
||||
if strings.ToLower(config.Provider) == "duckduckgo" || strings.ToLower(config.Provider) == "ddg" {
|
||||
fmt.Fprintln(&b, "注意:DuckDuckGo 即时答案不是全量网页搜索,结果可能较少。")
|
||||
}
|
||||
fmt.Fprintf(&b, "搜索时间: %s\n", time.Now().Format("2006-01-02 15:04:05"))
|
||||
fmt.Fprintf(&b, "搜索词: %s\n", query)
|
||||
if strings.TrimSpace(routeReason) != "" {
|
||||
fmt.Fprintf(&b, "调用原因: %s\n", strings.TrimSpace(routeReason))
|
||||
}
|
||||
fmt.Fprintln(&b, "\n搜索结果:")
|
||||
for i, r := range results {
|
||||
fmt.Fprintf(&b, "%d. 标题: %s\n", i+1, strings.TrimSpace(r.Title))
|
||||
if strings.TrimSpace(r.URL) != "" {
|
||||
fmt.Fprintf(&b, " 链接: %s\n", strings.TrimSpace(r.URL))
|
||||
}
|
||||
if strings.TrimSpace(r.Description) != "" {
|
||||
fmt.Fprintf(&b, " 摘要: %s\n", strings.TrimSpace(r.Description))
|
||||
}
|
||||
}
|
||||
fmt.Fprintln(&b, "\n如果搜索结果不足以回答,请明确说明不确定,不要编造。")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func BuildErrorContext(query string, err error) string {
|
||||
return fmt.Sprintf("工具路由尝试联网搜索但失败。用户问题:%s\n错误:%v\n请向用户说明联网搜索失败,不要编造搜索结果。", query, err)
|
||||
}
|
||||
|
||||
func defaultConfig() Config {
|
||||
return Config{
|
||||
Enabled: true,
|
||||
ActivationPrompt: defaultActivationPrompt,
|
||||
Profiles: ProfileConfigs{defaultProfileConfig()},
|
||||
}
|
||||
}
|
||||
|
||||
func defaultProfileConfig() ProfileConfig {
|
||||
return ProfileConfig{
|
||||
Name: "duckduckgo",
|
||||
Active: true,
|
||||
Enabled: true,
|
||||
Provider: "duckduckgo",
|
||||
BaseURL: defaultBaseURL,
|
||||
Count: defaultCount,
|
||||
Timeout: defaultTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeConfig(cfg *Config) error {
|
||||
if strings.TrimSpace(cfg.ActivationPrompt) == "" {
|
||||
cfg.ActivationPrompt = defaultActivationPrompt
|
||||
} else {
|
||||
cfg.ActivationPrompt = strings.TrimSpace(cfg.ActivationPrompt)
|
||||
}
|
||||
if len(cfg.Profiles) == 0 {
|
||||
cfg.Profiles = ProfileConfigs{defaultProfileConfig()}
|
||||
}
|
||||
activeIndex := -1
|
||||
seen := map[string]bool{}
|
||||
for i := range cfg.Profiles {
|
||||
profile := &cfg.Profiles[i]
|
||||
profile.Provider = strings.ToLower(strings.TrimSpace(profile.Provider))
|
||||
if profile.Provider == "" {
|
||||
profile.Provider = "duckduckgo"
|
||||
}
|
||||
name := strings.TrimSpace(profile.Name)
|
||||
if name == "" {
|
||||
name = profile.Provider
|
||||
if seen[name] {
|
||||
name = fmt.Sprintf("%s-%d", profile.Provider, i+1)
|
||||
}
|
||||
}
|
||||
profile.Name = name
|
||||
if seen[name] {
|
||||
return fmt.Errorf("search 配置名称重复: %s", name)
|
||||
}
|
||||
seen[name] = true
|
||||
if strings.TrimSpace(profile.BaseURL) == "" {
|
||||
switch profile.Provider {
|
||||
case "duckduckgo", "ddg":
|
||||
profile.BaseURL = defaultBaseURL
|
||||
case "brave":
|
||||
profile.BaseURL = "https://api.search.brave.com/res/v1/web/search"
|
||||
default:
|
||||
return fmt.Errorf("暂不支持搜索服务: %s", profile.Provider)
|
||||
}
|
||||
}
|
||||
if profile.Count <= 0 {
|
||||
profile.Count = defaultCount
|
||||
}
|
||||
if profile.Timeout <= 0 {
|
||||
profile.Timeout = defaultTimeout
|
||||
}
|
||||
if profile.Active {
|
||||
if activeIndex == -1 {
|
||||
activeIndex = i
|
||||
} else {
|
||||
profile.Active = false
|
||||
}
|
||||
}
|
||||
}
|
||||
if activeIndex == -1 {
|
||||
cfg.Profiles[0].Active = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func applyEnv(cfg *Config) {
|
||||
if key := os.Getenv("BRAVE_SEARCH_API_KEY"); key != "" {
|
||||
for i := range cfg.Profiles {
|
||||
if strings.ToLower(cfg.Profiles[i].Provider) == "brave" {
|
||||
cfg.Profiles[i].APIKey = key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func duckDuckGoSearch(ctx context.Context, config ProfileConfig, query string) ([]Result, error) {
|
||||
searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
u, err := url.Parse(config.BaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("搜索服务地址无效: %w", err)
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("q", query)
|
||||
q.Set("format", "json")
|
||||
q.Set("no_html", "1")
|
||||
q.Set("skip_disambig", "1")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建搜索请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("User-Agent", "aichat/1.0")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("联网搜索失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取搜索响应失败: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var parsed duckDuckGoResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("解析搜索响应失败: %w", err)
|
||||
}
|
||||
|
||||
limit := config.Count
|
||||
if limit <= 0 {
|
||||
limit = defaultCount
|
||||
}
|
||||
results := make([]Result, 0, limit)
|
||||
if strings.TrimSpace(parsed.Abstract) != "" {
|
||||
title := strings.TrimSpace(parsed.Heading)
|
||||
if title == "" {
|
||||
title = strings.TrimSpace(parsed.AbstractSource)
|
||||
}
|
||||
if title == "" {
|
||||
title = "DuckDuckGo 摘要"
|
||||
}
|
||||
results = append(results, Result{Title: title, URL: strings.TrimSpace(parsed.AbstractURL), Description: strings.TrimSpace(parsed.Abstract)})
|
||||
}
|
||||
for _, item := range parsed.Infobox.Content {
|
||||
if len(results) >= limit {
|
||||
break
|
||||
}
|
||||
if strings.TrimSpace(item.Label) == "" || strings.TrimSpace(item.Value) == "" {
|
||||
continue
|
||||
}
|
||||
results = append(results, Result{Title: item.Label, Description: item.Value})
|
||||
}
|
||||
appendRelated := func(text, firstURL string) {
|
||||
if len(results) >= limit || strings.TrimSpace(text) == "" {
|
||||
return
|
||||
}
|
||||
title, desc := splitDuckDuckGoText(text)
|
||||
results = append(results, Result{Title: title, URL: strings.TrimSpace(firstURL), Description: desc})
|
||||
}
|
||||
for _, topic := range parsed.RelatedTopics {
|
||||
appendRelated(topic.Text, topic.FirstURL)
|
||||
for _, nested := range topic.Topics {
|
||||
appendRelated(nested.Text, nested.FirstURL)
|
||||
}
|
||||
if len(results) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func splitDuckDuckGoText(text string) (string, string) {
|
||||
text = strings.TrimSpace(text)
|
||||
parts := strings.SplitN(text, " - ", 2)
|
||||
if len(parts) == 2 {
|
||||
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
|
||||
}
|
||||
runes := []rune(text)
|
||||
if len(runes) > 42 {
|
||||
return string(runes[:42]) + "...", text
|
||||
}
|
||||
return text, text
|
||||
}
|
||||
|
||||
func braveWebSearch(ctx context.Context, config ProfileConfig, query string) ([]Result, error) {
|
||||
if config.APIKey == "" {
|
||||
return nil, errors.New("Brave 搜索未配置 API Key,请设置 agents/search/config.yaml 中的 api_key 或环境变量 BRAVE_SEARCH_API_KEY")
|
||||
}
|
||||
searchCtx, cancel := context.WithTimeout(ctx, time.Duration(config.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
u, err := url.Parse(config.BaseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("搜索服务地址无效: %w", err)
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("q", query)
|
||||
q.Set("count", fmt.Sprintf("%d", config.Count))
|
||||
q.Set("search_lang", "zh-hans")
|
||||
q.Set("country", "CN")
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
req, err := http.NewRequestWithContext(searchCtx, http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建搜索请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("X-Subscription-Token", config.APIKey)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("联网搜索失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 2*1024*1024))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取搜索响应失败: %w", err)
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("搜索服务返回错误 %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var parsed braveSearchResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("解析搜索响应失败: %w", err)
|
||||
}
|
||||
return parsed.Web.Results, nil
|
||||
}
|
||||
Reference in New Issue
Block a user