更新工具链
This commit is contained in:
@@ -3,6 +3,7 @@ package sqlquery
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -15,11 +16,13 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
"gopkg.in/yaml.v3"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
const (
|
||||
ToolName = "sql"
|
||||
defaultActivationPrompt = `判断用户问题是否需要查询业务数据库。
|
||||
仅当用户询问数据库表、记录、字段、时间、状态、内容、统计、最近/最早/某时间范围内的数据时返回 activate=true。
|
||||
当用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,必须返回 activate=true,并说明应查询 tab_calendar_events 表。
|
||||
@@ -80,6 +83,19 @@ type QueryResult struct {
|
||||
MaxRows int `json:"max_rows"`
|
||||
}
|
||||
|
||||
type ToolArgs struct {
|
||||
Question string `json:"question"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type SQLGenerator func(ctx context.Context, prompt string, maxTokens int) (string, error)
|
||||
|
||||
type GenerationResult struct {
|
||||
Database string `json:"database"`
|
||||
SQL string `json:"sql"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
if _, err := os.Stat(path); err != nil {
|
||||
if !os.IsNotExist(err) {
|
||||
@@ -180,6 +196,73 @@ func (s *State) ActivationPrompt() string {
|
||||
return strings.TrimSpace(s.cfg.ActivationPrompt)
|
||||
}
|
||||
|
||||
func (s *State) ToolDefinition(description string) *model.Tool {
|
||||
description = strings.TrimSpace(description)
|
||||
if description == "" {
|
||||
description = s.ActivationPrompt()
|
||||
}
|
||||
return &model.Tool{
|
||||
Type: model.ToolTypeFunction,
|
||||
Function: &model.FunctionDefinition{
|
||||
Name: ToolName,
|
||||
Description: description,
|
||||
Parameters: map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"question": map[string]any{
|
||||
"type": "string",
|
||||
"description": "需要查询数据库的问题。若已有 time 工具结果,应使用其中的绝对日期范围解释相对时间。",
|
||||
},
|
||||
"reason": map[string]any{
|
||||
"type": "string",
|
||||
"description": "调用数据库查询工具的原因。",
|
||||
},
|
||||
},
|
||||
"required": []string{"question"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *State) ExecuteTool(ctx context.Context, args string, generator SQLGenerator) (string, error) {
|
||||
if generator == nil {
|
||||
return "", errors.New("SQL 生成器未配置")
|
||||
}
|
||||
var parsed ToolArgs
|
||||
if err := json.Unmarshal([]byte(strings.TrimSpace(args)), &parsed); err != nil {
|
||||
return "", fmt.Errorf("解析 SQL 工具参数失败: %w", err)
|
||||
}
|
||||
question := strings.TrimSpace(parsed.Question)
|
||||
if question == "" {
|
||||
return "", errors.New("数据库查询问题不能为空")
|
||||
}
|
||||
schemaContext, err := s.SchemaContext(ctx)
|
||||
if err != nil {
|
||||
return BuildErrorContext(question, err), nil
|
||||
}
|
||||
generated, err := GenerateSQL(ctx, generator, question, schemaContext)
|
||||
if err != nil {
|
||||
return BuildErrorContext(question, err), nil
|
||||
}
|
||||
generated.Database = strings.TrimSpace(generated.Database)
|
||||
generated.SQL = strings.TrimSpace(generated.SQL)
|
||||
if generated.SQL == "" {
|
||||
return BuildErrorContext(question, fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason)), nil
|
||||
}
|
||||
result, err := s.ExecuteReadOnly(ctx, generated.Database, generated.SQL)
|
||||
if err != nil {
|
||||
return BuildErrorContext(question, err), nil
|
||||
}
|
||||
contextText := BuildResultContext(question, generated.SQL, result)
|
||||
if strings.TrimSpace(parsed.Reason) != "" {
|
||||
contextText += "\n调用原因:" + strings.TrimSpace(parsed.Reason)
|
||||
}
|
||||
if strings.TrimSpace(generated.Reason) != "" {
|
||||
contextText += "\nSQL 生成原因:" + strings.TrimSpace(generated.Reason)
|
||||
}
|
||||
return contextText, nil
|
||||
}
|
||||
|
||||
func (s *State) DefaultDatabase() string {
|
||||
if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.DefaultDatabase) == "" {
|
||||
return defaultDatabaseName
|
||||
@@ -220,6 +303,37 @@ func (s *State) SchemaContext(ctx context.Context) (string, error) {
|
||||
return text, nil
|
||||
}
|
||||
|
||||
func GenerateSQL(ctx context.Context, generator SQLGenerator, userQuery string, schemaContext string) (*GenerationResult, error) {
|
||||
prompt := BuildSQLGenerationPrompt(userQuery, schemaContext)
|
||||
text, err := generator(ctx, prompt, 1024)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var generated GenerationResult
|
||||
if err := json.Unmarshal([]byte(extractJSONObject(text)), &generated); err != nil {
|
||||
return nil, fmt.Errorf("解析 SQL 生成结果失败: %w", err)
|
||||
}
|
||||
return &generated, nil
|
||||
}
|
||||
|
||||
func BuildSQLGenerationPrompt(userQuery string, schemaContext string) string {
|
||||
return fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题、工具结果上下文和数据库 schema 生成一条只读 SQL。
|
||||
要求:
|
||||
- 只能返回 JSON,不要使用 Markdown。
|
||||
- JSON 格式:{"database":"数据库名称","sql":"SELECT ... LIMIT N","reason":"生成原因"}
|
||||
- 只能生成 SELECT 或 WITH 查询,禁止 INSERT/UPDATE/DELETE/DROP/ALTER/CREATE 等任何修改语句。
|
||||
- 必须只使用 schema 中出现的数据库、表和字段。
|
||||
- 如果工具结果中包含“时间工具结果”,必须使用其中的绝对日期范围解释用户问题里的今天、明天、昨天、本周、本月、本年、最近等相对时间。
|
||||
- 用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,优先查询 tab_calendar_events 表;如果 schema 中没有该表,再返回无法根据已知表结构生成查询。
|
||||
- 查询日程表时,涉及日期范围必须使用半开区间:时间字段 >= start AND 时间字段 < end_exclusive;时间字段必须从 schema 中选择真实存在的字段。
|
||||
- 必须添加 LIMIT,且 LIMIT 不超过插件配置的 max_rows。
|
||||
- 如果无法根据 schema 回答,返回 {"database":"","sql":"","reason":"无法根据已知表结构生成查询"}。
|
||||
|
||||
%s
|
||||
|
||||
用户问题:%s`, schemaContext, userQuery)
|
||||
}
|
||||
|
||||
func (s *State) ExecuteReadOnly(ctx context.Context, databaseName string, query string) (*QueryResult, error) {
|
||||
if !s.Enabled() {
|
||||
return nil, errors.New("SQL 查询插件未启用")
|
||||
@@ -588,6 +702,16 @@ func (d *database) rejectExcludedTables(query string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractJSONObject(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
start := strings.Index(text, "{")
|
||||
end := strings.LastIndex(text, "}")
|
||||
if start >= 0 && end > start {
|
||||
return text[start : end+1]
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func scanRows(rows *sql.Rows, cfg DatabaseConfig, query string) (*QueryResult, error) {
|
||||
defer rows.Close()
|
||||
columns, err := rows.Columns()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package sqlquery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
@@ -43,3 +44,42 @@ func TestValidateReadOnlySQLRejectsUnsafeStatements(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolDefinition(t *testing.T) {
|
||||
state := &State{}
|
||||
definition := state.ToolDefinition("custom sql")
|
||||
if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom sql" {
|
||||
t.Fatalf("unexpected definition: %#v", definition)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSQLGenerationPromptIncludesSafetyRules(t *testing.T) {
|
||||
prompt := BuildSQLGenerationPrompt("本月有什么日程安排", "schema: tab_calendar_events(start_time)")
|
||||
for _, want := range []string{"只读 SQL", "SELECT", "WITH", "半开区间", "tab_calendar_events", "max_rows", "本月有什么日程安排"} {
|
||||
if !strings.Contains(prompt, want) {
|
||||
t.Fatalf("prompt missing %q:\n%s", want, prompt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSQLParsesJSONAndRejectsMalformed(t *testing.T) {
|
||||
generated, err := GenerateSQL(context.Background(), func(ctx context.Context, prompt string, maxTokens int) (string, error) {
|
||||
if !strings.Contains(prompt, "schema") || maxTokens != 1024 {
|
||||
t.Fatalf("unexpected prompt/maxTokens: %s / %d", prompt, maxTokens)
|
||||
}
|
||||
return `{"database":"default","sql":"SELECT * FROM events LIMIT 1","reason":"ok"}`, nil
|
||||
}, "查事件", "schema")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if generated.Database != "default" || !strings.Contains(generated.SQL, "SELECT") || generated.Reason != "ok" {
|
||||
t.Fatalf("unexpected generated SQL: %#v", generated)
|
||||
}
|
||||
|
||||
_, err = GenerateSQL(context.Background(), func(context.Context, string, int) (string, error) {
|
||||
return "not json", nil
|
||||
}, "查事件", "schema")
|
||||
if err == nil {
|
||||
t.Fatal("expected malformed JSON error")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user