更新工具链
This commit is contained in:
@@ -14,10 +14,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
ToolName = "search"
|
||||
defaultActivationPrompt = `判断用户问题是否需要联网搜索。
|
||||
当问题涉及实时信息、新闻、价格、当前版本、近期事件、政策、网页资料核验,或用户明确要求“查一下/搜索/联网/最新”时调用 search。
|
||||
当用户询问“历史上的今天”、某日期历史事件、需要按当前日期动态确定查询词的常识资料时,也应调用 search;如果联网无结果,主模型会回退到自身知识库回答并说明来源。
|
||||
@@ -91,6 +93,11 @@ type ListResponse struct {
|
||||
Profiles []ProfileConfig `json:"profiles"`
|
||||
}
|
||||
|
||||
type ToolArgs struct {
|
||||
Query string `json:"query"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type braveSearchResponse struct {
|
||||
Web struct {
|
||||
Results []Result `json:"results"`
|
||||
@@ -208,6 +215,53 @@ func (s *State) ActivationPrompt() string {
|
||||
return strings.TrimSpace(s.cfg.ActivationPrompt)
|
||||
}
|
||||
|
||||
func (s *State) ToolDefinition(description string) *model.Tool {
|
||||
description = strings.TrimSpace(description)
|
||||
if description == "" {
|
||||
description = s.ActivationPrompt()
|
||||
}
|
||||
return &model.Tool{
|
||||
Type: model.ToolTypeFunction,
|
||||
Function: &model.FunctionDefinition{
|
||||
Name: ToolName,
|
||||
Description: description,
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "要联网搜索的关键词。若问题包含相对日期,应先调用 time 工具后使用绝对日期改写查询词。",
|
||||
},
|
||||
"reason": map[string]any{
|
||||
"type": "string",
|
||||
"description": "调用联网搜索的原因。",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *State) ExecuteTool(ctx context.Context, args string) (string, error) {
|
||||
var parsed ToolArgs
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(args)), &parsed); err != nil {
|
||||
return "", fmt.Errorf("解析搜索工具参数失败: %w", err)
|
||||
}
|
||||
query := strings.TrimSpace(parsed.Query)
|
||||
if query == "" {
|
||||
return "", errors.New("搜索关键词不能为空")
|
||||
}
|
||||
results, profile, err := s.Search(ctx, query)
|
||||
if err != nil {
|
||||
return BuildErrorContext(query, err), nil
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return BuildFallbackContext(profile, query, parsed.Reason, errors.New("未搜索到相关网页结果")), nil
|
||||
}
|
||||
return BuildResultContext(profile, query, results, parsed.Reason), nil
|
||||
}
|
||||
|
||||
func (s *State) ActiveProfile() ProfileConfig {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
@@ -121,3 +121,42 @@ func TestLoadConfigWritesLegacyProfiles(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolDefinitionAndExecuteTool(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("q") != "golang" {
|
||||
t.Fatalf("query = %s", r.URL.RawQuery)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"Heading":"Go","Abstract":"Go language","AbstractURL":"https://go.dev"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
state, err := NewState(&Config{Enabled: true, Profiles: ProfileConfigs{{Name: "ddg", Active: true, Enabled: true, Provider: "duckduckgo", BaseURL: server.URL, Count: 1, Timeout: 1}}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
definition := state.ToolDefinition("custom search")
|
||||
if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom search" {
|
||||
t.Fatalf("unexpected definition: %#v", definition)
|
||||
}
|
||||
text, err := state.ExecuteTool(context.Background(), `{"query":"golang","reason":"测试搜索"}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, want := range []string{"联网搜索", "golang", "Go", "https://go.dev", "测试搜索"} {
|
||||
if !strings.Contains(text, want) {
|
||||
t.Fatalf("tool result missing %q:\n%s", want, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteToolRejectsEmptyQuery(t *testing.T) {
|
||||
state, err := NewState(&Config{Enabled: true, Profiles: ProfileConfigs{{Name: "ddg", Active: true, Enabled: true, Provider: "duckduckgo", BaseURL: defaultBaseURL, Count: 1, Timeout: 1}}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = state.ExecuteTool(context.Background(), `{"query":" "}`)
|
||||
if err == nil || !strings.Contains(err.Error(), "不能为空") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user