完善私聊

This commit is contained in:
2026-06-14 19:26:43 +08:00
parent a2d838d556
commit 5d4aced3e0
5 changed files with 650 additions and 34 deletions
+111 -9
View File
@@ -3,7 +3,10 @@ package main
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"strings"
"time"
@@ -12,6 +15,7 @@ import (
"meshtastic_mqtt_server/mqtpp"
mqtt "github.com/mochi-mqtt/server/v2"
"gorm.io/gorm"
)
const botMaxTextBytes = 200
@@ -73,13 +77,6 @@ func (s *botService) SendText(_ context.Context, req botSendTextRequest) (*botMe
if len([]byte(text)) > botMaxTextBytes {
return nil, fmt.Errorf("text is too long, max %d bytes", botMaxTextBytes)
}
channelID := strings.TrimSpace(req.ChannelID)
if channelID == "" {
channelID = bot.DefaultChannelID
}
if channelID == "" {
return nil, fmt.Errorf("channel id is required")
}
toNodeNum, toNodeID, err := botMessageTarget(messageType, req)
if err != nil {
return nil, err
@@ -88,6 +85,20 @@ func (s *botService) SendText(_ context.Context, req botSendTextRequest) (*botMe
if err != nil {
return nil, err
}
fromNodeNum := uint32(bot.NodeNum)
// direct 私聊走 PKIchannel 群聊保留旧的 AES-CTR + PSK 路径
if messageType == botMessageTypeDirect {
return s.sendPKIDirect(bot, fromNodeNum, uint32(toNodeNum), toNodeID, packetID, text, req.CreatedBy)
}
channelID := strings.TrimSpace(req.ChannelID)
if channelID == "" {
channelID = bot.DefaultChannelID
}
if channelID == "" {
return nil, fmt.Errorf("channel id is required")
}
psk := strings.TrimSpace(bot.PSK)
if psk == "" {
psk = botDefaultPSK
@@ -96,7 +107,6 @@ func (s *botService) SendText(_ context.Context, req botSendTextRequest) (*botMe
if err != nil {
return nil, err
}
fromNodeNum := uint32(bot.NodeNum)
raw, err := mqtpp.BuildTextMessageServiceEnvelope(mqtpp.TextMessageBuildOptions{
PacketBuildOptions: mqtpp.PacketBuildOptions{
FromNodeNum: fromNodeNum,
@@ -121,7 +131,7 @@ func (s *botService) SendText(_ context.Context, req botSendTextRequest) (*botMe
MessageType: messageType,
ChannelID: channelID,
ToNodeID: toNodeID,
ToNodeNum: int64PtrOrNil(toNodeNum, messageType == botMessageTypeDirect),
ToNodeNum: int64PtrOrNil(toNodeNum, false),
Topic: topic,
PacketID: int64(packetID),
Text: text,
@@ -130,6 +140,98 @@ func (s *botService) SendText(_ context.Context, req botSendTextRequest) (*botMe
Status: botMessageStatusPending,
CreatedBy: strings.TrimSpace(req.CreatedBy),
}
return s.persistAndPublish(row, topic, raw)
}
// sendPKIDirect 按固件 PKI 流程发送私聊:
// - 从 nodeinfo 中查目标节点的 X25519 公钥
// - 用 bot 自身私钥与对端公钥派生共享密钥,AES-CCM(M=8,L=2) 加密
// - ServiceEnvelope.channel_id = "PKI"topic 也用 "PKI"
func (s *botService) sendPKIDirect(bot *botNodeRecord, fromNodeNum, toNodeNum uint32, toNodeID *string, packetID uint32, text, createdBy string) (*botMessageRecord, error) {
if toNodeID == nil {
return nil, fmt.Errorf("target node id is required for pki direct message")
}
privateKeyB64 := strings.TrimSpace(bot.PrivateKey)
if privateKeyB64 == "" {
return nil, fmt.Errorf("bot has no private key, regenerate keys first")
}
privateKey, err := base64.StdEncoding.DecodeString(privateKeyB64)
if err != nil {
return nil, fmt.Errorf("invalid bot private key: %w", err)
}
senderPublic, err := decodeBotPublicKey(*bot)
if err != nil {
return nil, err
}
recipientPublic, err := s.lookupRecipientPublicKey(*toNodeID)
if err != nil {
return nil, err
}
raw, err := mqtpp.BuildPKITextMessageServiceEnvelope(mqtpp.PKITextMessageBuildOptions{
FromNodeNum: fromNodeNum,
ToNodeNum: toNodeNum,
PacketID: packetID,
GatewayID: bot.NodeID,
ViaMQTT: true,
SenderPrivate: privateKey,
RecipientPub: recipientPublic,
SenderPublic: senderPublic,
Text: text,
})
if err != nil {
return nil, err
}
topic := botMQTTTopic(bot.TopicPrefix, mqtpp.PKIChannelID, bot.NodeID)
row := &botMessageRecord{
BotID: bot.ID,
BotNodeID: bot.NodeID,
BotNodeNum: bot.NodeNum,
MessageType: botMessageTypeDirect,
ChannelID: mqtpp.PKIChannelID,
ToNodeID: toNodeID,
ToNodeNum: int64PtrOrNil(int64(toNodeNum), true),
Topic: topic,
PacketID: int64(packetID),
Text: text,
PayloadLen: int64(len(raw)),
Encrypted: true,
Status: botMessageStatusPending,
CreatedBy: strings.TrimSpace(createdBy),
}
return s.persistAndPublish(row, topic, raw)
}
// lookupRecipientPublicKey 从 nodeinfo 表中按 node_id 查询目标节点的 X25519 公钥(hex 编码)。
func (s *botService) lookupRecipientPublicKey(nodeID string) ([]byte, error) {
node, err := s.store.GetNodeInfo(nodeID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("recipient node %s not found in nodeinfo, cannot send PKI message", nodeID)
}
return nil, err
}
if node.PublicKey == nil || strings.TrimSpace(*node.PublicKey) == "" {
return nil, fmt.Errorf("recipient node %s has no public key on file", nodeID)
}
keyHex := strings.TrimSpace(*node.PublicKey)
keyBytes, err := hex.DecodeString(keyHex)
if err != nil {
// 兼容历史上可能存储为 base64 的情况
if alt, altErr := base64.StdEncoding.DecodeString(keyHex); altErr == nil {
keyBytes = alt
} else {
return nil, fmt.Errorf("invalid recipient public key for %s: %w", nodeID, err)
}
}
if len(keyBytes) != 32 {
return nil, fmt.Errorf("recipient public key for %s has unexpected length %d", nodeID, len(keyBytes))
}
return keyBytes, nil
}
// persistAndPublish 把消息记录入库后发布到 MQTT,统一处理失败状态写回。
func (s *botService) persistAndPublish(row *botMessageRecord, topic string, raw []byte) (*botMessageRecord, error) {
if err := s.store.InsertBotMessage(row); err != nil {
return nil, err
}
+60 -8
View File
@@ -69,10 +69,33 @@
"node": ">=6.9.0"
}
},
"node_modules/@emnapi/core": {
"version": "1.11.1",
"resolved": "https://registry.npmjs.org/@emnapi/core/-/core-1.11.1.tgz",
"integrity": "sha512-RSvbQmHzdKzNsLYa/wHrbc3KN4sYLKAdPZxqiM2HATqv/SBk2/ENSHpvXGaLOMcsAyz0poEGqkmmKYG3OWiJEQ==",
"dev": true,
"license": "MIT",
"optional": true,
"dependencies": {
"@emnapi/wasi-threads": "1.2.2",
"tslib": "^2.4.0"
}
},
"node_modules/@emnapi/runtime": {
"version": "1.11.1",
"resolved": "https://registry.npmjs.org/@emnapi/runtime/-/runtime-1.11.1.tgz",
"integrity": "sha512-vgj7R3y3Wgx24IQaGPA/R6YFXLHVMOZ0uVEyIQPaWs+rd1AzfEMXlAC22FYwO1XkKR6NPsq7mUandH8oIRdZFw==",
"dev": true,
"license": "MIT",
"optional": true,
"dependencies": {
"tslib": "^2.4.0"
}
},
"node_modules/@emnapi/wasi-threads": {
"version": "1.2.1",
"resolved": "https://registry.npmjs.org/@emnapi/wasi-threads/-/wasi-threads-1.2.1.tgz",
"integrity": "sha512-uTII7OYF+/Mes/MrcIOYp5yOtSMLBWSIoLPpcgwipoiKbli6k322tcoFsxoIIxPDqW01SQGAgko4EzZi2BNv2w==",
"version": "1.2.2",
"resolved": "https://registry.npmjs.org/@emnapi/wasi-threads/-/wasi-threads-1.2.2.tgz",
"integrity": "sha512-c95qOXkHdydNKhscBTebqEC1CVAZpyqOfVfBzQ1qgzyl3gfeldUjIggDbIZgDKsHLgnsM+igH7TJ/eAasaVuMA==",
"dev": true,
"license": "MIT",
"optional": true,
@@ -381,6 +404,40 @@
"node": "^20.19.0 || >=22.12.0"
}
},
"node_modules/@rolldown/binding-wasm32-wasi/node_modules/@emnapi/core": {
"version": "1.10.0",
"resolved": "https://registry.npmjs.org/@emnapi/core/-/core-1.10.0.tgz",
"integrity": "sha512-yq6OkJ4p82CAfPl0u9mQebQHKPJkY7WrIuk205cTYnYe+k2Z8YBh11FrbRG/H6ihirqcacOgl2BIO8oyMQLeXw==",
"dev": true,
"license": "MIT",
"optional": true,
"dependencies": {
"@emnapi/wasi-threads": "1.2.1",
"tslib": "^2.4.0"
}
},
"node_modules/@rolldown/binding-wasm32-wasi/node_modules/@emnapi/runtime": {
"version": "1.10.0",
"resolved": "https://registry.npmjs.org/@emnapi/runtime/-/runtime-1.10.0.tgz",
"integrity": "sha512-ewvYlk86xUoGI0zQRNq/mC+16R1QeDlKQy21Ki3oSYXNgLb45GV1P6A0M+/s6nyCuNDqe5VpaY84BzXGwVbwFA==",
"dev": true,
"license": "MIT",
"optional": true,
"dependencies": {
"tslib": "^2.4.0"
}
},
"node_modules/@rolldown/binding-wasm32-wasi/node_modules/@emnapi/wasi-threads": {
"version": "1.2.1",
"resolved": "https://registry.npmjs.org/@emnapi/wasi-threads/-/wasi-threads-1.2.1.tgz",
"integrity": "sha512-uTII7OYF+/Mes/MrcIOYp5yOtSMLBWSIoLPpcgwipoiKbli6k322tcoFsxoIIxPDqW01SQGAgko4EzZi2BNv2w==",
"dev": true,
"license": "MIT",
"optional": true,
"dependencies": {
"tslib": "^2.4.0"
}
},
"node_modules/@rolldown/binding-win32-arm64-msvc": {
"version": "1.0.3",
"resolved": "https://registry.npmjs.org/@rolldown/binding-win32-arm64-msvc/-/binding-win32-arm64-msvc-1.0.3.tgz",
@@ -728,7 +785,6 @@
"integrity": "sha512-GUUEShf+PBCGW2KaXwcIt3Yk+e3pkKwWKb9GSyM9WQVE+ep2jzmHdGsHzu4wgcZy5fN9FBdVzjpBQsYlpfpgLA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~7.16.0"
}
@@ -1339,7 +1395,6 @@
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -1470,7 +1525,6 @@
"integrity": "sha512-y2TvuxSZPDyQakkFRPZHKFm+KKVqIisdg9/CZwm9ftvKXLP8NRWj38/ODjNbr43SsoXqNuAisEf1GdCxqWcdBw==",
"devOptional": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -1492,7 +1546,6 @@
"integrity": "sha512-h9bXPmJichP5fLmVQo3PyaGSDE2n3aPuomeAlVRm0JLmt4rY6zmPKd59HYI4LNW8oTK7tlTsuC7l/m7awx9Jcw==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"lightningcss": "^1.32.0",
"picomatch": "^4.0.4",
@@ -1577,7 +1630,6 @@
"resolved": "https://registry.npmjs.org/vue/-/vue-3.5.35.tgz",
"integrity": "sha512-cx89fnr+0kVGHiNFG6y6s0bdjypJRFNZn6x3WPstNdQR1bi1mbB7h4v5IBGTsPJU3nK1+0Iqj3Zf+hZWMieR4Q==",
"license": "MIT",
"peer": true,
"dependencies": {
"@vue/compiler-dom": "3.5.35",
"@vue/compiler-sfc": "3.5.35",
@@ -7,13 +7,14 @@ const chatPageSize = 30
const maxTextBytes = 200
const topThreshold = 8
const bottomThreshold = 40
// 私聊固定走 PKIchannel_id 与固件 ServiceEnvelope 保持一致
const directChannelId = 'PKI'
const bots = ref<BotNode[]>([])
const targets = ref<NodeInfo[]>([])
const messages = ref<TextMessage[]>([])
const selectedBotId = ref<number | null>(null)
const selectedTargetId = ref('')
const channelId = ref('LongFast')
const text = ref('')
const loading = ref(false)
const sending = ref(false)
@@ -33,7 +34,7 @@ let restoreMessageCount = 0
const selectedBot = computed(() => bots.value.find((item) => item.id === selectedBotId.value) ?? null)
const selectedTarget = computed(() => targets.value.find((item) => item.node_id === selectedTargetId.value) ?? null)
const directTextBytes = computed(() => new TextEncoder().encode(text.value).length)
const canSend = computed(() => !!selectedBot.value && !!selectedTarget.value && !!channelId.value.trim() && !!text.value.trim() && directTextBytes.value <= maxTextBytes && !sending.value)
const canSend = computed(() => !!selectedBot.value && !!selectedTarget.value && !!text.value.trim() && directTextBytes.value <= maxTextBytes && !sending.value)
const groupedMessages = computed(() => {
const groups = new Map<string, TextMessage & { mergedCount: number; mergedMessages: TextMessage[] }>()
for (const item of messages.value) {
@@ -49,8 +50,7 @@ const groupedMessages = computed(() => {
return Array.from(groups.values())
})
watch(selectedBot, (bot) => {
if (bot) channelId.value = bot.default_channel_id
watch(selectedBot, () => {
resetChat()
loadInitialMessages()
})
@@ -60,11 +60,6 @@ watch(selectedTargetId, () => {
loadInitialMessages()
})
watch(channelId, () => {
resetChat()
loadInitialMessages()
})
function resetChat() {
messages.value = []
hasMore.value = true
@@ -110,7 +105,6 @@ async function refreshLists() {
targets.value = nodeResponse.items
if (!selectedBotId.value && bots.value.length > 0) {
selectedBotId.value = bots.value[0].id
channelId.value = bots.value[0].default_channel_id
}
} catch (err) {
error.value = err instanceof Error ? err.message : String(err)
@@ -123,7 +117,7 @@ async function loadInitialMessages() {
if (!selectedBot.value || !selectedTarget.value) return
loadingOlder.value = true
try {
const response = await getBotDirectTextMessages(selectedBot.value.id, selectedTarget.value.node_num, chatPageSize, 0, channelId.value)
const response = await getBotDirectTextMessages(selectedBot.value.id, selectedTarget.value.node_num, chatPageSize, 0, directChannelId)
messages.value = toChronological(response.items)
hasMore.value = response.items.length === chatPageSize
initialized.value = true
@@ -141,7 +135,7 @@ async function loadOlderMessages() {
if (!selectedBot.value || !selectedTarget.value || loadingOlder.value || !hasMore.value) return
loadingOlder.value = true
try {
const response = await getBotDirectTextMessages(selectedBot.value.id, selectedTarget.value.node_num, chatPageSize, messages.value.length, channelId.value)
const response = await getBotDirectTextMessages(selectedBot.value.id, selectedTarget.value.node_num, chatPageSize, messages.value.length, directChannelId)
messages.value = mergeMessages(messages.value, toChronological(response.items))
hasMore.value = response.items.length === chatPageSize
} catch (err) {
@@ -153,7 +147,7 @@ async function loadOlderMessages() {
async function pollLatestMessages() {
if (!selectedBot.value || !selectedTarget.value) return
const response = await getBotDirectTextMessages(selectedBot.value.id, selectedTarget.value.node_num, chatPageSize, 0, channelId.value)
const response = await getBotDirectTextMessages(selectedBot.value.id, selectedTarget.value.node_num, chatPageSize, 0, directChannelId)
messages.value = mergeMessages(messages.value, toChronological(response.items))
}
@@ -163,7 +157,7 @@ async function sendDirectMessage() {
error.value = ''
notice.value = ''
try {
const response = await sendBotMessage({ bot_id: selectedBot.value.id, message_type: 'direct', channel_id: channelId.value, to_node_id: selectedTarget.value.node_id, text: text.value })
const response = await sendBotMessage({ bot_id: selectedBot.value.id, message_type: 'direct', channel_id: directChannelId, to_node_id: selectedTarget.value.node_id, text: text.value })
if (response.error) {
error.value = response.error
} else {
@@ -240,7 +234,7 @@ onBeforeUnmount(() => {
<div class="direct-header">
<div>
<p class="eyebrow">Direct Bot Chat</p>
<h2>机器人私聊功能未完成</h2>
<h2>机器人私聊 <span class="pki-badge" title="使用 X25519 + AES-CCM 与目标节点端到端加密">PKI 加密</span></h2>
</div>
<div class="direct-actions">
<a class="admin-button secondary" href="/admin/bot">返回频道聊天</a>
@@ -264,9 +258,10 @@ onBeforeUnmount(() => {
<option v-for="node in targets" :key="node.node_id" :value="node.node_id">{{ node.long_name || node.short_name || node.node_id }} · {{ node.node_id }}</option>
</select>
</label>
<label>频道 ID<input v-model="channelId" /></label>
</div>
<p class="direct-hint">私聊固定走 PKIchannel_id = "PKI"需要目标节点已上报 NodeInfo 公钥才能加密</p>
<div ref="panelRef" class="direct-chat-list" @scroll.passive="handleScroll">
<div v-if="loadingOlder" class="chat-loading">正在加载更早消息...</div>
<div v-else-if="!hasMore && messages.length > 0" class="chat-end">没有更多历史消息</div>
@@ -293,7 +288,9 @@ onBeforeUnmount(() => {
<style scoped>
.direct-page { display: grid; gap: 12px; padding: 16px; }
.direct-header, .direct-actions, .send-actions { display: flex; align-items: center; justify-content: space-between; gap: 10px; flex-wrap: wrap; }
.direct-selectors { display: grid; grid-template-columns: repeat(3, minmax(180px, 1fr)); gap: 12px; }
.direct-selectors { display: grid; grid-template-columns: repeat(2, minmax(180px, 1fr)); gap: 12px; }
.direct-hint { color: #475569; font-size: 12px; margin: 0; }
.pki-badge { display: inline-flex; align-items: center; margin-left: 8px; border-radius: 999px; padding: 2px 10px; color: #1d4ed8; background: #dbeafe; font-size: 12px; font-weight: 700; vertical-align: middle; }
label { display: grid; gap: 5px; color: #334155; font-size: 13px; font-weight: 800; }
input, select, textarea { box-sizing: border-box; width: 100%; border: 1px solid #cbd5e1; border-radius: 10px; padding: 9px 11px; color: #0f172a; font: inherit; background: #fff; }
.direct-chat-list { min-height: 420px; max-height: 560px; overflow: auto; display: flex; flex-direction: column; gap: 10px; border: 1px solid #e2e8f0; border-radius: 14px; padding: 14px; background: linear-gradient(180deg, #f8fafc 0%, #eef4ff 100%); }
+304
View File
@@ -0,0 +1,304 @@
package mqtpp
import (
"crypto/aes"
"crypto/ecdh"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/binary"
"fmt"
"strings"
"unicode/utf8"
"google.golang.org/protobuf/encoding/protowire"
)
// PKIChannelID 是固件在 ServiceEnvelope/MQTT topic 中标识 PKI 加密包时使用的字面量
// 见 firmware/src/mqtt/MQTT.cpp 中 `channelId = isPKIEncrypted ? "PKI" : channels.getGlobalId(chIndex);`
const PKIChannelID = "PKI"
// pkcOverhead 与固件 MESHTASTIC_PKC_OVERHEAD 一致:8 字节 AES-CCM 认证标签 + 4 字节 extraNonce
const pkcOverhead = 12
// PKITextMessageBuildOptions 描述构造一条 PKI 加密 DM 所需的全部上下文
type PKITextMessageBuildOptions struct {
FromNodeNum uint32
ToNodeNum uint32
PacketID uint32
GatewayID string
ViaMQTT bool
SenderPrivate []byte // X25519 32 字节私钥
RecipientPub []byte // X25519 32 字节公钥
SenderPublic []byte // 可选;附在 MeshPacket.public_key (tag 16)
Text string
}
// BuildPKITextMessageServiceEnvelope 构造一条遵循固件实现的 PKI 私聊文本消息:
// - data 包: portnum=TEXT_MESSAGE_APP, payload=text
// - 共享密钥: SHA256(X25519(senderPriv, recipientPub))
// - AES-CCM(M=8,L=2,AAD=0); nonce = packetId(8B LE) | fromNode(4B LE) | extraNonce(4B LE,覆盖 fromNode 后续 4 字节)
// - encrypted bytes 末尾追加 8 字节 auth + 4 字节 extraNonce(LE)
// - MeshPacket.channel = 0, pki_encrypted(tag17)=1
// - ServiceEnvelope.channel_id 固定 "PKI"
func BuildPKITextMessageServiceEnvelope(opts PKITextMessageBuildOptions) ([]byte, error) {
if opts.FromNodeNum == 0 {
return nil, fmt.Errorf("from node number is required")
}
if opts.ToNodeNum == 0 || opts.ToNodeNum == NodeNumBroadcast {
return nil, fmt.Errorf("pki direct message requires a non-broadcast destination")
}
if opts.PacketID == 0 {
return nil, fmt.Errorf("packet id is required")
}
if opts.Text == "" {
return nil, fmt.Errorf("text is required")
}
if !utf8.ValidString(opts.Text) {
return nil, fmt.Errorf("text must be valid utf-8")
}
if len(opts.SenderPrivate) != 32 {
return nil, fmt.Errorf("sender private key must be 32 bytes")
}
if len(opts.RecipientPub) != 32 {
return nil, fmt.Errorf("recipient public key must be 32 bytes")
}
if strings.TrimSpace(opts.GatewayID) == "" {
opts.GatewayID = NodeNumToID(opts.FromNodeNum)
}
plaintext := buildDataPacket(textMessageApp, []byte(opts.Text))
sharedKey, err := pkiSharedKey(opts.SenderPrivate, opts.RecipientPub)
if err != nil {
return nil, err
}
var extraNonceBuf [4]byte
if _, err := rand.Read(extraNonceBuf[:]); err != nil {
return nil, err
}
extraNonce := binary.LittleEndian.Uint32(extraNonceBuf[:])
ciphertext, auth, err := aesCCMEncrypt(sharedKey, pkiNonce(opts.PacketID, opts.FromNodeNum, extraNonce), plaintext)
if err != nil {
return nil, err
}
encrypted := make([]byte, 0, len(ciphertext)+pkcOverhead)
encrypted = append(encrypted, ciphertext...)
encrypted = append(encrypted, auth...)
encrypted = append(encrypted, extraNonceBuf[:]...)
packet := buildPKIMeshPacket(opts.FromNodeNum, opts.ToNodeNum, opts.PacketID, opts.ViaMQTT, encrypted, opts.SenderPublic)
return buildServiceEnvelope(packet, PKIChannelID, opts.GatewayID), nil
}
// pkiSharedKey 用 X25519 计算共享密钥,再做一次 SHA-256(与固件一致)。
func pkiSharedKey(privateKey, publicKey []byte) ([]byte, error) {
curve := ecdh.X25519()
priv, err := curve.NewPrivateKey(privateKey)
if err != nil {
return nil, fmt.Errorf("invalid sender private key: %w", err)
}
pub, err := curve.NewPublicKey(publicKey)
if err != nil {
return nil, fmt.Errorf("invalid recipient public key: %w", err)
}
shared, err := priv.ECDH(pub)
if err != nil {
return nil, fmt.Errorf("x25519 ecdh failed: %w", err)
}
digest := sha256.Sum256(shared)
return digest[:], nil
}
// pkiNonce 完整复刻固件 CryptoEngine::initNonce(fromNode, packetId, extraNonce) 的字节布局。
// 固件实现(mesh/CryptoEngine.cpp):
//
// memcpy(nonce + 0, &packetId, 8); // packetId 是 uint64,写入 nonce[0..8)
// memcpy(nonce + 8, &fromNode, 4); // fromNode 写入 nonce[8..12)
// if (extraNonce)
// memcpy(nonce + 4, &extraNonce, 4); // extraNonce 覆盖 nonce[4..8)
//
// 因此 13 字节 nonce 布局为:packetId_lo(4B LE) | extraNonce_or_packetId_hi(4B LE) | fromNode(4B LE) | 0x00
func pkiNonce(packetID, fromNode, extraNonce uint32) []byte {
nonce := make([]byte, 16)
binary.LittleEndian.PutUint64(nonce[0:8], uint64(packetID)) // packetId 是 uint64,高 32 位为 0
binary.LittleEndian.PutUint32(nonce[8:12], fromNode)
if extraNonce != 0 {
binary.LittleEndian.PutUint32(nonce[4:8], extraNonce)
}
// CCM L=2 → nonce 占 15-L=13 字节
return nonce[:13]
}
// aesCCMEncrypt 使用与固件相同的参数(AES-CCM, M=8 即 8 字节 tag, L=2, 无 AAD)。
func aesCCMEncrypt(key, nonce, plaintext []byte) (ciphertext []byte, auth []byte, err error) {
if len(nonce) != 13 {
return nil, nil, fmt.Errorf("ccm nonce must be 13 bytes")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}
const tagLen = 8
if len(plaintext) > 0xffff {
return nil, nil, fmt.Errorf("plaintext too large for L=2 ccm")
}
// CBC-MAC 鉴权
var x [aes.BlockSize]byte
var b [aes.BlockSize]byte
b[0] = byte((tagLen-2)/2) << 3 // M', AAD=0 时 Adata=0
b[0] |= byte(2 - 1) // L'=L-1
copy(b[1:], nonce[:13])
binary.BigEndian.PutUint16(b[14:], uint16(len(plaintext)))
block.Encrypt(x[:], b[:])
// 鉴权明文
for offset := 0; offset < len(plaintext); offset += aes.BlockSize {
end := offset + aes.BlockSize
if end > len(plaintext) {
end = len(plaintext)
}
var blk [aes.BlockSize]byte
copy(blk[:], plaintext[offset:end])
for i := range x {
x[i] ^= blk[i]
}
block.Encrypt(x[:], x[:])
}
// CTR 流:A_i = L' | nonce | counter_be16
var a [aes.BlockSize]byte
a[0] = byte(2 - 1)
copy(a[1:], nonce[:13])
encryptCounter := func(i uint16) [aes.BlockSize]byte {
var ai [aes.BlockSize]byte
copy(ai[:], a[:])
binary.BigEndian.PutUint16(ai[14:], i)
var s [aes.BlockSize]byte
block.Encrypt(s[:], ai[:])
return s
}
ciphertext = make([]byte, len(plaintext))
for i, offset := 1, 0; offset < len(plaintext); i, offset = i+1, offset+aes.BlockSize {
s := encryptCounter(uint16(i))
end := offset + aes.BlockSize
if end > len(plaintext) {
end = len(plaintext)
}
for j := offset; j < end; j++ {
ciphertext[j] = plaintext[j] ^ s[j-offset]
}
}
// auth = T XOR S_0
s0 := encryptCounter(0)
auth = make([]byte, tagLen)
for i := 0; i < tagLen; i++ {
auth[i] = x[i] ^ s0[i]
}
return ciphertext, auth, nil
}
// aesCCMDecrypt 与 encrypt 对称,验证标签后返回明文。仅用于测试与可能的回程解密。
func aesCCMDecrypt(key, nonce, ciphertext, auth []byte) ([]byte, error) {
if len(nonce) != 13 {
return nil, fmt.Errorf("ccm nonce must be 13 bytes")
}
if len(auth) != 8 {
return nil, fmt.Errorf("ccm auth tag must be 8 bytes")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
// 先 CTR 解密
var a [aes.BlockSize]byte
a[0] = byte(2 - 1)
copy(a[1:], nonce[:13])
encryptCounter := func(i uint16) [aes.BlockSize]byte {
var ai [aes.BlockSize]byte
copy(ai[:], a[:])
binary.BigEndian.PutUint16(ai[14:], i)
var s [aes.BlockSize]byte
block.Encrypt(s[:], ai[:])
return s
}
plain := make([]byte, len(ciphertext))
for i, offset := 1, 0; offset < len(ciphertext); i, offset = i+1, offset+aes.BlockSize {
s := encryptCounter(uint16(i))
end := offset + aes.BlockSize
if end > len(ciphertext) {
end = len(ciphertext)
}
for j := offset; j < end; j++ {
plain[j] = ciphertext[j] ^ s[j-offset]
}
}
// 再 CBC-MAC 校验
var x [aes.BlockSize]byte
var b [aes.BlockSize]byte
b[0] = byte((8-2)/2) << 3
b[0] |= byte(2 - 1)
copy(b[1:], nonce[:13])
binary.BigEndian.PutUint16(b[14:], uint16(len(plain)))
block.Encrypt(x[:], b[:])
for offset := 0; offset < len(plain); offset += aes.BlockSize {
end := offset + aes.BlockSize
if end > len(plain) {
end = len(plain)
}
var blk [aes.BlockSize]byte
copy(blk[:], plain[offset:end])
for i := range x {
x[i] ^= blk[i]
}
block.Encrypt(x[:], x[:])
}
s0 := encryptCounter(0)
expected := make([]byte, 8)
for i := 0; i < 8; i++ {
expected[i] = x[i] ^ s0[i]
}
if subtle.ConstantTimeCompare(expected, auth) != 1 {
return nil, fmt.Errorf("aes-ccm auth mismatch")
}
return plain, nil
}
// buildPKIMeshPacket 构造一个 PKI 加密的 MeshPacket
// - tag 1/2: from/to (fixed32)
// - tag 3 channel = 0 (省略,默认即为 0)
// - tag 5 encrypted (含 ciphertext|auth|extraNonce)
// - tag 6 packet_id
// - tag 14 via_mqtt
// - tag 16 public_key(可选,附带发送者公钥)
// - tag 17 pki_encrypted = 1
func buildPKIMeshPacket(from, to, packetID uint32, viaMQTT bool, encrypted []byte, senderPublic []byte) []byte {
var out []byte
out = protowire.AppendTag(out, 1, protowire.Fixed32Type)
out = protowire.AppendFixed32(out, from)
out = protowire.AppendTag(out, 2, protowire.Fixed32Type)
out = protowire.AppendFixed32(out, to)
out = protowire.AppendTag(out, 5, protowire.BytesType)
out = protowire.AppendBytes(out, encrypted)
out = protowire.AppendTag(out, 6, protowire.Fixed32Type)
out = protowire.AppendFixed32(out, packetID)
if viaMQTT {
out = protowire.AppendTag(out, 14, protowire.VarintType)
out = protowire.AppendVarint(out, 1)
}
if len(senderPublic) == 32 {
out = protowire.AppendTag(out, 16, protowire.BytesType)
out = protowire.AppendBytes(out, senderPublic)
}
out = protowire.AppendTag(out, 17, protowire.VarintType)
out = protowire.AppendVarint(out, 1)
return out
}
+161
View File
@@ -0,0 +1,161 @@
package mqtpp
import (
"bytes"
"crypto/ecdh"
"crypto/rand"
"encoding/binary"
"testing"
"google.golang.org/protobuf/encoding/protowire"
)
func TestBuildPKITextMessageRoundTrip(t *testing.T) {
curve := ecdh.X25519()
senderPriv, err := curve.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("generate sender key: %v", err)
}
recipientPriv, err := curve.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("generate recipient key: %v", err)
}
const text = "hello over PKI 你好"
const fromNum uint32 = 0x12345678
const toNum uint32 = 0xa1b2c3d4
const packetID uint32 = 0xdeadbeef
raw, err := BuildPKITextMessageServiceEnvelope(PKITextMessageBuildOptions{
FromNodeNum: fromNum,
ToNodeNum: toNum,
PacketID: packetID,
GatewayID: NodeNumToID(fromNum),
ViaMQTT: true,
SenderPrivate: senderPriv.Bytes(),
RecipientPub: recipientPriv.PublicKey().Bytes(),
SenderPublic: senderPriv.PublicKey().Bytes(),
Text: text,
})
if err != nil {
t.Fatalf("BuildPKITextMessageServiceEnvelope: %v", err)
}
env, err := parseServiceEnvelope(raw)
if err != nil {
t.Fatalf("parseServiceEnvelope: %v", err)
}
if env.ChannelID != PKIChannelID {
t.Fatalf("channel_id = %q want %q", env.ChannelID, PKIChannelID)
}
if env.GatewayID != NodeNumToID(fromNum) {
t.Fatalf("gateway_id = %q", env.GatewayID)
}
pkt := env.Packet
if pkt.From != fromNum || pkt.To != toNum || pkt.ID != packetID {
t.Fatalf("packet header mismatch: %+v", pkt)
}
if !pkt.PKIEncrypted {
t.Fatalf("pki_encrypted = false")
}
if !pkt.ViaMQTT {
t.Fatalf("via_mqtt = false")
}
if pkt.Channel != 0 {
t.Fatalf("channel = %d want 0", pkt.Channel)
}
if pkt.PayloadVariant != "encrypted" || len(pkt.Encrypted) <= pkcOverhead {
t.Fatalf("encrypted payload missing: %+v", pkt)
}
// 收件人用对端私钥 + 发件人公钥推导共享密钥并解密
sharedKey, err := pkiSharedKey(recipientPriv.Bytes(), senderPriv.PublicKey().Bytes())
if err != nil {
t.Fatalf("pkiSharedKey: %v", err)
}
encryptedLen := len(pkt.Encrypted) - pkcOverhead
ciphertext := pkt.Encrypted[:encryptedLen]
auth := pkt.Encrypted[encryptedLen : encryptedLen+8]
extraNonce := binary.LittleEndian.Uint32(pkt.Encrypted[encryptedLen+8:])
plaintext, err := aesCCMDecrypt(sharedKey, pkiNonce(packetID, fromNum, extraNonce), ciphertext, auth)
if err != nil {
t.Fatalf("aesCCMDecrypt: %v", err)
}
data, err := parseDataPacket(plaintext)
if err != nil {
t.Fatalf("parseDataPacket: %v", err)
}
if data.Portnum != textMessageApp {
t.Fatalf("portnum = %d", data.Portnum)
}
if string(data.Payload) != text {
t.Fatalf("text = %q want %q", string(data.Payload), text)
}
// 同样用 MQTTPP 解析路径:PKI 包对外应被识别为 encrypted_packet(无法解密),
// 但用错的 PSK 不应误报“channel hash mismatch” 之外的奇怪错误。
dummyPSK, _ := ExpandPSK("AQ==")
_, _, record := MQTTPP("msh/2/e/PKI/!12345678", raw, dummyPSK, Options{AllowEncryptedForwarding: true})
if record["channel_id"] != PKIChannelID {
t.Fatalf("MQTTPP record channel_id = %v", record["channel_id"])
}
if record["pki_encrypted"] != true {
t.Fatalf("pki_encrypted record = %v", record["pki_encrypted"])
}
}
func TestPKINonceLayoutMatchesFirmware(t *testing.T) {
// 复刻 firmware initNonce(fromNode, packetId, extraNonce) 期望的字节布局:
// nonce[0..8) = packetId(uint64 LE)
// nonce[4..8) 被 extraNonce(uint32 LE) 覆盖(当 extraNonce != 0
// nonce[8..12) = fromNode(uint32 LE)
// nonce[12] = 0
got := pkiNonce(0xaabbccdd, 0x11223344, 0x55667788)
want := []byte{
0xdd, 0xcc, 0xbb, 0xaa, // packetId low 4 bytes,未被 extraNonce 覆盖前
0x88, 0x77, 0x66, 0x55, // extraNonce 覆盖 nonce[4..8)
0x44, 0x33, 0x22, 0x11, // fromNode
0x00,
}
if !bytes.Equal(got, want) {
t.Fatalf("pkiNonce = % x\nwant % x", got, want)
}
}
func TestBuildPKITextMessageRejectsBroadcast(t *testing.T) {
curve := ecdh.X25519()
priv, _ := curve.GenerateKey(rand.Reader)
pub, _ := curve.GenerateKey(rand.Reader)
if _, err := BuildPKITextMessageServiceEnvelope(PKITextMessageBuildOptions{
FromNodeNum: 0x1,
ToNodeNum: NodeNumBroadcast,
PacketID: 0x2,
SenderPrivate: priv.Bytes(),
RecipientPub: pub.PublicKey().Bytes(),
Text: "hi",
}); err == nil {
t.Fatalf("expected error for broadcast destination")
}
}
// 确认 MeshPacket 中确实带上 pki_encrypted (tag 17) 与 public_key (tag 16)
func TestBuildPKIMeshPacketTags(t *testing.T) {
encrypted := []byte{0x01, 0x02, 0x03}
pub := make([]byte, 32)
for i := range pub {
pub[i] = byte(i)
}
raw := buildPKIMeshPacket(0x11, 0x22, 0x33, true, encrypted, pub)
tags := map[protowire.Number]bool{}
if err := walkFields(raw, func(num protowire.Number, _ protowire.Type, _ any) error {
tags[num] = true
return nil
}); err != nil {
t.Fatalf("walkFields: %v", err)
}
for _, want := range []protowire.Number{1, 2, 5, 6, 14, 16, 17} {
if !tags[want] {
t.Fatalf("missing tag %d", want)
}
}
}