package sqlquery import ( "context" "database/sql" "errors" "fmt" "os" "path/filepath" "regexp" "sort" "strings" "sync" "time" "unicode/utf8" _ "github.com/go-sql-driver/mysql" "gopkg.in/yaml.v3" _ "modernc.org/sqlite" ) const ( defaultActivationPrompt = `判断用户问题是否需要查询业务数据库。 仅当用户询问数据库表、记录、字段、时间、状态、内容、统计、最近/最早/某时间范围内的数据时返回 activate=true。 当用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,必须返回 activate=true,并说明应查询 tab_calendar_events 表。 普通知识问答、代码问题、闲聊、联网搜索问题返回 activate=false。 只返回 JSON: {"activate": true/false, "reason": "..."}` defaultDatabaseName = "default" defaultSQLiteDSN = "file:data/app.db?mode=ro" defaultTimeout = 10 defaultMaxRows = 50 defaultMaxCellBytes = 4096 defaultSchemaCacheSecond = 300 ) type Config struct { Enabled bool `yaml:"enabled" json:"enabled"` ActivationPrompt string `yaml:"activation_prompt" json:"activation_prompt"` DefaultDatabase string `yaml:"default_database" json:"default_database"` SchemaCacheSeconds int `yaml:"schema_cache_seconds" json:"schema_cache_seconds"` Databases []DatabaseConfig `yaml:"databases" json:"databases"` } type DatabaseConfig struct { Name string `yaml:"name" json:"name"` Active bool `yaml:"active" json:"active"` Driver string `yaml:"driver" json:"driver"` DSN string `yaml:"dsn" json:"-"` Timeout int `yaml:"timeout" json:"timeout"` MaxRows int `yaml:"max_rows" json:"max_rows"` MaxCellBytes int `yaml:"max_cell_bytes" json:"max_cell_bytes"` Schema SchemaConfig `yaml:"schema" json:"schema"` } type SchemaConfig struct { IncludeTables []string `yaml:"include_tables" json:"include_tables"` ExcludeTables []string `yaml:"exclude_tables" json:"exclude_tables"` } type State struct { cfg *Config dbs map[string]*database order []string cacheMu sync.Mutex cacheText string cacheAt time.Time } type database struct { cfg DatabaseConfig db *sql.DB } type QueryResult struct { Database string `json:"database"` SQL string `json:"sql"` Columns []string `json:"columns"` Rows [][]string `json:"rows"` Truncated bool `json:"truncated"` MaxRows int `json:"max_rows"` } func LoadConfig(path string) (*Config, error) { if _, err := os.Stat(path); err != nil { if !os.IsNotExist(err) { return nil, fmt.Errorf("检查 SQL 查询插件配置失败: %w", err) } cfg := defaultConfig() if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { return nil, fmt.Errorf("创建 SQL 查询插件目录失败: %w", err) } data, err := yaml.Marshal(&cfg) if err != nil { return nil, fmt.Errorf("生成 SQL 查询插件配置失败: %w", err) } if err := os.WriteFile(path, data, 0644); err != nil { return nil, fmt.Errorf("写入 SQL 查询插件配置失败: %w", err) } } data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("读取 SQL 查询插件配置失败: %w", err) } var cfg Config if err := yaml.Unmarshal(data, &cfg); err != nil { return nil, fmt.Errorf("解析 SQL 查询插件配置失败: %w", err) } if err := normalizeConfig(&cfg); err != nil { return nil, err } return &cfg, nil } func NewState(cfg *Config) (*State, error) { state := &State{cfg: cfg, dbs: map[string]*database{}} if cfg == nil || !cfg.Enabled { return state, nil } for _, item := range cfg.Databases { db, err := sql.Open(driverName(item.Driver), item.DSN) if err != nil { return nil, fmt.Errorf("打开数据库 %s 失败: %w", item.Name, err) } if item.Driver == "sqlite" { db.SetMaxOpenConns(1) db.SetMaxIdleConns(1) } else { db.SetMaxOpenConns(5) db.SetMaxIdleConns(2) } db.SetConnMaxLifetime(30 * time.Minute) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(item.Timeout)*time.Second) if err := db.PingContext(ctx); err != nil { cancel() db.Close() return nil, fmt.Errorf("连接数据库 %s 失败: %w", item.Name, err) } if item.Driver == "sqlite" { if _, err := db.ExecContext(ctx, "PRAGMA query_only = ON"); err != nil { cancel() db.Close() return nil, fmt.Errorf("启用 SQLite 只读模式失败: %w", err) } } cancel() state.dbs[item.Name] = &database{cfg: item, db: db} state.order = append(state.order, item.Name) } return state, nil } func (s *State) Close() error { if s == nil { return nil } var errs []string for _, db := range s.dbs { if err := db.db.Close(); err != nil { errs = append(errs, err.Error()) } } if len(errs) > 0 { return errors.New(strings.Join(errs, "; ")) } return nil } func (s *State) Enabled() bool { return s != nil && s.cfg != nil && s.cfg.Enabled && len(s.dbs) > 0 } func (s *State) ActivationPrompt() string { if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.ActivationPrompt) == "" { return defaultActivationPrompt } return strings.TrimSpace(s.cfg.ActivationPrompt) } func (s *State) DefaultDatabase() string { if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.DefaultDatabase) == "" { return defaultDatabaseName } return strings.TrimSpace(s.cfg.DefaultDatabase) } func (s *State) SchemaContext(ctx context.Context) (string, error) { if !s.Enabled() { return "", errors.New("SQL 查询插件未启用") } s.cacheMu.Lock() if s.cacheText != "" && time.Since(s.cacheAt) < time.Duration(s.cfg.SchemaCacheSeconds)*time.Second { text := s.cacheText s.cacheMu.Unlock() return text, nil } s.cacheMu.Unlock() var b strings.Builder fmt.Fprintf(&b, "可查询数据库列表(只能生成 SELECT/WITH 查询):\n") for _, name := range s.order { handle := s.dbs[name] fmt.Fprintf(&b, "\n数据库 %s,类型 %s,单次最多返回 %d 行:\n", handle.cfg.Name, handle.cfg.Driver, handle.cfg.MaxRows) schema, err := handle.schemaContext(ctx) if err != nil { return "", err } b.WriteString(schema) } text := b.String() s.cacheMu.Lock() s.cacheText = text s.cacheAt = time.Now() s.cacheMu.Unlock() return text, nil } func (s *State) ExecuteReadOnly(ctx context.Context, databaseName string, query string) (*QueryResult, error) { if !s.Enabled() { return nil, errors.New("SQL 查询插件未启用") } if err := ValidateReadOnlySQL(query); err != nil { return nil, err } handle := s.databaseByName(databaseName) if handle == nil { return nil, fmt.Errorf("数据库配置不存在: %s", databaseName) } if err := handle.rejectExcludedTables(query); err != nil { return nil, err } timeout := time.Duration(handle.cfg.Timeout) * time.Second queryCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() if handle.cfg.Driver == "mysql" { tx, err := handle.db.BeginTx(queryCtx, &sql.TxOptions{ReadOnly: true}) if err != nil { return nil, fmt.Errorf("开启只读事务失败: %w", err) } defer tx.Rollback() rows, err := tx.QueryContext(queryCtx, query) if err != nil { return nil, fmt.Errorf("执行 SQL 查询失败: %w", err) } result, err := scanRows(rows, handle.cfg, query) if err != nil { return nil, err } if err := tx.Commit(); err != nil { return nil, fmt.Errorf("提交只读事务失败: %w", err) } return result, nil } rows, err := handle.db.QueryContext(queryCtx, query) if err != nil { return nil, fmt.Errorf("执行 SQL 查询失败: %w", err) } return scanRows(rows, handle.cfg, query) } func BuildResultContext(userQuery string, generatedSQL string, result *QueryResult) string { var b strings.Builder fmt.Fprintf(&b, "用户问题需要查询本地数据库。请仅根据以下 SQL 查询结果回答;不要编造结果中不存在的记录。\n") fmt.Fprintf(&b, "用户问题:%s\n", userQuery) fmt.Fprintf(&b, "数据库:%s\n", result.Database) fmt.Fprintf(&b, "已执行只读 SQL:%s\n", generatedSQL) if len(result.Columns) == 0 { b.WriteString("查询没有返回列。\n") return b.String() } if len(result.Rows) == 0 { b.WriteString("查询结果:没有匹配记录。\n") return b.String() } b.WriteString("查询结果:\n") b.WriteString(markdownTable(result.Columns, result.Rows)) if result.Truncated { fmt.Fprintf(&b, "\n结果已按配置截断,只展示前 %d 行。\n", result.MaxRows) } return b.String() } func BuildErrorContext(userQuery string, err error) string { return fmt.Sprintf("用户问题可能需要查询本地数据库,但 SQL 查询插件执行失败:%s\n用户问题:%s\n请向用户说明无法完成数据库查询,不要编造数据库记录。", err.Error(), userQuery) } func ValidateReadOnlySQL(query string) error { trimmed := strings.TrimSpace(query) if trimmed == "" { return errors.New("SQL 不能为空") } if strings.Contains(trimmed, "--") || strings.Contains(trimmed, "/*") || strings.Contains(trimmed, "*/") { return errors.New("SQL 不允许包含注释") } body := strings.TrimSuffix(trimmed, ";") if strings.Contains(body, ";") { return errors.New("SQL 只允许单条语句") } upper := strings.ToUpper(body) first := firstToken(upper) if first != "SELECT" && first != "WITH" { return fmt.Errorf("SQL 只允许 SELECT/WITH 查询,当前为 %s", first) } stripped := stripSingleQuotedStrings(upper) forbidden := []string{ "INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "TRUNCATE", "REPLACE", "MERGE", "GRANT", "REVOKE", "VACUUM", "ANALYZE", "ATTACH", "DETACH", "LOAD", "CALL", "EXEC", "SET", "USE", "LOCK", "UNLOCK", "BEGIN", "COMMIT", "ROLLBACK", } for _, word := range forbidden { if hasSQLWord(stripped, word) { return fmt.Errorf("SQL 包含禁止关键词: %s", word) } } risky := []string{"INTO OUTFILE", "INTO DUMPFILE", "LOAD_FILE", "SLEEP", "BENCHMARK", "LOAD_EXTENSION"} for _, phrase := range risky { if strings.Contains(stripped, phrase) { return fmt.Errorf("SQL 包含禁止函数或语法: %s", phrase) } } return nil } func defaultConfig() Config { return Config{ Enabled: false, ActivationPrompt: defaultActivationPrompt, DefaultDatabase: defaultDatabaseName, SchemaCacheSeconds: defaultSchemaCacheSecond, Databases: []DatabaseConfig{{ Name: defaultDatabaseName, Active: true, Driver: "sqlite", DSN: defaultSQLiteDSN, Timeout: defaultTimeout, MaxRows: defaultMaxRows, MaxCellBytes: defaultMaxCellBytes, }}, } } func normalizeConfig(cfg *Config) error { if strings.TrimSpace(cfg.ActivationPrompt) == "" { cfg.ActivationPrompt = defaultActivationPrompt } if strings.TrimSpace(cfg.DefaultDatabase) == "" { cfg.DefaultDatabase = defaultDatabaseName } if cfg.SchemaCacheSeconds <= 0 { cfg.SchemaCacheSeconds = defaultSchemaCacheSecond } if len(cfg.Databases) == 0 { cfg.Databases = defaultConfig().Databases } seen := map[string]bool{} activeIndex := -1 for i := range cfg.Databases { item := &cfg.Databases[i] item.Name = strings.TrimSpace(item.Name) if item.Name == "" { item.Name = fmt.Sprintf("database-%d", i+1) } if seen[item.Name] { return fmt.Errorf("SQL 查询插件数据库名称重复: %s", item.Name) } seen[item.Name] = true item.Driver = strings.ToLower(strings.TrimSpace(item.Driver)) if item.Driver == "" { item.Driver = "sqlite" } if item.Driver != "sqlite" && item.Driver != "mysql" { return fmt.Errorf("SQL 查询插件暂不支持数据库类型: %s", item.Driver) } if strings.TrimSpace(item.DSN) == "" { return fmt.Errorf("数据库 %s 缺少 dsn", item.Name) } if item.Timeout <= 0 { item.Timeout = defaultTimeout } if item.MaxRows <= 0 { item.MaxRows = defaultMaxRows } if item.MaxCellBytes <= 0 { item.MaxCellBytes = defaultMaxCellBytes } item.Schema.IncludeTables = cleanList(item.Schema.IncludeTables) item.Schema.ExcludeTables = cleanList(item.Schema.ExcludeTables) if item.Active { if activeIndex == -1 { activeIndex = i } else { item.Active = false } } } if activeIndex == -1 { cfg.Databases[0].Active = true } if !seen[cfg.DefaultDatabase] { for _, item := range cfg.Databases { if item.Active { cfg.DefaultDatabase = item.Name break } } } return nil } func driverName(driver string) string { if driver == "sqlite" { return "sqlite" } return driver } func (s *State) databaseByName(name string) *database { name = strings.TrimSpace(name) if name == "" { name = s.DefaultDatabase() } if db := s.dbs[name]; db != nil { return db } return s.dbs[s.DefaultDatabase()] } func (d *database) schemaContext(ctx context.Context) (string, error) { timeout := time.Duration(d.cfg.Timeout) * time.Second schemaCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() if d.cfg.Driver == "mysql" { return d.mysqlSchemaContext(schemaCtx) } return d.sqliteSchemaContext(schemaCtx) } func (d *database) sqliteSchemaContext(ctx context.Context) (string, error) { rows, err := d.db.QueryContext(ctx, `SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name`) if err != nil { return "", fmt.Errorf("读取 SQLite 表列表失败: %w", err) } defer rows.Close() var tables []string for rows.Next() { var name string if err := rows.Scan(&name); err != nil { return "", err } if d.tableAllowed(name) { tables = append(tables, name) } } if err := rows.Err(); err != nil { return "", err } var b strings.Builder for _, table := range tables { fmt.Fprintf(&b, "- 表 %s\n", table) colRows, err := d.db.QueryContext(ctx, "PRAGMA table_info("+quoteSQLiteString(table)+")") if err != nil { return "", fmt.Errorf("读取 SQLite 表 %s 字段失败: %w", table, err) } for colRows.Next() { var cid int var name, typ string var notNull int var defaultValue any var pk int if err := colRows.Scan(&cid, &name, &typ, ¬Null, &defaultValue, &pk); err != nil { colRows.Close() return "", err } extra := "" if pk > 0 { extra = " primary_key" } fmt.Fprintf(&b, " - %s %s%s\n", name, typ, extra) } if err := colRows.Close(); err != nil { return "", err } } if len(tables) == 0 { b.WriteString("- 没有可查询表,或表被 include/exclude 规则过滤。\n") } return b.String(), nil } func (d *database) mysqlSchemaContext(ctx context.Context) (string, error) { rows, err := d.db.QueryContext(ctx, `SELECT table_name FROM information_schema.tables WHERE table_schema = DATABASE() AND table_type = 'BASE TABLE' ORDER BY table_name`) if err != nil { return "", fmt.Errorf("读取 MySQL 表列表失败: %w", err) } defer rows.Close() var tables []string for rows.Next() { var name string if err := rows.Scan(&name); err != nil { return "", err } if d.tableAllowed(name) { tables = append(tables, name) } } if err := rows.Err(); err != nil { return "", err } var b strings.Builder for _, table := range tables { fmt.Fprintf(&b, "- 表 %s\n", table) colRows, err := d.db.QueryContext(ctx, `SELECT column_name, data_type, is_nullable, column_key FROM information_schema.columns WHERE table_schema = DATABASE() AND table_name = ? ORDER BY ordinal_position`, table) if err != nil { return "", fmt.Errorf("读取 MySQL 表 %s 字段失败: %w", table, err) } for colRows.Next() { var name, typ, nullable, key string if err := colRows.Scan(&name, &typ, &nullable, &key); err != nil { colRows.Close() return "", err } extra := "" if key == "PRI" { extra = " primary_key" } if nullable == "NO" { extra += " not_null" } fmt.Fprintf(&b, " - %s %s%s\n", name, typ, extra) } if err := colRows.Close(); err != nil { return "", err } } if len(tables) == 0 { b.WriteString("- 没有可查询表,或表被 include/exclude 规则过滤。\n") } return b.String(), nil } func (d *database) tableAllowed(table string) bool { name := strings.ToLower(strings.TrimSpace(table)) include := lowerSet(d.cfg.Schema.IncludeTables) if len(include) > 0 && !include[name] { return false } exclude := lowerSet(d.cfg.Schema.ExcludeTables) return !exclude[name] } func (d *database) rejectExcludedTables(query string) error { cleaned := strings.ToLower(stripSingleQuotedStrings(query)) include := lowerSet(d.cfg.Schema.IncludeTables) if len(include) > 0 { matched := false for table := range include { if hasSQLWord(cleaned, table) { matched = true break } } if !matched { return errors.New("SQL 未访问 include_tables 中允许的表") } } exclude := lowerSet(d.cfg.Schema.ExcludeTables) for table := range exclude { if hasSQLWord(cleaned, table) { return fmt.Errorf("SQL 访问了被排除的表: %s", table) } } return nil } func scanRows(rows *sql.Rows, cfg DatabaseConfig, query string) (*QueryResult, error) { defer rows.Close() columns, err := rows.Columns() if err != nil { return nil, fmt.Errorf("读取查询列失败: %w", err) } result := &QueryResult{ Database: cfg.Name, SQL: query, Columns: columns, MaxRows: cfg.MaxRows, } for rows.Next() { values := make([]any, len(columns)) ptrs := make([]any, len(columns)) for i := range values { ptrs[i] = &values[i] } if err := rows.Scan(ptrs...); err != nil { return nil, fmt.Errorf("读取查询结果失败: %w", err) } if len(result.Rows) >= cfg.MaxRows { result.Truncated = true break } row := make([]string, len(columns)) for i, value := range values { row[i] = formatCell(value, cfg.MaxCellBytes) } result.Rows = append(result.Rows, row) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("读取查询结果失败: %w", err) } return result, nil } func formatCell(value any, maxBytes int) string { if value == nil { return "NULL" } switch v := value.(type) { case time.Time: return v.Format(time.RFC3339) case []byte: if !utf8.Valid(v) { return fmt.Sprintf("", len(v)) } return truncateString(string(v), maxBytes) case string: return truncateString(v, maxBytes) default: return truncateString(fmt.Sprint(v), maxBytes) } } func truncateString(s string, maxBytes int) string { if maxBytes <= 0 || len(s) <= maxBytes { return s } cut := s[:maxBytes] for !utf8.ValidString(cut) && len(cut) > 0 { cut = cut[:len(cut)-1] } return cut + "..." } func markdownTable(columns []string, rows [][]string) string { var b strings.Builder b.WriteString("| ") for i, col := range columns { if i > 0 { b.WriteString(" | ") } b.WriteString(escapeMarkdownCell(col)) } b.WriteString(" |\n| ") for i := range columns { if i > 0 { b.WriteString(" | ") } b.WriteString("---") } b.WriteString(" |\n") for _, row := range rows { b.WriteString("| ") for i, cell := range row { if i > 0 { b.WriteString(" | ") } b.WriteString(escapeMarkdownCell(cell)) } b.WriteString(" |\n") } return b.String() } func escapeMarkdownCell(s string) string { s = strings.ReplaceAll(s, "|", "\\|") s = strings.ReplaceAll(s, "\r\n", " ") s = strings.ReplaceAll(s, "\n", " ") s = strings.ReplaceAll(s, "\r", " ") return s } func firstToken(s string) string { fields := strings.Fields(s) if len(fields) == 0 { return "" } return strings.Trim(fields[0], "();") } func stripSingleQuotedStrings(s string) string { var b strings.Builder inString := false for i := 0; i < len(s); i++ { ch := s[i] if ch == '\'' { if inString && i+1 < len(s) && s[i+1] == '\'' { i++ continue } inString = !inString b.WriteByte(' ') continue } if inString { b.WriteByte(' ') } else { b.WriteByte(ch) } } return b.String() } func hasSQLWord(s string, word string) bool { pattern := `(?i)(^|[^a-zA-Z0-9_])` + regexp.QuoteMeta(word) + `([^a-zA-Z0-9_]|$)` return regexp.MustCompile(pattern).FindStringIndex(s) != nil } func quoteSQLiteString(s string) string { return "'" + strings.ReplaceAll(s, "'", "''") + "'" } func lowerSet(items []string) map[string]bool { set := map[string]bool{} for _, item := range items { item = strings.ToLower(strings.TrimSpace(item)) if item != "" { set[item] = true } } return set } func cleanList(items []string) []string { seen := map[string]bool{} var cleaned []string for _, item := range items { item = strings.TrimSpace(item) if item == "" { continue } key := strings.ToLower(item) if seen[key] { continue } seen[key] = true cleaned = append(cleaned, item) } sort.Strings(cleaned) return cleaned }