From 44dfb14cf4a0d92e2842dbe2b19fa7134074d3bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E6=96=87=E5=B3=B0?= Date: Wed, 3 Jun 2026 13:41:54 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E9=85=8D=E7=BD=AE=E6=96=87?= =?UTF-8?q?=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 55 ++++++++++++- config.go | 219 +++++++++++++++++++++++++++++++++++++++++++++++++ config_test.go | 134 ++++++++++++++++++++++++++++++ go.mod | 2 +- main.go | 42 +++++----- 5 files changed, 428 insertions(+), 24 deletions(-) create mode 100644 config.go create mode 100644 config_test.go diff --git a/README.md b/README.md index 1fdd255..0425db3 100644 --- a/README.md +++ b/README.md @@ -24,8 +24,36 @@ go run . - host:`0.0.0.0` - port:`1883` - PSK:`AQ==` +- TLS:关闭 -也可以指定监听地址和 PSK: +首次启动会自动生成配置文件;之后每次启动都会检查配置项,缺失项会自动补全并写回。 + +配置文件路径: + +- Unix/Linux:`/etc/mesh_mqtt_go/config.yaml` +- Windows 测试:`./win/etc/mesh_mqtt_go/config.yaml` + +默认配置内容: + +```yaml +mqtt: + host: 0.0.0.0 + port: 1883 + tls: + enabled: false + cert_file: "" + key_file: "" +meshtastic: + psk: AQ== +``` + +配置优先级: + +```text +内置默认值 < 配置文件 < 命令行参数 +``` + +也可以用命令行临时覆盖监听地址、PSK 和 TLS 设置: ```bash go run . --host 127.0.0.1 --port 1883 --psk AQ== @@ -34,11 +62,30 @@ go run . --host 127.0.0.1 --port 1883 --psk AQ== ## 参数 ```text ---host MQTT broker listen host ---port MQTT broker listen port ---psk Base64 channel PSK used to try decrypting encrypted packets +--host MQTT broker listen host +--port MQTT broker listen port +--psk Base64 channel PSK used to try decrypting encrypted packets +--tls Enable MQTT TLS listener +--tls-cert MQTT TLS certificate file +--tls-key MQTT TLS private key file ``` +## TLS 配置示例 + +```yaml +mqtt: + host: 0.0.0.0 + port: 8883 + tls: + enabled: true + cert_file: ./certs/server.crt + key_file: ./certs/server.key +meshtastic: + psk: AQ== +``` + +启用 TLS 后,`cert_file` 和 `key_file` 必须指向可读取的证书和私钥文件。 + ## 转发规则 程序监听所有传入 publish。payload 能被 `mqtpp.MQTTPP` 解析时,认为 `valid == true`,broker 会继续把原始 MQTT 消息转发给订阅者;解析失败时,认为 `valid == false`,broker 会拒绝并丢弃该 publish。 diff --git a/config.go b/config.go new file mode 100644 index 0000000..32579be --- /dev/null +++ b/config.go @@ -0,0 +1,219 @@ +package main + +import ( + cryptotls "crypto/tls" + "fmt" + "os" + "path/filepath" + "runtime" + + "gopkg.in/yaml.v3" +) + +const configFileName = "config.yaml" + +type config struct { + MQTT mqttConfig `yaml:"mqtt"` + Meshtastic meshtasticConfig `yaml:"meshtastic"` + key []byte +} + +type mqttConfig struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + TLS tlsConfig `yaml:"tls"` +} + +type tlsConfig struct { + Enabled bool `yaml:"enabled"` + CertFile string `yaml:"cert_file"` + KeyFile string `yaml:"key_file"` +} + +type meshtasticConfig struct { + PSK string `yaml:"psk"` +} + +type rawConfig struct { + MQTT *rawMQTTConfig `yaml:"mqtt"` + Meshtastic *rawMeshtasticConfig `yaml:"meshtastic"` +} + +type rawMQTTConfig struct { + Host *string `yaml:"host"` + Port *int `yaml:"port"` + TLS *rawTLSConfig `yaml:"tls"` +} + +type rawTLSConfig struct { + Enabled *bool `yaml:"enabled"` + CertFile *string `yaml:"cert_file"` + KeyFile *string `yaml:"key_file"` +} + +type rawMeshtasticConfig struct { + PSK *string `yaml:"psk"` +} + +// defaultConfig 返回内置默认配置。 +func defaultConfig() *config { + return &config{ + MQTT: mqttConfig{ + Host: "0.0.0.0", + Port: 1883, + TLS: tlsConfig{ + Enabled: false, + CertFile: "", + KeyFile: "", + }, + }, + Meshtastic: meshtasticConfig{ + PSK: "AQ==", + }, + } +} + +// defaultConfigDir 根据操作系统返回配置目录。 +func defaultConfigDir() string { + if runtime.GOOS == "windows" { + return filepath.Join(".", "win", "etc", "mesh_mqtt_go") + } + return filepath.Join(string(filepath.Separator), "etc", "mesh_mqtt_go") +} + +// defaultConfigPath 返回默认配置文件路径。 +func defaultConfigPath() string { + return filepath.Join(defaultConfigDir(), configFileName) +} + +// loadConfig 加载配置文件;文件不存在时生成,字段缺失时自动补全并写回。 +func loadConfig(path string) (*config, error) { + if path == "" { + path = defaultConfigPath() + } + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return nil, fmt.Errorf("create config directory %s: %w", filepath.Dir(path), err) + } + + if _, err := os.Stat(path); err != nil { + if !os.IsNotExist(err) { + return nil, fmt.Errorf("stat config file %s: %w", path, err) + } + cfg := defaultConfig() + if err := writeConfig(path, cfg); err != nil { + return nil, err + } + return cfg, nil + } + + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read config file %s: %w", path, err) + } + + var raw rawConfig + if err := yaml.Unmarshal(data, &raw); err != nil { + return nil, fmt.Errorf("parse config file %s: %w", path, err) + } + + cfg, changed := normalizeConfig(raw) + if err := validateConfig(cfg); err != nil { + return nil, err + } + if changed { + if err := writeConfig(path, cfg); err != nil { + return nil, err + } + } + return cfg, nil +} + +// normalizeConfig 将原始配置合并到默认配置,并标记是否补齐了缺失项。 +func normalizeConfig(raw rawConfig) (*config, bool) { + cfg := defaultConfig() + changed := false + + if raw.MQTT == nil { + changed = true + } else { + if raw.MQTT.Host == nil { + changed = true + } else { + cfg.MQTT.Host = *raw.MQTT.Host + } + if raw.MQTT.Port == nil { + changed = true + } else { + cfg.MQTT.Port = *raw.MQTT.Port + } + if raw.MQTT.TLS == nil { + changed = true + } else { + if raw.MQTT.TLS.Enabled == nil { + changed = true + } else { + cfg.MQTT.TLS.Enabled = *raw.MQTT.TLS.Enabled + } + if raw.MQTT.TLS.CertFile == nil { + changed = true + } else { + cfg.MQTT.TLS.CertFile = *raw.MQTT.TLS.CertFile + } + if raw.MQTT.TLS.KeyFile == nil { + changed = true + } else { + cfg.MQTT.TLS.KeyFile = *raw.MQTT.TLS.KeyFile + } + } + } + + if raw.Meshtastic == nil { + changed = true + } else if raw.Meshtastic.PSK == nil { + changed = true + } else { + cfg.Meshtastic.PSK = *raw.Meshtastic.PSK + } + + return cfg, changed +} + +func validateConfig(cfg *config) error { + if cfg.MQTT.Port <= 0 || cfg.MQTT.Port > 65535 { + return fmt.Errorf("invalid mqtt port %d: must be 1-65535", cfg.MQTT.Port) + } + return nil +} + +func writeConfig(path string, cfg *config) error { + data, err := yaml.Marshal(cfg) + if err != nil { + return fmt.Errorf("encode config file %s: %w", path, err) + } + if err := os.WriteFile(path, data, 0644); err != nil { + return fmt.Errorf("write config file %s: %w", path, err) + } + return nil +} + +// buildTLSConfig 根据配置构造 mochi listener 使用的 TLS 设置。 +func buildTLSConfig(cfg tlsConfig) (*cryptotls.Config, error) { + if !cfg.Enabled { + return nil, nil + } + if cfg.CertFile == "" { + return nil, fmt.Errorf("mqtt tls cert_file is required when tls is enabled") + } + if cfg.KeyFile == "" { + return nil, fmt.Errorf("mqtt tls key_file is required when tls is enabled") + } + + cert, err := cryptotls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile) + if err != nil { + return nil, fmt.Errorf("load mqtt tls certificate: %w", err) + } + return &cryptotls.Config{ + MinVersion: cryptotls.VersionTLS12, + Certificates: []cryptotls.Certificate{cert}, + }, nil +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..4c0f4a7 --- /dev/null +++ b/config_test.go @@ -0,0 +1,134 @@ +package main + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestLoadConfigCreatesDefaultFile(t *testing.T) { + path := filepath.Join(t.TempDir(), "mesh_mqtt_go", configFileName) + + cfg, err := loadConfig(path) + if err != nil { + t.Fatalf("loadConfig() error = %v", err) + } + if cfg.MQTT.Host != "0.0.0.0" { + t.Fatalf("host = %q, want 0.0.0.0", cfg.MQTT.Host) + } + if cfg.MQTT.Port != 1883 { + t.Fatalf("port = %d, want 1883", cfg.MQTT.Port) + } + if cfg.MQTT.TLS.Enabled { + t.Fatalf("tls enabled = true, want false") + } + if cfg.Meshtastic.PSK != "AQ==" { + t.Fatalf("psk = %q, want AQ==", cfg.Meshtastic.PSK) + } + if _, err := os.Stat(path); err != nil { + t.Fatalf("default config was not written: %v", err) + } +} + +func TestLoadConfigFillsMissingFields(t *testing.T) { + path := filepath.Join(t.TempDir(), "mesh_mqtt_go", configFileName) + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(path, []byte("mqtt:\n port: 1884\n"), 0644); err != nil { + t.Fatal(err) + } + + cfg, err := loadConfig(path) + if err != nil { + t.Fatalf("loadConfig() error = %v", err) + } + if cfg.MQTT.Port != 1884 { + t.Fatalf("port = %d, want 1884", cfg.MQTT.Port) + } + if cfg.MQTT.Host != "0.0.0.0" { + t.Fatalf("host = %q, want 0.0.0.0", cfg.MQTT.Host) + } + if cfg.Meshtastic.PSK != "AQ==" { + t.Fatalf("psk = %q, want AQ==", cfg.Meshtastic.PSK) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + text := string(data) + for _, want := range []string{"host:", "tls:", "enabled:", "cert_file:", "key_file:", "meshtastic:", "psk:"} { + if !strings.Contains(text, want) { + t.Fatalf("completed config missing %q in:\n%s", want, text) + } + } +} + +func TestLoadConfigPreservesExplicitFalse(t *testing.T) { + path := filepath.Join(t.TempDir(), "mesh_mqtt_go", configFileName) + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + t.Fatal(err) + } + content := "mqtt:\n host: 127.0.0.1\n port: 1885\n tls:\n enabled: false\n cert_file: cert.pem\n key_file: key.pem\nmeshtastic:\n psk: AQ==\n" + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + cfg, err := loadConfig(path) + if err != nil { + t.Fatalf("loadConfig() error = %v", err) + } + if cfg.MQTT.TLS.Enabled { + t.Fatalf("tls enabled = true, want explicit false") + } + if cfg.MQTT.TLS.CertFile != "cert.pem" || cfg.MQTT.TLS.KeyFile != "key.pem" { + t.Fatalf("tls paths = %q/%q, want cert.pem/key.pem", cfg.MQTT.TLS.CertFile, cfg.MQTT.TLS.KeyFile) + } +} + +func TestLoadConfigMalformedYAMLDoesNotOverwrite(t *testing.T) { + path := filepath.Join(t.TempDir(), "mesh_mqtt_go", configFileName) + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + t.Fatal(err) + } + content := "mqtt:\n port: [\n" + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + _, err := loadConfig(path) + if err == nil { + t.Fatalf("loadConfig() error = nil, want parse error") + } + data, readErr := os.ReadFile(path) + if readErr != nil { + t.Fatal(readErr) + } + if string(data) != content { + t.Fatalf("malformed config was overwritten: %q", string(data)) + } +} + +func TestBuildTLSConfigDisabled(t *testing.T) { + cfg, err := buildTLSConfig(tlsConfig{}) + if err != nil { + t.Fatalf("buildTLSConfig() error = %v", err) + } + if cfg != nil { + t.Fatalf("buildTLSConfig() = %#v, want nil", cfg) + } +} + +func TestBuildTLSConfigRequiresCertAndKey(t *testing.T) { + _, err := buildTLSConfig(tlsConfig{Enabled: true}) + if err == nil || !strings.Contains(err.Error(), "cert_file") { + t.Fatalf("missing cert error = %v, want cert_file error", err) + } + + _, err = buildTLSConfig(tlsConfig{Enabled: true, CertFile: "cert.pem"}) + if err == nil || !strings.Contains(err.Error(), "key_file") { + t.Fatalf("missing key error = %v, want key_file error", err) + } +} diff --git a/go.mod b/go.mod index 9c7dd90..35c793d 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,10 @@ go 1.23 require ( github.com/mochi-mqtt/server/v2 v2.7.9 google.golang.org/protobuf v1.36.11 + gopkg.in/yaml.v3 v3.0.1 ) require ( github.com/gorilla/websocket v1.5.3 // indirect github.com/rs/xid v1.4.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/main.go b/main.go index d3814ab..ee319ab 100644 --- a/main.go +++ b/main.go @@ -18,9 +18,6 @@ import ( ) const ( - defaultHost = "0.0.0.0" - defaultPSK = "AQ==" - ansiGreenBGWhiteText = "\033[42;37m" ansiBlueBGWhiteText = "\033[44;37m" ansiPurpleBGWhiteText = "\033[45;37m" @@ -31,13 +28,6 @@ const ( ansiReset = "\033[0m" ) -type config struct { - host string - port int - psk string - key []byte -} - type meshtasticFilterHook struct { mqtt.HookBase key []byte @@ -79,15 +69,25 @@ func main() { } } -// parseArgs 解析命令行参数,并展开 Meshtastic channel PSK。 +// parseArgs 加载配置文件、解析命令行覆盖项,并展开 Meshtastic channel PSK。 func parseArgs() (*config, error) { - cfg := &config{} - flag.StringVar(&cfg.host, "host", defaultHost, "MQTT broker listen host") - flag.IntVar(&cfg.port, "port", 1883, "MQTT broker listen port") - flag.StringVar(&cfg.psk, "psk", defaultPSK, "Base64 channel PSK used to try decrypting encrypted packets") + cfg, err := loadConfig(defaultConfigPath()) + if err != nil { + return nil, err + } + + flag.StringVar(&cfg.MQTT.Host, "host", cfg.MQTT.Host, "MQTT broker listen host") + flag.IntVar(&cfg.MQTT.Port, "port", cfg.MQTT.Port, "MQTT broker listen port") + flag.StringVar(&cfg.Meshtastic.PSK, "psk", cfg.Meshtastic.PSK, "Base64 channel PSK used to try decrypting encrypted packets") + flag.BoolVar(&cfg.MQTT.TLS.Enabled, "tls", cfg.MQTT.TLS.Enabled, "Enable MQTT TLS listener") + flag.StringVar(&cfg.MQTT.TLS.CertFile, "tls-cert", cfg.MQTT.TLS.CertFile, "MQTT TLS certificate file") + flag.StringVar(&cfg.MQTT.TLS.KeyFile, "tls-key", cfg.MQTT.TLS.KeyFile, "MQTT TLS private key file") flag.Parse() - key, err := mqtpp.ExpandPSK(cfg.psk) + if err := validateConfig(cfg); err != nil { + return nil, err + } + key, err := mqtpp.ExpandPSK(cfg.Meshtastic.PSK) if err != nil { return nil, err } @@ -105,15 +105,19 @@ func run(cfg *config) error { return err } - addr := net.JoinHostPort(cfg.host, strconv.Itoa(cfg.port)) - listener := listeners.NewTCP(listeners.Config{ID: "tcp", Address: addr}) + addr := net.JoinHostPort(cfg.MQTT.Host, strconv.Itoa(cfg.MQTT.Port)) + tlsConfig, err := buildTLSConfig(cfg.MQTT.TLS) + if err != nil { + return err + } + listener := listeners.NewTCP(listeners.Config{ID: "tcp", Address: addr, TLSConfig: tlsConfig}) if err := server.AddListener(listener); err != nil { return err } if err := server.Serve(); err != nil { return err } - printJSON(map[string]any{"event": "broker_started", "address": addr}) + printJSON(map[string]any{"event": "broker_started", "address": addr, "tls": cfg.MQTT.TLS.Enabled}) sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)