更新工具链
This commit is contained in:
@@ -14,10 +14,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
ToolName = "search"
|
||||
defaultActivationPrompt = `判断用户问题是否需要联网搜索。
|
||||
当问题涉及实时信息、新闻、价格、当前版本、近期事件、政策、网页资料核验,或用户明确要求“查一下/搜索/联网/最新”时调用 search。
|
||||
当用户询问“历史上的今天”、某日期历史事件、需要按当前日期动态确定查询词的常识资料时,也应调用 search;如果联网无结果,主模型会回退到自身知识库回答并说明来源。
|
||||
@@ -91,6 +93,11 @@ type ListResponse struct {
|
||||
Profiles []ProfileConfig `json:"profiles"`
|
||||
}
|
||||
|
||||
type ToolArgs struct {
|
||||
Query string `json:"query"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type braveSearchResponse struct {
|
||||
Web struct {
|
||||
Results []Result `json:"results"`
|
||||
@@ -208,6 +215,53 @@ func (s *State) ActivationPrompt() string {
|
||||
return strings.TrimSpace(s.cfg.ActivationPrompt)
|
||||
}
|
||||
|
||||
func (s *State) ToolDefinition(description string) *model.Tool {
|
||||
description = strings.TrimSpace(description)
|
||||
if description == "" {
|
||||
description = s.ActivationPrompt()
|
||||
}
|
||||
return &model.Tool{
|
||||
Type: model.ToolTypeFunction,
|
||||
Function: &model.FunctionDefinition{
|
||||
Name: ToolName,
|
||||
Description: description,
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"query": map[string]any{
|
||||
"type": "string",
|
||||
"description": "要联网搜索的关键词。若问题包含相对日期,应先调用 time 工具后使用绝对日期改写查询词。",
|
||||
},
|
||||
"reason": map[string]any{
|
||||
"type": "string",
|
||||
"description": "调用联网搜索的原因。",
|
||||
},
|
||||
},
|
||||
"required": []string{"query"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *State) ExecuteTool(ctx context.Context, args string) (string, error) {
|
||||
var parsed ToolArgs
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(args)), &parsed); err != nil {
|
||||
return "", fmt.Errorf("解析搜索工具参数失败: %w", err)
|
||||
}
|
||||
query := strings.TrimSpace(parsed.Query)
|
||||
if query == "" {
|
||||
return "", errors.New("搜索关键词不能为空")
|
||||
}
|
||||
results, profile, err := s.Search(ctx, query)
|
||||
if err != nil {
|
||||
return BuildErrorContext(query, err), nil
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return BuildFallbackContext(profile, query, parsed.Reason, errors.New("未搜索到相关网页结果")), nil
|
||||
}
|
||||
return BuildResultContext(profile, query, results, parsed.Reason), nil
|
||||
}
|
||||
|
||||
func (s *State) ActiveProfile() ProfileConfig {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
@@ -121,3 +121,42 @@ func TestLoadConfigWritesLegacyProfiles(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolDefinitionAndExecuteTool(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("q") != "golang" {
|
||||
t.Fatalf("query = %s", r.URL.RawQuery)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{"Heading":"Go","Abstract":"Go language","AbstractURL":"https://go.dev"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
state, err := NewState(&Config{Enabled: true, Profiles: ProfileConfigs{{Name: "ddg", Active: true, Enabled: true, Provider: "duckduckgo", BaseURL: server.URL, Count: 1, Timeout: 1}}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
definition := state.ToolDefinition("custom search")
|
||||
if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom search" {
|
||||
t.Fatalf("unexpected definition: %#v", definition)
|
||||
}
|
||||
text, err := state.ExecuteTool(context.Background(), `{"query":"golang","reason":"测试搜索"}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, want := range []string{"联网搜索", "golang", "Go", "https://go.dev", "测试搜索"} {
|
||||
if !strings.Contains(text, want) {
|
||||
t.Fatalf("tool result missing %q:\n%s", want, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteToolRejectsEmptyQuery(t *testing.T) {
|
||||
state, err := NewState(&Config{Enabled: true, Profiles: ProfileConfigs{{Name: "ddg", Active: true, Enabled: true, Provider: "duckduckgo", BaseURL: defaultBaseURL, Count: 1, Timeout: 1}}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = state.ExecuteTool(context.Background(), `{"query":" "}`)
|
||||
if err == nil || !strings.Contains(err.Error(), "不能为空") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package sqlquery
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -15,11 +16,13 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
"gopkg.in/yaml.v3"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
const (
|
||||
ToolName = "sql"
|
||||
defaultActivationPrompt = `判断用户问题是否需要查询业务数据库。
|
||||
仅当用户询问数据库表、记录、字段、时间、状态、内容、统计、最近/最早/某时间范围内的数据时返回 activate=true。
|
||||
当用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,必须返回 activate=true,并说明应查询 tab_calendar_events 表。
|
||||
@@ -80,6 +83,19 @@ type QueryResult struct {
|
||||
MaxRows int `json:"max_rows"`
|
||||
}
|
||||
|
||||
type ToolArgs struct {
|
||||
Question string `json:"question"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type SQLGenerator func(ctx context.Context, prompt string, maxTokens int) (string, error)
|
||||
|
||||
type GenerationResult struct {
|
||||
Database string `json:"database"`
|
||||
SQL string `json:"sql"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
@@ -180,6 +196,73 @@ func (s *State) ActivationPrompt() string {
|
||||
return strings.TrimSpace(s.cfg.ActivationPrompt)
|
||||
}
|
||||
|
||||
func (s *State) ToolDefinition(description string) *model.Tool {
|
||||
description = strings.TrimSpace(description)
|
||||
if description == "" {
|
||||
description = s.ActivationPrompt()
|
||||
}
|
||||
return &model.Tool{
|
||||
Type: model.ToolTypeFunction,
|
||||
Function: &model.FunctionDefinition{
|
||||
Name: ToolName,
|
||||
Description: description,
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"question": map[string]any{
|
||||
"type": "string",
|
||||
"description": "需要查询数据库的问题。若已有 time 工具结果,应使用其中的绝对日期范围解释相对时间。",
|
||||
},
|
||||
"reason": map[string]any{
|
||||
"type": "string",
|
||||
"description": "调用数据库查询工具的原因。",
|
||||
},
|
||||
},
|
||||
"required": []string{"question"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *State) ExecuteTool(ctx context.Context, args string, generator SQLGenerator) (string, error) {
|
||||
if generator == nil {
|
||||
return "", errors.New("SQL 生成器未配置")
|
||||
}
|
||||
var parsed ToolArgs
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(args)), &parsed); err != nil {
|
||||
return "", fmt.Errorf("解析 SQL 工具参数失败: %w", err)
|
||||
}
|
||||
question := strings.TrimSpace(parsed.Question)
|
||||
if question == "" {
|
||||
return "", errors.New("数据库查询问题不能为空")
|
||||
}
|
||||
schemaContext, err := s.SchemaContext(ctx)
|
||||
if err != nil {
|
||||
return BuildErrorContext(question, err), nil
|
||||
}
|
||||
generated, err := GenerateSQL(ctx, generator, question, schemaContext)
|
||||
if err != nil {
|
||||
return BuildErrorContext(question, err), nil
|
||||
}
|
||||
generated.Database = strings.TrimSpace(generated.Database)
|
||||
generated.SQL = strings.TrimSpace(generated.SQL)
|
||||
if generated.SQL == "" {
|
||||
return BuildErrorContext(question, fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason)), nil
|
||||
}
|
||||
result, err := s.ExecuteReadOnly(ctx, generated.Database, generated.SQL)
|
||||
if err != nil {
|
||||
return BuildErrorContext(question, err), nil
|
||||
}
|
||||
contextText := BuildResultContext(question, generated.SQL, result)
|
||||
if strings.TrimSpace(parsed.Reason) != "" {
|
||||
contextText += "\n调用原因:" + strings.TrimSpace(parsed.Reason)
|
||||
}
|
||||
if strings.TrimSpace(generated.Reason) != "" {
|
||||
contextText += "\nSQL 生成原因:" + strings.TrimSpace(generated.Reason)
|
||||
}
|
||||
return contextText, nil
|
||||
}
|
||||
|
||||
func (s *State) DefaultDatabase() string {
|
||||
if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.DefaultDatabase) == "" {
|
||||
return defaultDatabaseName
|
||||
@@ -220,6 +303,37 @@ func (s *State) SchemaContext(ctx context.Context) (string, error) {
|
||||
return text, nil
|
||||
}
|
||||
|
||||
func GenerateSQL(ctx context.Context, generator SQLGenerator, userQuery string, schemaContext string) (*GenerationResult, error) {
|
||||
prompt := BuildSQLGenerationPrompt(userQuery, schemaContext)
|
||||
text, err := generator(ctx, prompt, 1024)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var generated GenerationResult
|
||||
if err := json.Unmarshal([]byte(extractJSONObject(text)), &generated); err != nil {
|
||||
return nil, fmt.Errorf("解析 SQL 生成结果失败: %w", err)
|
||||
}
|
||||
return &generated, nil
|
||||
}
|
||||
|
||||
func BuildSQLGenerationPrompt(userQuery string, schemaContext string) string {
|
||||
return fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题、工具结果上下文和数据库 schema 生成一条只读 SQL。
|
||||
要求:
|
||||
- 只能返回 JSON,不要使用 Markdown。
|
||||
- JSON 格式:{"database":"数据库名称","sql":"SELECT ... LIMIT N","reason":"生成原因"}
|
||||
- 只能生成 SELECT 或 WITH 查询,禁止 INSERT/UPDATE/DELETE/DROP/ALTER/CREATE 等任何修改语句。
|
||||
- 必须只使用 schema 中出现的数据库、表和字段。
|
||||
- 如果工具结果中包含“时间工具结果”,必须使用其中的绝对日期范围解释用户问题里的今天、明天、昨天、本周、本月、本年、最近等相对时间。
|
||||
- 用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,优先查询 tab_calendar_events 表;如果 schema 中没有该表,再返回无法根据已知表结构生成查询。
|
||||
- 查询日程表时,涉及日期范围必须使用半开区间:时间字段 >= start AND 时间字段 < end_exclusive;时间字段必须从 schema 中选择真实存在的字段。
|
||||
- 必须添加 LIMIT,且 LIMIT 不超过插件配置的 max_rows。
|
||||
- 如果无法根据 schema 回答,返回 {"database":"","sql":"","reason":"无法根据已知表结构生成查询"}。
|
||||
|
||||
%s
|
||||
|
||||
用户问题:%s`, schemaContext, userQuery)
|
||||
}
|
||||
|
||||
func (s *State) ExecuteReadOnly(ctx context.Context, databaseName string, query string) (*QueryResult, error) {
|
||||
if !s.Enabled() {
|
||||
return nil, errors.New("SQL 查询插件未启用")
|
||||
@@ -588,6 +702,16 @@ func (d *database) rejectExcludedTables(query string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractJSONObject(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
start := strings.Index(text, "{")
|
||||
end := strings.LastIndex(text, "}")
|
||||
if start >= 0 && end > start {
|
||||
return text[start : end+1]
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func scanRows(rows *sql.Rows, cfg DatabaseConfig, query string) (*QueryResult, error) {
|
||||
defer rows.Close()
|
||||
columns, err := rows.Columns()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
@@ -43,3 +44,42 @@ func TestValidateReadOnlySQLRejectsUnsafeStatements(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolDefinition(t *testing.T) {
|
||||
state := &State{}
|
||||
definition := state.ToolDefinition("custom sql")
|
||||
if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom sql" {
|
||||
t.Fatalf("unexpected definition: %#v", definition)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSQLGenerationPromptIncludesSafetyRules(t *testing.T) {
|
||||
prompt := BuildSQLGenerationPrompt("本月有什么日程安排", "schema: tab_calendar_events(start_time)")
|
||||
for _, want := range []string{"只读 SQL", "SELECT", "WITH", "半开区间", "tab_calendar_events", "max_rows", "本月有什么日程安排"} {
|
||||
if !strings.Contains(prompt, want) {
|
||||
t.Fatalf("prompt missing %q:\n%s", want, prompt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSQLParsesJSONAndRejectsMalformed(t *testing.T) {
|
||||
generated, err := GenerateSQL(context.Background(), func(ctx context.Context, prompt string, maxTokens int) (string, error) {
|
||||
if !strings.Contains(prompt, "schema") || maxTokens != 1024 {
|
||||
t.Fatalf("unexpected prompt/maxTokens: %s / %d", prompt, maxTokens)
|
||||
}
|
||||
return `{"database":"default","sql":"SELECT * FROM events LIMIT 1","reason":"ok"}`, nil
|
||||
}, "查事件", "schema")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if generated.Database != "default" || !strings.Contains(generated.SQL, "SELECT") || generated.Reason != "ok" {
|
||||
t.Fatalf("unexpected generated SQL: %#v", generated)
|
||||
}
|
||||
|
||||
_, err = GenerateSQL(context.Background(), func(context.Context, string, int) (string, error) {
|
||||
return "not json", nil
|
||||
}, "查事件", "schema")
|
||||
if err == nil {
|
||||
t.Fatal("expected malformed JSON error")
|
||||
}
|
||||
}
|
||||
|
||||
+45
-2
@@ -1,12 +1,22 @@
|
||||
package timeagent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
)
|
||||
|
||||
const ActivationPrompt = "提供当前日期、时间和常用时间范围。当用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近、历史上的今天、日程安排等相对时间表达时,应先调用此工具;如果后续还需要联网搜索或查数据库,可继续调用 search 或 sql。"
|
||||
const (
|
||||
ToolName = "time"
|
||||
ActivationPrompt = "提供当前日期、时间和常用时间范围。当用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近、历史上的今天、日程安排等相对时间表达时,应先调用此工具;如果后续还需要联网搜索或查数据库,可继续调用 search 或 sql。"
|
||||
)
|
||||
|
||||
type ToolArgs struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type Range struct {
|
||||
Start time.Time
|
||||
@@ -23,6 +33,39 @@ type Context struct {
|
||||
ThisYear Range
|
||||
}
|
||||
|
||||
func ToolDefinition(description string) *model.Tool {
|
||||
description = strings.TrimSpace(description)
|
||||
if description == "" {
|
||||
description = ActivationPrompt
|
||||
}
|
||||
return &model.Tool{
|
||||
Type: model.ToolTypeFunction,
|
||||
Function: &model.FunctionDefinition{
|
||||
Name: ToolName,
|
||||
Description: description,
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"reason": map[string]any{
|
||||
"type": "string",
|
||||
"description": "调用时间工具的原因。",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ExecuteTool(args string, now time.Time) (string, error) {
|
||||
var parsed ToolArgs
|
||||
if strings.TrimSpace(args) != "" {
|
||||
if err := json.Unmarshal([]byte(args), &parsed); err != nil {
|
||||
return "", fmt.Errorf("解析时间工具参数失败: %w", err)
|
||||
}
|
||||
}
|
||||
return BuildContext(Resolve(now), parsed.Reason), nil
|
||||
}
|
||||
|
||||
func Resolve(now time.Time) Context {
|
||||
loc := now.Location()
|
||||
today := Range{Start: startOfDay(now), End: startOfDay(now).AddDate(0, 0, 1)}
|
||||
@@ -47,7 +90,7 @@ func Resolve(now time.Time) Context {
|
||||
|
||||
func BuildContext(ctx Context, routeReason string) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("时间工具结果。后续工具必须优先使用这里的绝对日期解释用户问题中的相对时间,不要自行猜测当前日期。\n")
|
||||
b.WriteString("时间工具结果。请优先使用这里的绝对日期解释用户问题中的相对时间,不要自行猜测当前日期。\n")
|
||||
fmt.Fprintf(&b, "当前本地日期时间:%s\n", ctx.Now.Format("2006-01-02 15:04:05 MST"))
|
||||
fmt.Fprintf(&b, "今天:%s\n", FormatSQLRange(ctx.Today))
|
||||
fmt.Fprintf(&b, "明天:%s\n", FormatSQLRange(ctx.Tomorrow))
|
||||
|
||||
@@ -29,3 +29,19 @@ func TestBuildContextIncludesSQLHints(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolDefinitionAndExecuteTool(t *testing.T) {
|
||||
definition := ToolDefinition("custom description")
|
||||
if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom description" {
|
||||
t.Fatalf("unexpected definition: %#v", definition)
|
||||
}
|
||||
text, err := ExecuteTool(`{"reason":"测试原因"}`, time.Date(2026, 6, 10, 13, 14, 15, 0, time.UTC))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, want := range []string{"时间工具结果", "2026-06-10", "本月", "start=", "end_exclusive=", "测试原因"} {
|
||||
if !strings.Contains(text, want) {
|
||||
t.Fatalf("tool result missing %q:\n%s", want, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user