diff --git a/.gitignore b/.gitignore index 31bf06e..b8726a6 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,9 @@ portal Thumbs.db Desktop.ini +# 配置文件(运行时自动生成) +/conf/ + # 数据库 /data/*.db /data/*.db-shm diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..de4ac21 --- /dev/null +++ b/config/config.go @@ -0,0 +1,216 @@ +package config + +import ( + "fmt" + "log" + "os" + "path/filepath" + "runtime" + + "github.com/pelletier/go-toml/v2" +) + +// Cfg 是全局配置实例,在 Load() 后可用 +var Cfg *Config + +// Config 是完整的配置结构 +type Config struct { + Data DataConfig `toml:"data"` + Database DatabaseConfig `toml:"database"` + Server ServerConfig `toml:"server"` +} + +// DataConfig 数据存储配置 +type DataConfig struct { + // 数据存储根目录,Windows 默认 "data",Linux 默认 "/srv/portal_page" + Dir string `toml:"dir"` +} + +// DatabaseConfig 数据库配置 +type DatabaseConfig struct { + // 数据库类型: "sqlite" 或 "mysql" + Type string `toml:"type"` + // SQLite 数据库文件名(相对于 data.dir) + Path string `toml:"path"` + // MySQL 配置 + MySQL MySQLConfig `toml:"mysql"` +} + +// MySQLConfig MySQL 连接配置 +type MySQLConfig struct { + Host string `toml:"host"` + Port int `toml:"port"` + User string `toml:"user"` + Password string `toml:"password"` + DBName string `toml:"dbname"` +} + +// ServerConfig Web 服务器配置 +type ServerConfig struct { + // 监听地址,格式 ":8080" + Addr string `toml:"addr"` + // Unix socket 路径(设置后优先使用 unix socket,addr 被忽略) + Unix string `toml:"unix"` +} + +// defaultConfig 返回当前平台的默认配置 +func defaultConfig() *Config { + cfg := &Config{ + Data: DataConfig{ + Dir: "data", + }, + Database: DatabaseConfig{ + Type: "sqlite", + Path: "portal.db", + MySQL: MySQLConfig{ + Host: "127.0.0.1", + Port: 3306, + User: "root", + Password: "", + DBName: "portal_page", + }, + }, + Server: ServerConfig{ + Addr: ":8080", + Unix: "", + }, + } + + // Linux 下遵循 FHS 标准 + if runtime.GOOS == "linux" { + cfg.Data.Dir = "/srv/portal_page" + } + + return cfg +} + +// configPath 返回当前平台的配置文件路径 +func configPath() string { + if runtime.GOOS == "windows" { + // Windows: 相对于可执行文件的 conf/config.toml + exePath, err := os.Executable() + if err != nil { + return "conf/config.toml" + } + return filepath.Join(filepath.Dir(exePath), "conf", "config.toml") + } + // Linux: FHS 标准 /etc/portal_page/config.toml + return "/etc/portal_page/config.toml" +} + +// Load 加载配置文件,不存在则自动生成,缺失项自动补全 +func Load() error { + path := configPath() + def := defaultConfig() + + // 如果配置文件不存在,使用默认配置直接生成 + if _, err := os.Stat(path); os.IsNotExist(err) { + log.Printf("配置文件不存在,自动生成: %s", path) + Cfg = def + return saveConfig(path, Cfg) + } + + // 读取并解析配置文件 + data, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("读取配置文件失败: %w", err) + } + + Cfg = &Config{} + if err := toml.Unmarshal(data, Cfg); err != nil { + return fmt.Errorf("解析配置文件失败: %w", err) + } + + // 补全缺失项 + changed := fillDefaults(Cfg, def) + if changed { + log.Printf("配置文件有缺失项,已自动补全: %s", path) + if err := saveConfig(path, Cfg); err != nil { + return fmt.Errorf("保存补全后的配置文件失败: %w", err) + } + } + + return nil +} + +// fillDefaults 用默认值补全零值字段,返回是否有变更 +func fillDefaults(cfg, def *Config) bool { + changed := false + + // Data + if cfg.Data.Dir == "" { + cfg.Data.Dir = def.Data.Dir + changed = true + } + + // Database + if cfg.Database.Type == "" { + cfg.Database.Type = def.Database.Type + changed = true + } + if cfg.Database.Path == "" { + cfg.Database.Path = def.Database.Path + changed = true + } + if cfg.Database.MySQL.Host == "" { + cfg.Database.MySQL.Host = def.Database.MySQL.Host + changed = true + } + if cfg.Database.MySQL.Port == 0 { + cfg.Database.MySQL.Port = def.Database.MySQL.Port + changed = true + } + if cfg.Database.MySQL.User == "" { + cfg.Database.MySQL.User = def.Database.MySQL.User + changed = true + } + // Password 允许为空字符串,不补全 + if cfg.Database.MySQL.DBName == "" { + cfg.Database.MySQL.DBName = def.Database.MySQL.DBName + changed = true + } + + // Server + if cfg.Server.Addr == "" && cfg.Server.Unix == "" { + cfg.Server.Addr = def.Server.Addr + changed = true + } + + return changed +} + +// saveConfig 将配置写入文件 +func saveConfig(path string, cfg *Config) error { + // 确保目录存在 + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("创建配置目录失败: %w", err) + } + + data, err := toml.Marshal(cfg) + if err != nil { + return fmt.Errorf("序列化配置失败: %w", err) + } + + if err := os.WriteFile(path, data, 0644); err != nil { + return fmt.Errorf("写入配置文件失败: %w", err) + } + + return nil +} + +// GetUploadDir 返回上传文件目录的完整路径 +func GetUploadDir() string { + return filepath.Join(Cfg.Data.Dir, "uploads") +} + +// GetDBPath 返回 SQLite 数据库文件的完整路径 +func GetDBPath() string { + return filepath.Join(Cfg.Data.Dir, Cfg.Database.Path) +} + +// GetDSN 返回 MySQL 的 DSN 连接字符串 +func (m *MySQLConfig) GetDSN() string { + return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local", + m.User, m.Password, m.Host, m.Port, m.DBName) +} diff --git a/database/db.go b/database/db.go index ea967ac..0b88b61 100644 --- a/database/db.go +++ b/database/db.go @@ -6,50 +6,98 @@ import ( "os" "path/filepath" + "simple_portal/config" + + _ "github.com/go-sql-driver/mysql" _ "modernc.org/sqlite" ) // DB is the global database connection pointer. var DB *sql.DB -// InitDB initializes the SQLite database, creates the data directory, -// opens the connection, sets WAL mode, and creates default tables and data. +// InitDB 根据配置初始化数据库,支持 SQLite 和 MySQL func InitDB() error { - dbPath := filepath.Join(".", "data", "portal.db") + dbType := config.Cfg.Database.Type + dataDir := config.Cfg.Data.Dir - // Create data directory if not exists + switch dbType { + case "mysql": + return initMySQL() + case "sqlite": + return initSQLite(dataDir) + default: + return fmt.Errorf("不支持的数据库类型: %s(可选: sqlite, mysql)", dbType) + } +} + +// initSQLite 初始化 SQLite 数据库 +func initSQLite(dataDir string) error { + dbPath := config.GetDBPath() + + // 创建数据目录 if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil { - return fmt.Errorf("failed to create data directory: %w", err) + return fmt.Errorf("创建数据目录失败: %w", err) } var err error DB, err = sql.Open("sqlite", dbPath) if err != nil { - return fmt.Errorf("failed to open database: %w", err) + return fmt.Errorf("打开 SQLite 数据库失败: %w", err) } - // SQLite requires max open conns = 1 for safe writes + // SQLite 需要限制最大连接数为 1 以保证写安全 DB.SetMaxOpenConns(1) - // Enable WAL mode for better concurrent read performance + // 启用 WAL 模式提升并发读性能 if _, err := DB.Exec("PRAGMA journal_mode=WAL"); err != nil { - return fmt.Errorf("failed to set WAL mode: %w", err) + return fmt.Errorf("设置 WAL 模式失败: %w", err) } - // Create tables + // 创建表 if err := createTables(); err != nil { - return fmt.Errorf("failed to create tables: %w", err) + return fmt.Errorf("创建数据表失败: %w", err) } - // Seed default data + // 初始化默认数据 if err := seedData(); err != nil { - return fmt.Errorf("failed to seed data: %w", err) + return fmt.Errorf("初始化默认数据失败: %w", err) } return nil } -// createTables creates the cards, settings, and admins tables. +// initMySQL 初始化 MySQL 数据库 +func initMySQL() error { + dsn := config.Cfg.Database.MySQL.GetDSN() + + var err error + DB, err = sql.Open("mysql", dsn) + if err != nil { + return fmt.Errorf("打开 MySQL 数据库失败: %w", err) + } + + // 测试连接 + if err := DB.Ping(); err != nil { + return fmt.Errorf("连接 MySQL 失败: %w", err) + } + + // MySQL 不限制最大连接数(使用默认池) + DB.SetMaxOpenConns(0) + + // 创建表 + if err := createTables(); err != nil { + return fmt.Errorf("创建数据表失败: %w", err) + } + + // 初始化默认数据 + if err := seedData(); err != nil { + return fmt.Errorf("初始化默认数据失败: %w", err) + } + + return nil +} + +// createTables 创建所有数据表 func createTables() error { _, err := DB.Exec(` CREATE TABLE IF NOT EXISTS cards ( @@ -110,9 +158,9 @@ func createTables() error { return err } -// seedData inserts default admin account and search engine setting if not present. +// seedData 插入默认管理员和搜索引擎配置 func seedData() error { - // Insert default admin if not exists + // 插入默认管理员 var count int err := DB.QueryRow("SELECT COUNT(*) FROM admins WHERE username = ?", "admin").Scan(&count) if err != nil { @@ -129,7 +177,7 @@ func seedData() error { } } - // Insert default search engine setting if not exists + // 插入默认搜索引擎设置 var settingCount int err = DB.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "search_engine").Scan(&settingCount) if err != nil { @@ -146,7 +194,7 @@ func seedData() error { } } - // Insert default homepage settings if not exists + // 插入默认主页设置 defaultSettings := []struct { key string value string @@ -172,7 +220,7 @@ func seedData() error { return nil } -// CloseDB closes the database connection. +// CloseDB 关闭数据库连接 func CloseDB() error { if DB != nil { return DB.Close() diff --git a/go.mod b/go.mod index 5f1cdaa..7c6ca0a 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,19 @@ module simple_portal -go 1.22 +go 1.24.0 require ( github.com/disintegration/imaging v1.6.2 github.com/gin-gonic/gin v1.10.0 + github.com/go-sql-driver/mysql v1.10.0 github.com/google/uuid v1.6.0 + github.com/pelletier/go-toml/v2 v2.2.2 golang.org/x/crypto v0.28.0 modernc.org/sqlite v1.34.5 ) require ( + filippo.io/edwards25519 v1.2.0 // indirect github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect @@ -29,7 +32,6 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect - github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect diff --git a/go.sum b/go.sum index 7ef8a5b..7a3d983 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= @@ -27,6 +29,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/go-sql-driver/mysql v1.10.0 h1:Q+1LV8DkHJvSYAdR83XzuhDaTykuDx0l6fkXxoWCWfw= +github.com/go-sql-driver/mysql v1.10.0/go.mod h1:M+cqaI7+xxXGG9swrdeUIoPG3Y3KCkF0pZej+SK+nWk= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= diff --git a/handlers/upload.go b/handlers/upload.go index 80ca67e..b705000 100644 --- a/handlers/upload.go +++ b/handlers/upload.go @@ -9,6 +9,8 @@ import ( "path/filepath" "strings" + "simple_portal/config" + "github.com/disintegration/imaging" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -24,9 +26,6 @@ var allowedMIMETypes = map[string]string{ // maxUploadSize is the maximum allowed file size in bytes (5 MB). const maxUploadSize = 5 << 20 -// uploadDir is the directory where uploaded files are stored. -const uploadDir = "./data/uploads" - // thumbSuffix is the suffix appended to compressed image filenames. const thumbSuffix = "_thumb" @@ -94,13 +93,13 @@ func UploadHandler(c *gin.Context) { filename := fileUUID + ext // Ensure upload directory exists - if err := os.MkdirAll(uploadDir, 0755); err != nil { + if err := os.MkdirAll(config.GetUploadDir(), 0755); err != nil { c.JSON(500, gin.H{"error": "创建上传目录失败"}) return } // Save original file - originalPath := filepath.Join(uploadDir, filename) + originalPath := filepath.Join(config.GetUploadDir(), filename) dst, err := os.Create(originalPath) if err != nil { c.JSON(500, gin.H{"error": "保存文件失败"}) @@ -142,14 +141,14 @@ func ServeUploadHandler(c *gin.Context) { return } - filePath := filepath.Join(uploadDir, filename) + filePath := filepath.Join(config.GetUploadDir(), filename) // Check if thumb=1 query parameter is requested if c.Query("thumb") == "1" { // Try to serve the thumbnail version ext := filepath.Ext(filename) baseName := filename[:len(filename)-len(ext)] - thumbPath := filepath.Join(uploadDir, baseName+thumbSuffix+".jpg") + thumbPath := filepath.Join(config.GetUploadDir(), baseName+thumbSuffix+".jpg") if _, err := os.Stat(thumbPath); err == nil { c.File(thumbPath) @@ -221,7 +220,7 @@ func generateThumbnail(originalPath, fileUUID, uploadType string) error { thumb := fitImage(img, maxWidth, maxHeight) // Save thumbnail as JPEG - thumbPath := filepath.Join(uploadDir, fileUUID+thumbSuffix+".jpg") + thumbPath := filepath.Join(config.GetUploadDir(), fileUUID+thumbSuffix+".jpg") thumbFile, err := os.Create(thumbPath) if err != nil { return fmt.Errorf("failed to create thumbnail file: %w", err) diff --git a/main.go b/main.go index c792d94..c6f59e2 100644 --- a/main.go +++ b/main.go @@ -3,10 +3,12 @@ package main import ( "html/template" "log" + "net" "os" "path/filepath" "strings" + "simple_portal/config" "simple_portal/database" "simple_portal/handlers" "simple_portal/middleware" @@ -15,8 +17,8 @@ import ( "github.com/gin-gonic/gin" ) -// loadTemplates loads HTML templates from templates/ directory recursively. -// Custom implementation because Go's ParseGlob has issues with directories on Windows. +// loadTemplates 加载 templates/ 目录下所有 HTML 模板 +// 自定义实现,因为 Go 的 ParseGlob 在 Windows 下有路径问题 func loadTemplates() *template.Template { funcMap := template.FuncMap{ "hasPrefix": strings.HasPrefix, @@ -24,7 +26,6 @@ func loadTemplates() *template.Template { "add": func(a, b int) int { return a + b }, } t := template.New("").Funcs(funcMap) - // 收集所有 .html 模板文件路径 var files []string filepath.Walk("templates", func(path string, info os.FileInfo, err error) error { if err != nil { @@ -38,7 +39,6 @@ func loadTemplates() *template.Template { if len(files) == 0 { log.Fatal("No template files found in templates/") } - // 将 Windows 反斜杠路径转为正斜杠,避免模板名问题 for i, f := range files { files[i] = filepath.ToSlash(f) } @@ -51,24 +51,30 @@ func loadTemplates() *template.Template { } func main() { - // Initialize database + // 加载配置文件(自动生成 + 补全缺失项) + if err := config.Load(); err != nil { + log.Fatalf("加载配置失败: %v", err) + } + log.Printf("配置加载成功,数据目录: %s,数据库: %s", config.Cfg.Data.Dir, config.Cfg.Database.Type) + + // 初始化数据库 if err := database.InitDB(); err != nil { - log.Fatalf("Failed to initialize database: %v", err) + log.Fatalf("初始化数据库失败: %v", err) } defer database.CloseDB() - // Create uploads directory - if err := os.MkdirAll(filepath.Join(".", "data", "uploads"), 0755); err != nil { - log.Fatalf("Failed to create uploads directory: %v", err) + // 创建上传目录 + if err := os.MkdirAll(config.GetUploadDir(), 0755); err != nil { + log.Fatalf("创建上传目录失败: %v", err) } - // Create session store + // 创建 session 存储 sessionStore := session.NewSessionStore() - // Create IP ban guard (in-memory fail counter) + // 创建 IP 封禁守护 ipBanGuard := middleware.NewIPBanGuard() - // Set Gin mode + // 设置 Gin 模式 ginMode := os.Getenv("GIN_MODE") if ginMode == "" { gin.SetMode(gin.DebugMode) @@ -76,43 +82,43 @@ func main() { r := gin.Default() - // Load HTML templates (custom loader for nested directories) + // 加载 HTML 模板 r.SetHTMLTemplate(loadTemplates()) - // Serve static files + // 静态文件 r.Static("/static", "./static") - // Inject session store and IP ban guard into context for handlers + // 注入 session 和 IP 封禁守护 r.Use(func(c *gin.Context) { c.Set("sessionStore", sessionStore) c.Set("ipBanGuard", ipBanGuard) c.Next() }) - // Public routes (home page and uploads — no IP restriction) + // 公开路由 r.GET("/", handlers.HomeHandler) r.GET("/click/:id", handlers.CardClickHandler) r.GET("/search", handlers.SearchHandler) r.GET("/uploads/:filename", handlers.ServeUploadHandler) - // Admin routes with IP whitelist check applied to all /admin/* routes + // 后台路由(IP 白名单) adminGroup := r.Group("/admin") adminGroup.Use(middleware.IPWhitelistRequired(func(sessionID string) bool { return sessionStore.Get(sessionID) != nil })) { - // Public admin routes (login — no auth required, but IP whitelist applies) + // 登录(无需认证,受 IP 白名单限制) adminGroup.GET("/login", handlers.LoginGet) adminGroup.POST("/login", handlers.LoginPost) - // Protected admin routes (auth required) + // 需要认证的后台路由 protected := adminGroup.Group("") protected.Use(middleware.AuthRequired(sessionStore)) { protected.POST("/logout", handlers.Logout) protected.GET("/", handlers.AdminIndex) - // Cards management + // 卡片管理 protected.GET("/cards", handlers.CardsList) protected.GET("/cards/new", handlers.CardCreateGet) protected.POST("/cards", handlers.CardCreatePost) @@ -123,39 +129,50 @@ func main() { protected.POST("/cards/:id/move-up", handlers.CardMoveUp) protected.POST("/cards/:id/move-down", handlers.CardMoveDown) - // Image upload + // 图片上传 protected.POST("/upload", handlers.UploadHandler) - // Settings + // 设置 protected.GET("/settings", handlers.SettingsGet) protected.POST("/settings", handlers.SettingsPost) - // Security: login logs + // 安全:登录日志 protected.GET("/logs", handlers.LoginLogsGet) protected.POST("/logs/unban/:id", handlers.UnbanIP) - // Security: change password + // 安全:修改密码 protected.GET("/password", handlers.ChangePasswordGet) protected.POST("/password", handlers.ChangePasswordPost) - // Security: IP whitelist management + // 安全:IP 白名单 protected.GET("/ip-whitelist", handlers.IPWhitelistGet) protected.POST("/ip-whitelist/add", handlers.IPWhitelistAdd) protected.POST("/ip-whitelist/:id/delete", handlers.IPWhitelistDelete) - // Analytics: access logs + // 分析:访问日志 protected.GET("/access-logs", handlers.AccessLogsGet) } } - // Determine port - port := os.Getenv("PORT") - if port == "" { - port = "8080" - } - - log.Printf("Starting Portal server on :%s", port) - if err := r.Run(":" + port); err != nil { - log.Fatalf("Failed to start server: %v", err) + // 启动服务器 + if config.Cfg.Server.Unix != "" { + // Unix socket 模式 + listener, err := net.Listen("unix", config.Cfg.Server.Unix) + if err != nil { + log.Fatalf("监听 Unix socket 失败: %v", err) + } + // 设置 socket 文件权限,允许 nginx 等其他进程访问 + os.Chmod(config.Cfg.Server.Unix, 0666) + log.Printf("启动 Portal 服务器,监听 Unix socket: %s", config.Cfg.Server.Unix) + if err := r.RunListener(listener); err != nil { + log.Fatalf("服务器启动失败: %v", err) + } + } else { + // TCP 模式 + addr := config.Cfg.Server.Addr + log.Printf("启动 Portal 服务器,监听: %s", addr) + if err := r.Run(addr); err != nil { + log.Fatalf("服务器启动失败: %v", err) + } } }