765 lines
20 KiB
Go
765 lines
20 KiB
Go
package sqlquery
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"errors"
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"regexp"
|
||
"sort"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
"unicode/utf8"
|
||
|
||
_ "github.com/go-sql-driver/mysql"
|
||
"gopkg.in/yaml.v3"
|
||
_ "modernc.org/sqlite"
|
||
)
|
||
|
||
const (
|
||
defaultActivationPrompt = `判断用户问题是否需要查询业务数据库。
|
||
仅当用户询问数据库表、记录、字段、时间、状态、内容、统计、最近/最早/某时间范围内的数据时返回 activate=true。
|
||
当用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,必须返回 activate=true,并说明应查询 tab_calendar_events 表。
|
||
普通知识问答、代码问题、闲聊、联网搜索问题返回 activate=false。
|
||
只返回 JSON: {"activate": true/false, "reason": "..."}`
|
||
defaultDatabaseName = "default"
|
||
defaultSQLiteDSN = "file:data/app.db?mode=ro"
|
||
defaultTimeout = 10
|
||
defaultMaxRows = 50
|
||
defaultMaxCellBytes = 4096
|
||
defaultSchemaCacheSecond = 300
|
||
)
|
||
|
||
type Config struct {
|
||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||
ActivationPrompt string `yaml:"activation_prompt" json:"activation_prompt"`
|
||
DefaultDatabase string `yaml:"default_database" json:"default_database"`
|
||
SchemaCacheSeconds int `yaml:"schema_cache_seconds" json:"schema_cache_seconds"`
|
||
Databases []DatabaseConfig `yaml:"databases" json:"databases"`
|
||
}
|
||
|
||
type DatabaseConfig struct {
|
||
Name string `yaml:"name" json:"name"`
|
||
Active bool `yaml:"active" json:"active"`
|
||
Driver string `yaml:"driver" json:"driver"`
|
||
DSN string `yaml:"dsn" json:"-"`
|
||
Timeout int `yaml:"timeout" json:"timeout"`
|
||
MaxRows int `yaml:"max_rows" json:"max_rows"`
|
||
MaxCellBytes int `yaml:"max_cell_bytes" json:"max_cell_bytes"`
|
||
Schema SchemaConfig `yaml:"schema" json:"schema"`
|
||
}
|
||
|
||
type SchemaConfig struct {
|
||
IncludeTables []string `yaml:"include_tables" json:"include_tables"`
|
||
ExcludeTables []string `yaml:"exclude_tables" json:"exclude_tables"`
|
||
}
|
||
|
||
type State struct {
|
||
cfg *Config
|
||
dbs map[string]*database
|
||
order []string
|
||
cacheMu sync.Mutex
|
||
cacheText string
|
||
cacheAt time.Time
|
||
}
|
||
|
||
type database struct {
|
||
cfg DatabaseConfig
|
||
db *sql.DB
|
||
}
|
||
|
||
type QueryResult struct {
|
||
Database string `json:"database"`
|
||
SQL string `json:"sql"`
|
||
Columns []string `json:"columns"`
|
||
Rows [][]string `json:"rows"`
|
||
Truncated bool `json:"truncated"`
|
||
MaxRows int `json:"max_rows"`
|
||
}
|
||
|
||
func LoadConfig(path string) (*Config, error) {
|
||
if _, err := os.Stat(path); err != nil {
|
||
if !os.IsNotExist(err) {
|
||
return nil, fmt.Errorf("检查 SQL 查询插件配置失败: %w", err)
|
||
}
|
||
cfg := defaultConfig()
|
||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||
return nil, fmt.Errorf("创建 SQL 查询插件目录失败: %w", err)
|
||
}
|
||
data, err := yaml.Marshal(&cfg)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("生成 SQL 查询插件配置失败: %w", err)
|
||
}
|
||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||
return nil, fmt.Errorf("写入 SQL 查询插件配置失败: %w", err)
|
||
}
|
||
}
|
||
|
||
data, err := os.ReadFile(path)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取 SQL 查询插件配置失败: %w", err)
|
||
}
|
||
var cfg Config
|
||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||
return nil, fmt.Errorf("解析 SQL 查询插件配置失败: %w", err)
|
||
}
|
||
if err := normalizeConfig(&cfg); err != nil {
|
||
return nil, err
|
||
}
|
||
return &cfg, nil
|
||
}
|
||
|
||
func NewState(cfg *Config) (*State, error) {
|
||
state := &State{cfg: cfg, dbs: map[string]*database{}}
|
||
if cfg == nil || !cfg.Enabled {
|
||
return state, nil
|
||
}
|
||
|
||
for _, item := range cfg.Databases {
|
||
db, err := sql.Open(driverName(item.Driver), item.DSN)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("打开数据库 %s 失败: %w", item.Name, err)
|
||
}
|
||
if item.Driver == "sqlite" {
|
||
db.SetMaxOpenConns(1)
|
||
db.SetMaxIdleConns(1)
|
||
} else {
|
||
db.SetMaxOpenConns(5)
|
||
db.SetMaxIdleConns(2)
|
||
}
|
||
db.SetConnMaxLifetime(30 * time.Minute)
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(item.Timeout)*time.Second)
|
||
if err := db.PingContext(ctx); err != nil {
|
||
cancel()
|
||
db.Close()
|
||
return nil, fmt.Errorf("连接数据库 %s 失败: %w", item.Name, err)
|
||
}
|
||
if item.Driver == "sqlite" {
|
||
if _, err := db.ExecContext(ctx, "PRAGMA query_only = ON"); err != nil {
|
||
cancel()
|
||
db.Close()
|
||
return nil, fmt.Errorf("启用 SQLite 只读模式失败: %w", err)
|
||
}
|
||
}
|
||
cancel()
|
||
|
||
state.dbs[item.Name] = &database{cfg: item, db: db}
|
||
state.order = append(state.order, item.Name)
|
||
}
|
||
return state, nil
|
||
}
|
||
|
||
func (s *State) Close() error {
|
||
if s == nil {
|
||
return nil
|
||
}
|
||
var errs []string
|
||
for _, db := range s.dbs {
|
||
if err := db.db.Close(); err != nil {
|
||
errs = append(errs, err.Error())
|
||
}
|
||
}
|
||
if len(errs) > 0 {
|
||
return errors.New(strings.Join(errs, "; "))
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (s *State) Enabled() bool {
|
||
return s != nil && s.cfg != nil && s.cfg.Enabled && len(s.dbs) > 0
|
||
}
|
||
|
||
func (s *State) ActivationPrompt() string {
|
||
if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.ActivationPrompt) == "" {
|
||
return defaultActivationPrompt
|
||
}
|
||
return strings.TrimSpace(s.cfg.ActivationPrompt)
|
||
}
|
||
|
||
func (s *State) DefaultDatabase() string {
|
||
if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.DefaultDatabase) == "" {
|
||
return defaultDatabaseName
|
||
}
|
||
return strings.TrimSpace(s.cfg.DefaultDatabase)
|
||
}
|
||
|
||
func (s *State) SchemaContext(ctx context.Context) (string, error) {
|
||
if !s.Enabled() {
|
||
return "", errors.New("SQL 查询插件未启用")
|
||
}
|
||
|
||
s.cacheMu.Lock()
|
||
if s.cacheText != "" && time.Since(s.cacheAt) < time.Duration(s.cfg.SchemaCacheSeconds)*time.Second {
|
||
text := s.cacheText
|
||
s.cacheMu.Unlock()
|
||
return text, nil
|
||
}
|
||
s.cacheMu.Unlock()
|
||
|
||
var b strings.Builder
|
||
fmt.Fprintf(&b, "可查询数据库列表(只能生成 SELECT/WITH 查询):\n")
|
||
for _, name := range s.order {
|
||
handle := s.dbs[name]
|
||
fmt.Fprintf(&b, "\n数据库 %s,类型 %s,单次最多返回 %d 行:\n", handle.cfg.Name, handle.cfg.Driver, handle.cfg.MaxRows)
|
||
schema, err := handle.schemaContext(ctx)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
b.WriteString(schema)
|
||
}
|
||
text := b.String()
|
||
|
||
s.cacheMu.Lock()
|
||
s.cacheText = text
|
||
s.cacheAt = time.Now()
|
||
s.cacheMu.Unlock()
|
||
return text, nil
|
||
}
|
||
|
||
func (s *State) ExecuteReadOnly(ctx context.Context, databaseName string, query string) (*QueryResult, error) {
|
||
if !s.Enabled() {
|
||
return nil, errors.New("SQL 查询插件未启用")
|
||
}
|
||
if err := ValidateReadOnlySQL(query); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
handle := s.databaseByName(databaseName)
|
||
if handle == nil {
|
||
return nil, fmt.Errorf("数据库配置不存在: %s", databaseName)
|
||
}
|
||
if err := handle.rejectExcludedTables(query); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
timeout := time.Duration(handle.cfg.Timeout) * time.Second
|
||
queryCtx, cancel := context.WithTimeout(ctx, timeout)
|
||
defer cancel()
|
||
|
||
if handle.cfg.Driver == "mysql" {
|
||
tx, err := handle.db.BeginTx(queryCtx, &sql.TxOptions{ReadOnly: true})
|
||
if err != nil {
|
||
return nil, fmt.Errorf("开启只读事务失败: %w", err)
|
||
}
|
||
defer tx.Rollback()
|
||
rows, err := tx.QueryContext(queryCtx, query)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("执行 SQL 查询失败: %w", err)
|
||
}
|
||
result, err := scanRows(rows, handle.cfg, query)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if err := tx.Commit(); err != nil {
|
||
return nil, fmt.Errorf("提交只读事务失败: %w", err)
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
rows, err := handle.db.QueryContext(queryCtx, query)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("执行 SQL 查询失败: %w", err)
|
||
}
|
||
return scanRows(rows, handle.cfg, query)
|
||
}
|
||
|
||
func BuildResultContext(userQuery string, generatedSQL string, result *QueryResult) string {
|
||
var b strings.Builder
|
||
fmt.Fprintf(&b, "用户问题需要查询本地数据库。请仅根据以下 SQL 查询结果回答;不要编造结果中不存在的记录。\n")
|
||
fmt.Fprintf(&b, "用户问题:%s\n", userQuery)
|
||
fmt.Fprintf(&b, "数据库:%s\n", result.Database)
|
||
fmt.Fprintf(&b, "已执行只读 SQL:%s\n", generatedSQL)
|
||
if len(result.Columns) == 0 {
|
||
b.WriteString("查询没有返回列。\n")
|
||
return b.String()
|
||
}
|
||
if len(result.Rows) == 0 {
|
||
b.WriteString("查询结果:没有匹配记录。\n")
|
||
return b.String()
|
||
}
|
||
b.WriteString("查询结果:\n")
|
||
b.WriteString(markdownTable(result.Columns, result.Rows))
|
||
if result.Truncated {
|
||
fmt.Fprintf(&b, "\n结果已按配置截断,只展示前 %d 行。\n", result.MaxRows)
|
||
}
|
||
return b.String()
|
||
}
|
||
|
||
func BuildErrorContext(userQuery string, err error) string {
|
||
return fmt.Sprintf("用户问题可能需要查询本地数据库,但 SQL 查询插件执行失败:%s\n用户问题:%s\n请向用户说明无法完成数据库查询,不要编造数据库记录。", err.Error(), userQuery)
|
||
}
|
||
|
||
func ValidateReadOnlySQL(query string) error {
|
||
trimmed := strings.TrimSpace(query)
|
||
if trimmed == "" {
|
||
return errors.New("SQL 不能为空")
|
||
}
|
||
if strings.Contains(trimmed, "--") || strings.Contains(trimmed, "/*") || strings.Contains(trimmed, "*/") {
|
||
return errors.New("SQL 不允许包含注释")
|
||
}
|
||
body := strings.TrimSuffix(trimmed, ";")
|
||
if strings.Contains(body, ";") {
|
||
return errors.New("SQL 只允许单条语句")
|
||
}
|
||
upper := strings.ToUpper(body)
|
||
first := firstToken(upper)
|
||
if first != "SELECT" && first != "WITH" {
|
||
return fmt.Errorf("SQL 只允许 SELECT/WITH 查询,当前为 %s", first)
|
||
}
|
||
|
||
stripped := stripSingleQuotedStrings(upper)
|
||
forbidden := []string{
|
||
"INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "TRUNCATE", "REPLACE", "MERGE",
|
||
"GRANT", "REVOKE", "VACUUM", "ANALYZE", "ATTACH", "DETACH", "LOAD", "CALL", "EXEC",
|
||
"SET", "USE", "LOCK", "UNLOCK", "BEGIN", "COMMIT", "ROLLBACK",
|
||
}
|
||
for _, word := range forbidden {
|
||
if hasSQLWord(stripped, word) {
|
||
return fmt.Errorf("SQL 包含禁止关键词: %s", word)
|
||
}
|
||
}
|
||
risky := []string{"INTO OUTFILE", "INTO DUMPFILE", "LOAD_FILE", "SLEEP", "BENCHMARK", "LOAD_EXTENSION"}
|
||
for _, phrase := range risky {
|
||
if strings.Contains(stripped, phrase) {
|
||
return fmt.Errorf("SQL 包含禁止函数或语法: %s", phrase)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func defaultConfig() Config {
|
||
return Config{
|
||
Enabled: false,
|
||
ActivationPrompt: defaultActivationPrompt,
|
||
DefaultDatabase: defaultDatabaseName,
|
||
SchemaCacheSeconds: defaultSchemaCacheSecond,
|
||
Databases: []DatabaseConfig{{
|
||
Name: defaultDatabaseName,
|
||
Active: true,
|
||
Driver: "sqlite",
|
||
DSN: defaultSQLiteDSN,
|
||
Timeout: defaultTimeout,
|
||
MaxRows: defaultMaxRows,
|
||
MaxCellBytes: defaultMaxCellBytes,
|
||
}},
|
||
}
|
||
}
|
||
|
||
func normalizeConfig(cfg *Config) error {
|
||
if strings.TrimSpace(cfg.ActivationPrompt) == "" {
|
||
cfg.ActivationPrompt = defaultActivationPrompt
|
||
}
|
||
if strings.TrimSpace(cfg.DefaultDatabase) == "" {
|
||
cfg.DefaultDatabase = defaultDatabaseName
|
||
}
|
||
if cfg.SchemaCacheSeconds <= 0 {
|
||
cfg.SchemaCacheSeconds = defaultSchemaCacheSecond
|
||
}
|
||
if len(cfg.Databases) == 0 {
|
||
cfg.Databases = defaultConfig().Databases
|
||
}
|
||
|
||
seen := map[string]bool{}
|
||
activeIndex := -1
|
||
for i := range cfg.Databases {
|
||
item := &cfg.Databases[i]
|
||
item.Name = strings.TrimSpace(item.Name)
|
||
if item.Name == "" {
|
||
item.Name = fmt.Sprintf("database-%d", i+1)
|
||
}
|
||
if seen[item.Name] {
|
||
return fmt.Errorf("SQL 查询插件数据库名称重复: %s", item.Name)
|
||
}
|
||
seen[item.Name] = true
|
||
|
||
item.Driver = strings.ToLower(strings.TrimSpace(item.Driver))
|
||
if item.Driver == "" {
|
||
item.Driver = "sqlite"
|
||
}
|
||
if item.Driver != "sqlite" && item.Driver != "mysql" {
|
||
return fmt.Errorf("SQL 查询插件暂不支持数据库类型: %s", item.Driver)
|
||
}
|
||
if strings.TrimSpace(item.DSN) == "" {
|
||
return fmt.Errorf("数据库 %s 缺少 dsn", item.Name)
|
||
}
|
||
if item.Timeout <= 0 {
|
||
item.Timeout = defaultTimeout
|
||
}
|
||
if item.MaxRows <= 0 {
|
||
item.MaxRows = defaultMaxRows
|
||
}
|
||
if item.MaxCellBytes <= 0 {
|
||
item.MaxCellBytes = defaultMaxCellBytes
|
||
}
|
||
item.Schema.IncludeTables = cleanList(item.Schema.IncludeTables)
|
||
item.Schema.ExcludeTables = cleanList(item.Schema.ExcludeTables)
|
||
if item.Active {
|
||
if activeIndex == -1 {
|
||
activeIndex = i
|
||
} else {
|
||
item.Active = false
|
||
}
|
||
}
|
||
}
|
||
if activeIndex == -1 {
|
||
cfg.Databases[0].Active = true
|
||
}
|
||
if !seen[cfg.DefaultDatabase] {
|
||
for _, item := range cfg.Databases {
|
||
if item.Active {
|
||
cfg.DefaultDatabase = item.Name
|
||
break
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func driverName(driver string) string {
|
||
if driver == "sqlite" {
|
||
return "sqlite"
|
||
}
|
||
return driver
|
||
}
|
||
|
||
func (s *State) databaseByName(name string) *database {
|
||
name = strings.TrimSpace(name)
|
||
if name == "" {
|
||
name = s.DefaultDatabase()
|
||
}
|
||
if db := s.dbs[name]; db != nil {
|
||
return db
|
||
}
|
||
return s.dbs[s.DefaultDatabase()]
|
||
}
|
||
|
||
func (d *database) schemaContext(ctx context.Context) (string, error) {
|
||
timeout := time.Duration(d.cfg.Timeout) * time.Second
|
||
schemaCtx, cancel := context.WithTimeout(ctx, timeout)
|
||
defer cancel()
|
||
if d.cfg.Driver == "mysql" {
|
||
return d.mysqlSchemaContext(schemaCtx)
|
||
}
|
||
return d.sqliteSchemaContext(schemaCtx)
|
||
}
|
||
|
||
func (d *database) sqliteSchemaContext(ctx context.Context) (string, error) {
|
||
rows, err := d.db.QueryContext(ctx, `SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name`)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取 SQLite 表列表失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
var tables []string
|
||
for rows.Next() {
|
||
var name string
|
||
if err := rows.Scan(&name); err != nil {
|
||
return "", err
|
||
}
|
||
if d.tableAllowed(name) {
|
||
tables = append(tables, name)
|
||
}
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return "", err
|
||
}
|
||
|
||
var b strings.Builder
|
||
for _, table := range tables {
|
||
fmt.Fprintf(&b, "- 表 %s\n", table)
|
||
colRows, err := d.db.QueryContext(ctx, "PRAGMA table_info("+quoteSQLiteString(table)+")")
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取 SQLite 表 %s 字段失败: %w", table, err)
|
||
}
|
||
for colRows.Next() {
|
||
var cid int
|
||
var name, typ string
|
||
var notNull int
|
||
var defaultValue any
|
||
var pk int
|
||
if err := colRows.Scan(&cid, &name, &typ, ¬Null, &defaultValue, &pk); err != nil {
|
||
colRows.Close()
|
||
return "", err
|
||
}
|
||
extra := ""
|
||
if pk > 0 {
|
||
extra = " primary_key"
|
||
}
|
||
fmt.Fprintf(&b, " - %s %s%s\n", name, typ, extra)
|
||
}
|
||
if err := colRows.Close(); err != nil {
|
||
return "", err
|
||
}
|
||
}
|
||
if len(tables) == 0 {
|
||
b.WriteString("- 没有可查询表,或表被 include/exclude 规则过滤。\n")
|
||
}
|
||
return b.String(), nil
|
||
}
|
||
|
||
func (d *database) mysqlSchemaContext(ctx context.Context) (string, error) {
|
||
rows, err := d.db.QueryContext(ctx, `SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE() AND table_type = 'BASE TABLE' ORDER BY table_name`)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取 MySQL 表列表失败: %w", err)
|
||
}
|
||
defer rows.Close()
|
||
|
||
var tables []string
|
||
for rows.Next() {
|
||
var name string
|
||
if err := rows.Scan(&name); err != nil {
|
||
return "", err
|
||
}
|
||
if d.tableAllowed(name) {
|
||
tables = append(tables, name)
|
||
}
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return "", err
|
||
}
|
||
|
||
var b strings.Builder
|
||
for _, table := range tables {
|
||
fmt.Fprintf(&b, "- 表 %s\n", table)
|
||
colRows, err := d.db.QueryContext(ctx, `SELECT column_name, data_type, is_nullable, column_key FROM information_schema.columns WHERE table_schema = DATABASE() AND table_name = ? ORDER BY ordinal_position`, table)
|
||
if err != nil {
|
||
return "", fmt.Errorf("读取 MySQL 表 %s 字段失败: %w", table, err)
|
||
}
|
||
for colRows.Next() {
|
||
var name, typ, nullable, key string
|
||
if err := colRows.Scan(&name, &typ, &nullable, &key); err != nil {
|
||
colRows.Close()
|
||
return "", err
|
||
}
|
||
extra := ""
|
||
if key == "PRI" {
|
||
extra = " primary_key"
|
||
}
|
||
if nullable == "NO" {
|
||
extra += " not_null"
|
||
}
|
||
fmt.Fprintf(&b, " - %s %s%s\n", name, typ, extra)
|
||
}
|
||
if err := colRows.Close(); err != nil {
|
||
return "", err
|
||
}
|
||
}
|
||
if len(tables) == 0 {
|
||
b.WriteString("- 没有可查询表,或表被 include/exclude 规则过滤。\n")
|
||
}
|
||
return b.String(), nil
|
||
}
|
||
|
||
func (d *database) tableAllowed(table string) bool {
|
||
name := strings.ToLower(strings.TrimSpace(table))
|
||
include := lowerSet(d.cfg.Schema.IncludeTables)
|
||
if len(include) > 0 && !include[name] {
|
||
return false
|
||
}
|
||
exclude := lowerSet(d.cfg.Schema.ExcludeTables)
|
||
return !exclude[name]
|
||
}
|
||
|
||
func (d *database) rejectExcludedTables(query string) error {
|
||
cleaned := strings.ToLower(stripSingleQuotedStrings(query))
|
||
include := lowerSet(d.cfg.Schema.IncludeTables)
|
||
if len(include) > 0 {
|
||
matched := false
|
||
for table := range include {
|
||
if hasSQLWord(cleaned, table) {
|
||
matched = true
|
||
break
|
||
}
|
||
}
|
||
if !matched {
|
||
return errors.New("SQL 未访问 include_tables 中允许的表")
|
||
}
|
||
}
|
||
exclude := lowerSet(d.cfg.Schema.ExcludeTables)
|
||
for table := range exclude {
|
||
if hasSQLWord(cleaned, table) {
|
||
return fmt.Errorf("SQL 访问了被排除的表: %s", table)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func scanRows(rows *sql.Rows, cfg DatabaseConfig, query string) (*QueryResult, error) {
|
||
defer rows.Close()
|
||
columns, err := rows.Columns()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取查询列失败: %w", err)
|
||
}
|
||
result := &QueryResult{
|
||
Database: cfg.Name,
|
||
SQL: query,
|
||
Columns: columns,
|
||
MaxRows: cfg.MaxRows,
|
||
}
|
||
for rows.Next() {
|
||
values := make([]any, len(columns))
|
||
ptrs := make([]any, len(columns))
|
||
for i := range values {
|
||
ptrs[i] = &values[i]
|
||
}
|
||
if err := rows.Scan(ptrs...); err != nil {
|
||
return nil, fmt.Errorf("读取查询结果失败: %w", err)
|
||
}
|
||
if len(result.Rows) >= cfg.MaxRows {
|
||
result.Truncated = true
|
||
break
|
||
}
|
||
row := make([]string, len(columns))
|
||
for i, value := range values {
|
||
row[i] = formatCell(value, cfg.MaxCellBytes)
|
||
}
|
||
result.Rows = append(result.Rows, row)
|
||
}
|
||
if err := rows.Err(); err != nil {
|
||
return nil, fmt.Errorf("读取查询结果失败: %w", err)
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
func formatCell(value any, maxBytes int) string {
|
||
if value == nil {
|
||
return "NULL"
|
||
}
|
||
switch v := value.(type) {
|
||
case time.Time:
|
||
return v.Format(time.RFC3339)
|
||
case []byte:
|
||
if !utf8.Valid(v) {
|
||
return fmt.Sprintf("<binary %d bytes>", len(v))
|
||
}
|
||
return truncateString(string(v), maxBytes)
|
||
case string:
|
||
return truncateString(v, maxBytes)
|
||
default:
|
||
return truncateString(fmt.Sprint(v), maxBytes)
|
||
}
|
||
}
|
||
|
||
func truncateString(s string, maxBytes int) string {
|
||
if maxBytes <= 0 || len(s) <= maxBytes {
|
||
return s
|
||
}
|
||
cut := s[:maxBytes]
|
||
for !utf8.ValidString(cut) && len(cut) > 0 {
|
||
cut = cut[:len(cut)-1]
|
||
}
|
||
return cut + "...<truncated>"
|
||
}
|
||
|
||
func markdownTable(columns []string, rows [][]string) string {
|
||
var b strings.Builder
|
||
b.WriteString("| ")
|
||
for i, col := range columns {
|
||
if i > 0 {
|
||
b.WriteString(" | ")
|
||
}
|
||
b.WriteString(escapeMarkdownCell(col))
|
||
}
|
||
b.WriteString(" |\n| ")
|
||
for i := range columns {
|
||
if i > 0 {
|
||
b.WriteString(" | ")
|
||
}
|
||
b.WriteString("---")
|
||
}
|
||
b.WriteString(" |\n")
|
||
for _, row := range rows {
|
||
b.WriteString("| ")
|
||
for i, cell := range row {
|
||
if i > 0 {
|
||
b.WriteString(" | ")
|
||
}
|
||
b.WriteString(escapeMarkdownCell(cell))
|
||
}
|
||
b.WriteString(" |\n")
|
||
}
|
||
return b.String()
|
||
}
|
||
|
||
func escapeMarkdownCell(s string) string {
|
||
s = strings.ReplaceAll(s, "|", "\\|")
|
||
s = strings.ReplaceAll(s, "\r\n", " ")
|
||
s = strings.ReplaceAll(s, "\n", " ")
|
||
s = strings.ReplaceAll(s, "\r", " ")
|
||
return s
|
||
}
|
||
|
||
func firstToken(s string) string {
|
||
fields := strings.Fields(s)
|
||
if len(fields) == 0 {
|
||
return ""
|
||
}
|
||
return strings.Trim(fields[0], "();")
|
||
}
|
||
|
||
func stripSingleQuotedStrings(s string) string {
|
||
var b strings.Builder
|
||
inString := false
|
||
for i := 0; i < len(s); i++ {
|
||
ch := s[i]
|
||
if ch == '\'' {
|
||
if inString && i+1 < len(s) && s[i+1] == '\'' {
|
||
i++
|
||
continue
|
||
}
|
||
inString = !inString
|
||
b.WriteByte(' ')
|
||
continue
|
||
}
|
||
if inString {
|
||
b.WriteByte(' ')
|
||
} else {
|
||
b.WriteByte(ch)
|
||
}
|
||
}
|
||
return b.String()
|
||
}
|
||
|
||
func hasSQLWord(s string, word string) bool {
|
||
pattern := `(?i)(^|[^a-zA-Z0-9_])` + regexp.QuoteMeta(word) + `([^a-zA-Z0-9_]|$)`
|
||
return regexp.MustCompile(pattern).FindStringIndex(s) != nil
|
||
}
|
||
|
||
func quoteSQLiteString(s string) string {
|
||
return "'" + strings.ReplaceAll(s, "'", "''") + "'"
|
||
}
|
||
|
||
func lowerSet(items []string) map[string]bool {
|
||
set := map[string]bool{}
|
||
for _, item := range items {
|
||
item = strings.ToLower(strings.TrimSpace(item))
|
||
if item != "" {
|
||
set[item] = true
|
||
}
|
||
}
|
||
return set
|
||
}
|
||
|
||
func cleanList(items []string) []string {
|
||
seen := map[string]bool{}
|
||
var cleaned []string
|
||
for _, item := range items {
|
||
item = strings.TrimSpace(item)
|
||
if item == "" {
|
||
continue
|
||
}
|
||
key := strings.ToLower(item)
|
||
if seen[key] {
|
||
continue
|
||
}
|
||
seen[key] = true
|
||
cleaned = append(cleaned, item)
|
||
}
|
||
sort.Strings(cleaned)
|
||
return cleaned
|
||
}
|