屏蔽词功能ok

This commit is contained in:
2026-06-04 15:20:40 +08:00
parent c3cdcfd379
commit 2e6eab3e01
8 changed files with 505 additions and 31 deletions
+31 -12
View File
@@ -30,7 +30,14 @@ type forbiddenWordBlockingRequest struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
} }
func registerAdminBlockingRoutes(r gin.IRouter, store *store) { func registerAdminBlockingRoutes(r gin.IRouter, store *store, blocking *blockingCache) {
reloadBlocking := func() error {
if blocking == nil {
return nil
}
return blocking.Reload(store)
}
r.GET("/blocking/nodes", func(c *gin.Context) { r.GET("/blocking/nodes", func(c *gin.Context) {
opts, ok := parseListOptions(c) opts, ok := parseListOptions(c)
if !ok { if !ok {
@@ -51,7 +58,7 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) {
return return
} }
row, err := store.CreateNodeBlocking(req.NodeID, req.NodeNum, req.Reason, req.Enabled) row, err := store.CreateNodeBlocking(req.NodeID, req.NodeNum, req.Reason, req.Enabled)
writeBlockingMutationResponse(c, http.StatusCreated, row, err, nodeBlockingDTO) writeBlockingMutationResponse(c, http.StatusCreated, row, err, nodeBlockingDTO, reloadBlocking)
}) })
r.PUT("/blocking/nodes/:id", func(c *gin.Context) { r.PUT("/blocking/nodes/:id", func(c *gin.Context) {
id, ok := parseBlockingID(c) id, ok := parseBlockingID(c)
@@ -64,14 +71,14 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) {
return return
} }
row, err := store.UpdateNodeBlocking(id, req.NodeID, req.NodeNum, req.Reason, req.Enabled) row, err := store.UpdateNodeBlocking(id, req.NodeID, req.NodeNum, req.Reason, req.Enabled)
writeBlockingMutationResponse(c, http.StatusOK, row, err, nodeBlockingDTO) writeBlockingMutationResponse(c, http.StatusOK, row, err, nodeBlockingDTO, reloadBlocking)
}) })
r.DELETE("/blocking/nodes/:id", func(c *gin.Context) { r.DELETE("/blocking/nodes/:id", func(c *gin.Context) {
id, ok := parseBlockingID(c) id, ok := parseBlockingID(c)
if !ok { if !ok {
return return
} }
writeBlockingDeleteResponse(c, store.DeleteNodeBlocking(id)) writeBlockingDeleteResponse(c, store.DeleteNodeBlocking(id), reloadBlocking)
}) })
r.GET("/blocking/ips", func(c *gin.Context) { r.GET("/blocking/ips", func(c *gin.Context) {
@@ -94,7 +101,7 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) {
return return
} }
row, err := store.CreateIPBlocking(req.IPValue, req.Reason, req.Enabled) row, err := store.CreateIPBlocking(req.IPValue, req.Reason, req.Enabled)
writeBlockingMutationResponse(c, http.StatusCreated, row, err, ipBlockingDTO) writeBlockingMutationResponse(c, http.StatusCreated, row, err, ipBlockingDTO, reloadBlocking)
}) })
r.PUT("/blocking/ips/:id", func(c *gin.Context) { r.PUT("/blocking/ips/:id", func(c *gin.Context) {
id, ok := parseBlockingID(c) id, ok := parseBlockingID(c)
@@ -107,14 +114,14 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) {
return return
} }
row, err := store.UpdateIPBlocking(id, req.IPValue, req.Reason, req.Enabled) row, err := store.UpdateIPBlocking(id, req.IPValue, req.Reason, req.Enabled)
writeBlockingMutationResponse(c, http.StatusOK, row, err, ipBlockingDTO) writeBlockingMutationResponse(c, http.StatusOK, row, err, ipBlockingDTO, reloadBlocking)
}) })
r.DELETE("/blocking/ips/:id", func(c *gin.Context) { r.DELETE("/blocking/ips/:id", func(c *gin.Context) {
id, ok := parseBlockingID(c) id, ok := parseBlockingID(c)
if !ok { if !ok {
return return
} }
writeBlockingDeleteResponse(c, store.DeleteIPBlocking(id)) writeBlockingDeleteResponse(c, store.DeleteIPBlocking(id), reloadBlocking)
}) })
r.GET("/blocking/words", func(c *gin.Context) { r.GET("/blocking/words", func(c *gin.Context) {
@@ -137,7 +144,7 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) {
return return
} }
row, err := store.CreateForbiddenWordBlocking(req.Word, req.MatchType, req.CaseSensitive, req.Reason, req.Enabled) row, err := store.CreateForbiddenWordBlocking(req.Word, req.MatchType, req.CaseSensitive, req.Reason, req.Enabled)
writeBlockingMutationResponse(c, http.StatusCreated, row, err, forbiddenWordBlockingDTO) writeBlockingMutationResponse(c, http.StatusCreated, row, err, forbiddenWordBlockingDTO, reloadBlocking)
}) })
r.PUT("/blocking/words/:id", func(c *gin.Context) { r.PUT("/blocking/words/:id", func(c *gin.Context) {
id, ok := parseBlockingID(c) id, ok := parseBlockingID(c)
@@ -150,14 +157,14 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) {
return return
} }
row, err := store.UpdateForbiddenWordBlocking(id, req.Word, req.MatchType, req.CaseSensitive, req.Reason, req.Enabled) row, err := store.UpdateForbiddenWordBlocking(id, req.Word, req.MatchType, req.CaseSensitive, req.Reason, req.Enabled)
writeBlockingMutationResponse(c, http.StatusOK, row, err, forbiddenWordBlockingDTO) writeBlockingMutationResponse(c, http.StatusOK, row, err, forbiddenWordBlockingDTO, reloadBlocking)
}) })
r.DELETE("/blocking/words/:id", func(c *gin.Context) { r.DELETE("/blocking/words/:id", func(c *gin.Context) {
id, ok := parseBlockingID(c) id, ok := parseBlockingID(c)
if !ok { if !ok {
return return
} }
writeBlockingDeleteResponse(c, store.DeleteForbiddenWordBlocking(id)) writeBlockingDeleteResponse(c, store.DeleteForbiddenWordBlocking(id), reloadBlocking)
}) })
} }
@@ -170,7 +177,7 @@ func parseBlockingID(c *gin.Context) (uint64, bool) {
return id, true return id, true
} }
func writeBlockingMutationResponse[T any](c *gin.Context, status int, row *T, err error, convert func(T) gin.H) { func writeBlockingMutationResponse[T any](c *gin.Context, status int, row *T, err error, convert func(T) gin.H, afterSuccess func() error) {
if errors.Is(err, errBlockingAlreadyExists) { if errors.Is(err, errBlockingAlreadyExists) {
c.JSON(http.StatusConflict, gin.H{"error": "blocking rule already exists"}) c.JSON(http.StatusConflict, gin.H{"error": "blocking rule already exists"})
return return
@@ -183,10 +190,16 @@ func writeBlockingMutationResponse[T any](c *gin.Context, status int, row *T, er
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
if afterSuccess != nil {
if err := afterSuccess(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "blocking rule saved but cache reload failed: " + err.Error()})
return
}
}
c.JSON(status, gin.H{"item": convert(*row)}) c.JSON(status, gin.H{"item": convert(*row)})
} }
func writeBlockingDeleteResponse(c *gin.Context, err error) { func writeBlockingDeleteResponse(c *gin.Context, err error, afterSuccess func() error) {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{"error": "blocking rule not found"}) c.JSON(http.StatusNotFound, gin.H{"error": "blocking rule not found"})
return return
@@ -195,6 +208,12 @@ func writeBlockingDeleteResponse(c *gin.Context, err error) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if afterSuccess != nil {
if err := afterSuccess(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "blocking rule deleted but cache reload failed: " + err.Error()})
return
}
}
c.JSON(http.StatusOK, gin.H{"status": "ok"}) c.JSON(http.StatusOK, gin.H{"status": "ok"})
} }
+212
View File
@@ -0,0 +1,212 @@
package main
import (
"fmt"
"net"
"strconv"
"strings"
"sync"
)
type blockingCache struct {
mu sync.RWMutex
nodes map[string]struct{}
nodeNums map[int64]struct{}
ips map[string]struct{}
cidrs []*net.IPNet
words []forbiddenWordRule
}
type forbiddenWordRule struct {
word string
foldedWord string
matchType string
caseSensitive bool
}
func newBlockingCache(store *store) (*blockingCache, error) {
cache := &blockingCache{}
if err := cache.Reload(store); err != nil {
return nil, err
}
return cache, nil
}
func (c *blockingCache) Reload(store *store) error {
if store == nil {
return fmt.Errorf("store is required")
}
nodeRows, err := store.ListEnabledNodeBlocking()
if err != nil {
return err
}
ipRows, err := store.ListEnabledIPBlocking()
if err != nil {
return err
}
wordRows, err := store.ListEnabledForbiddenWordBlocking()
if err != nil {
return err
}
nodes := make(map[string]struct{}, len(nodeRows))
nodeNums := make(map[int64]struct{}, len(nodeRows))
for _, row := range nodeRows {
nodeID := strings.TrimSpace(row.NodeID)
if nodeID != "" {
nodes[nodeID] = struct{}{}
}
if row.NodeNum != nil {
nodeNums[*row.NodeNum] = struct{}{}
}
}
ips := make(map[string]struct{}, len(ipRows))
cidrs := make([]*net.IPNet, 0, len(ipRows))
for _, row := range ipRows {
value := strings.TrimSpace(row.IPValue)
if value == "" {
continue
}
if ip := net.ParseIP(value); ip != nil {
ips[ip.String()] = struct{}{}
continue
}
if _, ipNet, err := net.ParseCIDR(value); err == nil {
cidrs = append(cidrs, ipNet)
}
}
words := make([]forbiddenWordRule, 0, len(wordRows))
for _, row := range wordRows {
word := strings.TrimSpace(row.Word)
if word == "" || row.MatchType != forbiddenWordMatchContains {
continue
}
words = append(words, forbiddenWordRule{word: word, foldedWord: strings.ToLower(word), matchType: row.MatchType, caseSensitive: row.CaseSensitive})
}
c.mu.Lock()
c.nodes = nodes
c.nodeNums = nodeNums
c.ips = ips
c.cidrs = cidrs
c.words = words
c.mu.Unlock()
return nil
}
func (c *blockingCache) IsNodeBlocked(nodeID any, nodeNum any) bool {
if c == nil {
return false
}
id, _ := nodeID.(string)
num, hasNum := blockingInt64FromAny(nodeNum)
c.mu.RLock()
defer c.mu.RUnlock()
if id != "" {
if _, ok := c.nodes[id]; ok {
return true
}
}
if hasNum {
_, ok := c.nodeNums[num]
return ok
}
return false
}
func (c *blockingCache) IsIPBlocked(host string) bool {
if c == nil {
return false
}
host = strings.TrimSpace(host)
if host == "" {
return false
}
ip := net.ParseIP(host)
if ip == nil {
return false
}
c.mu.RLock()
defer c.mu.RUnlock()
if _, ok := c.ips[ip.String()]; ok {
return true
}
for _, ipNet := range c.cidrs {
if ipNet.Contains(ip) {
return true
}
}
return false
}
func (c *blockingCache) FindForbiddenWord(text any) (string, bool) {
if c == nil {
return "", false
}
value, ok := text.(string)
if !ok || value == "" {
return "", false
}
c.mu.RLock()
defer c.mu.RUnlock()
foldedText := ""
for _, rule := range c.words {
if rule.matchType != forbiddenWordMatchContains {
continue
}
if rule.caseSensitive {
if strings.Contains(value, rule.word) {
return rule.word, true
}
continue
}
if foldedText == "" {
foldedText = strings.ToLower(value)
}
if strings.Contains(foldedText, rule.foldedWord) {
return rule.word, true
}
}
return "", false
}
func blockingInt64FromAny(value any) (int64, bool) {
switch v := value.(type) {
case int:
return int64(v), true
case int8:
return int64(v), true
case int16:
return int64(v), true
case int32:
return int64(v), true
case int64:
return v, true
case uint:
return int64(v), true
case uint8:
return int64(v), true
case uint16:
return int64(v), true
case uint32:
return int64(v), true
case uint64:
if v > uint64(^uint64(0)>>1) {
return 0, false
}
return int64(v), true
case float64:
return int64(v), v == float64(int64(v))
case string:
n, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64)
return n, err == nil
default:
return 0, false
}
}
+102
View File
@@ -0,0 +1,102 @@
package main
import "testing"
func TestBlockingCacheLoadsEnabledRules(t *testing.T) {
st := openTestStore(t)
defer st.Close()
nodeNum := int64(305419896)
if _, err := st.CreateNodeBlocking("!12345678", &nodeNum, "enabled", true); err != nil {
t.Fatalf("CreateNodeBlocking(enabled) error = %v", err)
}
disabledNodeNum := int64(7)
if _, err := st.CreateNodeBlocking("!00000007", &disabledNodeNum, "disabled", false); err != nil {
t.Fatalf("CreateNodeBlocking(disabled) error = %v", err)
}
if _, err := st.CreateIPBlocking("192.168.1.0/24", "lan", true); err != nil {
t.Fatalf("CreateIPBlocking(cidr) error = %v", err)
}
if _, err := st.CreateIPBlocking("10.0.0.1", "disabled", false); err != nil {
t.Fatalf("CreateIPBlocking(disabled) error = %v", err)
}
if _, err := st.CreateForbiddenWordBlocking("spam", "contains", false, "enabled", true); err != nil {
t.Fatalf("CreateForbiddenWordBlocking(enabled) error = %v", err)
}
if _, err := st.CreateForbiddenWordBlocking("blocked", "contains", false, "disabled", false); err != nil {
t.Fatalf("CreateForbiddenWordBlocking(disabled) error = %v", err)
}
cache, err := newBlockingCache(st)
if err != nil {
t.Fatalf("newBlockingCache() error = %v", err)
}
if !cache.IsNodeBlocked("!12345678", nil) {
t.Fatal("IsNodeBlocked(enabled node id) = false, want true")
}
if !cache.IsNodeBlocked("", uint32(nodeNum)) {
t.Fatal("IsNodeBlocked(enabled node num) = false, want true")
}
if cache.IsNodeBlocked("!00000007", disabledNodeNum) {
t.Fatal("IsNodeBlocked(disabled node) = true, want false")
}
if !cache.IsIPBlocked("192.168.1.42") {
t.Fatal("IsIPBlocked(CIDR member) = false, want true")
}
if cache.IsIPBlocked("10.0.0.1") {
t.Fatal("IsIPBlocked(disabled IP) = true, want false")
}
if word, ok := cache.FindForbiddenWord("This is SPAM text"); !ok || word != "spam" {
t.Fatalf("FindForbiddenWord(case-insensitive) = %q, %v, want spam, true", word, ok)
}
if _, ok := cache.FindForbiddenWord("disabled blocked text"); ok {
t.Fatal("FindForbiddenWord(disabled word) = true, want false")
}
}
func TestBlockingCacheIPExactAndCIDR(t *testing.T) {
st := openTestStore(t)
defer st.Close()
if _, err := st.CreateIPBlocking("127.0.0.1", "loopback", true); err != nil {
t.Fatalf("CreateIPBlocking(ip) error = %v", err)
}
if _, err := st.CreateIPBlocking("2001:db8::/32", "docs", true); err != nil {
t.Fatalf("CreateIPBlocking(ipv6 cidr) error = %v", err)
}
cache, err := newBlockingCache(st)
if err != nil {
t.Fatalf("newBlockingCache() error = %v", err)
}
if !cache.IsIPBlocked("127.0.0.1") {
t.Fatal("IsIPBlocked(exact IPv4) = false, want true")
}
if !cache.IsIPBlocked("2001:db8::1") {
t.Fatal("IsIPBlocked(IPv6 CIDR) = false, want true")
}
if cache.IsIPBlocked("localhost") {
t.Fatal("IsIPBlocked(hostname) = true, want false")
}
}
func TestBlockingCacheForbiddenWordCaseSensitivity(t *testing.T) {
st := openTestStore(t)
defer st.Close()
if _, err := st.CreateForbiddenWordBlocking("Spam", "contains", true, "case-sensitive", true); err != nil {
t.Fatalf("CreateForbiddenWordBlocking(case-sensitive) error = %v", err)
}
cache, err := newBlockingCache(st)
if err != nil {
t.Fatalf("newBlockingCache() error = %v", err)
}
if _, ok := cache.FindForbiddenWord("lowercase spam"); ok {
t.Fatal("FindForbiddenWord(lowercase) = true, want false")
}
if word, ok := cache.FindForbiddenWord("contains Spam"); !ok || word != "Spam" {
t.Fatalf("FindForbiddenWord(exact case) = %q, %v, want Spam, true", word, ok)
}
}
+15
View File
@@ -30,6 +30,11 @@ func (s *store) CountNodeBlocking(opts listOptions) (int64, error) {
return total, s.db.Model(&nodeBlockingRecord{}).Count(&total).Error return total, s.db.Model(&nodeBlockingRecord{}).Count(&total).Error
} }
func (s *store) ListEnabledNodeBlocking() ([]nodeBlockingRecord, error) {
var rows []nodeBlockingRecord
return rows, s.db.Where("enabled = ?", true).Find(&rows).Error
}
func (s *store) CreateNodeBlocking(nodeID string, nodeNum *int64, reason string, enabled bool) (*nodeBlockingRecord, error) { func (s *store) CreateNodeBlocking(nodeID string, nodeNum *int64, reason string, enabled bool) (*nodeBlockingRecord, error) {
nodeID = strings.TrimSpace(nodeID) nodeID = strings.TrimSpace(nodeID)
if nodeID == "" { if nodeID == "" {
@@ -93,6 +98,11 @@ func (s *store) CountIPBlocking(opts listOptions) (int64, error) {
return total, s.db.Model(&ipBlockingRecord{}).Count(&total).Error return total, s.db.Model(&ipBlockingRecord{}).Count(&total).Error
} }
func (s *store) ListEnabledIPBlocking() ([]ipBlockingRecord, error) {
var rows []ipBlockingRecord
return rows, s.db.Where("enabled = ?", true).Find(&rows).Error
}
func (s *store) CreateIPBlocking(ipValue string, reason string, enabled bool) (*ipBlockingRecord, error) { func (s *store) CreateIPBlocking(ipValue string, reason string, enabled bool) (*ipBlockingRecord, error) {
value, err := normalizeIPBlockingValue(ipValue) value, err := normalizeIPBlockingValue(ipValue)
if err != nil { if err != nil {
@@ -156,6 +166,11 @@ func (s *store) CountForbiddenWordBlocking(opts listOptions) (int64, error) {
return total, s.db.Model(&forbiddenWordBlockingRecord{}).Count(&total).Error return total, s.db.Model(&forbiddenWordBlockingRecord{}).Count(&total).Error
} }
func (s *store) ListEnabledForbiddenWordBlocking() ([]forbiddenWordBlockingRecord, error) {
var rows []forbiddenWordBlockingRecord
return rows, s.db.Where("enabled = ?", true).Find(&rows).Error
}
func (s *store) CreateForbiddenWordBlocking(word, matchType string, caseSensitive bool, reason string, enabled bool) (*forbiddenWordBlockingRecord, error) { func (s *store) CreateForbiddenWordBlocking(word, matchType string, caseSensitive bool, reason string, enabled bool) (*forbiddenWordBlockingRecord, error) {
word = strings.TrimSpace(word) word = strings.TrimSpace(word)
if word == "" { if word == "" {
+36
View File
@@ -120,6 +120,42 @@ func TestIPBlockingCRUDAndValidation(t *testing.T) {
} }
} }
func TestListEnabledBlockingRules(t *testing.T) {
st := openTestStore(t)
defer st.Close()
nodeNum := int64(1)
if _, err := st.CreateNodeBlocking("!00000001", &nodeNum, "enabled", true); err != nil {
t.Fatalf("CreateNodeBlocking(enabled) error = %v", err)
}
if _, err := st.CreateNodeBlocking("!00000002", nil, "disabled", false); err != nil {
t.Fatalf("CreateNodeBlocking(disabled) error = %v", err)
}
if rows, err := st.ListEnabledNodeBlocking(); err != nil || len(rows) != 1 || rows[0].NodeID != "!00000001" {
t.Fatalf("ListEnabledNodeBlocking() = %+v, %v, want only enabled node", rows, err)
}
if _, err := st.CreateIPBlocking("127.0.0.1", "enabled", true); err != nil {
t.Fatalf("CreateIPBlocking(enabled) error = %v", err)
}
if _, err := st.CreateIPBlocking("192.168.1.1", "disabled", false); err != nil {
t.Fatalf("CreateIPBlocking(disabled) error = %v", err)
}
if rows, err := st.ListEnabledIPBlocking(); err != nil || len(rows) != 1 || rows[0].IPValue != "127.0.0.1" {
t.Fatalf("ListEnabledIPBlocking() = %+v, %v, want only enabled IP", rows, err)
}
if _, err := st.CreateForbiddenWordBlocking("spam", "contains", false, "enabled", true); err != nil {
t.Fatalf("CreateForbiddenWordBlocking(enabled) error = %v", err)
}
if _, err := st.CreateForbiddenWordBlocking("eggs", "contains", false, "disabled", false); err != nil {
t.Fatalf("CreateForbiddenWordBlocking(disabled) error = %v", err)
}
if rows, err := st.ListEnabledForbiddenWordBlocking(); err != nil || len(rows) != 1 || rows[0].Word != "spam" {
t.Fatalf("ListEnabledForbiddenWordBlocking() = %+v, %v, want only enabled word", rows, err)
}
}
func TestForbiddenWordBlockingCRUDAndValidation(t *testing.T) { func TestForbiddenWordBlockingCRUDAndValidation(t *testing.T) {
st := openTestStore(t) st := openTestStore(t)
defer st.Close() defer st.Close()
+64 -13
View File
@@ -34,9 +34,10 @@ const (
type meshtasticFilterHook struct { type meshtasticFilterHook struct {
mqtt.HookBase mqtt.HookBase
key []byte key []byte
store *store store *store
stats *meshtasticMessageStats stats *meshtasticMessageStats
blocking *blockingCache
} }
// ID 返回用于识别 Meshtastic payload 过滤器的 hook 名称。 // ID 返回用于识别 Meshtastic payload 过滤器的 hook 名称。
@@ -46,19 +47,31 @@ func (h *meshtasticFilterHook) ID() string {
// Provides 声明该 hook 只处理客户端发布消息。 // Provides 声明该 hook 只处理客户端发布消息。
func (h *meshtasticFilterHook) Provides(b byte) bool { func (h *meshtasticFilterHook) Provides(b byte) bool {
return b == mqtt.OnPublish return b == mqtt.OnConnect || b == mqtt.OnPublish
}
// OnConnect 在 MQTT 会话建立前拒绝命中 IP 屏蔽表的客户端。
func (h *meshtasticFilterHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error {
info := mqttClientInfoFromClient(cl)
if h.blocking != nil && h.blocking.IsIPBlocked(info.RemoteHost) {
printJSON(map[string]any{"event": "mqtt_client_rejected", "reason": "blocked_ip", "client_id": info.ClientID, "remote_addr": info.RemoteAddr, "remote_host": info.RemoteHost})
return packets.ErrNotAuthorized
}
return nil
} }
// OnPublish 在 broker 转发消息前校验 payload;无效消息会被拒绝并丢弃。 // OnPublish 在 broker 转发消息前校验 payload;无效消息会被拒绝并丢弃。
func (h *meshtasticFilterHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) { func (h *meshtasticFilterHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
valid, _, record := mqtpp.MQTTPP(pk.TopicName, pk.Payload, h.key) valid, _, record := mqtpp.MQTTPP(pk.TopicName, pk.Payload, h.key)
if !valid { if !valid {
h.stats.IncDropped() h.rejectPublish(cl, pk, record)
if h.store != nil { return pk, packets.ErrRejectPacket
if err := h.store.InsertDiscardDetails(record, pk.Payload, mqttClientInfoFromClient(cl)); err != nil { }
printJSON(map[string]any{"event": "db_error", "type": "discard_details", "topic": pk.TopicName, "error": err.Error()}) if violation := blockingViolationForRecord(h.blocking, record); violation != nil {
} for key, value := range violation {
record[key] = value
} }
h.rejectPublish(cl, pk, record)
return pk, packets.ErrRejectPacket return pk, packets.ErrRejectPacket
} }
h.stats.IncForwarded() h.stats.IncForwarded()
@@ -113,6 +126,39 @@ func (h *meshtasticFilterHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (pa
return pk, nil return pk, nil
} }
func (h *meshtasticFilterHook) rejectPublish(cl *mqtt.Client, pk packets.Packet, record map[string]any) {
if h.stats != nil {
h.stats.IncDropped()
}
if h.store != nil {
if err := h.store.InsertDiscardDetails(record, pk.Payload, mqttClientInfoFromClient(cl)); err != nil {
printJSON(map[string]any{"event": "db_error", "type": "discard_details", "topic": pk.TopicName, "error": err.Error()})
}
}
}
func blockingViolationForRecord(blocking *blockingCache, record map[string]any) map[string]any {
if blocking == nil || record == nil {
return nil
}
if blocking.IsNodeBlocked(record["from"], record["from_num"]) {
return map[string]any{"error": "blocked node", "blocking_type": "node"}
}
var field string
switch record["type"] {
case "text_message":
field = "text"
case "nodeinfo", "map_report":
field = "long_name"
default:
return nil
}
if word, ok := blocking.FindForbiddenWord(record[field]); ok {
return map[string]any{"error": "forbidden word", "blocking_type": "forbidden_word", "blocking_field": field, "matched_word": word}
}
return nil
}
func mqttClientInfoFromClient(cl *mqtt.Client) mqttClientInfo { func mqttClientInfoFromClient(cl *mqtt.Client) mqttClientInfo {
if cl == nil { if cl == nil {
return mqttClientInfo{} return mqttClientInfo{}
@@ -201,8 +247,13 @@ func run(cfg *config) error {
return err return err
} }
blocking, err := newBlockingCache(store)
if err != nil {
return err
}
messageStats := &meshtasticMessageStats{} messageStats := &meshtasticMessageStats{}
server, mqttAddr, err := startMQTTServer(cfg, store, messageStats) server, mqttAddr, err := startMQTTServer(cfg, store, messageStats, blocking)
if err != nil { if err != nil {
return err return err
} }
@@ -215,7 +266,7 @@ func run(cfg *config) error {
return err return err
} }
mqttStatus := mqttRuntimeStatus{server: server, address: mqttAddr, tls: cfg.MQTT.TLS.Enabled, stats: messageStats} mqttStatus := mqttRuntimeStatus{server: server, address: mqttAddr, tls: cfg.MQTT.TLS.Enabled, stats: messageStats}
httpServer = newHTTPServer(cfg.Web, store, sessions, mqttStatus) httpServer = newHTTPServer(cfg.Web, store, sessions, mqttStatus, blocking)
go func() { go func() {
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
errCh <- err errCh <- err
@@ -246,12 +297,12 @@ func run(cfg *config) error {
return runErr return runErr
} }
func startMQTTServer(cfg *config, store *store, stats *meshtasticMessageStats) (*mqtt.Server, string, error) { func startMQTTServer(cfg *config, store *store, stats *meshtasticMessageStats, blocking *blockingCache) (*mqtt.Server, string, error) {
server := mqtt.New(nil) server := mqtt.New(nil)
if err := server.AddHook(new(auth.AllowHook), nil); err != nil { if err := server.AddHook(new(auth.AllowHook), nil); err != nil {
return nil, "", err return nil, "", err
} }
if err := server.AddHook(&meshtasticFilterHook{key: cfg.key, store: store, stats: stats}, nil); err != nil { if err := server.AddHook(&meshtasticFilterHook{key: cfg.key, store: store, stats: stats, blocking: blocking}, nil); err != nil {
return nil, "", err return nil, "", err
} }
+39
View File
@@ -41,3 +41,42 @@ func TestMQTTClientInfoFromClientUnsplitRemote(t *testing.T) {
t.Fatalf("remote fields = %#v, want host localhost and empty port", info) t.Fatalf("remote fields = %#v, want host localhost and empty port", info)
} }
} }
func TestBlockingViolationForRecordNode(t *testing.T) {
cache := &blockingCache{nodes: map[string]struct{}{"!12345678": {}}, nodeNums: map[int64]struct{}{}, ips: map[string]struct{}{}}
record := map[string]any{"type": "position", "from": "!12345678", "from_num": uint32(305419896)}
violation := blockingViolationForRecord(cache, record)
if violation == nil || violation["blocking_type"] != "node" {
t.Fatalf("blockingViolationForRecord() = %#v, want node violation", violation)
}
}
func TestBlockingViolationForRecordForbiddenWordFields(t *testing.T) {
cache := &blockingCache{nodes: map[string]struct{}{}, nodeNums: map[int64]struct{}{}, ips: map[string]struct{}{}, words: []forbiddenWordRule{{word: "spam", foldedWord: "spam", matchType: forbiddenWordMatchContains}}}
for _, tc := range []struct {
name string
record map[string]any
field string
}{
{name: "text", record: map[string]any{"type": "text_message", "from": "!1", "text": "has SPAM"}, field: "text"},
{name: "nodeinfo", record: map[string]any{"type": "nodeinfo", "from": "!1", "long_name": "has SPAM"}, field: "long_name"},
{name: "map_report", record: map[string]any{"type": "map_report", "from": "!1", "long_name": "has SPAM"}, field: "long_name"},
} {
t.Run(tc.name, func(t *testing.T) {
violation := blockingViolationForRecord(cache, tc.record)
if violation == nil || violation["blocking_type"] != "forbidden_word" || violation["blocking_field"] != tc.field || violation["matched_word"] != "spam" {
t.Fatalf("blockingViolationForRecord() = %#v, want forbidden word on %s", violation, tc.field)
}
})
}
}
func TestBlockingViolationForRecordAllowed(t *testing.T) {
cache := &blockingCache{nodes: map[string]struct{}{}, nodeNums: map[int64]struct{}{}, ips: map[string]struct{}{}, words: []forbiddenWordRule{{word: "spam", foldedWord: "spam", matchType: forbiddenWordMatchContains}}}
record := map[string]any{"type": "text_message", "from": "!1", "text": "hello"}
if violation := blockingViolationForRecord(cache, record); violation != nil {
t.Fatalf("blockingViolationForRecord() = %#v, want nil", violation)
}
}
+6 -6
View File
@@ -14,19 +14,19 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
func newHTTPServer(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider) *http.Server { func newHTTPServer(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache) *http.Server {
return &http.Server{ return &http.Server{
Addr: net.JoinHostPort(cfg.Host, strconv.Itoa(cfg.Port)), Addr: net.JoinHostPort(cfg.Host, strconv.Itoa(cfg.Port)),
Handler: newRouter(cfg, store, sessions, mqttStatus), Handler: newRouter(cfg, store, sessions, mqttStatus, blocking),
} }
} }
func newRouter(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider) *gin.Engine { func newRouter(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache) *gin.Engine {
r := gin.New() r := gin.New()
r.Use(gin.Logger(), gin.Recovery()) r.Use(gin.Logger(), gin.Recovery())
api := r.Group("/api") api := r.Group("/api")
registerAPIRoutes(api, store) registerAPIRoutes(api, store)
registerAdminRoutes(api.Group("/admin"), store, sessions, mqttStatus) registerAdminRoutes(api.Group("/admin"), store, sessions, mqttStatus, blocking)
registerStaticRoutes(r, cfg.StaticDir) registerStaticRoutes(r, cfg.StaticDir)
return r return r
} }
@@ -96,7 +96,7 @@ func registerAPIRoutes(r gin.IRouter, store *store) {
}) })
} }
func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider) { func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache) {
type loginRequest struct { type loginRequest struct {
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
@@ -156,7 +156,7 @@ func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager,
protected := r.Group("") protected := r.Group("")
protected.Use(requireAdmin(sessions)) protected.Use(requireAdmin(sessions))
registerAdminBlockingRoutes(protected, store) registerAdminBlockingRoutes(protected, store, blocking)
protected.GET("/me", func(c *gin.Context) { protected.GET("/me", func(c *gin.Context) {
claims := c.MustGet("admin_claims").(*sessionClaims) claims := c.MustGet("admin_claims").(*sessionClaims)
c.JSON(http.StatusOK, gin.H{"user": adminUserDTO{Username: claims.Username, Role: claims.Role}}) c.JSON(http.StatusOK, gin.H{"user": adminUserDTO{Username: claims.Username, Role: claims.Role}})