From 2e6eab3e01313242759def186ece5d2f9f783a69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E6=96=87=E5=B3=B0?= Date: Thu, 4 Jun 2026 15:20:40 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B1=8F=E8=94=BD=E8=AF=8D=E5=8A=9F=E8=83=BDok?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- admin_blocking_routes.go | 43 +++++--- blocking_cache.go | 212 +++++++++++++++++++++++++++++++++++++++ blocking_cache_test.go | 102 +++++++++++++++++++ blocking_store.go | 15 +++ blocking_store_test.go | 36 +++++++ main.go | 77 +++++++++++--- main_test.go | 39 +++++++ web.go | 12 +-- 8 files changed, 505 insertions(+), 31 deletions(-) create mode 100644 blocking_cache.go create mode 100644 blocking_cache_test.go diff --git a/admin_blocking_routes.go b/admin_blocking_routes.go index d958060..176d293 100644 --- a/admin_blocking_routes.go +++ b/admin_blocking_routes.go @@ -30,7 +30,14 @@ type forbiddenWordBlockingRequest struct { Enabled bool `json:"enabled"` } -func registerAdminBlockingRoutes(r gin.IRouter, store *store) { +func registerAdminBlockingRoutes(r gin.IRouter, store *store, blocking *blockingCache) { + reloadBlocking := func() error { + if blocking == nil { + return nil + } + return blocking.Reload(store) + } + r.GET("/blocking/nodes", func(c *gin.Context) { opts, ok := parseListOptions(c) if !ok { @@ -51,7 +58,7 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) { return } row, err := store.CreateNodeBlocking(req.NodeID, req.NodeNum, req.Reason, req.Enabled) - writeBlockingMutationResponse(c, http.StatusCreated, row, err, nodeBlockingDTO) + writeBlockingMutationResponse(c, http.StatusCreated, row, err, nodeBlockingDTO, reloadBlocking) }) r.PUT("/blocking/nodes/:id", func(c *gin.Context) { id, ok := parseBlockingID(c) @@ -64,14 +71,14 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) { return } row, err := store.UpdateNodeBlocking(id, req.NodeID, req.NodeNum, req.Reason, req.Enabled) - writeBlockingMutationResponse(c, http.StatusOK, row, err, nodeBlockingDTO) + writeBlockingMutationResponse(c, http.StatusOK, row, err, nodeBlockingDTO, reloadBlocking) }) r.DELETE("/blocking/nodes/:id", func(c *gin.Context) { id, ok := parseBlockingID(c) if !ok { return } - writeBlockingDeleteResponse(c, store.DeleteNodeBlocking(id)) + writeBlockingDeleteResponse(c, store.DeleteNodeBlocking(id), reloadBlocking) }) r.GET("/blocking/ips", func(c *gin.Context) { @@ -94,7 +101,7 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) { return } row, err := store.CreateIPBlocking(req.IPValue, req.Reason, req.Enabled) - writeBlockingMutationResponse(c, http.StatusCreated, row, err, ipBlockingDTO) + writeBlockingMutationResponse(c, http.StatusCreated, row, err, ipBlockingDTO, reloadBlocking) }) r.PUT("/blocking/ips/:id", func(c *gin.Context) { id, ok := parseBlockingID(c) @@ -107,14 +114,14 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) { return } row, err := store.UpdateIPBlocking(id, req.IPValue, req.Reason, req.Enabled) - writeBlockingMutationResponse(c, http.StatusOK, row, err, ipBlockingDTO) + writeBlockingMutationResponse(c, http.StatusOK, row, err, ipBlockingDTO, reloadBlocking) }) r.DELETE("/blocking/ips/:id", func(c *gin.Context) { id, ok := parseBlockingID(c) if !ok { return } - writeBlockingDeleteResponse(c, store.DeleteIPBlocking(id)) + writeBlockingDeleteResponse(c, store.DeleteIPBlocking(id), reloadBlocking) }) r.GET("/blocking/words", func(c *gin.Context) { @@ -137,7 +144,7 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) { return } row, err := store.CreateForbiddenWordBlocking(req.Word, req.MatchType, req.CaseSensitive, req.Reason, req.Enabled) - writeBlockingMutationResponse(c, http.StatusCreated, row, err, forbiddenWordBlockingDTO) + writeBlockingMutationResponse(c, http.StatusCreated, row, err, forbiddenWordBlockingDTO, reloadBlocking) }) r.PUT("/blocking/words/:id", func(c *gin.Context) { id, ok := parseBlockingID(c) @@ -150,14 +157,14 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) { return } row, err := store.UpdateForbiddenWordBlocking(id, req.Word, req.MatchType, req.CaseSensitive, req.Reason, req.Enabled) - writeBlockingMutationResponse(c, http.StatusOK, row, err, forbiddenWordBlockingDTO) + writeBlockingMutationResponse(c, http.StatusOK, row, err, forbiddenWordBlockingDTO, reloadBlocking) }) r.DELETE("/blocking/words/:id", func(c *gin.Context) { id, ok := parseBlockingID(c) if !ok { return } - writeBlockingDeleteResponse(c, store.DeleteForbiddenWordBlocking(id)) + writeBlockingDeleteResponse(c, store.DeleteForbiddenWordBlocking(id), reloadBlocking) }) } @@ -170,7 +177,7 @@ func parseBlockingID(c *gin.Context) (uint64, bool) { return id, true } -func writeBlockingMutationResponse[T any](c *gin.Context, status int, row *T, err error, convert func(T) gin.H) { +func writeBlockingMutationResponse[T any](c *gin.Context, status int, row *T, err error, convert func(T) gin.H, afterSuccess func() error) { if errors.Is(err, errBlockingAlreadyExists) { c.JSON(http.StatusConflict, gin.H{"error": "blocking rule already exists"}) return @@ -183,10 +190,16 @@ func writeBlockingMutationResponse[T any](c *gin.Context, status int, row *T, er c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + if afterSuccess != nil { + if err := afterSuccess(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "blocking rule saved but cache reload failed: " + err.Error()}) + return + } + } c.JSON(status, gin.H{"item": convert(*row)}) } -func writeBlockingDeleteResponse(c *gin.Context, err error) { +func writeBlockingDeleteResponse(c *gin.Context, err error, afterSuccess func() error) { if errors.Is(err, gorm.ErrRecordNotFound) { c.JSON(http.StatusNotFound, gin.H{"error": "blocking rule not found"}) return @@ -195,6 +208,12 @@ func writeBlockingDeleteResponse(c *gin.Context, err error) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } + if afterSuccess != nil { + if err := afterSuccess(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "blocking rule deleted but cache reload failed: " + err.Error()}) + return + } + } c.JSON(http.StatusOK, gin.H{"status": "ok"}) } diff --git a/blocking_cache.go b/blocking_cache.go new file mode 100644 index 0000000..48fe0a4 --- /dev/null +++ b/blocking_cache.go @@ -0,0 +1,212 @@ +package main + +import ( + "fmt" + "net" + "strconv" + "strings" + "sync" +) + +type blockingCache struct { + mu sync.RWMutex + nodes map[string]struct{} + nodeNums map[int64]struct{} + ips map[string]struct{} + cidrs []*net.IPNet + words []forbiddenWordRule +} + +type forbiddenWordRule struct { + word string + foldedWord string + matchType string + caseSensitive bool +} + +func newBlockingCache(store *store) (*blockingCache, error) { + cache := &blockingCache{} + if err := cache.Reload(store); err != nil { + return nil, err + } + return cache, nil +} + +func (c *blockingCache) Reload(store *store) error { + if store == nil { + return fmt.Errorf("store is required") + } + + nodeRows, err := store.ListEnabledNodeBlocking() + if err != nil { + return err + } + ipRows, err := store.ListEnabledIPBlocking() + if err != nil { + return err + } + wordRows, err := store.ListEnabledForbiddenWordBlocking() + if err != nil { + return err + } + + nodes := make(map[string]struct{}, len(nodeRows)) + nodeNums := make(map[int64]struct{}, len(nodeRows)) + for _, row := range nodeRows { + nodeID := strings.TrimSpace(row.NodeID) + if nodeID != "" { + nodes[nodeID] = struct{}{} + } + if row.NodeNum != nil { + nodeNums[*row.NodeNum] = struct{}{} + } + } + + ips := make(map[string]struct{}, len(ipRows)) + cidrs := make([]*net.IPNet, 0, len(ipRows)) + for _, row := range ipRows { + value := strings.TrimSpace(row.IPValue) + if value == "" { + continue + } + if ip := net.ParseIP(value); ip != nil { + ips[ip.String()] = struct{}{} + continue + } + if _, ipNet, err := net.ParseCIDR(value); err == nil { + cidrs = append(cidrs, ipNet) + } + } + + words := make([]forbiddenWordRule, 0, len(wordRows)) + for _, row := range wordRows { + word := strings.TrimSpace(row.Word) + if word == "" || row.MatchType != forbiddenWordMatchContains { + continue + } + words = append(words, forbiddenWordRule{word: word, foldedWord: strings.ToLower(word), matchType: row.MatchType, caseSensitive: row.CaseSensitive}) + } + + c.mu.Lock() + c.nodes = nodes + c.nodeNums = nodeNums + c.ips = ips + c.cidrs = cidrs + c.words = words + c.mu.Unlock() + return nil +} + +func (c *blockingCache) IsNodeBlocked(nodeID any, nodeNum any) bool { + if c == nil { + return false + } + id, _ := nodeID.(string) + num, hasNum := blockingInt64FromAny(nodeNum) + + c.mu.RLock() + defer c.mu.RUnlock() + if id != "" { + if _, ok := c.nodes[id]; ok { + return true + } + } + if hasNum { + _, ok := c.nodeNums[num] + return ok + } + return false +} + +func (c *blockingCache) IsIPBlocked(host string) bool { + if c == nil { + return false + } + host = strings.TrimSpace(host) + if host == "" { + return false + } + ip := net.ParseIP(host) + if ip == nil { + return false + } + + c.mu.RLock() + defer c.mu.RUnlock() + if _, ok := c.ips[ip.String()]; ok { + return true + } + for _, ipNet := range c.cidrs { + if ipNet.Contains(ip) { + return true + } + } + return false +} + +func (c *blockingCache) FindForbiddenWord(text any) (string, bool) { + if c == nil { + return "", false + } + value, ok := text.(string) + if !ok || value == "" { + return "", false + } + + c.mu.RLock() + defer c.mu.RUnlock() + foldedText := "" + for _, rule := range c.words { + if rule.matchType != forbiddenWordMatchContains { + continue + } + if rule.caseSensitive { + if strings.Contains(value, rule.word) { + return rule.word, true + } + continue + } + if foldedText == "" { + foldedText = strings.ToLower(value) + } + if strings.Contains(foldedText, rule.foldedWord) { + return rule.word, true + } + } + return "", false +} + +func blockingInt64FromAny(value any) (int64, bool) { + switch v := value.(type) { + case int: + return int64(v), true + case int8: + return int64(v), true + case int16: + return int64(v), true + case int32: + return int64(v), true + case int64: + return v, true + case uint: + return int64(v), true + case uint8: + return int64(v), true + case uint16: + return int64(v), true + case uint32: + return int64(v), true + case uint64: + if v > uint64(^uint64(0)>>1) { + return 0, false + } + return int64(v), true + case float64: + return int64(v), v == float64(int64(v)) + case string: + n, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64) + return n, err == nil + default: + return 0, false + } +} diff --git a/blocking_cache_test.go b/blocking_cache_test.go new file mode 100644 index 0000000..691c1cf --- /dev/null +++ b/blocking_cache_test.go @@ -0,0 +1,102 @@ +package main + +import "testing" + +func TestBlockingCacheLoadsEnabledRules(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + nodeNum := int64(305419896) + if _, err := st.CreateNodeBlocking("!12345678", &nodeNum, "enabled", true); err != nil { + t.Fatalf("CreateNodeBlocking(enabled) error = %v", err) + } + disabledNodeNum := int64(7) + if _, err := st.CreateNodeBlocking("!00000007", &disabledNodeNum, "disabled", false); err != nil { + t.Fatalf("CreateNodeBlocking(disabled) error = %v", err) + } + if _, err := st.CreateIPBlocking("192.168.1.0/24", "lan", true); err != nil { + t.Fatalf("CreateIPBlocking(cidr) error = %v", err) + } + if _, err := st.CreateIPBlocking("10.0.0.1", "disabled", false); err != nil { + t.Fatalf("CreateIPBlocking(disabled) error = %v", err) + } + if _, err := st.CreateForbiddenWordBlocking("spam", "contains", false, "enabled", true); err != nil { + t.Fatalf("CreateForbiddenWordBlocking(enabled) error = %v", err) + } + if _, err := st.CreateForbiddenWordBlocking("blocked", "contains", false, "disabled", false); err != nil { + t.Fatalf("CreateForbiddenWordBlocking(disabled) error = %v", err) + } + + cache, err := newBlockingCache(st) + if err != nil { + t.Fatalf("newBlockingCache() error = %v", err) + } + + if !cache.IsNodeBlocked("!12345678", nil) { + t.Fatal("IsNodeBlocked(enabled node id) = false, want true") + } + if !cache.IsNodeBlocked("", uint32(nodeNum)) { + t.Fatal("IsNodeBlocked(enabled node num) = false, want true") + } + if cache.IsNodeBlocked("!00000007", disabledNodeNum) { + t.Fatal("IsNodeBlocked(disabled node) = true, want false") + } + if !cache.IsIPBlocked("192.168.1.42") { + t.Fatal("IsIPBlocked(CIDR member) = false, want true") + } + if cache.IsIPBlocked("10.0.0.1") { + t.Fatal("IsIPBlocked(disabled IP) = true, want false") + } + if word, ok := cache.FindForbiddenWord("This is SPAM text"); !ok || word != "spam" { + t.Fatalf("FindForbiddenWord(case-insensitive) = %q, %v, want spam, true", word, ok) + } + if _, ok := cache.FindForbiddenWord("disabled blocked text"); ok { + t.Fatal("FindForbiddenWord(disabled word) = true, want false") + } +} + +func TestBlockingCacheIPExactAndCIDR(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + if _, err := st.CreateIPBlocking("127.0.0.1", "loopback", true); err != nil { + t.Fatalf("CreateIPBlocking(ip) error = %v", err) + } + if _, err := st.CreateIPBlocking("2001:db8::/32", "docs", true); err != nil { + t.Fatalf("CreateIPBlocking(ipv6 cidr) error = %v", err) + } + cache, err := newBlockingCache(st) + if err != nil { + t.Fatalf("newBlockingCache() error = %v", err) + } + + if !cache.IsIPBlocked("127.0.0.1") { + t.Fatal("IsIPBlocked(exact IPv4) = false, want true") + } + if !cache.IsIPBlocked("2001:db8::1") { + t.Fatal("IsIPBlocked(IPv6 CIDR) = false, want true") + } + if cache.IsIPBlocked("localhost") { + t.Fatal("IsIPBlocked(hostname) = true, want false") + } +} + +func TestBlockingCacheForbiddenWordCaseSensitivity(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + if _, err := st.CreateForbiddenWordBlocking("Spam", "contains", true, "case-sensitive", true); err != nil { + t.Fatalf("CreateForbiddenWordBlocking(case-sensitive) error = %v", err) + } + cache, err := newBlockingCache(st) + if err != nil { + t.Fatalf("newBlockingCache() error = %v", err) + } + + if _, ok := cache.FindForbiddenWord("lowercase spam"); ok { + t.Fatal("FindForbiddenWord(lowercase) = true, want false") + } + if word, ok := cache.FindForbiddenWord("contains Spam"); !ok || word != "Spam" { + t.Fatalf("FindForbiddenWord(exact case) = %q, %v, want Spam, true", word, ok) + } +} diff --git a/blocking_store.go b/blocking_store.go index 764d37b..ae3df15 100644 --- a/blocking_store.go +++ b/blocking_store.go @@ -30,6 +30,11 @@ func (s *store) CountNodeBlocking(opts listOptions) (int64, error) { return total, s.db.Model(&nodeBlockingRecord{}).Count(&total).Error } +func (s *store) ListEnabledNodeBlocking() ([]nodeBlockingRecord, error) { + var rows []nodeBlockingRecord + return rows, s.db.Where("enabled = ?", true).Find(&rows).Error +} + func (s *store) CreateNodeBlocking(nodeID string, nodeNum *int64, reason string, enabled bool) (*nodeBlockingRecord, error) { nodeID = strings.TrimSpace(nodeID) if nodeID == "" { @@ -93,6 +98,11 @@ func (s *store) CountIPBlocking(opts listOptions) (int64, error) { return total, s.db.Model(&ipBlockingRecord{}).Count(&total).Error } +func (s *store) ListEnabledIPBlocking() ([]ipBlockingRecord, error) { + var rows []ipBlockingRecord + return rows, s.db.Where("enabled = ?", true).Find(&rows).Error +} + func (s *store) CreateIPBlocking(ipValue string, reason string, enabled bool) (*ipBlockingRecord, error) { value, err := normalizeIPBlockingValue(ipValue) if err != nil { @@ -156,6 +166,11 @@ func (s *store) CountForbiddenWordBlocking(opts listOptions) (int64, error) { return total, s.db.Model(&forbiddenWordBlockingRecord{}).Count(&total).Error } +func (s *store) ListEnabledForbiddenWordBlocking() ([]forbiddenWordBlockingRecord, error) { + var rows []forbiddenWordBlockingRecord + return rows, s.db.Where("enabled = ?", true).Find(&rows).Error +} + func (s *store) CreateForbiddenWordBlocking(word, matchType string, caseSensitive bool, reason string, enabled bool) (*forbiddenWordBlockingRecord, error) { word = strings.TrimSpace(word) if word == "" { diff --git a/blocking_store_test.go b/blocking_store_test.go index 1feec3c..f35a784 100644 --- a/blocking_store_test.go +++ b/blocking_store_test.go @@ -120,6 +120,42 @@ func TestIPBlockingCRUDAndValidation(t *testing.T) { } } +func TestListEnabledBlockingRules(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + nodeNum := int64(1) + if _, err := st.CreateNodeBlocking("!00000001", &nodeNum, "enabled", true); err != nil { + t.Fatalf("CreateNodeBlocking(enabled) error = %v", err) + } + if _, err := st.CreateNodeBlocking("!00000002", nil, "disabled", false); err != nil { + t.Fatalf("CreateNodeBlocking(disabled) error = %v", err) + } + if rows, err := st.ListEnabledNodeBlocking(); err != nil || len(rows) != 1 || rows[0].NodeID != "!00000001" { + t.Fatalf("ListEnabledNodeBlocking() = %+v, %v, want only enabled node", rows, err) + } + + if _, err := st.CreateIPBlocking("127.0.0.1", "enabled", true); err != nil { + t.Fatalf("CreateIPBlocking(enabled) error = %v", err) + } + if _, err := st.CreateIPBlocking("192.168.1.1", "disabled", false); err != nil { + t.Fatalf("CreateIPBlocking(disabled) error = %v", err) + } + if rows, err := st.ListEnabledIPBlocking(); err != nil || len(rows) != 1 || rows[0].IPValue != "127.0.0.1" { + t.Fatalf("ListEnabledIPBlocking() = %+v, %v, want only enabled IP", rows, err) + } + + if _, err := st.CreateForbiddenWordBlocking("spam", "contains", false, "enabled", true); err != nil { + t.Fatalf("CreateForbiddenWordBlocking(enabled) error = %v", err) + } + if _, err := st.CreateForbiddenWordBlocking("eggs", "contains", false, "disabled", false); err != nil { + t.Fatalf("CreateForbiddenWordBlocking(disabled) error = %v", err) + } + if rows, err := st.ListEnabledForbiddenWordBlocking(); err != nil || len(rows) != 1 || rows[0].Word != "spam" { + t.Fatalf("ListEnabledForbiddenWordBlocking() = %+v, %v, want only enabled word", rows, err) + } +} + func TestForbiddenWordBlockingCRUDAndValidation(t *testing.T) { st := openTestStore(t) defer st.Close() diff --git a/main.go b/main.go index 2e51deb..3692bff 100644 --- a/main.go +++ b/main.go @@ -34,9 +34,10 @@ const ( type meshtasticFilterHook struct { mqtt.HookBase - key []byte - store *store - stats *meshtasticMessageStats + key []byte + store *store + stats *meshtasticMessageStats + blocking *blockingCache } // ID 返回用于识别 Meshtastic payload 过滤器的 hook 名称。 @@ -46,19 +47,31 @@ func (h *meshtasticFilterHook) ID() string { // Provides 声明该 hook 只处理客户端发布消息。 func (h *meshtasticFilterHook) Provides(b byte) bool { - return b == mqtt.OnPublish + return b == mqtt.OnConnect || b == mqtt.OnPublish +} + +// OnConnect 在 MQTT 会话建立前拒绝命中 IP 屏蔽表的客户端。 +func (h *meshtasticFilterHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error { + info := mqttClientInfoFromClient(cl) + if h.blocking != nil && h.blocking.IsIPBlocked(info.RemoteHost) { + printJSON(map[string]any{"event": "mqtt_client_rejected", "reason": "blocked_ip", "client_id": info.ClientID, "remote_addr": info.RemoteAddr, "remote_host": info.RemoteHost}) + return packets.ErrNotAuthorized + } + return nil } // OnPublish 在 broker 转发消息前校验 payload;无效消息会被拒绝并丢弃。 func (h *meshtasticFilterHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) { valid, _, record := mqtpp.MQTTPP(pk.TopicName, pk.Payload, h.key) if !valid { - h.stats.IncDropped() - if h.store != nil { - if err := h.store.InsertDiscardDetails(record, pk.Payload, mqttClientInfoFromClient(cl)); err != nil { - printJSON(map[string]any{"event": "db_error", "type": "discard_details", "topic": pk.TopicName, "error": err.Error()}) - } + h.rejectPublish(cl, pk, record) + return pk, packets.ErrRejectPacket + } + if violation := blockingViolationForRecord(h.blocking, record); violation != nil { + for key, value := range violation { + record[key] = value } + h.rejectPublish(cl, pk, record) return pk, packets.ErrRejectPacket } h.stats.IncForwarded() @@ -113,6 +126,39 @@ func (h *meshtasticFilterHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (pa return pk, nil } +func (h *meshtasticFilterHook) rejectPublish(cl *mqtt.Client, pk packets.Packet, record map[string]any) { + if h.stats != nil { + h.stats.IncDropped() + } + if h.store != nil { + if err := h.store.InsertDiscardDetails(record, pk.Payload, mqttClientInfoFromClient(cl)); err != nil { + printJSON(map[string]any{"event": "db_error", "type": "discard_details", "topic": pk.TopicName, "error": err.Error()}) + } + } +} + +func blockingViolationForRecord(blocking *blockingCache, record map[string]any) map[string]any { + if blocking == nil || record == nil { + return nil + } + if blocking.IsNodeBlocked(record["from"], record["from_num"]) { + return map[string]any{"error": "blocked node", "blocking_type": "node"} + } + var field string + switch record["type"] { + case "text_message": + field = "text" + case "nodeinfo", "map_report": + field = "long_name" + default: + return nil + } + if word, ok := blocking.FindForbiddenWord(record[field]); ok { + return map[string]any{"error": "forbidden word", "blocking_type": "forbidden_word", "blocking_field": field, "matched_word": word} + } + return nil +} + func mqttClientInfoFromClient(cl *mqtt.Client) mqttClientInfo { if cl == nil { return mqttClientInfo{} @@ -201,8 +247,13 @@ func run(cfg *config) error { return err } + blocking, err := newBlockingCache(store) + if err != nil { + return err + } + messageStats := &meshtasticMessageStats{} - server, mqttAddr, err := startMQTTServer(cfg, store, messageStats) + server, mqttAddr, err := startMQTTServer(cfg, store, messageStats, blocking) if err != nil { return err } @@ -215,7 +266,7 @@ func run(cfg *config) error { return err } mqttStatus := mqttRuntimeStatus{server: server, address: mqttAddr, tls: cfg.MQTT.TLS.Enabled, stats: messageStats} - httpServer = newHTTPServer(cfg.Web, store, sessions, mqttStatus) + httpServer = newHTTPServer(cfg.Web, store, sessions, mqttStatus, blocking) go func() { if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { errCh <- err @@ -246,12 +297,12 @@ func run(cfg *config) error { return runErr } -func startMQTTServer(cfg *config, store *store, stats *meshtasticMessageStats) (*mqtt.Server, string, error) { +func startMQTTServer(cfg *config, store *store, stats *meshtasticMessageStats, blocking *blockingCache) (*mqtt.Server, string, error) { server := mqtt.New(nil) if err := server.AddHook(new(auth.AllowHook), nil); err != nil { return nil, "", err } - if err := server.AddHook(&meshtasticFilterHook{key: cfg.key, store: store, stats: stats}, nil); err != nil { + if err := server.AddHook(&meshtasticFilterHook{key: cfg.key, store: store, stats: stats, blocking: blocking}, nil); err != nil { return nil, "", err } diff --git a/main_test.go b/main_test.go index c603584..108ff0f 100644 --- a/main_test.go +++ b/main_test.go @@ -41,3 +41,42 @@ func TestMQTTClientInfoFromClientUnsplitRemote(t *testing.T) { t.Fatalf("remote fields = %#v, want host localhost and empty port", info) } } + +func TestBlockingViolationForRecordNode(t *testing.T) { + cache := &blockingCache{nodes: map[string]struct{}{"!12345678": {}}, nodeNums: map[int64]struct{}{}, ips: map[string]struct{}{}} + record := map[string]any{"type": "position", "from": "!12345678", "from_num": uint32(305419896)} + + violation := blockingViolationForRecord(cache, record) + if violation == nil || violation["blocking_type"] != "node" { + t.Fatalf("blockingViolationForRecord() = %#v, want node violation", violation) + } +} + +func TestBlockingViolationForRecordForbiddenWordFields(t *testing.T) { + cache := &blockingCache{nodes: map[string]struct{}{}, nodeNums: map[int64]struct{}{}, ips: map[string]struct{}{}, words: []forbiddenWordRule{{word: "spam", foldedWord: "spam", matchType: forbiddenWordMatchContains}}} + + for _, tc := range []struct { + name string + record map[string]any + field string + }{ + {name: "text", record: map[string]any{"type": "text_message", "from": "!1", "text": "has SPAM"}, field: "text"}, + {name: "nodeinfo", record: map[string]any{"type": "nodeinfo", "from": "!1", "long_name": "has SPAM"}, field: "long_name"}, + {name: "map_report", record: map[string]any{"type": "map_report", "from": "!1", "long_name": "has SPAM"}, field: "long_name"}, + } { + t.Run(tc.name, func(t *testing.T) { + violation := blockingViolationForRecord(cache, tc.record) + if violation == nil || violation["blocking_type"] != "forbidden_word" || violation["blocking_field"] != tc.field || violation["matched_word"] != "spam" { + t.Fatalf("blockingViolationForRecord() = %#v, want forbidden word on %s", violation, tc.field) + } + }) + } +} + +func TestBlockingViolationForRecordAllowed(t *testing.T) { + cache := &blockingCache{nodes: map[string]struct{}{}, nodeNums: map[int64]struct{}{}, ips: map[string]struct{}{}, words: []forbiddenWordRule{{word: "spam", foldedWord: "spam", matchType: forbiddenWordMatchContains}}} + record := map[string]any{"type": "text_message", "from": "!1", "text": "hello"} + if violation := blockingViolationForRecord(cache, record); violation != nil { + t.Fatalf("blockingViolationForRecord() = %#v, want nil", violation) + } +} diff --git a/web.go b/web.go index 3fee417..d63cd9c 100644 --- a/web.go +++ b/web.go @@ -14,19 +14,19 @@ import ( "gorm.io/gorm" ) -func newHTTPServer(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider) *http.Server { +func newHTTPServer(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache) *http.Server { return &http.Server{ Addr: net.JoinHostPort(cfg.Host, strconv.Itoa(cfg.Port)), - Handler: newRouter(cfg, store, sessions, mqttStatus), + Handler: newRouter(cfg, store, sessions, mqttStatus, blocking), } } -func newRouter(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider) *gin.Engine { +func newRouter(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache) *gin.Engine { r := gin.New() r.Use(gin.Logger(), gin.Recovery()) api := r.Group("/api") registerAPIRoutes(api, store) - registerAdminRoutes(api.Group("/admin"), store, sessions, mqttStatus) + registerAdminRoutes(api.Group("/admin"), store, sessions, mqttStatus, blocking) registerStaticRoutes(r, cfg.StaticDir) return r } @@ -96,7 +96,7 @@ func registerAPIRoutes(r gin.IRouter, store *store) { }) } -func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider) { +func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache) { type loginRequest struct { Username string `json:"username"` Password string `json:"password"` @@ -156,7 +156,7 @@ func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager, protected := r.Group("") protected.Use(requireAdmin(sessions)) - registerAdminBlockingRoutes(protected, store) + registerAdminBlockingRoutes(protected, store, blocking) protected.GET("/me", func(c *gin.Context) { claims := c.MustGet("admin_claims").(*sessionClaims) c.JSON(http.StatusOK, gin.H{"user": adminUserDTO{Username: claims.Username, Role: claims.Role}})