From ae001b82e65a136370b8b8159f1a89059ae14afe Mon Sep 17 00:00:00 2001 From: kevin Date: Sun, 12 Apr 2026 14:24:26 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0uinx=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.go | 81 +++++++++++++++++++++++++++++++++++++++++++++--- main.go | 2 +- search/server.go | 39 +++++++++++++++++++++-- 3 files changed, 113 insertions(+), 9 deletions(-) diff --git a/config/config.go b/config/config.go index a4f6122..0609e98 100644 --- a/config/config.go +++ b/config/config.go @@ -4,8 +4,10 @@ package config import ( "fmt" + "math" "os" "path/filepath" + "reflect" "sync" "gopkg.in/yaml.v3" @@ -61,7 +63,8 @@ type SearchConfig struct { ServerPort int `yaml:"server_port"` FlushIntervalSeconds int `yaml:"flush_interval_seconds"` StatsRefreshInterval int `yaml:"stats_refresh_interval"` // 统计缓存刷新间隔(秒),默认 30 - MissPenalty float64 `yaml:"miss_penalty"` // 缺词惩罚系数(0=不惩罚,1=完全忽略缺词URL),默认 0.15 + MissPenalty float64 `yaml:"miss_penalty"` // 缺词惩罚系数(0=不惩罚,1=完全忽略缺词URL),默认 0.15 + UnixSocket string `yaml:"unix_socket"` // Unix socket 路径(仅 Linux/macOS),空字符串表示不启用 } // BacklinkConfig 反向链接计算相关配置 @@ -84,22 +87,87 @@ type PrometheusConfig struct { // Global 全局配置实例,加载后可通过此变量访问 var Global Config -// Load 从指定路径加载配置文件 +// Load 从指定路径加载配置文件,并自动补全缺失的字段。 +// 流程:读取 YAML → 与默认值合并 → 写回 config.yml → 赋值 Global func Load(configPath string) error { data, err := os.ReadFile(configPath) if err != nil { return fmt.Errorf("failed to read config file: %v", err) } - var cfg Config - if err := yaml.Unmarshal(data, &cfg); err != nil { + // 先拿到当前 YAML 内容,用于判断哪些字段实际存在于文件中 + var yamlOnly Config + if err := yaml.Unmarshal(data, &yamlOnly); err != nil { return fmt.Errorf("failed to parse config file: %v", err) } - Global = cfg + // 从默认值开始,YAML 中有值的字段会被覆盖 + merged := GetDefaultConfig() + mergeConfig(&merged, &yamlOnly) + + // 写回 config.yml(自动补全缺失字段) + yamlOut, err := yaml.Marshal(&merged) + if err != nil { + return fmt.Errorf("failed to marshal config: %v", err) + } + if err := os.WriteFile(configPath, yamlOut, 0644); err != nil { + return fmt.Errorf("failed to write config file: %v", err) + } + + Global = merged return nil } +// mergeConfig 将 src 中的非零字段合并到 dst(原地修改 dst)。 +// 用于把 YAML 实际配置值覆盖到默认值结构上。 +func mergeConfig(dst, src interface{}) { + if dst == nil || src == nil { + return + } + dstVal := reflect.ValueOf(dst).Elem() + srcVal := reflect.ValueOf(src).Elem() + + for i := 0; i < dstVal.NumField(); i++ { + dstField := dstVal.Field(i) + srcField := srcVal.Field(i) + + switch dstField.Kind() { + case reflect.Struct: + // 递归合并嵌套 struct + mergeConfig(dstField.Addr().Interface(), srcField.Addr().Interface()) + case reflect.Slice: + // slice:仅当 src 非空时才覆盖(避免覆盖用户显式设置的长 0 slice) + if srcField.Len() > 0 { + dstField.Set(srcField) + } + default: + // 其他类型:src 为零值则保留 dst 原值(默认值) + if !isZero(srcField) { + dstField.Set(srcField) + } + } + } +} + +// isZero 检查 reflect.Value 是否为该类型的零值。 +func isZero(v reflect.Value) bool { + switch v.Kind() { + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return math.Float64bits(v.Float()) == 0 + case reflect.String: + return v.String() == "" + case reflect.Ptr, reflect.Interface: + return v.IsNil() + } + return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) +} + // LoadFromSavedata 从 savedata 目录加载 config.yml func LoadFromSavedata() error { configPath := filepath.Join("savedata", "config.yml") @@ -264,6 +332,9 @@ func StatsRefreshInterval() int { // MissPenalty 返回缺词惩罚系数(0~1),值越大对缺少查询词的 URL 惩罚越重。 func MissPenalty() float64 { return Global.Search.MissPenalty } +// UnixSocket 返回 Unix socket 路径,空字符串表示不启用。 +func UnixSocket() string { return Global.Search.UnixSocket } + // BacklinkBaseline 返回配置值 func BacklinkBaseline() int { return Global.Backlink.Baseline } diff --git a/main.go b/main.go index c341423..c7b0d14 100644 --- a/main.go +++ b/main.go @@ -161,7 +161,7 @@ func main() { searchSrv := search.New(db, infoSvc, anal) go func() { addr := fmt.Sprintf(":%d", config.SearchServerPort()) - if err := searchSrv.ListenAndServe(addr); err != nil { + if err := searchSrv.ListenAndServe(addr, config.UnixSocket()); err != nil { log.Fatalf("[search] fatal: %v", err) } }() diff --git a/search/server.go b/search/server.go index f8baf1a..f4d43a4 100644 --- a/search/server.go +++ b/search/server.go @@ -11,10 +11,12 @@ import ( "log" // 日志 "math" // 数学运算(Log、幂) "math/rand" // 随机数(刷盘时打乱顺序、概率性去重/裁剪) + "net" // net.Listen(Unix socket) "net/http" // HTTP 服务端 "net/url" // URL 解析 "os" // 文件系统(静态文件读取) "regexp" // 正则表达式(site: 过滤语法) + "runtime" // runtime.GOOS(平台判断) "sort" // 排序 "strconv" // 字符串转整数 "strings" // 字符串操作(URL 清洗) @@ -324,15 +326,46 @@ func (h spaHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // ListenAndServe 启动搜索服务器(带超时保护)。 -func (s *Server) ListenAndServe(addr string) error { - log.Printf("[search] listening on %s", addr) +// 在 Linux/macOS 下,若 unixSocket 非空,则同时在 TCP 和 Unix socket 上监听。 +// 在 Windows 下,unixSocket 参数被忽略,只监听 TCP。 +func (s *Server) ListenAndServe(addr string, unixSocket string) error { + handler := s.Handler() srv := &http.Server{ Addr: addr, - Handler: s.Handler(), + Handler: handler, ReadTimeout: 10 * time.Second, WriteTimeout: 60 * time.Second, IdleTimeout: 120 * time.Second, } + + // Linux/macOS 下且配置了 socket 路径时,额外启动 Unix socket 监听 + if unixSocket != "" && runtime.GOOS != "windows" { + // 清理旧的 socket 文件(上次异常退出可能残留) + _ = os.Remove(unixSocket) + ln, err := net.Listen("unix", unixSocket) + if err != nil { + log.Printf("[search] unix socket failed (%s): %v — continuing with TCP only", unixSocket, err) + } else { + // 设置 socket 文件权限:www 用户和 nginx 都可访问 + if err := os.Chmod(unixSocket, 0660); err != nil { + log.Printf("[search] chmod unix socket: %v", err) + } + log.Printf("[search] also listening on unix:%s", unixSocket) + unixSrv := &http.Server{ + Handler: handler, + ReadTimeout: 10 * time.Second, + WriteTimeout: 60 * time.Second, + IdleTimeout: 120 * time.Second, + } + go func() { + if err := unixSrv.Serve(ln); err != nil && err != http.ErrServerClosed { + log.Printf("[search] unix socket serve error: %v", err) + } + }() + } + } + + log.Printf("[search] listening on %s", addr) return srv.ListenAndServe() }