This commit is contained in:
2026-05-15 22:03:46 +08:00
parent e6a7565745
commit e047eacfdc
6 changed files with 1239 additions and 272 deletions
+492 -23
View File
@@ -2,14 +2,200 @@ package stats
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"fmt"
"sync"
"sync/atomic"
"time"
mqtt "github.com/mochi-mqtt/server/v2"
"github.com/mochi-mqtt/server/v2/packets"
"google.golang.org/protobuf/encoding/protowire"
)
// ---------------------------------------------------------------------------
// Meshtastic Protobuf 手动解析 (简化版)
// ---------------------------------------------------------------------------
// MeshPacket 简化版,只包含解密所需字段
type MeshPacket struct {
Id uint64
From uint32
WhichPayloadVariant int // 10=decoded, 11=encrypted
Encrypted []byte
DecryptedPayload []byte // field 7: decoded.payload
PkiEncrypted bool
}
// ServiceEnvelope 简化版
type ServiceEnvelope struct {
ChannelId string
GatewayId string
Packet *MeshPacket
}
// ParseServiceEnvelope 解析 ServiceEnvelope 二进制 protobuf
func ParseServiceEnvelope(data []byte) (*ServiceEnvelope, error) {
env := &ServiceEnvelope{}
pos := 0
for pos < len(data) {
fieldNum, wireType, n := protowire.ConsumeTag(data[pos:])
if n < 0 {
return nil, fmt.Errorf("invalid wire format at pos %d", pos)
}
pos += n
switch {
case int(fieldNum) == 1 && wireType == protowire.BytesType:
msgData, n := protowire.ConsumeBytes(data[pos:])
if n < 0 {
return nil, fmt.Errorf("invalid bytes at pos %d", pos)
}
pos += n
packet, err := parseMeshPacket(msgData)
if err != nil {
return nil, fmt.Errorf("failed to parse packet: %w", err)
}
env.Packet = packet
case int(fieldNum) == 2 && wireType == protowire.BytesType:
val, n := protowire.ConsumeBytes(data[pos:])
if n < 0 {
return nil, fmt.Errorf("invalid bytes at pos %d", pos)
}
env.ChannelId = string(val)
pos += n
case int(fieldNum) == 3 && wireType == protowire.BytesType:
val, n := protowire.ConsumeBytes(data[pos:])
if n < 0 {
return nil, fmt.Errorf("invalid bytes at pos %d", pos)
}
env.GatewayId = string(val)
pos += n
default:
n, ok := skipField(data[pos:], int(fieldNum), wireType)
if !ok {
return nil, fmt.Errorf("skip failed at pos %d", pos)
}
pos += n
}
}
return env, nil
}
// parseMeshPacket 解析 MeshPacket
func parseMeshPacket(data []byte) (*MeshPacket, error) {
packet := &MeshPacket{}
pos := 0
for pos < len(data) {
fieldNum, wireType, n := protowire.ConsumeTag(data[pos:])
if n < 0 {
return nil, fmt.Errorf("invalid wire format at pos %d", pos)
}
pos += n
switch {
case int(fieldNum) == 1 && wireType == protowire.VarintType:
val, n := protowire.ConsumeVarint(data[pos:])
if n < 0 {
return nil, fmt.Errorf("invalid varint at pos %d", pos)
}
packet.Id = val
pos += n
case int(fieldNum) == 3 && wireType == protowire.VarintType:
val, n := protowire.ConsumeVarint(data[pos:])
if n < 0 {
return nil, fmt.Errorf("invalid varint at pos %d", pos)
}
packet.From = uint32(val)
pos += n
case int(fieldNum) == 8 && wireType == protowire.VarintType:
val, n := protowire.ConsumeVarint(data[pos:])
if n < 0 {
return nil, fmt.Errorf("invalid varint at pos %d", pos)
}
packet.WhichPayloadVariant = int(val)
pos += n
case int(fieldNum) == 11 && wireType == protowire.BytesType:
// encrypted 字段 (variant 11)
val, n := protowire.ConsumeBytes(data[pos:])
if n < 0 {
return nil, fmt.Errorf("invalid bytes at pos %d", pos)
}
packet.Encrypted = val
pos += n
case int(fieldNum) == 7 && wireType == protowire.BytesType:
// decoded.payload 字段 (variant 10) - 已经是解密的数据
val, n := protowire.ConsumeBytes(data[pos:])
if n < 0 {
return nil, fmt.Errorf("invalid bytes at pos %d", pos)
}
packet.DecryptedPayload = val
pos += n
case int(fieldNum) == 15 && wireType == protowire.VarintType:
val, n := protowire.ConsumeVarint(data[pos:])
if n < 0 {
return nil, fmt.Errorf("invalid varint at pos %d", pos)
}
packet.PkiEncrypted = val != 0
pos += n
default:
skipped, ok := skipField(data[pos:], int(fieldNum), wireType)
if !ok {
return nil, fmt.Errorf("skip failed at pos %d", pos)
}
pos += skipped
}
}
return packet, nil
}
// skipField 跳过未知 protobuf 字段
func skipField(data []byte, fieldNum int, wireType protowire.Type) (int, bool) {
switch wireType {
case protowire.VarintType:
_, n := protowire.ConsumeVarint(data)
if n < 0 {
return 0, false
}
return n, true
case protowire.Fixed32Type:
if len(data) < 4 {
return 0, false
}
return 4, true
case protowire.Fixed64Type:
if len(data) < 8 {
return 0, false
}
return 8, true
case protowire.BytesType:
_, n := protowire.ConsumeBytes(data)
if n < 0 {
return 0, false
}
return n, true
case protowire.StartGroupType, protowire.EndGroupType:
return 0, false
default:
return 0, false
}
}
// ---------------------------------------------------------------------------
// 数据结构
// ---------------------------------------------------------------------------
@@ -25,31 +211,51 @@ type ClientInfo struct {
// Stats 当前统计快照
type Stats struct {
Connections int64 `json:"connections"` // 当前连接数
MessagesTotal int64 `json:"messages_total"` // 累计消息数(所有主题)
MessagesMsh int64 `json:"messages_msh"` // msh/# 消息数
Uptime int64 `json:"uptime"` // 服务运行时长(秒)
Clients []ClientInfo `json:"clients"` // 在线客户端列表
Topics map[string]int64 `json:"topics"` // 各主题消息数
Connections int64 `json:"connections"`
MessagesTotal int64 `json:"messages_total"`
MessagesMsh int64 `json:"messages_msh"`
Uptime int64 `json:"uptime"`
Clients []ClientInfo `json:"clients"`
Topics map[string]int64 `json:"topics"`
}
// DecryptedMessage 解密后的消息结构
type DecryptedMessage struct {
ChannelId string `json:"channel_id"`
GatewayId string `json:"gateway_id"`
PacketId uint64 `json:"packet_id"`
From uint32 `json:"from"`
PortNum uint32 `json:"port_num"`
Payload []byte `json:"payload"`
}
// MessageRecord 一条MQTT消息记录
type MessageRecord struct {
Topic string `json:"topic"`
Payload string `json:"payload"`
Time time.Time `json:"time"`
Decrypted *DecryptedMessage `json:"decrypted,omitempty"`
}
// ---------------------------------------------------------------------------
// 全局统计atomic + mutex 无锁热点路径)
// 全局统计
// ---------------------------------------------------------------------------
var (
connections atomic.Int64
messagesTotal atomic.Int64
messagesMsh atomic.Int64
startTime = time.Now()
clientsMu sync.RWMutex
clients = make(map[string]ClientInfo) // clientID → info
subs = make(map[string][]string) // clientID → []filter
topicsMu sync.RWMutex
topics = make(map[string]int64) // topic → count
connections atomic.Int64
messagesTotal atomic.Int64
messagesMsh atomic.Int64
startTime = time.Now()
clientsMu sync.RWMutex
clients = make(map[string]ClientInfo)
subs = make(map[string][]string)
topicsMu sync.RWMutex
topics = make(map[string]int64)
msgMu sync.RWMutex
msgBuf []MessageRecord
)
// GetStats 返回当前统计快照(只读副本)
// GetStats 返回当前统计快照
func GetStats() Stats {
clientsMu.RLock()
clientList := make([]ClientInfo, 0, len(clients))
@@ -76,11 +282,29 @@ func GetStats() Stats {
}
}
// GetClient 返回指定客户端的详细信息
func GetClient(id string) *ClientInfo {
clientsMu.RLock()
defer clientsMu.RUnlock()
info, ok := clients[id]
if !ok {
return nil
}
info.SubsCount = len(subs[id])
return &info
}
// GetClientSubs 返回指定客户端的订阅主题列表
func GetClientSubs(id string) []string {
clientsMu.RLock()
defer clientsMu.RUnlock()
return subs[id]
}
// ---------------------------------------------------------------------------
// Hook 实现
// ---------------------------------------------------------------------------
// Hook 收集 MQTT 运行统计
type Hook struct {
mqtt.HookBase
}
@@ -97,7 +321,6 @@ func (h *Hook) Provides(b byte) bool {
}, []byte{b})
}
// OnSessionEstablished 客户端连接成功
func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
username := string(pk.Connect.Username)
if username == "" {
@@ -115,7 +338,6 @@ func (h *Hook) OnSessionEstablished(cl *mqtt.Client, pk packets.Packet) {
connections.Add(1)
}
// OnDisconnect 客户端断开
func (h *Hook) OnDisconnect(cl *mqtt.Client, err error, expire bool) {
clientsMu.Lock()
delete(clients, cl.ID)
@@ -124,11 +346,11 @@ func (h *Hook) OnDisconnect(cl *mqtt.Client, err error, expire bool) {
connections.Add(-1)
}
// OnPublish 收到发布消息
func (h *Hook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
messagesTotal.Add(1)
if len(pk.TopicName) >= 4 && pk.TopicName[:4] == "msh/" {
messagesMsh.Add(1)
addMessage(pk.TopicName, pk.Payload)
}
topicsMu.Lock()
topics[pk.TopicName]++
@@ -136,7 +358,6 @@ func (h *Hook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, er
return pk, nil
}
// OnSubscribe 客户端订阅
func (h *Hook) OnSubscribe(cl *mqtt.Client, pk packets.Packet) packets.Packet {
clientsMu.Lock()
for _, f := range pk.Filters {
@@ -146,7 +367,6 @@ func (h *Hook) OnSubscribe(cl *mqtt.Client, pk packets.Packet) packets.Packet {
return pk
}
// OnUnsubscribe 客户端取消订阅
func (h *Hook) OnUnsubscribe(cl *mqtt.Client, pk packets.Packet) packets.Packet {
clientsMu.Lock()
for _, f := range pk.Filters {
@@ -163,3 +383,252 @@ func (h *Hook) OnUnsubscribe(cl *mqtt.Client, pk packets.Packet) packets.Packet
}
var _ mqtt.Hook = (*Hook)(nil)
func addMessage(topic string, payload []byte) {
rec := MessageRecord{
Topic: topic,
Payload: base64.StdEncoding.EncodeToString(payload),
Time: time.Now(),
}
msgMu.Lock()
defer msgMu.Unlock()
msgBuf = append(msgBuf, rec)
if len(msgBuf) > 200 {
msgBuf = msgBuf[len(msgBuf)-200:]
}
}
func GetMessages() []MessageRecord {
msgMu.RLock()
defer msgMu.RUnlock()
out := make([]MessageRecord, len(msgBuf))
copy(out, msgBuf)
return out
}
// ParseServiceEnvelopeDebug 解析但不解密(用于调试)
func ParseServiceEnvelopeDebug(payloadB64 string) (*ServiceEnvelope, error) {
data, err := base64.StdEncoding.DecodeString(payloadB64)
if err != nil {
return nil, fmt.Errorf("failed to decode base64: %w", err)
}
return ParseServiceEnvelope(data)
}
// ---------------------------------------------------------------------------
// Meshtastic AES-CTR 解密
// ---------------------------------------------------------------------------
// 默认 PSK (索引1 = 不变)
var DefaultPSK = []byte{0xd4, 0xf1, 0xbb, 0x3a, 0x20, 0x29, 0x07, 0x59, 0xf0, 0xbc, 0xff, 0xab, 0xcf, 0x4e, 0x69, 0x01}
// 默认 PSK 索引 (1-8)
var defaultPSKIndex byte = 1
// Payload variant tags
const (
MeshPacket_decoded_tag = 10
MeshPacket_encrypted_tag = 11
)
// ExpandPSK 将 1 字节 PSK 索引扩展为 16 字节 AES128 密钥
func ExpandPSK(pskIndex byte) ([]byte, error) {
if pskIndex == 0 {
return nil, nil // 无加密
}
if pskIndex > 8 {
return nil, fmt.Errorf("PSK index must be 0-8, got %d", pskIndex)
}
key := make([]byte, 16)
copy(key, DefaultPSK)
// 索引1不变,索引2-8在最后一位累加
if pskIndex > 1 {
key[15] += pskIndex - 1
}
return key, nil
}
// buildNonce 构建 AES-CTR 用的 nonce (16字节)
// nonce 结构: packetId(8字节小端) + fromNode(4字节小端) + counter(4字节,通常为0)
func buildNonce(packetId uint64, fromNode uint32) [16]byte {
var nonce [16]byte
// packetId: 8字节,小端序
nonce[0] = byte(packetId)
nonce[1] = byte(packetId >> 8)
nonce[2] = byte(packetId >> 16)
nonce[3] = byte(packetId >> 24)
nonce[4] = byte(packetId >> 32)
nonce[5] = byte(packetId >> 40)
nonce[6] = byte(packetId >> 48)
nonce[7] = byte(packetId >> 56)
// fromNode: 4字节,小端序
nonce[8] = byte(fromNode)
nonce[9] = byte(fromNode >> 8)
nonce[10] = byte(fromNode >> 16)
nonce[11] = byte(fromNode >> 24)
// counter: 4字节,默认为0 (nonce[12-15] 已经是0)
return nonce
}
// decryptAESCtr 使用 AES-CTR 解密
func decryptAESCtr(key []byte, nonce [16]byte, ciphertext []byte) ([]byte, error) {
if len(key) != 16 {
return nil, fmt.Errorf("key must be 16 bytes, got %d", len(key))
}
if len(ciphertext) == 0 {
return nil, fmt.Errorf("ciphertext is empty")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
stream := cipher.NewCTR(block, nonce[:])
plaintext := make([]byte, len(ciphertext))
stream.XORKeyStream(plaintext, ciphertext)
return plaintext, nil
}
// DecryptMeshPacket 解密 MeshPacket
func DecryptMeshPacket(psk []byte, packetId uint64, fromNode uint32, encrypted []byte) ([]byte, error) {
if psk == nil {
return nil, fmt.Errorf("no PSK configured")
}
if len(encrypted) == 0 {
return nil, fmt.Errorf("encrypted payload is empty")
}
nonce := buildNonce(packetId, fromNode)
return decryptAESCtr(psk, nonce, encrypted)
}
// SetDefaultPSKIndex 设置默认 PSK 索引
func SetDefaultPSKIndex(index byte) {
defaultPSKIndex = index
}
// GetDefaultPSK 返回当前 PSK 的 16 字节密钥
func GetDefaultPSK() []byte {
key, _ := ExpandPSK(defaultPSKIndex)
return key
}
// GetDefaultPSKIndex 返回当前 PSK 索引
func GetDefaultPSKIndex() byte {
return defaultPSKIndex
}
// TryDecryptServiceEnvelope 尝试解密 ServiceEnvelope
func TryDecryptServiceEnvelope(data []byte, pskIndex byte) (*DecryptedMessage, error) {
psk, err := ExpandPSK(pskIndex)
if err != nil {
return nil, err
}
env, err := ParseServiceEnvelope(data)
if err != nil {
return nil, err
}
if env.Packet == nil {
return nil, fmt.Errorf("ServiceEnvelope has no packet")
}
msg := &DecryptedMessage{
ChannelId: env.ChannelId,
GatewayId: env.GatewayId,
PacketId: env.Packet.Id,
From: env.Packet.From,
}
// variant 10: 已经解密的数据 (decoded.payload)
if env.Packet.WhichPayloadVariant == MeshPacket_decoded_tag {
if len(env.Packet.DecryptedPayload) > 0 {
msg.Payload = env.Packet.DecryptedPayload
// portnum 是解密数据的第一个字节
msg.PortNum = uint32(env.Packet.DecryptedPayload[0])
}
return msg, nil
}
// variant 11: 加密数据,需要解密
if env.Packet.WhichPayloadVariant == MeshPacket_encrypted_tag && !env.Packet.PkiEncrypted {
plaintext, err := DecryptMeshPacket(psk, env.Packet.Id, env.Packet.From, env.Packet.Encrypted)
if err != nil {
return msg, fmt.Errorf("decryption failed: %w", err)
}
msg.Payload = plaintext
// 解析 portnum (第一个字节)
if len(plaintext) > 0 {
msg.PortNum = uint32(plaintext[0])
}
} else if env.Packet.WhichPayloadVariant == MeshPacket_encrypted_tag && env.Packet.PkiEncrypted {
return msg, fmt.Errorf("PKI encrypted packet (not supported)")
} else {
return msg, fmt.Errorf("unknown packet variant: %d", env.Packet.WhichPayloadVariant)
}
return msg, nil
}
// TryDecryptMessage 尝试解密消息
func TryDecryptMessage(payloadB64 string, pskIndex byte) (*DecryptedMessage, error) {
data, err := base64.StdEncoding.DecodeString(payloadB64)
if err != nil {
return nil, fmt.Errorf("failed to decode base64: %w", err)
}
return TryDecryptServiceEnvelope(data, pskIndex)
}
// PortNumName 返回 PortNum 对应的名称
func PortNumName(portNum uint32) string {
names := map[uint32]string{
0: "Reserved",
1: "TEXT_MESSAGE_APP",
2: "REMOTE_HARDWARE_APP",
3: "POSITION_APP",
4: "NODEINFO_APP",
5: "ROUTING_APP",
6: "ADMIN_APP",
7: "TEXT_MESSAGE_APP2",
8: "WAYPOINT_APP",
9: "WIFI_APP",
10: "MXT_AI_APP",
11: "RANGE_TEST_APP",
12: "DETECTION_SENSOR_APP",
13: "REPLY_APP",
14: "IP_TUNNEL_APP",
15: "SERIAL_APP",
16: "STORE_FORWARD_APP",
17: "TELEMETRY_APP",
18: "ZPS_APP",
19: "SIMULATOR_APP",
20: "TRACEROUTE_APP",
21: "NEIGHBORINFO_APP",
22: "AUDIO_APP",
23: "DUPLICATE_MESSAGES_APP",
24: "ACKNOWLEDGEMENT_APP",
25: "CONFIG_APP",
26: "IPLY_CONFIG_APP",
27: "MAP_REPORT_APP",
28: "PaxCounter_APP",
32: "PRIVATE_APP",
256: "ATAK_PLUGIN",
257: "HALP",
258: "RPC_APP",
259: "XMPP_APP",
260: "STREAM_APP",
261: "TUNNEL_APP",
}
if name, ok := names[portNum]; ok {
return name
}
return fmt.Sprintf("Unknown(%d)", portNum)
}