This commit is contained in:
2026-05-28 14:43:13 +08:00
parent c16a8dfbc4
commit 957a594a0f
7 changed files with 353 additions and 64 deletions
+3
View File
@@ -18,6 +18,9 @@ portal
Thumbs.db
Desktop.ini
# 配置文件(运行时自动生成)
/conf/
# 数据库
/data/*.db
/data/*.db-shm
+216
View File
@@ -0,0 +1,216 @@
package config
import (
"fmt"
"log"
"os"
"path/filepath"
"runtime"
"github.com/pelletier/go-toml/v2"
)
// Cfg 是全局配置实例,在 Load() 后可用
var Cfg *Config
// Config 是完整的配置结构
type Config struct {
Data DataConfig `toml:"data"`
Database DatabaseConfig `toml:"database"`
Server ServerConfig `toml:"server"`
}
// DataConfig 数据存储配置
type DataConfig struct {
// 数据存储根目录,Windows 默认 "data"Linux 默认 "/srv/portal_page"
Dir string `toml:"dir"`
}
// DatabaseConfig 数据库配置
type DatabaseConfig struct {
// 数据库类型: "sqlite" 或 "mysql"
Type string `toml:"type"`
// SQLite 数据库文件名(相对于 data.dir)
Path string `toml:"path"`
// MySQL 配置
MySQL MySQLConfig `toml:"mysql"`
}
// MySQLConfig MySQL 连接配置
type MySQLConfig struct {
Host string `toml:"host"`
Port int `toml:"port"`
User string `toml:"user"`
Password string `toml:"password"`
DBName string `toml:"dbname"`
}
// ServerConfig Web 服务器配置
type ServerConfig struct {
// 监听地址,格式 ":8080"
Addr string `toml:"addr"`
// Unix socket 路径(设置后优先使用 unix socketaddr 被忽略)
Unix string `toml:"unix"`
}
// defaultConfig 返回当前平台的默认配置
func defaultConfig() *Config {
cfg := &Config{
Data: DataConfig{
Dir: "data",
},
Database: DatabaseConfig{
Type: "sqlite",
Path: "portal.db",
MySQL: MySQLConfig{
Host: "127.0.0.1",
Port: 3306,
User: "root",
Password: "",
DBName: "portal_page",
},
},
Server: ServerConfig{
Addr: ":8080",
Unix: "",
},
}
// Linux 下遵循 FHS 标准
if runtime.GOOS == "linux" {
cfg.Data.Dir = "/srv/portal_page"
}
return cfg
}
// configPath 返回当前平台的配置文件路径
func configPath() string {
if runtime.GOOS == "windows" {
// Windows: 相对于可执行文件的 conf/config.toml
exePath, err := os.Executable()
if err != nil {
return "conf/config.toml"
}
return filepath.Join(filepath.Dir(exePath), "conf", "config.toml")
}
// Linux: FHS 标准 /etc/portal_page/config.toml
return "/etc/portal_page/config.toml"
}
// Load 加载配置文件,不存在则自动生成,缺失项自动补全
func Load() error {
path := configPath()
def := defaultConfig()
// 如果配置文件不存在,使用默认配置直接生成
if _, err := os.Stat(path); os.IsNotExist(err) {
log.Printf("配置文件不存在,自动生成: %s", path)
Cfg = def
return saveConfig(path, Cfg)
}
// 读取并解析配置文件
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("读取配置文件失败: %w", err)
}
Cfg = &Config{}
if err := toml.Unmarshal(data, Cfg); err != nil {
return fmt.Errorf("解析配置文件失败: %w", err)
}
// 补全缺失项
changed := fillDefaults(Cfg, def)
if changed {
log.Printf("配置文件有缺失项,已自动补全: %s", path)
if err := saveConfig(path, Cfg); err != nil {
return fmt.Errorf("保存补全后的配置文件失败: %w", err)
}
}
return nil
}
// fillDefaults 用默认值补全零值字段,返回是否有变更
func fillDefaults(cfg, def *Config) bool {
changed := false
// Data
if cfg.Data.Dir == "" {
cfg.Data.Dir = def.Data.Dir
changed = true
}
// Database
if cfg.Database.Type == "" {
cfg.Database.Type = def.Database.Type
changed = true
}
if cfg.Database.Path == "" {
cfg.Database.Path = def.Database.Path
changed = true
}
if cfg.Database.MySQL.Host == "" {
cfg.Database.MySQL.Host = def.Database.MySQL.Host
changed = true
}
if cfg.Database.MySQL.Port == 0 {
cfg.Database.MySQL.Port = def.Database.MySQL.Port
changed = true
}
if cfg.Database.MySQL.User == "" {
cfg.Database.MySQL.User = def.Database.MySQL.User
changed = true
}
// Password 允许为空字符串,不补全
if cfg.Database.MySQL.DBName == "" {
cfg.Database.MySQL.DBName = def.Database.MySQL.DBName
changed = true
}
// Server
if cfg.Server.Addr == "" && cfg.Server.Unix == "" {
cfg.Server.Addr = def.Server.Addr
changed = true
}
return changed
}
// saveConfig 将配置写入文件
func saveConfig(path string, cfg *Config) error {
// 确保目录存在
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("创建配置目录失败: %w", err)
}
data, err := toml.Marshal(cfg)
if err != nil {
return fmt.Errorf("序列化配置失败: %w", err)
}
if err := os.WriteFile(path, data, 0644); err != nil {
return fmt.Errorf("写入配置文件失败: %w", err)
}
return nil
}
// GetUploadDir 返回上传文件目录的完整路径
func GetUploadDir() string {
return filepath.Join(Cfg.Data.Dir, "uploads")
}
// GetDBPath 返回 SQLite 数据库文件的完整路径
func GetDBPath() string {
return filepath.Join(Cfg.Data.Dir, Cfg.Database.Path)
}
// GetDSN 返回 MySQL 的 DSN 连接字符串
func (m *MySQLConfig) GetDSN() string {
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
m.User, m.Password, m.Host, m.Port, m.DBName)
}
+67 -19
View File
@@ -6,50 +6,98 @@ import (
"os"
"path/filepath"
"simple_portal/config"
_ "github.com/go-sql-driver/mysql"
_ "modernc.org/sqlite"
)
// DB is the global database connection pointer.
var DB *sql.DB
// InitDB initializes the SQLite database, creates the data directory,
// opens the connection, sets WAL mode, and creates default tables and data.
// InitDB 根据配置初始化数据库,支持 SQLite 和 MySQL
func InitDB() error {
dbPath := filepath.Join(".", "data", "portal.db")
dbType := config.Cfg.Database.Type
dataDir := config.Cfg.Data.Dir
// Create data directory if not exists
switch dbType {
case "mysql":
return initMySQL()
case "sqlite":
return initSQLite(dataDir)
default:
return fmt.Errorf("不支持的数据库类型: %s(可选: sqlite, mysql", dbType)
}
}
// initSQLite 初始化 SQLite 数据库
func initSQLite(dataDir string) error {
dbPath := config.GetDBPath()
// 创建数据目录
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
return fmt.Errorf("failed to create data directory: %w", err)
return fmt.Errorf("创建数据目录失败: %w", err)
}
var err error
DB, err = sql.Open("sqlite", dbPath)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
return fmt.Errorf("打开 SQLite 数据库失败: %w", err)
}
// SQLite requires max open conns = 1 for safe writes
// SQLite 需要限制最大连接数为 1 以保证写安全
DB.SetMaxOpenConns(1)
// Enable WAL mode for better concurrent read performance
// 启用 WAL 模式提升并发读性能
if _, err := DB.Exec("PRAGMA journal_mode=WAL"); err != nil {
return fmt.Errorf("failed to set WAL mode: %w", err)
return fmt.Errorf("设置 WAL 模式失败: %w", err)
}
// Create tables
// 创建表
if err := createTables(); err != nil {
return fmt.Errorf("failed to create tables: %w", err)
return fmt.Errorf("创建数据表失败: %w", err)
}
// Seed default data
// 初始化默认数据
if err := seedData(); err != nil {
return fmt.Errorf("failed to seed data: %w", err)
return fmt.Errorf("初始化默认数据失败: %w", err)
}
return nil
}
// createTables creates the cards, settings, and admins tables.
// initMySQL 初始化 MySQL 数据库
func initMySQL() error {
dsn := config.Cfg.Database.MySQL.GetDSN()
var err error
DB, err = sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("打开 MySQL 数据库失败: %w", err)
}
// 测试连接
if err := DB.Ping(); err != nil {
return fmt.Errorf("连接 MySQL 失败: %w", err)
}
// MySQL 不限制最大连接数(使用默认池)
DB.SetMaxOpenConns(0)
// 创建表
if err := createTables(); err != nil {
return fmt.Errorf("创建数据表失败: %w", err)
}
// 初始化默认数据
if err := seedData(); err != nil {
return fmt.Errorf("初始化默认数据失败: %w", err)
}
return nil
}
// createTables 创建所有数据表
func createTables() error {
_, err := DB.Exec(`
CREATE TABLE IF NOT EXISTS cards (
@@ -110,9 +158,9 @@ func createTables() error {
return err
}
// seedData inserts default admin account and search engine setting if not present.
// seedData 插入默认管理员和搜索引擎配置
func seedData() error {
// Insert default admin if not exists
// 插入默认管理员
var count int
err := DB.QueryRow("SELECT COUNT(*) FROM admins WHERE username = ?", "admin").Scan(&count)
if err != nil {
@@ -129,7 +177,7 @@ func seedData() error {
}
}
// Insert default search engine setting if not exists
// 插入默认搜索引擎设置
var settingCount int
err = DB.QueryRow("SELECT COUNT(*) FROM settings WHERE key = ?", "search_engine").Scan(&settingCount)
if err != nil {
@@ -146,7 +194,7 @@ func seedData() error {
}
}
// Insert default homepage settings if not exists
// 插入默认主页设置
defaultSettings := []struct {
key string
value string
@@ -172,7 +220,7 @@ func seedData() error {
return nil
}
// CloseDB closes the database connection.
// CloseDB 关闭数据库连接
func CloseDB() error {
if DB != nil {
return DB.Close()
+4 -2
View File
@@ -1,16 +1,19 @@
module simple_portal
go 1.22
go 1.24.0
require (
github.com/disintegration/imaging v1.6.2
github.com/gin-gonic/gin v1.10.0
github.com/go-sql-driver/mysql v1.10.0
github.com/google/uuid v1.6.0
github.com/pelletier/go-toml/v2 v2.2.2
golang.org/x/crypto v0.28.0
modernc.org/sqlite v1.34.5
)
require (
filippo.io/edwards25519 v1.2.0 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
@@ -29,7 +32,6 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
+4
View File
@@ -1,3 +1,5 @@
filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
@@ -27,6 +29,8 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-sql-driver/mysql v1.10.0 h1:Q+1LV8DkHJvSYAdR83XzuhDaTykuDx0l6fkXxoWCWfw=
github.com/go-sql-driver/mysql v1.10.0/go.mod h1:M+cqaI7+xxXGG9swrdeUIoPG3Y3KCkF0pZej+SK+nWk=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
+7 -8
View File
@@ -9,6 +9,8 @@ import (
"path/filepath"
"strings"
"simple_portal/config"
"github.com/disintegration/imaging"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -24,9 +26,6 @@ var allowedMIMETypes = map[string]string{
// maxUploadSize is the maximum allowed file size in bytes (5 MB).
const maxUploadSize = 5 << 20
// uploadDir is the directory where uploaded files are stored.
const uploadDir = "./data/uploads"
// thumbSuffix is the suffix appended to compressed image filenames.
const thumbSuffix = "_thumb"
@@ -94,13 +93,13 @@ func UploadHandler(c *gin.Context) {
filename := fileUUID + ext
// Ensure upload directory exists
if err := os.MkdirAll(uploadDir, 0755); err != nil {
if err := os.MkdirAll(config.GetUploadDir(), 0755); err != nil {
c.JSON(500, gin.H{"error": "创建上传目录失败"})
return
}
// Save original file
originalPath := filepath.Join(uploadDir, filename)
originalPath := filepath.Join(config.GetUploadDir(), filename)
dst, err := os.Create(originalPath)
if err != nil {
c.JSON(500, gin.H{"error": "保存文件失败"})
@@ -142,14 +141,14 @@ func ServeUploadHandler(c *gin.Context) {
return
}
filePath := filepath.Join(uploadDir, filename)
filePath := filepath.Join(config.GetUploadDir(), filename)
// Check if thumb=1 query parameter is requested
if c.Query("thumb") == "1" {
// Try to serve the thumbnail version
ext := filepath.Ext(filename)
baseName := filename[:len(filename)-len(ext)]
thumbPath := filepath.Join(uploadDir, baseName+thumbSuffix+".jpg")
thumbPath := filepath.Join(config.GetUploadDir(), baseName+thumbSuffix+".jpg")
if _, err := os.Stat(thumbPath); err == nil {
c.File(thumbPath)
@@ -221,7 +220,7 @@ func generateThumbnail(originalPath, fileUUID, uploadType string) error {
thumb := fitImage(img, maxWidth, maxHeight)
// Save thumbnail as JPEG
thumbPath := filepath.Join(uploadDir, fileUUID+thumbSuffix+".jpg")
thumbPath := filepath.Join(config.GetUploadDir(), fileUUID+thumbSuffix+".jpg")
thumbFile, err := os.Create(thumbPath)
if err != nil {
return fmt.Errorf("failed to create thumbnail file: %w", err)
+51 -34
View File
@@ -3,10 +3,12 @@ package main
import (
"html/template"
"log"
"net"
"os"
"path/filepath"
"strings"
"simple_portal/config"
"simple_portal/database"
"simple_portal/handlers"
"simple_portal/middleware"
@@ -15,8 +17,8 @@ import (
"github.com/gin-gonic/gin"
)
// loadTemplates loads HTML templates from templates/ directory recursively.
// Custom implementation because Go's ParseGlob has issues with directories on Windows.
// loadTemplates 加载 templates/ 目录下所有 HTML 模板
// 自定义实现,因为 Go ParseGlob Windows 下有路径问题
func loadTemplates() *template.Template {
funcMap := template.FuncMap{
"hasPrefix": strings.HasPrefix,
@@ -24,7 +26,6 @@ func loadTemplates() *template.Template {
"add": func(a, b int) int { return a + b },
}
t := template.New("").Funcs(funcMap)
// 收集所有 .html 模板文件路径
var files []string
filepath.Walk("templates", func(path string, info os.FileInfo, err error) error {
if err != nil {
@@ -38,7 +39,6 @@ func loadTemplates() *template.Template {
if len(files) == 0 {
log.Fatal("No template files found in templates/")
}
// 将 Windows 反斜杠路径转为正斜杠,避免模板名问题
for i, f := range files {
files[i] = filepath.ToSlash(f)
}
@@ -51,24 +51,30 @@ func loadTemplates() *template.Template {
}
func main() {
// Initialize database
// 加载配置文件(自动生成 + 补全缺失项)
if err := config.Load(); err != nil {
log.Fatalf("加载配置失败: %v", err)
}
log.Printf("配置加载成功,数据目录: %s,数据库: %s", config.Cfg.Data.Dir, config.Cfg.Database.Type)
// 初始化数据库
if err := database.InitDB(); err != nil {
log.Fatalf("Failed to initialize database: %v", err)
log.Fatalf("初始化数据库失败: %v", err)
}
defer database.CloseDB()
// Create uploads directory
if err := os.MkdirAll(filepath.Join(".", "data", "uploads"), 0755); err != nil {
log.Fatalf("Failed to create uploads directory: %v", err)
// 创建上传目录
if err := os.MkdirAll(config.GetUploadDir(), 0755); err != nil {
log.Fatalf("创建上传目录失败: %v", err)
}
// Create session store
// 创建 session 存储
sessionStore := session.NewSessionStore()
// Create IP ban guard (in-memory fail counter)
// 创建 IP 封禁守护
ipBanGuard := middleware.NewIPBanGuard()
// Set Gin mode
// 设置 Gin 模式
ginMode := os.Getenv("GIN_MODE")
if ginMode == "" {
gin.SetMode(gin.DebugMode)
@@ -76,43 +82,43 @@ func main() {
r := gin.Default()
// Load HTML templates (custom loader for nested directories)
// 加载 HTML 模板
r.SetHTMLTemplate(loadTemplates())
// Serve static files
// 静态文件
r.Static("/static", "./static")
// Inject session store and IP ban guard into context for handlers
// 注入 session 和 IP 封禁守护
r.Use(func(c *gin.Context) {
c.Set("sessionStore", sessionStore)
c.Set("ipBanGuard", ipBanGuard)
c.Next()
})
// Public routes (home page and uploads — no IP restriction)
// 公开路由
r.GET("/", handlers.HomeHandler)
r.GET("/click/:id", handlers.CardClickHandler)
r.GET("/search", handlers.SearchHandler)
r.GET("/uploads/:filename", handlers.ServeUploadHandler)
// Admin routes with IP whitelist check applied to all /admin/* routes
// 后台路由(IP 白名单)
adminGroup := r.Group("/admin")
adminGroup.Use(middleware.IPWhitelistRequired(func(sessionID string) bool {
return sessionStore.Get(sessionID) != nil
}))
{
// Public admin routes (login — no auth required, but IP whitelist applies)
// 登录(无需认证,受 IP 白名单限制)
adminGroup.GET("/login", handlers.LoginGet)
adminGroup.POST("/login", handlers.LoginPost)
// Protected admin routes (auth required)
// 需要认证的后台路由
protected := adminGroup.Group("")
protected.Use(middleware.AuthRequired(sessionStore))
{
protected.POST("/logout", handlers.Logout)
protected.GET("/", handlers.AdminIndex)
// Cards management
// 卡片管理
protected.GET("/cards", handlers.CardsList)
protected.GET("/cards/new", handlers.CardCreateGet)
protected.POST("/cards", handlers.CardCreatePost)
@@ -123,39 +129,50 @@ func main() {
protected.POST("/cards/:id/move-up", handlers.CardMoveUp)
protected.POST("/cards/:id/move-down", handlers.CardMoveDown)
// Image upload
// 图片上传
protected.POST("/upload", handlers.UploadHandler)
// Settings
// 设置
protected.GET("/settings", handlers.SettingsGet)
protected.POST("/settings", handlers.SettingsPost)
// Security: login logs
// 安全:登录日志
protected.GET("/logs", handlers.LoginLogsGet)
protected.POST("/logs/unban/:id", handlers.UnbanIP)
// Security: change password
// 安全:修改密码
protected.GET("/password", handlers.ChangePasswordGet)
protected.POST("/password", handlers.ChangePasswordPost)
// Security: IP whitelist management
// 安全:IP 白名单
protected.GET("/ip-whitelist", handlers.IPWhitelistGet)
protected.POST("/ip-whitelist/add", handlers.IPWhitelistAdd)
protected.POST("/ip-whitelist/:id/delete", handlers.IPWhitelistDelete)
// Analytics: access logs
// 分析:访问日志
protected.GET("/access-logs", handlers.AccessLogsGet)
}
}
// Determine port
port := os.Getenv("PORT")
if port == "" {
port = "8080"
// 启动服务器
if config.Cfg.Server.Unix != "" {
// Unix socket 模式
listener, err := net.Listen("unix", config.Cfg.Server.Unix)
if err != nil {
log.Fatalf("监听 Unix socket 失败: %v", err)
}
// 设置 socket 文件权限,允许 nginx 等其他进程访问
os.Chmod(config.Cfg.Server.Unix, 0666)
log.Printf("启动 Portal 服务器,监听 Unix socket: %s", config.Cfg.Server.Unix)
if err := r.RunListener(listener); err != nil {
log.Fatalf("服务器启动失败: %v", err)
}
} else {
// TCP 模式
addr := config.Cfg.Server.Addr
log.Printf("启动 Portal 服务器,监听: %s", addr)
if err := r.Run(addr); err != nil {
log.Fatalf("服务器启动失败: %v", err)
}
log.Printf("Starting Portal server on :%s", port)
if err := r.Run(":" + port); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
}