Files
aichat/agents/sql/sql_query.go
T
2026-06-10 12:07:07 +08:00

765 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package sqlquery
import (
"context"
"database/sql"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"sync"
"time"
"unicode/utf8"
_ "github.com/go-sql-driver/mysql"
"gopkg.in/yaml.v3"
_ "modernc.org/sqlite"
)
const (
defaultActivationPrompt = `判断用户问题是否需要查询业务数据库。
仅当用户询问数据库表、记录、字段、时间、状态、内容、统计、最近/最早/某时间范围内的数据时返回 activate=true。
当用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,必须返回 activate=true,并说明应查询 tab_calendar_events 表。
普通知识问答、代码问题、闲聊、联网搜索问题返回 activate=false。
只返回 JSON: {"activate": true/false, "reason": "..."}`
defaultDatabaseName = "default"
defaultSQLiteDSN = "file:data/app.db?mode=ro"
defaultTimeout = 10
defaultMaxRows = 50
defaultMaxCellBytes = 4096
defaultSchemaCacheSecond = 300
)
type Config struct {
Enabled bool `yaml:"enabled" json:"enabled"`
ActivationPrompt string `yaml:"activation_prompt" json:"activation_prompt"`
DefaultDatabase string `yaml:"default_database" json:"default_database"`
SchemaCacheSeconds int `yaml:"schema_cache_seconds" json:"schema_cache_seconds"`
Databases []DatabaseConfig `yaml:"databases" json:"databases"`
}
type DatabaseConfig struct {
Name string `yaml:"name" json:"name"`
Active bool `yaml:"active" json:"active"`
Driver string `yaml:"driver" json:"driver"`
DSN string `yaml:"dsn" json:"-"`
Timeout int `yaml:"timeout" json:"timeout"`
MaxRows int `yaml:"max_rows" json:"max_rows"`
MaxCellBytes int `yaml:"max_cell_bytes" json:"max_cell_bytes"`
Schema SchemaConfig `yaml:"schema" json:"schema"`
}
type SchemaConfig struct {
IncludeTables []string `yaml:"include_tables" json:"include_tables"`
ExcludeTables []string `yaml:"exclude_tables" json:"exclude_tables"`
}
type State struct {
cfg *Config
dbs map[string]*database
order []string
cacheMu sync.Mutex
cacheText string
cacheAt time.Time
}
type database struct {
cfg DatabaseConfig
db *sql.DB
}
type QueryResult struct {
Database string `json:"database"`
SQL string `json:"sql"`
Columns []string `json:"columns"`
Rows [][]string `json:"rows"`
Truncated bool `json:"truncated"`
MaxRows int `json:"max_rows"`
}
func LoadConfig(path string) (*Config, error) {
if _, err := os.Stat(path); err != nil {
if !os.IsNotExist(err) {
return nil, fmt.Errorf("检查 SQL 查询插件配置失败: %w", err)
}
cfg := defaultConfig()
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
return nil, fmt.Errorf("创建 SQL 查询插件目录失败: %w", err)
}
data, err := yaml.Marshal(&cfg)
if err != nil {
return nil, fmt.Errorf("生成 SQL 查询插件配置失败: %w", err)
}
if err := os.WriteFile(path, data, 0644); err != nil {
return nil, fmt.Errorf("写入 SQL 查询插件配置失败: %w", err)
}
}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取 SQL 查询插件配置失败: %w", err)
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("解析 SQL 查询插件配置失败: %w", err)
}
if err := normalizeConfig(&cfg); err != nil {
return nil, err
}
return &cfg, nil
}
func NewState(cfg *Config) (*State, error) {
state := &State{cfg: cfg, dbs: map[string]*database{}}
if cfg == nil || !cfg.Enabled {
return state, nil
}
for _, item := range cfg.Databases {
db, err := sql.Open(driverName(item.Driver), item.DSN)
if err != nil {
return nil, fmt.Errorf("打开数据库 %s 失败: %w", item.Name, err)
}
if item.Driver == "sqlite" {
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
} else {
db.SetMaxOpenConns(5)
db.SetMaxIdleConns(2)
}
db.SetConnMaxLifetime(30 * time.Minute)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(item.Timeout)*time.Second)
if err := db.PingContext(ctx); err != nil {
cancel()
db.Close()
return nil, fmt.Errorf("连接数据库 %s 失败: %w", item.Name, err)
}
if item.Driver == "sqlite" {
if _, err := db.ExecContext(ctx, "PRAGMA query_only = ON"); err != nil {
cancel()
db.Close()
return nil, fmt.Errorf("启用 SQLite 只读模式失败: %w", err)
}
}
cancel()
state.dbs[item.Name] = &database{cfg: item, db: db}
state.order = append(state.order, item.Name)
}
return state, nil
}
func (s *State) Close() error {
if s == nil {
return nil
}
var errs []string
for _, db := range s.dbs {
if err := db.db.Close(); err != nil {
errs = append(errs, err.Error())
}
}
if len(errs) > 0 {
return errors.New(strings.Join(errs, "; "))
}
return nil
}
func (s *State) Enabled() bool {
return s != nil && s.cfg != nil && s.cfg.Enabled && len(s.dbs) > 0
}
func (s *State) ActivationPrompt() string {
if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.ActivationPrompt) == "" {
return defaultActivationPrompt
}
return strings.TrimSpace(s.cfg.ActivationPrompt)
}
func (s *State) DefaultDatabase() string {
if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.DefaultDatabase) == "" {
return defaultDatabaseName
}
return strings.TrimSpace(s.cfg.DefaultDatabase)
}
func (s *State) SchemaContext(ctx context.Context) (string, error) {
if !s.Enabled() {
return "", errors.New("SQL 查询插件未启用")
}
s.cacheMu.Lock()
if s.cacheText != "" && time.Since(s.cacheAt) < time.Duration(s.cfg.SchemaCacheSeconds)*time.Second {
text := s.cacheText
s.cacheMu.Unlock()
return text, nil
}
s.cacheMu.Unlock()
var b strings.Builder
fmt.Fprintf(&b, "可查询数据库列表(只能生成 SELECT/WITH 查询):\n")
for _, name := range s.order {
handle := s.dbs[name]
fmt.Fprintf(&b, "\n数据库 %s,类型 %s,单次最多返回 %d 行:\n", handle.cfg.Name, handle.cfg.Driver, handle.cfg.MaxRows)
schema, err := handle.schemaContext(ctx)
if err != nil {
return "", err
}
b.WriteString(schema)
}
text := b.String()
s.cacheMu.Lock()
s.cacheText = text
s.cacheAt = time.Now()
s.cacheMu.Unlock()
return text, nil
}
func (s *State) ExecuteReadOnly(ctx context.Context, databaseName string, query string) (*QueryResult, error) {
if !s.Enabled() {
return nil, errors.New("SQL 查询插件未启用")
}
if err := ValidateReadOnlySQL(query); err != nil {
return nil, err
}
handle := s.databaseByName(databaseName)
if handle == nil {
return nil, fmt.Errorf("数据库配置不存在: %s", databaseName)
}
if err := handle.rejectExcludedTables(query); err != nil {
return nil, err
}
timeout := time.Duration(handle.cfg.Timeout) * time.Second
queryCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if handle.cfg.Driver == "mysql" {
tx, err := handle.db.BeginTx(queryCtx, &sql.TxOptions{ReadOnly: true})
if err != nil {
return nil, fmt.Errorf("开启只读事务失败: %w", err)
}
defer tx.Rollback()
rows, err := tx.QueryContext(queryCtx, query)
if err != nil {
return nil, fmt.Errorf("执行 SQL 查询失败: %w", err)
}
result, err := scanRows(rows, handle.cfg, query)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("提交只读事务失败: %w", err)
}
return result, nil
}
rows, err := handle.db.QueryContext(queryCtx, query)
if err != nil {
return nil, fmt.Errorf("执行 SQL 查询失败: %w", err)
}
return scanRows(rows, handle.cfg, query)
}
func BuildResultContext(userQuery string, generatedSQL string, result *QueryResult) string {
var b strings.Builder
fmt.Fprintf(&b, "用户问题需要查询本地数据库。请仅根据以下 SQL 查询结果回答;不要编造结果中不存在的记录。\n")
fmt.Fprintf(&b, "用户问题:%s\n", userQuery)
fmt.Fprintf(&b, "数据库:%s\n", result.Database)
fmt.Fprintf(&b, "已执行只读 SQL%s\n", generatedSQL)
if len(result.Columns) == 0 {
b.WriteString("查询没有返回列。\n")
return b.String()
}
if len(result.Rows) == 0 {
b.WriteString("查询结果:没有匹配记录。\n")
return b.String()
}
b.WriteString("查询结果:\n")
b.WriteString(markdownTable(result.Columns, result.Rows))
if result.Truncated {
fmt.Fprintf(&b, "\n结果已按配置截断,只展示前 %d 行。\n", result.MaxRows)
}
return b.String()
}
func BuildErrorContext(userQuery string, err error) string {
return fmt.Sprintf("用户问题可能需要查询本地数据库,但 SQL 查询插件执行失败:%s\n用户问题:%s\n请向用户说明无法完成数据库查询,不要编造数据库记录。", err.Error(), userQuery)
}
func ValidateReadOnlySQL(query string) error {
trimmed := strings.TrimSpace(query)
if trimmed == "" {
return errors.New("SQL 不能为空")
}
if strings.Contains(trimmed, "--") || strings.Contains(trimmed, "/*") || strings.Contains(trimmed, "*/") {
return errors.New("SQL 不允许包含注释")
}
body := strings.TrimSuffix(trimmed, ";")
if strings.Contains(body, ";") {
return errors.New("SQL 只允许单条语句")
}
upper := strings.ToUpper(body)
first := firstToken(upper)
if first != "SELECT" && first != "WITH" {
return fmt.Errorf("SQL 只允许 SELECT/WITH 查询,当前为 %s", first)
}
stripped := stripSingleQuotedStrings(upper)
forbidden := []string{
"INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "TRUNCATE", "REPLACE", "MERGE",
"GRANT", "REVOKE", "VACUUM", "ANALYZE", "ATTACH", "DETACH", "LOAD", "CALL", "EXEC",
"SET", "USE", "LOCK", "UNLOCK", "BEGIN", "COMMIT", "ROLLBACK",
}
for _, word := range forbidden {
if hasSQLWord(stripped, word) {
return fmt.Errorf("SQL 包含禁止关键词: %s", word)
}
}
risky := []string{"INTO OUTFILE", "INTO DUMPFILE", "LOAD_FILE", "SLEEP", "BENCHMARK", "LOAD_EXTENSION"}
for _, phrase := range risky {
if strings.Contains(stripped, phrase) {
return fmt.Errorf("SQL 包含禁止函数或语法: %s", phrase)
}
}
return nil
}
func defaultConfig() Config {
return Config{
Enabled: false,
ActivationPrompt: defaultActivationPrompt,
DefaultDatabase: defaultDatabaseName,
SchemaCacheSeconds: defaultSchemaCacheSecond,
Databases: []DatabaseConfig{{
Name: defaultDatabaseName,
Active: true,
Driver: "sqlite",
DSN: defaultSQLiteDSN,
Timeout: defaultTimeout,
MaxRows: defaultMaxRows,
MaxCellBytes: defaultMaxCellBytes,
}},
}
}
func normalizeConfig(cfg *Config) error {
if strings.TrimSpace(cfg.ActivationPrompt) == "" {
cfg.ActivationPrompt = defaultActivationPrompt
}
if strings.TrimSpace(cfg.DefaultDatabase) == "" {
cfg.DefaultDatabase = defaultDatabaseName
}
if cfg.SchemaCacheSeconds <= 0 {
cfg.SchemaCacheSeconds = defaultSchemaCacheSecond
}
if len(cfg.Databases) == 0 {
cfg.Databases = defaultConfig().Databases
}
seen := map[string]bool{}
activeIndex := -1
for i := range cfg.Databases {
item := &cfg.Databases[i]
item.Name = strings.TrimSpace(item.Name)
if item.Name == "" {
item.Name = fmt.Sprintf("database-%d", i+1)
}
if seen[item.Name] {
return fmt.Errorf("SQL 查询插件数据库名称重复: %s", item.Name)
}
seen[item.Name] = true
item.Driver = strings.ToLower(strings.TrimSpace(item.Driver))
if item.Driver == "" {
item.Driver = "sqlite"
}
if item.Driver != "sqlite" && item.Driver != "mysql" {
return fmt.Errorf("SQL 查询插件暂不支持数据库类型: %s", item.Driver)
}
if strings.TrimSpace(item.DSN) == "" {
return fmt.Errorf("数据库 %s 缺少 dsn", item.Name)
}
if item.Timeout <= 0 {
item.Timeout = defaultTimeout
}
if item.MaxRows <= 0 {
item.MaxRows = defaultMaxRows
}
if item.MaxCellBytes <= 0 {
item.MaxCellBytes = defaultMaxCellBytes
}
item.Schema.IncludeTables = cleanList(item.Schema.IncludeTables)
item.Schema.ExcludeTables = cleanList(item.Schema.ExcludeTables)
if item.Active {
if activeIndex == -1 {
activeIndex = i
} else {
item.Active = false
}
}
}
if activeIndex == -1 {
cfg.Databases[0].Active = true
}
if !seen[cfg.DefaultDatabase] {
for _, item := range cfg.Databases {
if item.Active {
cfg.DefaultDatabase = item.Name
break
}
}
}
return nil
}
func driverName(driver string) string {
if driver == "sqlite" {
return "sqlite"
}
return driver
}
func (s *State) databaseByName(name string) *database {
name = strings.TrimSpace(name)
if name == "" {
name = s.DefaultDatabase()
}
if db := s.dbs[name]; db != nil {
return db
}
return s.dbs[s.DefaultDatabase()]
}
func (d *database) schemaContext(ctx context.Context) (string, error) {
timeout := time.Duration(d.cfg.Timeout) * time.Second
schemaCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
if d.cfg.Driver == "mysql" {
return d.mysqlSchemaContext(schemaCtx)
}
return d.sqliteSchemaContext(schemaCtx)
}
func (d *database) sqliteSchemaContext(ctx context.Context) (string, error) {
rows, err := d.db.QueryContext(ctx, `SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name`)
if err != nil {
return "", fmt.Errorf("读取 SQLite 表列表失败: %w", err)
}
defer rows.Close()
var tables []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return "", err
}
if d.tableAllowed(name) {
tables = append(tables, name)
}
}
if err := rows.Err(); err != nil {
return "", err
}
var b strings.Builder
for _, table := range tables {
fmt.Fprintf(&b, "- 表 %s\n", table)
colRows, err := d.db.QueryContext(ctx, "PRAGMA table_info("+quoteSQLiteString(table)+")")
if err != nil {
return "", fmt.Errorf("读取 SQLite 表 %s 字段失败: %w", table, err)
}
for colRows.Next() {
var cid int
var name, typ string
var notNull int
var defaultValue any
var pk int
if err := colRows.Scan(&cid, &name, &typ, &notNull, &defaultValue, &pk); err != nil {
colRows.Close()
return "", err
}
extra := ""
if pk > 0 {
extra = " primary_key"
}
fmt.Fprintf(&b, " - %s %s%s\n", name, typ, extra)
}
if err := colRows.Close(); err != nil {
return "", err
}
}
if len(tables) == 0 {
b.WriteString("- 没有可查询表,或表被 include/exclude 规则过滤。\n")
}
return b.String(), nil
}
func (d *database) mysqlSchemaContext(ctx context.Context) (string, error) {
rows, err := d.db.QueryContext(ctx, `SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE() AND table_type = 'BASE TABLE' ORDER BY table_name`)
if err != nil {
return "", fmt.Errorf("读取 MySQL 表列表失败: %w", err)
}
defer rows.Close()
var tables []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return "", err
}
if d.tableAllowed(name) {
tables = append(tables, name)
}
}
if err := rows.Err(); err != nil {
return "", err
}
var b strings.Builder
for _, table := range tables {
fmt.Fprintf(&b, "- 表 %s\n", table)
colRows, err := d.db.QueryContext(ctx, `SELECT column_name, data_type, is_nullable, column_key FROM information_schema.columns WHERE table_schema = DATABASE() AND table_name = ? ORDER BY ordinal_position`, table)
if err != nil {
return "", fmt.Errorf("读取 MySQL 表 %s 字段失败: %w", table, err)
}
for colRows.Next() {
var name, typ, nullable, key string
if err := colRows.Scan(&name, &typ, &nullable, &key); err != nil {
colRows.Close()
return "", err
}
extra := ""
if key == "PRI" {
extra = " primary_key"
}
if nullable == "NO" {
extra += " not_null"
}
fmt.Fprintf(&b, " - %s %s%s\n", name, typ, extra)
}
if err := colRows.Close(); err != nil {
return "", err
}
}
if len(tables) == 0 {
b.WriteString("- 没有可查询表,或表被 include/exclude 规则过滤。\n")
}
return b.String(), nil
}
func (d *database) tableAllowed(table string) bool {
name := strings.ToLower(strings.TrimSpace(table))
include := lowerSet(d.cfg.Schema.IncludeTables)
if len(include) > 0 && !include[name] {
return false
}
exclude := lowerSet(d.cfg.Schema.ExcludeTables)
return !exclude[name]
}
func (d *database) rejectExcludedTables(query string) error {
cleaned := strings.ToLower(stripSingleQuotedStrings(query))
include := lowerSet(d.cfg.Schema.IncludeTables)
if len(include) > 0 {
matched := false
for table := range include {
if hasSQLWord(cleaned, table) {
matched = true
break
}
}
if !matched {
return errors.New("SQL 未访问 include_tables 中允许的表")
}
}
exclude := lowerSet(d.cfg.Schema.ExcludeTables)
for table := range exclude {
if hasSQLWord(cleaned, table) {
return fmt.Errorf("SQL 访问了被排除的表: %s", table)
}
}
return nil
}
func scanRows(rows *sql.Rows, cfg DatabaseConfig, query string) (*QueryResult, error) {
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("读取查询列失败: %w", err)
}
result := &QueryResult{
Database: cfg.Name,
SQL: query,
Columns: columns,
MaxRows: cfg.MaxRows,
}
for rows.Next() {
values := make([]any, len(columns))
ptrs := make([]any, len(columns))
for i := range values {
ptrs[i] = &values[i]
}
if err := rows.Scan(ptrs...); err != nil {
return nil, fmt.Errorf("读取查询结果失败: %w", err)
}
if len(result.Rows) >= cfg.MaxRows {
result.Truncated = true
break
}
row := make([]string, len(columns))
for i, value := range values {
row[i] = formatCell(value, cfg.MaxCellBytes)
}
result.Rows = append(result.Rows, row)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("读取查询结果失败: %w", err)
}
return result, nil
}
func formatCell(value any, maxBytes int) string {
if value == nil {
return "NULL"
}
switch v := value.(type) {
case time.Time:
return v.Format(time.RFC3339)
case []byte:
if !utf8.Valid(v) {
return fmt.Sprintf("<binary %d bytes>", len(v))
}
return truncateString(string(v), maxBytes)
case string:
return truncateString(v, maxBytes)
default:
return truncateString(fmt.Sprint(v), maxBytes)
}
}
func truncateString(s string, maxBytes int) string {
if maxBytes <= 0 || len(s) <= maxBytes {
return s
}
cut := s[:maxBytes]
for !utf8.ValidString(cut) && len(cut) > 0 {
cut = cut[:len(cut)-1]
}
return cut + "...<truncated>"
}
func markdownTable(columns []string, rows [][]string) string {
var b strings.Builder
b.WriteString("| ")
for i, col := range columns {
if i > 0 {
b.WriteString(" | ")
}
b.WriteString(escapeMarkdownCell(col))
}
b.WriteString(" |\n| ")
for i := range columns {
if i > 0 {
b.WriteString(" | ")
}
b.WriteString("---")
}
b.WriteString(" |\n")
for _, row := range rows {
b.WriteString("| ")
for i, cell := range row {
if i > 0 {
b.WriteString(" | ")
}
b.WriteString(escapeMarkdownCell(cell))
}
b.WriteString(" |\n")
}
return b.String()
}
func escapeMarkdownCell(s string) string {
s = strings.ReplaceAll(s, "|", "\\|")
s = strings.ReplaceAll(s, "\r\n", " ")
s = strings.ReplaceAll(s, "\n", " ")
s = strings.ReplaceAll(s, "\r", " ")
return s
}
func firstToken(s string) string {
fields := strings.Fields(s)
if len(fields) == 0 {
return ""
}
return strings.Trim(fields[0], "();")
}
func stripSingleQuotedStrings(s string) string {
var b strings.Builder
inString := false
for i := 0; i < len(s); i++ {
ch := s[i]
if ch == '\'' {
if inString && i+1 < len(s) && s[i+1] == '\'' {
i++
continue
}
inString = !inString
b.WriteByte(' ')
continue
}
if inString {
b.WriteByte(' ')
} else {
b.WriteByte(ch)
}
}
return b.String()
}
func hasSQLWord(s string, word string) bool {
pattern := `(?i)(^|[^a-zA-Z0-9_])` + regexp.QuoteMeta(word) + `([^a-zA-Z0-9_]|$)`
return regexp.MustCompile(pattern).FindStringIndex(s) != nil
}
func quoteSQLiteString(s string) string {
return "'" + strings.ReplaceAll(s, "'", "''") + "'"
}
func lowerSet(items []string) map[string]bool {
set := map[string]bool{}
for _, item := range items {
item = strings.ToLower(strings.TrimSpace(item))
if item != "" {
set[item] = true
}
}
return set
}
func cleanList(items []string) []string {
seen := map[string]bool{}
var cleaned []string
for _, item := range items {
item = strings.TrimSpace(item)
if item == "" {
continue
}
key := strings.ToLower(item)
if seen[key] {
continue
}
seen[key] = true
cleaned = append(cleaned, item)
}
sort.Strings(cleaned)
return cleaned
}