230 lines
5.4 KiB
Go
230 lines
5.4 KiB
Go
package database
|
||
|
||
import (
|
||
"database/sql"
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
|
||
"simple_portal/config"
|
||
|
||
_ "github.com/go-sql-driver/mysql"
|
||
_ "modernc.org/sqlite"
|
||
)
|
||
|
||
// DB is the global database connection pointer.
|
||
var DB *sql.DB
|
||
|
||
// InitDB 根据配置初始化数据库,支持 SQLite 和 MySQL
|
||
func InitDB() error {
|
||
dbType := config.Cfg.Database.Type
|
||
dataDir := config.Cfg.Data.Dir
|
||
|
||
switch dbType {
|
||
case "mysql":
|
||
return initMySQL()
|
||
case "sqlite":
|
||
return initSQLite(dataDir)
|
||
default:
|
||
return fmt.Errorf("不支持的数据库类型: %s(可选: sqlite, mysql)", dbType)
|
||
}
|
||
}
|
||
|
||
// initSQLite 初始化 SQLite 数据库
|
||
func initSQLite(dataDir string) error {
|
||
dbPath := config.GetDBPath()
|
||
|
||
// 创建数据目录
|
||
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
|
||
return fmt.Errorf("创建数据目录失败: %w", err)
|
||
}
|
||
|
||
var err error
|
||
DB, err = sql.Open("sqlite", dbPath)
|
||
if err != nil {
|
||
return fmt.Errorf("打开 SQLite 数据库失败: %w", err)
|
||
}
|
||
|
||
// SQLite 需要限制最大连接数为 1 以保证写安全
|
||
DB.SetMaxOpenConns(1)
|
||
|
||
// 启用 WAL 模式提升并发读性能
|
||
if _, err := DB.Exec("PRAGMA journal_mode=WAL"); err != nil {
|
||
return fmt.Errorf("设置 WAL 模式失败: %w", err)
|
||
}
|
||
|
||
// 创建表
|
||
if err := createTables(); err != nil {
|
||
return fmt.Errorf("创建数据表失败: %w", err)
|
||
}
|
||
|
||
// 初始化默认数据
|
||
if err := seedData(); err != nil {
|
||
return fmt.Errorf("初始化默认数据失败: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// initMySQL 初始化 MySQL 数据库
|
||
func initMySQL() error {
|
||
dsn := config.Cfg.Database.MySQL.GetDSN()
|
||
|
||
var err error
|
||
DB, err = sql.Open("mysql", dsn)
|
||
if err != nil {
|
||
return fmt.Errorf("打开 MySQL 数据库失败: %w", err)
|
||
}
|
||
|
||
// 测试连接
|
||
if err := DB.Ping(); err != nil {
|
||
return fmt.Errorf("连接 MySQL 失败: %w", err)
|
||
}
|
||
|
||
// MySQL 不限制最大连接数(使用默认池)
|
||
DB.SetMaxOpenConns(0)
|
||
|
||
// 创建表
|
||
if err := createTables(); err != nil {
|
||
return fmt.Errorf("创建数据表失败: %w", err)
|
||
}
|
||
|
||
// 初始化默认数据
|
||
if err := seedData(); err != nil {
|
||
return fmt.Errorf("初始化默认数据失败: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// createTables 创建所有数据表
|
||
func createTables() error {
|
||
_, err := DB.Exec(`
|
||
CREATE TABLE IF NOT EXISTS cards (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
icon TEXT,
|
||
title TEXT NOT NULL,
|
||
subtitle TEXT,
|
||
url TEXT NOT NULL,
|
||
sort INTEGER DEFAULT 0,
|
||
enabled INTEGER DEFAULT 1,
|
||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
CREATE TABLE IF NOT EXISTS settings (
|
||
key TEXT PRIMARY KEY,
|
||
value TEXT
|
||
);
|
||
CREATE TABLE IF NOT EXISTS admins (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
username TEXT UNIQUE NOT NULL,
|
||
password TEXT NOT NULL
|
||
);
|
||
CREATE TABLE IF NOT EXISTS login_logs (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
admin_id INTEGER,
|
||
username TEXT NOT NULL,
|
||
ip TEXT NOT NULL,
|
||
user_agent TEXT,
|
||
success INTEGER NOT NULL DEFAULT 0,
|
||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
CREATE TABLE IF NOT EXISTS ip_bans (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
ip TEXT NOT NULL,
|
||
reason TEXT,
|
||
fail_count INTEGER DEFAULT 0,
|
||
banned_until DATETIME NOT NULL,
|
||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
CREATE TABLE IF NOT EXISTS ip_whitelist (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
ip TEXT NOT NULL,
|
||
comment TEXT,
|
||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
CREATE TABLE IF NOT EXISTS access_logs (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
ip TEXT NOT NULL,
|
||
user_agent TEXT,
|
||
action_type TEXT NOT NULL,
|
||
detail TEXT DEFAULT '',
|
||
referer TEXT DEFAULT '',
|
||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
CREATE INDEX IF NOT EXISTS idx_access_logs_ip ON access_logs(ip);
|
||
CREATE INDEX IF NOT EXISTS idx_access_logs_action_type ON access_logs(action_type);
|
||
CREATE INDEX IF NOT EXISTS idx_access_logs_created_at ON access_logs(created_at);
|
||
`)
|
||
return err
|
||
}
|
||
|
||
// seedData 插入默认管理员和搜索引擎配置
|
||
func seedData() error {
|
||
// 插入默认管理员
|
||
var count int
|
||
err := DB.QueryRow("SELECT COUNT(*) FROM admins WHERE username = ?", "admin").Scan(&count)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if count == 0 {
|
||
_, err = DB.Exec(
|
||
"INSERT INTO admins (username, password) VALUES (?, ?)",
|
||
"admin",
|
||
"$2a$10$h3Csm2HmWUtvim3MJ8VG0OHx/tevZorlUXQVDtN2EgWhROtiM3Sg.", // bcrypt hash for "admin123"
|
||
)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// 插入默认搜索引擎设置
|
||
var settingCount int
|
||
err = DB.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "search_engine").Scan(&settingCount)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if settingCount == 0 {
|
||
_, err = DB.Exec(
|
||
"INSERT INTO settings (key, value) VALUES (?, ?)",
|
||
"search_engine",
|
||
"https://www.google.com/search?q=%s",
|
||
)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
// 插入默认主页设置
|
||
defaultSettings := []struct {
|
||
key string
|
||
value string
|
||
}{
|
||
{"homepage_title", "Portal"},
|
||
{"homepage_subtitle", ""},
|
||
{"homepage_background", ""},
|
||
}
|
||
for _, s := range defaultSettings {
|
||
var cnt int
|
||
err = DB.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", s.key).Scan(&cnt)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if cnt == 0 {
|
||
_, err = DB.Exec("INSERT INTO settings (key, value) VALUES (?, ?)", s.key, s.value)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// CloseDB 关闭数据库连接
|
||
func CloseDB() error {
|
||
if DB != nil {
|
||
return DB.Close()
|
||
}
|
||
return nil
|
||
}
|