package database import ( "database/sql" "fmt" "os" "path/filepath" "simple_portal/config" _ "github.com/go-sql-driver/mysql" _ "modernc.org/sqlite" ) // DB is the global database connection pointer. var DB *sql.DB // InitDB 根据配置初始化数据库,支持 SQLite 和 MySQL func InitDB() error { dbType := config.Cfg.Database.Type dataDir := config.Cfg.Data.Dir switch dbType { case "mysql": return initMySQL() case "sqlite": return initSQLite(dataDir) default: return fmt.Errorf("不支持的数据库类型: %s(可选: sqlite, mysql)", dbType) } } // initSQLite 初始化 SQLite 数据库 func initSQLite(dataDir string) error { dbPath := config.GetDBPath() // 创建数据目录 if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil { return fmt.Errorf("创建数据目录失败: %w", err) } var err error DB, err = sql.Open("sqlite", dbPath) if err != nil { return fmt.Errorf("打开 SQLite 数据库失败: %w", err) } // SQLite 需要限制最大连接数为 1 以保证写安全 DB.SetMaxOpenConns(1) // 启用 WAL 模式提升并发读性能 if _, err := DB.Exec("PRAGMA journal_mode=WAL"); err != nil { return fmt.Errorf("设置 WAL 模式失败: %w", err) } // 创建表 if err := createTables(); err != nil { return fmt.Errorf("创建数据表失败: %w", err) } // 初始化默认数据 if err := seedData(); err != nil { return fmt.Errorf("初始化默认数据失败: %w", err) } return nil } // initMySQL 初始化 MySQL 数据库 func initMySQL() error { dsn := config.Cfg.Database.MySQL.GetDSN() var err error DB, err = sql.Open("mysql", dsn) if err != nil { return fmt.Errorf("打开 MySQL 数据库失败: %w", err) } // 测试连接 if err := DB.Ping(); err != nil { return fmt.Errorf("连接 MySQL 失败: %w", err) } // MySQL 不限制最大连接数(使用默认池) DB.SetMaxOpenConns(0) // 创建表 if err := createTables(); err != nil { return fmt.Errorf("创建数据表失败: %w", err) } // 初始化默认数据 if err := seedData(); err != nil { return fmt.Errorf("初始化默认数据失败: %w", err) } return nil } // createTables 创建所有数据表 func createTables() error { _, err := DB.Exec(` CREATE TABLE IF NOT EXISTS cards ( id INTEGER PRIMARY KEY AUTOINCREMENT, icon TEXT, title TEXT NOT NULL, subtitle TEXT, url TEXT NOT NULL, sort INTEGER DEFAULT 0, enabled INTEGER DEFAULT 1, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS settings ( key TEXT PRIMARY KEY, value TEXT ); CREATE TABLE IF NOT EXISTS admins ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT UNIQUE NOT NULL, password TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS login_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, admin_id INTEGER, username TEXT NOT NULL, ip TEXT NOT NULL, user_agent TEXT, success INTEGER NOT NULL DEFAULT 0, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS ip_bans ( id INTEGER PRIMARY KEY AUTOINCREMENT, ip TEXT NOT NULL, reason TEXT, fail_count INTEGER DEFAULT 0, banned_until DATETIME NOT NULL, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS ip_whitelist ( id INTEGER PRIMARY KEY AUTOINCREMENT, ip TEXT NOT NULL, comment TEXT, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS access_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, ip TEXT NOT NULL, user_agent TEXT, action_type TEXT NOT NULL, detail TEXT DEFAULT '', referer TEXT DEFAULT '', created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX IF NOT EXISTS idx_access_logs_ip ON access_logs(ip); CREATE INDEX IF NOT EXISTS idx_access_logs_action_type ON access_logs(action_type); CREATE INDEX IF NOT EXISTS idx_access_logs_created_at ON access_logs(created_at); `) return err } // seedData 插入默认管理员和搜索引擎配置 func seedData() error { // 插入默认管理员 var count int err := DB.QueryRow("SELECT COUNT(*) FROM admins WHERE username = ?", "admin").Scan(&count) if err != nil { return err } if count == 0 { _, err = DB.Exec( "INSERT INTO admins (username, password) VALUES (?, ?)", "admin", "$2a$10$h3Csm2HmWUtvim3MJ8VG0OHx/tevZorlUXQVDtN2EgWhROtiM3Sg.", // bcrypt hash for "admin123" ) if err != nil { return err } } // 插入默认搜索引擎设置 var settingCount int err = DB.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "search_engine").Scan(&settingCount) if err != nil { return err } if settingCount == 0 { _, err = DB.Exec( "INSERT INTO settings (key, value) VALUES (?, ?)", "search_engine", "https://www.google.com/search?q=%s", ) if err != nil { return err } } // 插入默认主页设置 defaultSettings := []struct { key string value string }{ {"homepage_title", "Portal"}, {"homepage_subtitle", ""}, {"homepage_background", ""}, } for _, s := range defaultSettings { var cnt int err = DB.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", s.key).Scan(&cnt) if err != nil { return err } if cnt == 0 { _, err = DB.Exec("INSERT INTO settings (key, value) VALUES (?, ?)", s.key, s.value) if err != nil { return err } } } return nil } // CloseDB 关闭数据库连接 func CloseDB() error { if DB != nil { return DB.Close() } return nil }