更新工具链

This commit is contained in:
2026-06-11 18:04:47 +08:00
parent 440f83f6a7
commit d1324dc2f2
9 changed files with 718 additions and 570 deletions
+124
View File
@@ -3,6 +3,7 @@ package sqlquery
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"os"
@@ -15,11 +16,13 @@ import (
"unicode/utf8"
_ "github.com/go-sql-driver/mysql"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
"gopkg.in/yaml.v3"
_ "modernc.org/sqlite"
)
const (
ToolName = "sql"
defaultActivationPrompt = `判断用户问题是否需要查询业务数据库。
仅当用户询问数据库表、记录、字段、时间、状态、内容、统计、最近/最早/某时间范围内的数据时返回 activate=true。
当用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,必须返回 activate=true,并说明应查询 tab_calendar_events 表。
@@ -80,6 +83,19 @@ type QueryResult struct {
MaxRows int `json:"max_rows"`
}
type ToolArgs struct {
Question string `json:"question"`
Reason string `json:"reason"`
}
type SQLGenerator func(ctx context.Context, prompt string, maxTokens int) (string, error)
type GenerationResult struct {
Database string `json:"database"`
SQL string `json:"sql"`
Reason string `json:"reason"`
}
func LoadConfig(path string) (*Config, error) {
if _, err := os.Stat(path); err != nil {
if !os.IsNotExist(err) {
@@ -180,6 +196,73 @@ func (s *State) ActivationPrompt() string {
return strings.TrimSpace(s.cfg.ActivationPrompt)
}
func (s *State) ToolDefinition(description string) *model.Tool {
description = strings.TrimSpace(description)
if description == "" {
description = s.ActivationPrompt()
}
return &model.Tool{
Type: model.ToolTypeFunction,
Function: &model.FunctionDefinition{
Name: ToolName,
Description: description,
Parameters: map[string]any{
"type": "object",
"properties": map[string]any{
"question": map[string]any{
"type": "string",
"description": "需要查询数据库的问题。若已有 time 工具结果,应使用其中的绝对日期范围解释相对时间。",
},
"reason": map[string]any{
"type": "string",
"description": "调用数据库查询工具的原因。",
},
},
"required": []string{"question"},
},
},
}
}
func (s *State) ExecuteTool(ctx context.Context, args string, generator SQLGenerator) (string, error) {
if generator == nil {
return "", errors.New("SQL 生成器未配置")
}
var parsed ToolArgs
if err := json.Unmarshal([]byte(strings.TrimSpace(args)), &parsed); err != nil {
return "", fmt.Errorf("解析 SQL 工具参数失败: %w", err)
}
question := strings.TrimSpace(parsed.Question)
if question == "" {
return "", errors.New("数据库查询问题不能为空")
}
schemaContext, err := s.SchemaContext(ctx)
if err != nil {
return BuildErrorContext(question, err), nil
}
generated, err := GenerateSQL(ctx, generator, question, schemaContext)
if err != nil {
return BuildErrorContext(question, err), nil
}
generated.Database = strings.TrimSpace(generated.Database)
generated.SQL = strings.TrimSpace(generated.SQL)
if generated.SQL == "" {
return BuildErrorContext(question, fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason)), nil
}
result, err := s.ExecuteReadOnly(ctx, generated.Database, generated.SQL)
if err != nil {
return BuildErrorContext(question, err), nil
}
contextText := BuildResultContext(question, generated.SQL, result)
if strings.TrimSpace(parsed.Reason) != "" {
contextText += "\n调用原因:" + strings.TrimSpace(parsed.Reason)
}
if strings.TrimSpace(generated.Reason) != "" {
contextText += "\nSQL 生成原因:" + strings.TrimSpace(generated.Reason)
}
return contextText, nil
}
func (s *State) DefaultDatabase() string {
if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.DefaultDatabase) == "" {
return defaultDatabaseName
@@ -220,6 +303,37 @@ func (s *State) SchemaContext(ctx context.Context) (string, error) {
return text, nil
}
func GenerateSQL(ctx context.Context, generator SQLGenerator, userQuery string, schemaContext string) (*GenerationResult, error) {
prompt := BuildSQLGenerationPrompt(userQuery, schemaContext)
text, err := generator(ctx, prompt, 1024)
if err != nil {
return nil, err
}
var generated GenerationResult
if err := json.Unmarshal([]byte(extractJSONObject(text)), &generated); err != nil {
return nil, fmt.Errorf("解析 SQL 生成结果失败: %w", err)
}
return &generated, nil
}
func BuildSQLGenerationPrompt(userQuery string, schemaContext string) string {
return fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题、工具结果上下文和数据库 schema 生成一条只读 SQL。
要求:
- 只能返回 JSON,不要使用 Markdown。
- JSON 格式:{"database":"数据库名称","sql":"SELECT ... LIMIT N","reason":"生成原因"}
- 只能生成 SELECT 或 WITH 查询,禁止 INSERT/UPDATE/DELETE/DROP/ALTER/CREATE 等任何修改语句。
- 必须只使用 schema 中出现的数据库、表和字段。
- 如果工具结果中包含“时间工具结果”,必须使用其中的绝对日期范围解释用户问题里的今天、明天、昨天、本周、本月、本年、最近等相对时间。
- 用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,优先查询 tab_calendar_events 表;如果 schema 中没有该表,再返回无法根据已知表结构生成查询。
- 查询日程表时,涉及日期范围必须使用半开区间:时间字段 >= start AND 时间字段 < end_exclusive;时间字段必须从 schema 中选择真实存在的字段。
- 必须添加 LIMIT,且 LIMIT 不超过插件配置的 max_rows。
- 如果无法根据 schema 回答,返回 {"database":"","sql":"","reason":"无法根据已知表结构生成查询"}。
%s
用户问题:%s`, schemaContext, userQuery)
}
func (s *State) ExecuteReadOnly(ctx context.Context, databaseName string, query string) (*QueryResult, error) {
if !s.Enabled() {
return nil, errors.New("SQL 查询插件未启用")
@@ -588,6 +702,16 @@ func (d *database) rejectExcludedTables(query string) error {
return nil
}
func extractJSONObject(text string) string {
text = strings.TrimSpace(text)
start := strings.Index(text, "{")
end := strings.LastIndex(text, "}")
if start >= 0 && end > start {
return text[start : end+1]
}
return text
}
func scanRows(rows *sql.Rows, cfg DatabaseConfig, query string) (*QueryResult, error) {
defer rows.Close()
columns, err := rows.Columns()