屏蔽词功能ok
This commit is contained in:
+31
-12
@@ -30,7 +30,14 @@ type forbiddenWordBlockingRequest struct {
|
|||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerAdminBlockingRoutes(r gin.IRouter, store *store) {
|
func registerAdminBlockingRoutes(r gin.IRouter, store *store, blocking *blockingCache) {
|
||||||
|
reloadBlocking := func() error {
|
||||||
|
if blocking == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return blocking.Reload(store)
|
||||||
|
}
|
||||||
|
|
||||||
r.GET("/blocking/nodes", func(c *gin.Context) {
|
r.GET("/blocking/nodes", func(c *gin.Context) {
|
||||||
opts, ok := parseListOptions(c)
|
opts, ok := parseListOptions(c)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -51,7 +58,7 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
row, err := store.CreateNodeBlocking(req.NodeID, req.NodeNum, req.Reason, req.Enabled)
|
row, err := store.CreateNodeBlocking(req.NodeID, req.NodeNum, req.Reason, req.Enabled)
|
||||||
writeBlockingMutationResponse(c, http.StatusCreated, row, err, nodeBlockingDTO)
|
writeBlockingMutationResponse(c, http.StatusCreated, row, err, nodeBlockingDTO, reloadBlocking)
|
||||||
})
|
})
|
||||||
r.PUT("/blocking/nodes/:id", func(c *gin.Context) {
|
r.PUT("/blocking/nodes/:id", func(c *gin.Context) {
|
||||||
id, ok := parseBlockingID(c)
|
id, ok := parseBlockingID(c)
|
||||||
@@ -64,14 +71,14 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
row, err := store.UpdateNodeBlocking(id, req.NodeID, req.NodeNum, req.Reason, req.Enabled)
|
row, err := store.UpdateNodeBlocking(id, req.NodeID, req.NodeNum, req.Reason, req.Enabled)
|
||||||
writeBlockingMutationResponse(c, http.StatusOK, row, err, nodeBlockingDTO)
|
writeBlockingMutationResponse(c, http.StatusOK, row, err, nodeBlockingDTO, reloadBlocking)
|
||||||
})
|
})
|
||||||
r.DELETE("/blocking/nodes/:id", func(c *gin.Context) {
|
r.DELETE("/blocking/nodes/:id", func(c *gin.Context) {
|
||||||
id, ok := parseBlockingID(c)
|
id, ok := parseBlockingID(c)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
writeBlockingDeleteResponse(c, store.DeleteNodeBlocking(id))
|
writeBlockingDeleteResponse(c, store.DeleteNodeBlocking(id), reloadBlocking)
|
||||||
})
|
})
|
||||||
|
|
||||||
r.GET("/blocking/ips", func(c *gin.Context) {
|
r.GET("/blocking/ips", func(c *gin.Context) {
|
||||||
@@ -94,7 +101,7 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
row, err := store.CreateIPBlocking(req.IPValue, req.Reason, req.Enabled)
|
row, err := store.CreateIPBlocking(req.IPValue, req.Reason, req.Enabled)
|
||||||
writeBlockingMutationResponse(c, http.StatusCreated, row, err, ipBlockingDTO)
|
writeBlockingMutationResponse(c, http.StatusCreated, row, err, ipBlockingDTO, reloadBlocking)
|
||||||
})
|
})
|
||||||
r.PUT("/blocking/ips/:id", func(c *gin.Context) {
|
r.PUT("/blocking/ips/:id", func(c *gin.Context) {
|
||||||
id, ok := parseBlockingID(c)
|
id, ok := parseBlockingID(c)
|
||||||
@@ -107,14 +114,14 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
row, err := store.UpdateIPBlocking(id, req.IPValue, req.Reason, req.Enabled)
|
row, err := store.UpdateIPBlocking(id, req.IPValue, req.Reason, req.Enabled)
|
||||||
writeBlockingMutationResponse(c, http.StatusOK, row, err, ipBlockingDTO)
|
writeBlockingMutationResponse(c, http.StatusOK, row, err, ipBlockingDTO, reloadBlocking)
|
||||||
})
|
})
|
||||||
r.DELETE("/blocking/ips/:id", func(c *gin.Context) {
|
r.DELETE("/blocking/ips/:id", func(c *gin.Context) {
|
||||||
id, ok := parseBlockingID(c)
|
id, ok := parseBlockingID(c)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
writeBlockingDeleteResponse(c, store.DeleteIPBlocking(id))
|
writeBlockingDeleteResponse(c, store.DeleteIPBlocking(id), reloadBlocking)
|
||||||
})
|
})
|
||||||
|
|
||||||
r.GET("/blocking/words", func(c *gin.Context) {
|
r.GET("/blocking/words", func(c *gin.Context) {
|
||||||
@@ -137,7 +144,7 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
row, err := store.CreateForbiddenWordBlocking(req.Word, req.MatchType, req.CaseSensitive, req.Reason, req.Enabled)
|
row, err := store.CreateForbiddenWordBlocking(req.Word, req.MatchType, req.CaseSensitive, req.Reason, req.Enabled)
|
||||||
writeBlockingMutationResponse(c, http.StatusCreated, row, err, forbiddenWordBlockingDTO)
|
writeBlockingMutationResponse(c, http.StatusCreated, row, err, forbiddenWordBlockingDTO, reloadBlocking)
|
||||||
})
|
})
|
||||||
r.PUT("/blocking/words/:id", func(c *gin.Context) {
|
r.PUT("/blocking/words/:id", func(c *gin.Context) {
|
||||||
id, ok := parseBlockingID(c)
|
id, ok := parseBlockingID(c)
|
||||||
@@ -150,14 +157,14 @@ func registerAdminBlockingRoutes(r gin.IRouter, store *store) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
row, err := store.UpdateForbiddenWordBlocking(id, req.Word, req.MatchType, req.CaseSensitive, req.Reason, req.Enabled)
|
row, err := store.UpdateForbiddenWordBlocking(id, req.Word, req.MatchType, req.CaseSensitive, req.Reason, req.Enabled)
|
||||||
writeBlockingMutationResponse(c, http.StatusOK, row, err, forbiddenWordBlockingDTO)
|
writeBlockingMutationResponse(c, http.StatusOK, row, err, forbiddenWordBlockingDTO, reloadBlocking)
|
||||||
})
|
})
|
||||||
r.DELETE("/blocking/words/:id", func(c *gin.Context) {
|
r.DELETE("/blocking/words/:id", func(c *gin.Context) {
|
||||||
id, ok := parseBlockingID(c)
|
id, ok := parseBlockingID(c)
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
writeBlockingDeleteResponse(c, store.DeleteForbiddenWordBlocking(id))
|
writeBlockingDeleteResponse(c, store.DeleteForbiddenWordBlocking(id), reloadBlocking)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -170,7 +177,7 @@ func parseBlockingID(c *gin.Context) (uint64, bool) {
|
|||||||
return id, true
|
return id, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeBlockingMutationResponse[T any](c *gin.Context, status int, row *T, err error, convert func(T) gin.H) {
|
func writeBlockingMutationResponse[T any](c *gin.Context, status int, row *T, err error, convert func(T) gin.H, afterSuccess func() error) {
|
||||||
if errors.Is(err, errBlockingAlreadyExists) {
|
if errors.Is(err, errBlockingAlreadyExists) {
|
||||||
c.JSON(http.StatusConflict, gin.H{"error": "blocking rule already exists"})
|
c.JSON(http.StatusConflict, gin.H{"error": "blocking rule already exists"})
|
||||||
return
|
return
|
||||||
@@ -183,10 +190,16 @@ func writeBlockingMutationResponse[T any](c *gin.Context, status int, row *T, er
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if afterSuccess != nil {
|
||||||
|
if err := afterSuccess(); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "blocking rule saved but cache reload failed: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
c.JSON(status, gin.H{"item": convert(*row)})
|
c.JSON(status, gin.H{"item": convert(*row)})
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeBlockingDeleteResponse(c *gin.Context, err error) {
|
func writeBlockingDeleteResponse(c *gin.Context, err error, afterSuccess func() error) {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "blocking rule not found"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "blocking rule not found"})
|
||||||
return
|
return
|
||||||
@@ -195,6 +208,12 @@ func writeBlockingDeleteResponse(c *gin.Context, err error) {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if afterSuccess != nil {
|
||||||
|
if err := afterSuccess(); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "blocking rule deleted but cache reload failed: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,212 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type blockingCache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
nodes map[string]struct{}
|
||||||
|
nodeNums map[int64]struct{}
|
||||||
|
ips map[string]struct{}
|
||||||
|
cidrs []*net.IPNet
|
||||||
|
words []forbiddenWordRule
|
||||||
|
}
|
||||||
|
|
||||||
|
type forbiddenWordRule struct {
|
||||||
|
word string
|
||||||
|
foldedWord string
|
||||||
|
matchType string
|
||||||
|
caseSensitive bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBlockingCache(store *store) (*blockingCache, error) {
|
||||||
|
cache := &blockingCache{}
|
||||||
|
if err := cache.Reload(store); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return cache, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *blockingCache) Reload(store *store) error {
|
||||||
|
if store == nil {
|
||||||
|
return fmt.Errorf("store is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeRows, err := store.ListEnabledNodeBlocking()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ipRows, err := store.ListEnabledIPBlocking()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
wordRows, err := store.ListEnabledForbiddenWordBlocking()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes := make(map[string]struct{}, len(nodeRows))
|
||||||
|
nodeNums := make(map[int64]struct{}, len(nodeRows))
|
||||||
|
for _, row := range nodeRows {
|
||||||
|
nodeID := strings.TrimSpace(row.NodeID)
|
||||||
|
if nodeID != "" {
|
||||||
|
nodes[nodeID] = struct{}{}
|
||||||
|
}
|
||||||
|
if row.NodeNum != nil {
|
||||||
|
nodeNums[*row.NodeNum] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ips := make(map[string]struct{}, len(ipRows))
|
||||||
|
cidrs := make([]*net.IPNet, 0, len(ipRows))
|
||||||
|
for _, row := range ipRows {
|
||||||
|
value := strings.TrimSpace(row.IPValue)
|
||||||
|
if value == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ip := net.ParseIP(value); ip != nil {
|
||||||
|
ips[ip.String()] = struct{}{}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ipNet, err := net.ParseCIDR(value); err == nil {
|
||||||
|
cidrs = append(cidrs, ipNet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
words := make([]forbiddenWordRule, 0, len(wordRows))
|
||||||
|
for _, row := range wordRows {
|
||||||
|
word := strings.TrimSpace(row.Word)
|
||||||
|
if word == "" || row.MatchType != forbiddenWordMatchContains {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
words = append(words, forbiddenWordRule{word: word, foldedWord: strings.ToLower(word), matchType: row.MatchType, caseSensitive: row.CaseSensitive})
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
c.nodes = nodes
|
||||||
|
c.nodeNums = nodeNums
|
||||||
|
c.ips = ips
|
||||||
|
c.cidrs = cidrs
|
||||||
|
c.words = words
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *blockingCache) IsNodeBlocked(nodeID any, nodeNum any) bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
id, _ := nodeID.(string)
|
||||||
|
num, hasNum := blockingInt64FromAny(nodeNum)
|
||||||
|
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
if id != "" {
|
||||||
|
if _, ok := c.nodes[id]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if hasNum {
|
||||||
|
_, ok := c.nodeNums[num]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *blockingCache) IsIPBlocked(host string) bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
host = strings.TrimSpace(host)
|
||||||
|
if host == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
if ip == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
if _, ok := c.ips[ip.String()]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, ipNet := range c.cidrs {
|
||||||
|
if ipNet.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *blockingCache) FindForbiddenWord(text any) (string, bool) {
|
||||||
|
if c == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
value, ok := text.(string)
|
||||||
|
if !ok || value == "" {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
foldedText := ""
|
||||||
|
for _, rule := range c.words {
|
||||||
|
if rule.matchType != forbiddenWordMatchContains {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if rule.caseSensitive {
|
||||||
|
if strings.Contains(value, rule.word) {
|
||||||
|
return rule.word, true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if foldedText == "" {
|
||||||
|
foldedText = strings.ToLower(value)
|
||||||
|
}
|
||||||
|
if strings.Contains(foldedText, rule.foldedWord) {
|
||||||
|
return rule.word, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
func blockingInt64FromAny(value any) (int64, bool) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
return int64(v), true
|
||||||
|
case int8:
|
||||||
|
return int64(v), true
|
||||||
|
case int16:
|
||||||
|
return int64(v), true
|
||||||
|
case int32:
|
||||||
|
return int64(v), true
|
||||||
|
case int64:
|
||||||
|
return v, true
|
||||||
|
case uint:
|
||||||
|
return int64(v), true
|
||||||
|
case uint8:
|
||||||
|
return int64(v), true
|
||||||
|
case uint16:
|
||||||
|
return int64(v), true
|
||||||
|
case uint32:
|
||||||
|
return int64(v), true
|
||||||
|
case uint64:
|
||||||
|
if v > uint64(^uint64(0)>>1) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return int64(v), true
|
||||||
|
case float64:
|
||||||
|
return int64(v), v == float64(int64(v))
|
||||||
|
case string:
|
||||||
|
n, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64)
|
||||||
|
return n, err == nil
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestBlockingCacheLoadsEnabledRules(t *testing.T) {
|
||||||
|
st := openTestStore(t)
|
||||||
|
defer st.Close()
|
||||||
|
|
||||||
|
nodeNum := int64(305419896)
|
||||||
|
if _, err := st.CreateNodeBlocking("!12345678", &nodeNum, "enabled", true); err != nil {
|
||||||
|
t.Fatalf("CreateNodeBlocking(enabled) error = %v", err)
|
||||||
|
}
|
||||||
|
disabledNodeNum := int64(7)
|
||||||
|
if _, err := st.CreateNodeBlocking("!00000007", &disabledNodeNum, "disabled", false); err != nil {
|
||||||
|
t.Fatalf("CreateNodeBlocking(disabled) error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := st.CreateIPBlocking("192.168.1.0/24", "lan", true); err != nil {
|
||||||
|
t.Fatalf("CreateIPBlocking(cidr) error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := st.CreateIPBlocking("10.0.0.1", "disabled", false); err != nil {
|
||||||
|
t.Fatalf("CreateIPBlocking(disabled) error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := st.CreateForbiddenWordBlocking("spam", "contains", false, "enabled", true); err != nil {
|
||||||
|
t.Fatalf("CreateForbiddenWordBlocking(enabled) error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := st.CreateForbiddenWordBlocking("blocked", "contains", false, "disabled", false); err != nil {
|
||||||
|
t.Fatalf("CreateForbiddenWordBlocking(disabled) error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cache, err := newBlockingCache(st)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newBlockingCache() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cache.IsNodeBlocked("!12345678", nil) {
|
||||||
|
t.Fatal("IsNodeBlocked(enabled node id) = false, want true")
|
||||||
|
}
|
||||||
|
if !cache.IsNodeBlocked("", uint32(nodeNum)) {
|
||||||
|
t.Fatal("IsNodeBlocked(enabled node num) = false, want true")
|
||||||
|
}
|
||||||
|
if cache.IsNodeBlocked("!00000007", disabledNodeNum) {
|
||||||
|
t.Fatal("IsNodeBlocked(disabled node) = true, want false")
|
||||||
|
}
|
||||||
|
if !cache.IsIPBlocked("192.168.1.42") {
|
||||||
|
t.Fatal("IsIPBlocked(CIDR member) = false, want true")
|
||||||
|
}
|
||||||
|
if cache.IsIPBlocked("10.0.0.1") {
|
||||||
|
t.Fatal("IsIPBlocked(disabled IP) = true, want false")
|
||||||
|
}
|
||||||
|
if word, ok := cache.FindForbiddenWord("This is SPAM text"); !ok || word != "spam" {
|
||||||
|
t.Fatalf("FindForbiddenWord(case-insensitive) = %q, %v, want spam, true", word, ok)
|
||||||
|
}
|
||||||
|
if _, ok := cache.FindForbiddenWord("disabled blocked text"); ok {
|
||||||
|
t.Fatal("FindForbiddenWord(disabled word) = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlockingCacheIPExactAndCIDR(t *testing.T) {
|
||||||
|
st := openTestStore(t)
|
||||||
|
defer st.Close()
|
||||||
|
|
||||||
|
if _, err := st.CreateIPBlocking("127.0.0.1", "loopback", true); err != nil {
|
||||||
|
t.Fatalf("CreateIPBlocking(ip) error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := st.CreateIPBlocking("2001:db8::/32", "docs", true); err != nil {
|
||||||
|
t.Fatalf("CreateIPBlocking(ipv6 cidr) error = %v", err)
|
||||||
|
}
|
||||||
|
cache, err := newBlockingCache(st)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newBlockingCache() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cache.IsIPBlocked("127.0.0.1") {
|
||||||
|
t.Fatal("IsIPBlocked(exact IPv4) = false, want true")
|
||||||
|
}
|
||||||
|
if !cache.IsIPBlocked("2001:db8::1") {
|
||||||
|
t.Fatal("IsIPBlocked(IPv6 CIDR) = false, want true")
|
||||||
|
}
|
||||||
|
if cache.IsIPBlocked("localhost") {
|
||||||
|
t.Fatal("IsIPBlocked(hostname) = true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlockingCacheForbiddenWordCaseSensitivity(t *testing.T) {
|
||||||
|
st := openTestStore(t)
|
||||||
|
defer st.Close()
|
||||||
|
|
||||||
|
if _, err := st.CreateForbiddenWordBlocking("Spam", "contains", true, "case-sensitive", true); err != nil {
|
||||||
|
t.Fatalf("CreateForbiddenWordBlocking(case-sensitive) error = %v", err)
|
||||||
|
}
|
||||||
|
cache, err := newBlockingCache(st)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("newBlockingCache() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := cache.FindForbiddenWord("lowercase spam"); ok {
|
||||||
|
t.Fatal("FindForbiddenWord(lowercase) = true, want false")
|
||||||
|
}
|
||||||
|
if word, ok := cache.FindForbiddenWord("contains Spam"); !ok || word != "Spam" {
|
||||||
|
t.Fatalf("FindForbiddenWord(exact case) = %q, %v, want Spam, true", word, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -30,6 +30,11 @@ func (s *store) CountNodeBlocking(opts listOptions) (int64, error) {
|
|||||||
return total, s.db.Model(&nodeBlockingRecord{}).Count(&total).Error
|
return total, s.db.Model(&nodeBlockingRecord{}).Count(&total).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *store) ListEnabledNodeBlocking() ([]nodeBlockingRecord, error) {
|
||||||
|
var rows []nodeBlockingRecord
|
||||||
|
return rows, s.db.Where("enabled = ?", true).Find(&rows).Error
|
||||||
|
}
|
||||||
|
|
||||||
func (s *store) CreateNodeBlocking(nodeID string, nodeNum *int64, reason string, enabled bool) (*nodeBlockingRecord, error) {
|
func (s *store) CreateNodeBlocking(nodeID string, nodeNum *int64, reason string, enabled bool) (*nodeBlockingRecord, error) {
|
||||||
nodeID = strings.TrimSpace(nodeID)
|
nodeID = strings.TrimSpace(nodeID)
|
||||||
if nodeID == "" {
|
if nodeID == "" {
|
||||||
@@ -93,6 +98,11 @@ func (s *store) CountIPBlocking(opts listOptions) (int64, error) {
|
|||||||
return total, s.db.Model(&ipBlockingRecord{}).Count(&total).Error
|
return total, s.db.Model(&ipBlockingRecord{}).Count(&total).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *store) ListEnabledIPBlocking() ([]ipBlockingRecord, error) {
|
||||||
|
var rows []ipBlockingRecord
|
||||||
|
return rows, s.db.Where("enabled = ?", true).Find(&rows).Error
|
||||||
|
}
|
||||||
|
|
||||||
func (s *store) CreateIPBlocking(ipValue string, reason string, enabled bool) (*ipBlockingRecord, error) {
|
func (s *store) CreateIPBlocking(ipValue string, reason string, enabled bool) (*ipBlockingRecord, error) {
|
||||||
value, err := normalizeIPBlockingValue(ipValue)
|
value, err := normalizeIPBlockingValue(ipValue)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -156,6 +166,11 @@ func (s *store) CountForbiddenWordBlocking(opts listOptions) (int64, error) {
|
|||||||
return total, s.db.Model(&forbiddenWordBlockingRecord{}).Count(&total).Error
|
return total, s.db.Model(&forbiddenWordBlockingRecord{}).Count(&total).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *store) ListEnabledForbiddenWordBlocking() ([]forbiddenWordBlockingRecord, error) {
|
||||||
|
var rows []forbiddenWordBlockingRecord
|
||||||
|
return rows, s.db.Where("enabled = ?", true).Find(&rows).Error
|
||||||
|
}
|
||||||
|
|
||||||
func (s *store) CreateForbiddenWordBlocking(word, matchType string, caseSensitive bool, reason string, enabled bool) (*forbiddenWordBlockingRecord, error) {
|
func (s *store) CreateForbiddenWordBlocking(word, matchType string, caseSensitive bool, reason string, enabled bool) (*forbiddenWordBlockingRecord, error) {
|
||||||
word = strings.TrimSpace(word)
|
word = strings.TrimSpace(word)
|
||||||
if word == "" {
|
if word == "" {
|
||||||
|
|||||||
@@ -120,6 +120,42 @@ func TestIPBlockingCRUDAndValidation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestListEnabledBlockingRules(t *testing.T) {
|
||||||
|
st := openTestStore(t)
|
||||||
|
defer st.Close()
|
||||||
|
|
||||||
|
nodeNum := int64(1)
|
||||||
|
if _, err := st.CreateNodeBlocking("!00000001", &nodeNum, "enabled", true); err != nil {
|
||||||
|
t.Fatalf("CreateNodeBlocking(enabled) error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := st.CreateNodeBlocking("!00000002", nil, "disabled", false); err != nil {
|
||||||
|
t.Fatalf("CreateNodeBlocking(disabled) error = %v", err)
|
||||||
|
}
|
||||||
|
if rows, err := st.ListEnabledNodeBlocking(); err != nil || len(rows) != 1 || rows[0].NodeID != "!00000001" {
|
||||||
|
t.Fatalf("ListEnabledNodeBlocking() = %+v, %v, want only enabled node", rows, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := st.CreateIPBlocking("127.0.0.1", "enabled", true); err != nil {
|
||||||
|
t.Fatalf("CreateIPBlocking(enabled) error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := st.CreateIPBlocking("192.168.1.1", "disabled", false); err != nil {
|
||||||
|
t.Fatalf("CreateIPBlocking(disabled) error = %v", err)
|
||||||
|
}
|
||||||
|
if rows, err := st.ListEnabledIPBlocking(); err != nil || len(rows) != 1 || rows[0].IPValue != "127.0.0.1" {
|
||||||
|
t.Fatalf("ListEnabledIPBlocking() = %+v, %v, want only enabled IP", rows, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := st.CreateForbiddenWordBlocking("spam", "contains", false, "enabled", true); err != nil {
|
||||||
|
t.Fatalf("CreateForbiddenWordBlocking(enabled) error = %v", err)
|
||||||
|
}
|
||||||
|
if _, err := st.CreateForbiddenWordBlocking("eggs", "contains", false, "disabled", false); err != nil {
|
||||||
|
t.Fatalf("CreateForbiddenWordBlocking(disabled) error = %v", err)
|
||||||
|
}
|
||||||
|
if rows, err := st.ListEnabledForbiddenWordBlocking(); err != nil || len(rows) != 1 || rows[0].Word != "spam" {
|
||||||
|
t.Fatalf("ListEnabledForbiddenWordBlocking() = %+v, %v, want only enabled word", rows, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestForbiddenWordBlockingCRUDAndValidation(t *testing.T) {
|
func TestForbiddenWordBlockingCRUDAndValidation(t *testing.T) {
|
||||||
st := openTestStore(t)
|
st := openTestStore(t)
|
||||||
defer st.Close()
|
defer st.Close()
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ type meshtasticFilterHook struct {
|
|||||||
key []byte
|
key []byte
|
||||||
store *store
|
store *store
|
||||||
stats *meshtasticMessageStats
|
stats *meshtasticMessageStats
|
||||||
|
blocking *blockingCache
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID 返回用于识别 Meshtastic payload 过滤器的 hook 名称。
|
// ID 返回用于识别 Meshtastic payload 过滤器的 hook 名称。
|
||||||
@@ -46,19 +47,31 @@ func (h *meshtasticFilterHook) ID() string {
|
|||||||
|
|
||||||
// Provides 声明该 hook 只处理客户端发布消息。
|
// Provides 声明该 hook 只处理客户端发布消息。
|
||||||
func (h *meshtasticFilterHook) Provides(b byte) bool {
|
func (h *meshtasticFilterHook) Provides(b byte) bool {
|
||||||
return b == mqtt.OnPublish
|
return b == mqtt.OnConnect || b == mqtt.OnPublish
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnConnect 在 MQTT 会话建立前拒绝命中 IP 屏蔽表的客户端。
|
||||||
|
func (h *meshtasticFilterHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error {
|
||||||
|
info := mqttClientInfoFromClient(cl)
|
||||||
|
if h.blocking != nil && h.blocking.IsIPBlocked(info.RemoteHost) {
|
||||||
|
printJSON(map[string]any{"event": "mqtt_client_rejected", "reason": "blocked_ip", "client_id": info.ClientID, "remote_addr": info.RemoteAddr, "remote_host": info.RemoteHost})
|
||||||
|
return packets.ErrNotAuthorized
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnPublish 在 broker 转发消息前校验 payload;无效消息会被拒绝并丢弃。
|
// OnPublish 在 broker 转发消息前校验 payload;无效消息会被拒绝并丢弃。
|
||||||
func (h *meshtasticFilterHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
|
func (h *meshtasticFilterHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
|
||||||
valid, _, record := mqtpp.MQTTPP(pk.TopicName, pk.Payload, h.key)
|
valid, _, record := mqtpp.MQTTPP(pk.TopicName, pk.Payload, h.key)
|
||||||
if !valid {
|
if !valid {
|
||||||
h.stats.IncDropped()
|
h.rejectPublish(cl, pk, record)
|
||||||
if h.store != nil {
|
return pk, packets.ErrRejectPacket
|
||||||
if err := h.store.InsertDiscardDetails(record, pk.Payload, mqttClientInfoFromClient(cl)); err != nil {
|
|
||||||
printJSON(map[string]any{"event": "db_error", "type": "discard_details", "topic": pk.TopicName, "error": err.Error()})
|
|
||||||
}
|
}
|
||||||
|
if violation := blockingViolationForRecord(h.blocking, record); violation != nil {
|
||||||
|
for key, value := range violation {
|
||||||
|
record[key] = value
|
||||||
}
|
}
|
||||||
|
h.rejectPublish(cl, pk, record)
|
||||||
return pk, packets.ErrRejectPacket
|
return pk, packets.ErrRejectPacket
|
||||||
}
|
}
|
||||||
h.stats.IncForwarded()
|
h.stats.IncForwarded()
|
||||||
@@ -113,6 +126,39 @@ func (h *meshtasticFilterHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (pa
|
|||||||
return pk, nil
|
return pk, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *meshtasticFilterHook) rejectPublish(cl *mqtt.Client, pk packets.Packet, record map[string]any) {
|
||||||
|
if h.stats != nil {
|
||||||
|
h.stats.IncDropped()
|
||||||
|
}
|
||||||
|
if h.store != nil {
|
||||||
|
if err := h.store.InsertDiscardDetails(record, pk.Payload, mqttClientInfoFromClient(cl)); err != nil {
|
||||||
|
printJSON(map[string]any{"event": "db_error", "type": "discard_details", "topic": pk.TopicName, "error": err.Error()})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func blockingViolationForRecord(blocking *blockingCache, record map[string]any) map[string]any {
|
||||||
|
if blocking == nil || record == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if blocking.IsNodeBlocked(record["from"], record["from_num"]) {
|
||||||
|
return map[string]any{"error": "blocked node", "blocking_type": "node"}
|
||||||
|
}
|
||||||
|
var field string
|
||||||
|
switch record["type"] {
|
||||||
|
case "text_message":
|
||||||
|
field = "text"
|
||||||
|
case "nodeinfo", "map_report":
|
||||||
|
field = "long_name"
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if word, ok := blocking.FindForbiddenWord(record[field]); ok {
|
||||||
|
return map[string]any{"error": "forbidden word", "blocking_type": "forbidden_word", "blocking_field": field, "matched_word": word}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func mqttClientInfoFromClient(cl *mqtt.Client) mqttClientInfo {
|
func mqttClientInfoFromClient(cl *mqtt.Client) mqttClientInfo {
|
||||||
if cl == nil {
|
if cl == nil {
|
||||||
return mqttClientInfo{}
|
return mqttClientInfo{}
|
||||||
@@ -201,8 +247,13 @@ func run(cfg *config) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
blocking, err := newBlockingCache(store)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
messageStats := &meshtasticMessageStats{}
|
messageStats := &meshtasticMessageStats{}
|
||||||
server, mqttAddr, err := startMQTTServer(cfg, store, messageStats)
|
server, mqttAddr, err := startMQTTServer(cfg, store, messageStats, blocking)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -215,7 +266,7 @@ func run(cfg *config) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
mqttStatus := mqttRuntimeStatus{server: server, address: mqttAddr, tls: cfg.MQTT.TLS.Enabled, stats: messageStats}
|
mqttStatus := mqttRuntimeStatus{server: server, address: mqttAddr, tls: cfg.MQTT.TLS.Enabled, stats: messageStats}
|
||||||
httpServer = newHTTPServer(cfg.Web, store, sessions, mqttStatus)
|
httpServer = newHTTPServer(cfg.Web, store, sessions, mqttStatus, blocking)
|
||||||
go func() {
|
go func() {
|
||||||
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
errCh <- err
|
errCh <- err
|
||||||
@@ -246,12 +297,12 @@ func run(cfg *config) error {
|
|||||||
return runErr
|
return runErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func startMQTTServer(cfg *config, store *store, stats *meshtasticMessageStats) (*mqtt.Server, string, error) {
|
func startMQTTServer(cfg *config, store *store, stats *meshtasticMessageStats, blocking *blockingCache) (*mqtt.Server, string, error) {
|
||||||
server := mqtt.New(nil)
|
server := mqtt.New(nil)
|
||||||
if err := server.AddHook(new(auth.AllowHook), nil); err != nil {
|
if err := server.AddHook(new(auth.AllowHook), nil); err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
if err := server.AddHook(&meshtasticFilterHook{key: cfg.key, store: store, stats: stats}, nil); err != nil {
|
if err := server.AddHook(&meshtasticFilterHook{key: cfg.key, store: store, stats: stats, blocking: blocking}, nil); err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -41,3 +41,42 @@ func TestMQTTClientInfoFromClientUnsplitRemote(t *testing.T) {
|
|||||||
t.Fatalf("remote fields = %#v, want host localhost and empty port", info)
|
t.Fatalf("remote fields = %#v, want host localhost and empty port", info)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBlockingViolationForRecordNode(t *testing.T) {
|
||||||
|
cache := &blockingCache{nodes: map[string]struct{}{"!12345678": {}}, nodeNums: map[int64]struct{}{}, ips: map[string]struct{}{}}
|
||||||
|
record := map[string]any{"type": "position", "from": "!12345678", "from_num": uint32(305419896)}
|
||||||
|
|
||||||
|
violation := blockingViolationForRecord(cache, record)
|
||||||
|
if violation == nil || violation["blocking_type"] != "node" {
|
||||||
|
t.Fatalf("blockingViolationForRecord() = %#v, want node violation", violation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlockingViolationForRecordForbiddenWordFields(t *testing.T) {
|
||||||
|
cache := &blockingCache{nodes: map[string]struct{}{}, nodeNums: map[int64]struct{}{}, ips: map[string]struct{}{}, words: []forbiddenWordRule{{word: "spam", foldedWord: "spam", matchType: forbiddenWordMatchContains}}}
|
||||||
|
|
||||||
|
for _, tc := range []struct {
|
||||||
|
name string
|
||||||
|
record map[string]any
|
||||||
|
field string
|
||||||
|
}{
|
||||||
|
{name: "text", record: map[string]any{"type": "text_message", "from": "!1", "text": "has SPAM"}, field: "text"},
|
||||||
|
{name: "nodeinfo", record: map[string]any{"type": "nodeinfo", "from": "!1", "long_name": "has SPAM"}, field: "long_name"},
|
||||||
|
{name: "map_report", record: map[string]any{"type": "map_report", "from": "!1", "long_name": "has SPAM"}, field: "long_name"},
|
||||||
|
} {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
violation := blockingViolationForRecord(cache, tc.record)
|
||||||
|
if violation == nil || violation["blocking_type"] != "forbidden_word" || violation["blocking_field"] != tc.field || violation["matched_word"] != "spam" {
|
||||||
|
t.Fatalf("blockingViolationForRecord() = %#v, want forbidden word on %s", violation, tc.field)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlockingViolationForRecordAllowed(t *testing.T) {
|
||||||
|
cache := &blockingCache{nodes: map[string]struct{}{}, nodeNums: map[int64]struct{}{}, ips: map[string]struct{}{}, words: []forbiddenWordRule{{word: "spam", foldedWord: "spam", matchType: forbiddenWordMatchContains}}}
|
||||||
|
record := map[string]any{"type": "text_message", "from": "!1", "text": "hello"}
|
||||||
|
if violation := blockingViolationForRecord(cache, record); violation != nil {
|
||||||
|
t.Fatalf("blockingViolationForRecord() = %#v, want nil", violation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,19 +14,19 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newHTTPServer(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider) *http.Server {
|
func newHTTPServer(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache) *http.Server {
|
||||||
return &http.Server{
|
return &http.Server{
|
||||||
Addr: net.JoinHostPort(cfg.Host, strconv.Itoa(cfg.Port)),
|
Addr: net.JoinHostPort(cfg.Host, strconv.Itoa(cfg.Port)),
|
||||||
Handler: newRouter(cfg, store, sessions, mqttStatus),
|
Handler: newRouter(cfg, store, sessions, mqttStatus, blocking),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRouter(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider) *gin.Engine {
|
func newRouter(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache) *gin.Engine {
|
||||||
r := gin.New()
|
r := gin.New()
|
||||||
r.Use(gin.Logger(), gin.Recovery())
|
r.Use(gin.Logger(), gin.Recovery())
|
||||||
api := r.Group("/api")
|
api := r.Group("/api")
|
||||||
registerAPIRoutes(api, store)
|
registerAPIRoutes(api, store)
|
||||||
registerAdminRoutes(api.Group("/admin"), store, sessions, mqttStatus)
|
registerAdminRoutes(api.Group("/admin"), store, sessions, mqttStatus, blocking)
|
||||||
registerStaticRoutes(r, cfg.StaticDir)
|
registerStaticRoutes(r, cfg.StaticDir)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
@@ -96,7 +96,7 @@ func registerAPIRoutes(r gin.IRouter, store *store) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider) {
|
func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache) {
|
||||||
type loginRequest struct {
|
type loginRequest struct {
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
@@ -156,7 +156,7 @@ func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager,
|
|||||||
|
|
||||||
protected := r.Group("")
|
protected := r.Group("")
|
||||||
protected.Use(requireAdmin(sessions))
|
protected.Use(requireAdmin(sessions))
|
||||||
registerAdminBlockingRoutes(protected, store)
|
registerAdminBlockingRoutes(protected, store, blocking)
|
||||||
protected.GET("/me", func(c *gin.Context) {
|
protected.GET("/me", func(c *gin.Context) {
|
||||||
claims := c.MustGet("admin_claims").(*sessionClaims)
|
claims := c.MustGet("admin_claims").(*sessionClaims)
|
||||||
c.JSON(http.StatusOK, gin.H{"user": adminUserDTO{Username: claims.Username, Role: claims.Role}})
|
c.JSON(http.StatusOK, gin.H{"user": adminUserDTO{Username: claims.Username, Role: claims.Role}})
|
||||||
|
|||||||
Reference in New Issue
Block a user