增加uinx配置

This commit is contained in:
2026-04-12 14:24:26 +08:00
parent c90c58ae05
commit ae001b82e6
3 changed files with 113 additions and 9 deletions
+75 -4
View File
@@ -4,8 +4,10 @@ package config
import ( import (
"fmt" "fmt"
"math"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"sync" "sync"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
@@ -62,6 +64,7 @@ type SearchConfig struct {
FlushIntervalSeconds int `yaml:"flush_interval_seconds"` FlushIntervalSeconds int `yaml:"flush_interval_seconds"`
StatsRefreshInterval int `yaml:"stats_refresh_interval"` // 统计缓存刷新间隔(秒),默认 30 StatsRefreshInterval int `yaml:"stats_refresh_interval"` // 统计缓存刷新间隔(秒),默认 30
MissPenalty float64 `yaml:"miss_penalty"` // 缺词惩罚系数(0=不惩罚,1=完全忽略缺词URL),默认 0.15 MissPenalty float64 `yaml:"miss_penalty"` // 缺词惩罚系数(0=不惩罚,1=完全忽略缺词URL),默认 0.15
UnixSocket string `yaml:"unix_socket"` // Unix socket 路径(仅 Linux/macOS),空字符串表示不启用
} }
// BacklinkConfig 反向链接计算相关配置 // BacklinkConfig 反向链接计算相关配置
@@ -84,22 +87,87 @@ type PrometheusConfig struct {
// Global 全局配置实例,加载后可通过此变量访问 // Global 全局配置实例,加载后可通过此变量访问
var Global Config var Global Config
// Load 从指定路径加载配置文件 // Load 从指定路径加载配置文件,并自动补全缺失的字段。
// 流程:读取 YAML → 与默认值合并 → 写回 config.yml → 赋值 Global
func Load(configPath string) error { func Load(configPath string) error {
data, err := os.ReadFile(configPath) data, err := os.ReadFile(configPath)
if err != nil { if err != nil {
return fmt.Errorf("failed to read config file: %v", err) return fmt.Errorf("failed to read config file: %v", err)
} }
var cfg Config // 先拿到当前 YAML 内容,用于判断哪些字段实际存在于文件中
if err := yaml.Unmarshal(data, &cfg); err != nil { var yamlOnly Config
if err := yaml.Unmarshal(data, &yamlOnly); err != nil {
return fmt.Errorf("failed to parse config file: %v", err) return fmt.Errorf("failed to parse config file: %v", err)
} }
Global = cfg // 从默认值开始,YAML 中有值的字段会被覆盖
merged := GetDefaultConfig()
mergeConfig(&merged, &yamlOnly)
// 写回 config.yml(自动补全缺失字段)
yamlOut, err := yaml.Marshal(&merged)
if err != nil {
return fmt.Errorf("failed to marshal config: %v", err)
}
if err := os.WriteFile(configPath, yamlOut, 0644); err != nil {
return fmt.Errorf("failed to write config file: %v", err)
}
Global = merged
return nil return nil
} }
// mergeConfig 将 src 中的非零字段合并到 dst(原地修改 dst)。
// 用于把 YAML 实际配置值覆盖到默认值结构上。
func mergeConfig(dst, src interface{}) {
if dst == nil || src == nil {
return
}
dstVal := reflect.ValueOf(dst).Elem()
srcVal := reflect.ValueOf(src).Elem()
for i := 0; i < dstVal.NumField(); i++ {
dstField := dstVal.Field(i)
srcField := srcVal.Field(i)
switch dstField.Kind() {
case reflect.Struct:
// 递归合并嵌套 struct
mergeConfig(dstField.Addr().Interface(), srcField.Addr().Interface())
case reflect.Slice:
// slice:仅当 src 非空时才覆盖(避免覆盖用户显式设置的长 0 slice)
if srcField.Len() > 0 {
dstField.Set(srcField)
}
default:
// 其他类型:src 为零值则保留 dst 原值(默认值)
if !isZero(srcField) {
dstField.Set(srcField)
}
}
}
}
// isZero 检查 reflect.Value 是否为该类型的零值。
func isZero(v reflect.Value) bool {
switch v.Kind() {
case reflect.Bool:
return !v.Bool()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int() == 0
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return math.Float64bits(v.Float()) == 0
case reflect.String:
return v.String() == ""
case reflect.Ptr, reflect.Interface:
return v.IsNil()
}
return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface())
}
// LoadFromSavedata 从 savedata 目录加载 config.yml // LoadFromSavedata 从 savedata 目录加载 config.yml
func LoadFromSavedata() error { func LoadFromSavedata() error {
configPath := filepath.Join("savedata", "config.yml") configPath := filepath.Join("savedata", "config.yml")
@@ -264,6 +332,9 @@ func StatsRefreshInterval() int {
// MissPenalty 返回缺词惩罚系数(0~1),值越大对缺少查询词的 URL 惩罚越重。 // MissPenalty 返回缺词惩罚系数(0~1),值越大对缺少查询词的 URL 惩罚越重。
func MissPenalty() float64 { return Global.Search.MissPenalty } func MissPenalty() float64 { return Global.Search.MissPenalty }
// UnixSocket 返回 Unix socket 路径,空字符串表示不启用。
func UnixSocket() string { return Global.Search.UnixSocket }
// BacklinkBaseline 返回配置值 // BacklinkBaseline 返回配置值
func BacklinkBaseline() int { return Global.Backlink.Baseline } func BacklinkBaseline() int { return Global.Backlink.Baseline }
+1 -1
View File
@@ -161,7 +161,7 @@ func main() {
searchSrv := search.New(db, infoSvc, anal) searchSrv := search.New(db, infoSvc, anal)
go func() { go func() {
addr := fmt.Sprintf(":%d", config.SearchServerPort()) addr := fmt.Sprintf(":%d", config.SearchServerPort())
if err := searchSrv.ListenAndServe(addr); err != nil { if err := searchSrv.ListenAndServe(addr, config.UnixSocket()); err != nil {
log.Fatalf("[search] fatal: %v", err) log.Fatalf("[search] fatal: %v", err)
} }
}() }()
+36 -3
View File
@@ -11,10 +11,12 @@ import (
"log" // 日志 "log" // 日志
"math" // 数学运算(Log、幂) "math" // 数学运算(Log、幂)
"math/rand" // 随机数(刷盘时打乱顺序、概率性去重/裁剪) "math/rand" // 随机数(刷盘时打乱顺序、概率性去重/裁剪)
"net" // net.ListenUnix socket
"net/http" // HTTP 服务端 "net/http" // HTTP 服务端
"net/url" // URL 解析 "net/url" // URL 解析
"os" // 文件系统(静态文件读取) "os" // 文件系统(静态文件读取)
"regexp" // 正则表达式(site: 过滤语法) "regexp" // 正则表达式(site: 过滤语法)
"runtime" // runtime.GOOS(平台判断)
"sort" // 排序 "sort" // 排序
"strconv" // 字符串转整数 "strconv" // 字符串转整数
"strings" // 字符串操作(URL 清洗) "strings" // 字符串操作(URL 清洗)
@@ -324,15 +326,46 @@ func (h spaHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
// ListenAndServe 启动搜索服务器(带超时保护)。 // ListenAndServe 启动搜索服务器(带超时保护)。
func (s *Server) ListenAndServe(addr string) error { // 在 Linux/macOS 下,若 unixSocket 非空,则同时在 TCP 和 Unix socket 上监听。
log.Printf("[search] listening on %s", addr) // 在 Windows 下,unixSocket 参数被忽略,只监听 TCP。
func (s *Server) ListenAndServe(addr string, unixSocket string) error {
handler := s.Handler()
srv := &http.Server{ srv := &http.Server{
Addr: addr, Addr: addr,
Handler: s.Handler(), Handler: handler,
ReadTimeout: 10 * time.Second, ReadTimeout: 10 * time.Second,
WriteTimeout: 60 * time.Second, WriteTimeout: 60 * time.Second,
IdleTimeout: 120 * time.Second, IdleTimeout: 120 * time.Second,
} }
// Linux/macOS 下且配置了 socket 路径时,额外启动 Unix socket 监听
if unixSocket != "" && runtime.GOOS != "windows" {
// 清理旧的 socket 文件(上次异常退出可能残留)
_ = os.Remove(unixSocket)
ln, err := net.Listen("unix", unixSocket)
if err != nil {
log.Printf("[search] unix socket failed (%s): %v — continuing with TCP only", unixSocket, err)
} else {
// 设置 socket 文件权限:www 用户和 nginx 都可访问
if err := os.Chmod(unixSocket, 0660); err != nil {
log.Printf("[search] chmod unix socket: %v", err)
}
log.Printf("[search] also listening on unix:%s", unixSocket)
unixSrv := &http.Server{
Handler: handler,
ReadTimeout: 10 * time.Second,
WriteTimeout: 60 * time.Second,
IdleTimeout: 120 * time.Second,
}
go func() {
if err := unixSrv.Serve(ln); err != nil && err != http.ErrServerClosed {
log.Printf("[search] unix socket serve error: %v", err)
}
}()
}
}
log.Printf("[search] listening on %s", addr)
return srv.ListenAndServe() return srv.ListenAndServe()
} }