Files
meshtastic_mqtt_server/mqtpp/pki_test.go
T
2026-06-14 19:41:52 +08:00

210 lines
6.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package mqtpp
import (
"bytes"
"crypto/ecdh"
"crypto/rand"
"encoding/binary"
"testing"
"google.golang.org/protobuf/encoding/protowire"
)
func TestBuildPKITextMessageRoundTrip(t *testing.T) {
curve := ecdh.X25519()
senderPriv, err := curve.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("generate sender key: %v", err)
}
recipientPriv, err := curve.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("generate recipient key: %v", err)
}
const text = "hello over PKI 你好"
const fromNum uint32 = 0x12345678
const toNum uint32 = 0xa1b2c3d4
const packetID uint32 = 0xdeadbeef
raw, err := BuildPKITextMessageServiceEnvelope(PKITextMessageBuildOptions{
FromNodeNum: fromNum,
ToNodeNum: toNum,
PacketID: packetID,
GatewayID: NodeNumToID(fromNum),
ViaMQTT: true,
SenderPrivate: senderPriv.Bytes(),
RecipientPub: recipientPriv.PublicKey().Bytes(),
SenderPublic: senderPriv.PublicKey().Bytes(),
Text: text,
})
if err != nil {
t.Fatalf("BuildPKITextMessageServiceEnvelope: %v", err)
}
env, err := parseServiceEnvelope(raw)
if err != nil {
t.Fatalf("parseServiceEnvelope: %v", err)
}
if env.ChannelID != PKIChannelID {
t.Fatalf("channel_id = %q want %q", env.ChannelID, PKIChannelID)
}
if env.GatewayID != NodeNumToID(fromNum) {
t.Fatalf("gateway_id = %q", env.GatewayID)
}
pkt := env.Packet
if pkt.From != fromNum || pkt.To != toNum || pkt.ID != packetID {
t.Fatalf("packet header mismatch: %+v", pkt)
}
if !pkt.PKIEncrypted {
t.Fatalf("pki_encrypted = false")
}
if !pkt.ViaMQTT {
t.Fatalf("via_mqtt = false")
}
if pkt.Channel != 0 {
t.Fatalf("channel = %d want 0", pkt.Channel)
}
if pkt.PayloadVariant != "encrypted" || len(pkt.Encrypted) <= pkcOverhead {
t.Fatalf("encrypted payload missing: %+v", pkt)
}
// 收件人用对端私钥 + 发件人公钥推导共享密钥并解密
sharedKey, err := pkiSharedKey(recipientPriv.Bytes(), senderPriv.PublicKey().Bytes())
if err != nil {
t.Fatalf("pkiSharedKey: %v", err)
}
encryptedLen := len(pkt.Encrypted) - pkcOverhead
ciphertext := pkt.Encrypted[:encryptedLen]
auth := pkt.Encrypted[encryptedLen : encryptedLen+8]
extraNonce := binary.LittleEndian.Uint32(pkt.Encrypted[encryptedLen+8:])
plaintext, err := aesCCMDecrypt(sharedKey, pkiNonce(packetID, fromNum, extraNonce), ciphertext, auth)
if err != nil {
t.Fatalf("aesCCMDecrypt: %v", err)
}
data, err := parseDataPacket(plaintext)
if err != nil {
t.Fatalf("parseDataPacket: %v", err)
}
if data.Portnum != textMessageApp {
t.Fatalf("portnum = %d", data.Portnum)
}
if string(data.Payload) != text {
t.Fatalf("text = %q want %q", string(data.Payload), text)
}
// 同样用 MQTTPP 解析路径:PKI 包对外应被识别为 encrypted_packet(无法解密),
// 但用错的 PSK 不应误报“channel hash mismatch” 之外的奇怪错误。
dummyPSK, _ := ExpandPSK("AQ==")
_, _, record := MQTTPP("msh/2/e/PKI/!12345678", raw, dummyPSK, Options{AllowEncryptedForwarding: true})
if record["channel_id"] != PKIChannelID {
t.Fatalf("MQTTPP record channel_id = %v", record["channel_id"])
}
if record["pki_encrypted"] != true {
t.Fatalf("pki_encrypted record = %v", record["pki_encrypted"])
}
}
func TestPKINonceLayoutMatchesFirmware(t *testing.T) {
// 复刻 firmware initNonce(fromNode, packetId, extraNonce) 期望的字节布局:
// nonce[0..8) = packetId(uint64 LE)
// nonce[4..8) 被 extraNonce(uint32 LE) 覆盖(当 extraNonce != 0
// nonce[8..12) = fromNode(uint32 LE)
// nonce[12] = 0
got := pkiNonce(0xaabbccdd, 0x11223344, 0x55667788)
want := []byte{
0xdd, 0xcc, 0xbb, 0xaa, // packetId low 4 bytes,未被 extraNonce 覆盖前
0x88, 0x77, 0x66, 0x55, // extraNonce 覆盖 nonce[4..8)
0x44, 0x33, 0x22, 0x11, // fromNode
0x00,
}
if !bytes.Equal(got, want) {
t.Fatalf("pkiNonce = % x\nwant % x", got, want)
}
}
func TestBuildPKITextMessageRejectsBroadcast(t *testing.T) {
curve := ecdh.X25519()
priv, _ := curve.GenerateKey(rand.Reader)
pub, _ := curve.GenerateKey(rand.Reader)
if _, err := BuildPKITextMessageServiceEnvelope(PKITextMessageBuildOptions{
FromNodeNum: 0x1,
ToNodeNum: NodeNumBroadcast,
PacketID: 0x2,
SenderPrivate: priv.Bytes(),
RecipientPub: pub.PublicKey().Bytes(),
Text: "hi",
}); err == nil {
t.Fatalf("expected error for broadcast destination")
}
}
// 确认 MeshPacket 中确实带上 pki_encrypted (tag 17) 与 public_key (tag 16)
func TestBuildPKIMeshPacketTags(t *testing.T) {
encrypted := []byte{0x01, 0x02, 0x03}
pub := make([]byte, 32)
for i := range pub {
pub[i] = byte(i)
}
raw := buildPKIMeshPacket(0x11, 0x22, 0x33, true, encrypted, pub)
tags := map[protowire.Number]bool{}
if err := walkFields(raw, func(num protowire.Number, _ protowire.Type, _ any) error {
tags[num] = true
return nil
}); err != nil {
t.Fatalf("walkFields: %v", err)
}
for _, want := range []protowire.Number{1, 2, 5, 6, 14, 16, 17} {
if !tags[want] {
t.Fatalf("missing tag %d", want)
}
}
}
// 端到端:发送方构造 PKI 包,接收方通过 PKIKeyResolver 解密并还原文本消息记录。
func TestMQTTPPDecryptsPKIWithResolver(t *testing.T) {
curve := ecdh.X25519()
senderPriv, _ := curve.GenerateKey(rand.Reader)
recipientPriv, _ := curve.GenerateKey(rand.Reader)
const text = "hello PKI inbound"
const fromNum uint32 = 0xaaaa1111
const toNum uint32 = 0xbbbb2222
const packetID uint32 = 0x77777777
raw, err := BuildPKITextMessageServiceEnvelope(PKITextMessageBuildOptions{
FromNodeNum: fromNum,
ToNodeNum: toNum,
PacketID: packetID,
GatewayID: NodeNumToID(fromNum),
ViaMQTT: true,
SenderPrivate: senderPriv.Bytes(),
RecipientPub: recipientPriv.PublicKey().Bytes(),
SenderPublic: senderPriv.PublicKey().Bytes(),
Text: text,
})
if err != nil {
t.Fatalf("build: %v", err)
}
resolver := func(to, from uint32) ([]byte, []byte, bool) {
if to != toNum || from != fromNum {
return nil, nil, false
}
return recipientPriv.Bytes(), senderPriv.PublicKey().Bytes(), true
}
dummyPSK, _ := ExpandPSK("AQ==")
valid, _, record := MQTTPP("msh/2/e/PKI/!aaaa1111", raw, dummyPSK, Options{PKIKeyResolver: resolver})
if !valid {
t.Fatalf("MQTTPP not valid: %#v", record)
}
if record["type"] != "text_message" {
t.Fatalf("type = %v, want text_message", record["type"])
}
if record["text"] != text {
t.Fatalf("text = %v", record["text"])
}
if record["pki_encrypted"] != true {
t.Fatalf("pki_encrypted = %v", record["pki_encrypted"])
}
}