218 lines
5.1 KiB
Go
218 lines
5.1 KiB
Go
package config
|
||
|
||
import (
|
||
"fmt"
|
||
"log"
|
||
"os"
|
||
"path/filepath"
|
||
"runtime"
|
||
|
||
"github.com/pelletier/go-toml/v2"
|
||
)
|
||
|
||
// Cfg 是全局配置实例,在 Load() 后可用
|
||
var Cfg *Config
|
||
|
||
// Config 是完整的配置结构
|
||
type Config struct {
|
||
Data DataConfig `toml:"data"`
|
||
Database DatabaseConfig `toml:"database"`
|
||
Server ServerConfig `toml:"server"`
|
||
}
|
||
|
||
// DataConfig 数据存储配置
|
||
type DataConfig struct {
|
||
// 数据存储根目录,Windows 默认 "data",Linux 默认 "/srv/portal_page"
|
||
Dir string `toml:"dir"`
|
||
}
|
||
|
||
// DatabaseConfig 数据库配置
|
||
type DatabaseConfig struct {
|
||
// 数据库类型: "sqlite" 或 "mysql"
|
||
Type string `toml:"type"`
|
||
// SQLite 数据库文件名(相对于 data.dir)
|
||
Path string `toml:"path"`
|
||
// MySQL 配置
|
||
MySQL MySQLConfig `toml:"mysql"`
|
||
}
|
||
|
||
// MySQLConfig MySQL 连接配置
|
||
type MySQLConfig struct {
|
||
Host string `toml:"host"`
|
||
Port int `toml:"port"`
|
||
User string `toml:"user"`
|
||
Password string `toml:"password"`
|
||
DBName string `toml:"dbname"`
|
||
}
|
||
|
||
// ServerConfig Web 服务器配置
|
||
type ServerConfig struct {
|
||
// 监听地址,格式 ":8080"
|
||
Addr string `toml:"addr"`
|
||
// Unix socket 路径(设置后优先使用 unix socket,addr 被忽略)
|
||
Unix string `toml:"unix"`
|
||
}
|
||
|
||
// defaultConfig 返回当前平台的默认配置
|
||
func defaultConfig() *Config {
|
||
cfg := &Config{
|
||
Data: DataConfig{
|
||
Dir: "data",
|
||
},
|
||
Database: DatabaseConfig{
|
||
Type: "sqlite",
|
||
Path: "portal.db",
|
||
MySQL: MySQLConfig{
|
||
Host: "127.0.0.1",
|
||
Port: 3306,
|
||
User: "root",
|
||
Password: "",
|
||
DBName: "portal_page",
|
||
},
|
||
},
|
||
Server: ServerConfig{
|
||
Addr: ":8080",
|
||
Unix: "",
|
||
},
|
||
}
|
||
|
||
// Linux 下遵循 FHS 标准
|
||
if runtime.GOOS == "linux" {
|
||
cfg.Data.Dir = "/srv/portal_page"
|
||
}
|
||
|
||
return cfg
|
||
}
|
||
|
||
// configPath 返回当前平台的配置文件路径
|
||
func configPath() string {
|
||
if runtime.GOOS == "windows" {
|
||
// Windows: 相对于工作目录的 conf/config.toml
|
||
// 注意: 使用工作目录而非 os.Executable(),因为 go run 时 exe 在临时目录
|
||
wd, err := os.Getwd()
|
||
if err != nil {
|
||
return "conf/config.toml"
|
||
}
|
||
return filepath.Join(wd, "conf", "config.toml")
|
||
}
|
||
// Linux: FHS 标准 /etc/portal_page/config.toml
|
||
return "/etc/portal_page/config.toml"
|
||
}
|
||
|
||
// Load 加载配置文件,不存在则自动生成,缺失项自动补全
|
||
func Load() error {
|
||
path := configPath()
|
||
def := defaultConfig()
|
||
|
||
// 如果配置文件不存在,使用默认配置直接生成
|
||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||
log.Printf("配置文件不存在,自动生成: %s", path)
|
||
Cfg = def
|
||
return saveConfig(path, Cfg)
|
||
}
|
||
|
||
// 读取并解析配置文件
|
||
data, err := os.ReadFile(path)
|
||
if err != nil {
|
||
return fmt.Errorf("读取配置文件失败: %w", err)
|
||
}
|
||
|
||
Cfg = &Config{}
|
||
if err := toml.Unmarshal(data, Cfg); err != nil {
|
||
return fmt.Errorf("解析配置文件失败: %w", err)
|
||
}
|
||
|
||
// 补全缺失项
|
||
changed := fillDefaults(Cfg, def)
|
||
if changed {
|
||
log.Printf("配置文件有缺失项,已自动补全: %s", path)
|
||
if err := saveConfig(path, Cfg); err != nil {
|
||
return fmt.Errorf("保存补全后的配置文件失败: %w", err)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// fillDefaults 用默认值补全零值字段,返回是否有变更
|
||
func fillDefaults(cfg, def *Config) bool {
|
||
changed := false
|
||
|
||
// Data
|
||
if cfg.Data.Dir == "" {
|
||
cfg.Data.Dir = def.Data.Dir
|
||
changed = true
|
||
}
|
||
|
||
// Database
|
||
if cfg.Database.Type == "" {
|
||
cfg.Database.Type = def.Database.Type
|
||
changed = true
|
||
}
|
||
if cfg.Database.Path == "" {
|
||
cfg.Database.Path = def.Database.Path
|
||
changed = true
|
||
}
|
||
if cfg.Database.MySQL.Host == "" {
|
||
cfg.Database.MySQL.Host = def.Database.MySQL.Host
|
||
changed = true
|
||
}
|
||
if cfg.Database.MySQL.Port == 0 {
|
||
cfg.Database.MySQL.Port = def.Database.MySQL.Port
|
||
changed = true
|
||
}
|
||
if cfg.Database.MySQL.User == "" {
|
||
cfg.Database.MySQL.User = def.Database.MySQL.User
|
||
changed = true
|
||
}
|
||
// Password 允许为空字符串,不补全
|
||
if cfg.Database.MySQL.DBName == "" {
|
||
cfg.Database.MySQL.DBName = def.Database.MySQL.DBName
|
||
changed = true
|
||
}
|
||
|
||
// Server
|
||
if cfg.Server.Addr == "" && cfg.Server.Unix == "" {
|
||
cfg.Server.Addr = def.Server.Addr
|
||
changed = true
|
||
}
|
||
|
||
return changed
|
||
}
|
||
|
||
// saveConfig 将配置写入文件
|
||
func saveConfig(path string, cfg *Config) error {
|
||
// 确保目录存在
|
||
dir := filepath.Dir(path)
|
||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||
return fmt.Errorf("创建配置目录失败: %w", err)
|
||
}
|
||
|
||
data, err := toml.Marshal(cfg)
|
||
if err != nil {
|
||
return fmt.Errorf("序列化配置失败: %w", err)
|
||
}
|
||
|
||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||
return fmt.Errorf("写入配置文件失败: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetUploadDir 返回上传文件目录的完整路径
|
||
func GetUploadDir() string {
|
||
return filepath.Join(Cfg.Data.Dir, "uploads")
|
||
}
|
||
|
||
// GetDBPath 返回 SQLite 数据库文件的完整路径
|
||
func GetDBPath() string {
|
||
return filepath.Join(Cfg.Data.Dir, Cfg.Database.Path)
|
||
}
|
||
|
||
// GetDSN 返回 MySQL 的 DSN 连接字符串
|
||
func (m *MySQLConfig) GetDSN() string {
|
||
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
|
||
m.User, m.Password, m.Host, m.Port, m.DBName)
|
||
}
|