From 8c1e1ef414fef8c4391bdc2ff4aab28ba3ad2ce9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E6=96=87=E5=B3=B0?= Date: Wed, 3 Jun 2026 15:08:32 +0800 Subject: [PATCH] =?UTF-8?q?text=5Fmessage=20=E8=A1=A8=E8=AE=BE=E8=AE=A1?= =?UTF-8?q?=E4=B8=8E=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 30 +++- db.go | 470 +++++++++++++++++++++++++++++++++++++++++++-------- db_test.go | 346 ++++++++++++++++++++++++++++++------- main.go | 40 ++++- main_test.go | 43 +++++ 5 files changed, 791 insertions(+), 138 deletions(-) create mode 100644 main_test.go diff --git a/README.md b/README.md index fdadc13..e6250e4 100644 --- a/README.md +++ b/README.md @@ -99,13 +99,35 @@ meshtastic: ## 数据库持久化 -程序默认启用 SQLite,并在收到 `nodeinfo` 数据包时写入 `nodeinfo` 表。 +程序默认启用 SQLite,并持久化以下数据: -- 当前只持久化 `type == "nodeinfo"` 的记录 +- `nodeinfo_map`:融合 `type == "nodeinfo"` 和 `type == "map_report"` 的节点信息 +- `text_message`:追加保存 `type == "text_message"` 的聊天消息 + +`nodeinfo_map` 规则: + +- `nodeinfo` 表不再使用;如果旧数据库中已经存在该表,程序不会自动删除它 - 同一节点以 `node_id`(即解析结果中的 `from`,例如 `!a8dfd867`)作为主键 -- 重复收到同一节点时不会插入重复行,只更新节点字段、`content_json` 和 `updated_at` +- 重复收到同一节点时不会插入重复行,只更新 `updated_at`、`content_json`、`latest_type` 和本次记录中有值的字段 +- `nodeinfo` 独有字段和 `map_report` 独有字段会互相保留;例如后续 `map_report` 不会清空已有的 `public_key` - `first_seen_at` 保留第一次写入时间 -- `content_json` 保存完整的解析结果 JSON +- `content_json` 保存最新一次 `nodeinfo` 或 `map_report` 的完整解析结果 JSON + +`text_message` 规则: + +- 使用自增 `id` 作为主键 +- 每条聊天消息都会新增一行,不做去重 +- 保存 `from_id`、`from_num`、`text`、`payload_hex`、topic、packet 元数据和完整 `content_json` +- 保存 MQTT 客户端信息:`mqtt_client_id`、`mqtt_username`、`mqtt_listener`、`mqtt_remote_addr`、`mqtt_remote_host`、`mqtt_remote_port` + +查询最近聊天消息示例: + +```sql +SELECT id, created_at, from_id, text, mqtt_remote_host +FROM text_message +ORDER BY id DESC +LIMIT 20; +``` SQLite 默认路径: diff --git a/db.go b/db.go index 201718e..56b4dc6 100644 --- a/db.go +++ b/db.go @@ -21,17 +21,68 @@ type store struct { driver string } -type nodeInfoRecord struct { - NodeID string - NodeNum int64 - UserID any - LongName any - ShortName any - HWModel any - Role any - IsLicensed bool - PublicKey any - ContentJSON []byte +type migrationQuery struct { + name string + query string +} + +type mqttClientInfo struct { + ClientID string + Username string + Listener string + RemoteAddr string + RemoteHost string + RemotePort string +} + +type nodeInfoMapRecord struct { + NodeID string + NodeNum int64 + LatestType string + UserID any + LongName any + ShortName any + HWModel any + Role any + IsLicensed any + PublicKey any + FirmwareVersion any + Region any + ModemPreset any + Latitude any + Longitude any + Altitude any + PositionPrecision any + NumOnlineLocalNodes any + HasOptedReportLocation any + ContentJSON []byte +} + +type textMessageRecord struct { + FromID string + FromNum int64 + Text any + PayloadHex any + Topic string + ChannelID any + GatewayID any + PacketID any + PacketTo any + PacketToNum any + Portnum any + PayloadLen any + PayloadVariant any + ViaMQTT any + PKIEncrypted any + DecryptSuccess any + DecryptStatus any + MQTTClientID any + MQTTUsername any + MQTTListener any + MQTTRemoteAddr any + MQTTRemoteHost any + MQTTRemotePort any + ContentJSON []byte } func openStore(cfg databaseConfig) (*store, error) { @@ -73,50 +124,143 @@ func (s *store) Close() error { } func (s *store) migrate() error { - var query string + queries, err := s.migrationQueries() + if err != nil { + return err + } + for _, q := range queries { + if _, err := s.db.Exec(q.query); err != nil { + return fmt.Errorf("migrate %s: %w", q.name, err) + } + } + return nil +} + +func (s *store) migrationQueries() ([]migrationQuery, error) { switch s.driver { case databaseDriverSQLite: - query = `CREATE TABLE IF NOT EXISTS nodeinfo ( + return []migrationQuery{ + {name: "nodeinfo_map table", query: `CREATE TABLE IF NOT EXISTS nodeinfo_map ( node_id TEXT PRIMARY KEY, node_num INTEGER NOT NULL, + latest_type TEXT NOT NULL, user_id TEXT, long_name TEXT, short_name TEXT, hw_model TEXT, role TEXT, - is_licensed BOOLEAN NOT NULL DEFAULT FALSE, + is_licensed BOOLEAN, public_key TEXT, + firmware_version TEXT, + region TEXT, + modem_preset TEXT, + latitude REAL, + longitude REAL, + altitude INTEGER, + position_precision INTEGER, + num_online_local_nodes INTEGER, + has_opted_report_location BOOLEAN, content_json TEXT NOT NULL, first_seen_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -);` +);`}, + {name: "text_message table", query: `CREATE TABLE IF NOT EXISTS text_message ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + from_id TEXT NOT NULL, + from_num INTEGER NOT NULL, + text TEXT, + payload_hex TEXT, + topic TEXT NOT NULL, + channel_id TEXT, + gateway_id TEXT, + packet_id INTEGER, + packet_to TEXT, + packet_to_num INTEGER, + portnum TEXT, + payload_len INTEGER, + payload_variant TEXT, + via_mqtt BOOLEAN, + pki_encrypted BOOLEAN, + decrypt_success BOOLEAN, + decrypt_status TEXT, + mqtt_client_id TEXT, + mqtt_username TEXT, + mqtt_listener TEXT, + mqtt_remote_addr TEXT, + mqtt_remote_host TEXT, + mqtt_remote_port TEXT, + content_json TEXT NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +);`}, + {name: "text_message from_num index", query: `CREATE INDEX IF NOT EXISTS idx_text_message_from_num_created_at ON text_message (from_num, created_at);`}, + {name: "text_message created_at index", query: `CREATE INDEX IF NOT EXISTS idx_text_message_created_at ON text_message (created_at);`}, + {name: "text_message packet_id index", query: `CREATE INDEX IF NOT EXISTS idx_text_message_packet_id ON text_message (packet_id);`}, + }, nil case databaseDriverMySQL: - query = `CREATE TABLE IF NOT EXISTS nodeinfo ( + return []migrationQuery{ + {name: "nodeinfo_map table", query: `CREATE TABLE IF NOT EXISTS nodeinfo_map ( node_id VARCHAR(32) NOT NULL PRIMARY KEY, node_num BIGINT UNSIGNED NOT NULL, + latest_type VARCHAR(32) NOT NULL, user_id VARCHAR(128), long_name TEXT, short_name VARCHAR(64), hw_model VARCHAR(128), role VARCHAR(128), - is_licensed BOOLEAN NOT NULL DEFAULT FALSE, + is_licensed BOOLEAN, public_key TEXT, + firmware_version VARCHAR(128), + region VARCHAR(128), + modem_preset VARCHAR(128), + latitude DOUBLE, + longitude DOUBLE, + altitude INT, + position_precision INT UNSIGNED, + num_online_local_nodes INT UNSIGNED, + has_opted_report_location BOOLEAN, content_json JSON NOT NULL, first_seen_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP -);` +);`}, + {name: "text_message table", query: `CREATE TABLE IF NOT EXISTS text_message ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, + from_id VARCHAR(32) NOT NULL, + from_num BIGINT UNSIGNED NOT NULL, + text TEXT, + payload_hex TEXT, + topic TEXT NOT NULL, + channel_id VARCHAR(128), + gateway_id VARCHAR(128), + packet_id BIGINT UNSIGNED, + packet_to VARCHAR(32), + packet_to_num BIGINT UNSIGNED, + portnum VARCHAR(64), + payload_len INT UNSIGNED, + payload_variant VARCHAR(32), + via_mqtt BOOLEAN, + pki_encrypted BOOLEAN, + decrypt_success BOOLEAN, + decrypt_status VARCHAR(255), + mqtt_client_id VARCHAR(255), + mqtt_username VARCHAR(255), + mqtt_listener VARCHAR(128), + mqtt_remote_addr VARCHAR(255), + mqtt_remote_host VARCHAR(255), + mqtt_remote_port VARCHAR(16), + content_json JSON NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + INDEX idx_text_message_from_num_created_at (from_num, created_at), + INDEX idx_text_message_created_at (created_at), + INDEX idx_text_message_packet_id (packet_id) +);`}, + }, nil default: - return fmt.Errorf("unsupported database driver %q", s.driver) + return nil, fmt.Errorf("unsupported database driver %q", s.driver) } - - if _, err := s.db.Exec(query); err != nil { - return fmt.Errorf("migrate nodeinfo table: %w", err) - } - return nil } -func (s *store) UpsertNodeInfo(record map[string]any) error { - node, err := nodeInfoFromRecord(record) +func (s *store) UpsertNodeInfoMap(record map[string]any) error { + node, err := nodeInfoMapFromRecord(record) if err != nil { return err } @@ -124,37 +268,61 @@ func (s *store) UpsertNodeInfo(record map[string]any) error { var query string switch s.driver { case databaseDriverSQLite: - query = `INSERT INTO nodeinfo ( - node_id, node_num, user_id, long_name, short_name, - hw_model, role, is_licensed, public_key, content_json, - first_seen_at, updated_at -) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + query = `INSERT INTO nodeinfo_map ( + node_id, node_num, latest_type, user_id, long_name, short_name, + hw_model, role, is_licensed, public_key, firmware_version, + region, modem_preset, latitude, longitude, altitude, + position_precision, num_online_local_nodes, has_opted_report_location, + content_json, first_seen_at, updated_at +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) ON CONFLICT(node_id) DO UPDATE SET node_num = excluded.node_num, - user_id = excluded.user_id, - long_name = excluded.long_name, - short_name = excluded.short_name, - hw_model = excluded.hw_model, - role = excluded.role, - is_licensed = excluded.is_licensed, - public_key = excluded.public_key, + latest_type = excluded.latest_type, + user_id = COALESCE(excluded.user_id, nodeinfo_map.user_id), + long_name = COALESCE(excluded.long_name, nodeinfo_map.long_name), + short_name = COALESCE(excluded.short_name, nodeinfo_map.short_name), + hw_model = COALESCE(excluded.hw_model, nodeinfo_map.hw_model), + role = COALESCE(excluded.role, nodeinfo_map.role), + is_licensed = COALESCE(excluded.is_licensed, nodeinfo_map.is_licensed), + public_key = COALESCE(excluded.public_key, nodeinfo_map.public_key), + firmware_version = COALESCE(excluded.firmware_version, nodeinfo_map.firmware_version), + region = COALESCE(excluded.region, nodeinfo_map.region), + modem_preset = COALESCE(excluded.modem_preset, nodeinfo_map.modem_preset), + latitude = COALESCE(excluded.latitude, nodeinfo_map.latitude), + longitude = COALESCE(excluded.longitude, nodeinfo_map.longitude), + altitude = COALESCE(excluded.altitude, nodeinfo_map.altitude), + position_precision = COALESCE(excluded.position_precision, nodeinfo_map.position_precision), + num_online_local_nodes = COALESCE(excluded.num_online_local_nodes, nodeinfo_map.num_online_local_nodes), + has_opted_report_location = COALESCE(excluded.has_opted_report_location, nodeinfo_map.has_opted_report_location), content_json = excluded.content_json, updated_at = CURRENT_TIMESTAMP;` case databaseDriverMySQL: - query = `INSERT INTO nodeinfo ( - node_id, node_num, user_id, long_name, short_name, - hw_model, role, is_licensed, public_key, content_json, - first_seen_at, updated_at -) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + query = `INSERT INTO nodeinfo_map ( + node_id, node_num, latest_type, user_id, long_name, short_name, + hw_model, role, is_licensed, public_key, firmware_version, + region, modem_preset, latitude, longitude, altitude, + position_precision, num_online_local_nodes, has_opted_report_location, + content_json, first_seen_at, updated_at +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) ON DUPLICATE KEY UPDATE node_num = VALUES(node_num), - user_id = VALUES(user_id), - long_name = VALUES(long_name), - short_name = VALUES(short_name), - hw_model = VALUES(hw_model), - role = VALUES(role), - is_licensed = VALUES(is_licensed), - public_key = VALUES(public_key), + latest_type = VALUES(latest_type), + user_id = COALESCE(VALUES(user_id), user_id), + long_name = COALESCE(VALUES(long_name), long_name), + short_name = COALESCE(VALUES(short_name), short_name), + hw_model = COALESCE(VALUES(hw_model), hw_model), + role = COALESCE(VALUES(role), role), + is_licensed = COALESCE(VALUES(is_licensed), is_licensed), + public_key = COALESCE(VALUES(public_key), public_key), + firmware_version = COALESCE(VALUES(firmware_version), firmware_version), + region = COALESCE(VALUES(region), region), + modem_preset = COALESCE(VALUES(modem_preset), modem_preset), + latitude = COALESCE(VALUES(latitude), latitude), + longitude = COALESCE(VALUES(longitude), longitude), + altitude = COALESCE(VALUES(altitude), altitude), + position_precision = COALESCE(VALUES(position_precision), position_precision), + num_online_local_nodes = COALESCE(VALUES(num_online_local_nodes), num_online_local_nodes), + has_opted_report_location = COALESCE(VALUES(has_opted_report_location), has_opted_report_location), content_json = VALUES(content_json), updated_at = CURRENT_TIMESTAMP;` default: @@ -164,6 +332,7 @@ ON DUPLICATE KEY UPDATE _, err = s.db.Exec(query, node.NodeID, node.NodeNum, + node.LatestType, node.UserID, node.LongName, node.ShortName, @@ -171,42 +340,158 @@ ON DUPLICATE KEY UPDATE node.Role, node.IsLicensed, node.PublicKey, + node.FirmwareVersion, + node.Region, + node.ModemPreset, + node.Latitude, + node.Longitude, + node.Altitude, + node.PositionPrecision, + node.NumOnlineLocalNodes, + node.HasOptedReportLocation, string(node.ContentJSON), ) if err != nil { - return fmt.Errorf("upsert nodeinfo %s: %w", node.NodeID, err) + return fmt.Errorf("upsert nodeinfo_map %s: %w", node.NodeID, err) } return nil } -func nodeInfoFromRecord(record map[string]any) (*nodeInfoRecord, error) { - if record["type"] != "nodeinfo" { - return nil, fmt.Errorf("record type %v is not nodeinfo", record["type"]) +func (s *store) InsertTextMessage(record map[string]any, clientInfo mqttClientInfo) error { + message, err := textMessageFromRecord(record, clientInfo) + if err != nil { + return err + } + + query := `INSERT INTO text_message ( + from_id, from_num, text, payload_hex, topic, channel_id, gateway_id, + packet_id, packet_to, packet_to_num, portnum, payload_len, + payload_variant, via_mqtt, pki_encrypted, decrypt_success, decrypt_status, + mqtt_client_id, mqtt_username, mqtt_listener, mqtt_remote_addr, + mqtt_remote_host, mqtt_remote_port, content_json +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);` + + _, err = s.db.Exec(query, + message.FromID, + message.FromNum, + message.Text, + message.PayloadHex, + message.Topic, + message.ChannelID, + message.GatewayID, + message.PacketID, + message.PacketTo, + message.PacketToNum, + message.Portnum, + message.PayloadLen, + message.PayloadVariant, + message.ViaMQTT, + message.PKIEncrypted, + message.DecryptSuccess, + message.DecryptStatus, + message.MQTTClientID, + message.MQTTUsername, + message.MQTTListener, + message.MQTTRemoteAddr, + message.MQTTRemoteHost, + message.MQTTRemotePort, + string(message.ContentJSON), + ) + if err != nil { + return fmt.Errorf("insert text_message from %s: %w", message.FromID, err) + } + return nil +} + +func nodeInfoMapFromRecord(record map[string]any) (*nodeInfoMapRecord, error) { + latestType, ok := record["type"].(string) + if !ok || (latestType != "nodeinfo" && latestType != "map_report") { + return nil, fmt.Errorf("record type %v is not nodeinfo or map_report", record["type"]) } nodeID, ok := record["from"].(string) if !ok || nodeID == "" { - return nil, fmt.Errorf("nodeinfo missing from") + return nil, fmt.Errorf("nodeinfo_map missing from") } nodeNum, err := int64FromAny(record["from_num"]) if err != nil { - return nil, fmt.Errorf("nodeinfo from_num: %w", err) + return nil, fmt.Errorf("nodeinfo_map from_num: %w", err) } contentJSON, err := json.Marshal(record) if err != nil { - return nil, fmt.Errorf("encode nodeinfo content_json: %w", err) + return nil, fmt.Errorf("encode nodeinfo_map content_json: %w", err) } - return &nodeInfoRecord{ - NodeID: nodeID, - NodeNum: nodeNum, - UserID: nullableString(record["user_id"]), - LongName: nullableString(record["long_name"]), - ShortName: nullableString(record["short_name"]), - HWModel: nullableString(record["hw_model"]), - Role: nullableString(record["role"]), - IsLicensed: boolFromAny(record["is_licensed"]), - PublicKey: nullableString(record["public_key"]), - ContentJSON: contentJSON, + return &nodeInfoMapRecord{ + NodeID: nodeID, + NodeNum: nodeNum, + LatestType: latestType, + UserID: nullableString(record["user_id"]), + LongName: nullableString(record["long_name"]), + ShortName: nullableString(record["short_name"]), + HWModel: nullableString(record["hw_model"]), + Role: nullableString(record["role"]), + IsLicensed: nullableBool(record["is_licensed"]), + PublicKey: nullableString(record["public_key"]), + FirmwareVersion: nullableString(record["firmware_version"]), + Region: nullableString(record["region"]), + ModemPreset: nullableString(record["modem_preset"]), + Latitude: nullableFloat64(record["latitude"]), + Longitude: nullableFloat64(record["longitude"]), + Altitude: nullableInt64(record["altitude"]), + PositionPrecision: nullableInt64(record["position_precision"]), + NumOnlineLocalNodes: nullableInt64(record["num_online_local_nodes"]), + HasOptedReportLocation: nullableBool(record["has_opted_report_location"]), + ContentJSON: contentJSON, + }, nil +} + +func textMessageFromRecord(record map[string]any, clientInfo mqttClientInfo) (*textMessageRecord, error) { + recordType, ok := record["type"].(string) + if !ok || recordType != "text_message" { + return nil, fmt.Errorf("record type %v is not text_message", record["type"]) + } + fromID, ok := record["from"].(string) + if !ok || fromID == "" { + return nil, fmt.Errorf("text_message missing from") + } + fromNum, err := int64FromAny(record["from_num"]) + if err != nil { + return nil, fmt.Errorf("text_message from_num: %w", err) + } + topic, ok := record["topic"].(string) + if !ok || topic == "" { + return nil, fmt.Errorf("text_message missing topic") + } + contentJSON, err := json.Marshal(record) + if err != nil { + return nil, fmt.Errorf("encode text_message content_json: %w", err) + } + + return &textMessageRecord{ + FromID: fromID, + FromNum: fromNum, + Text: nullableString(record["text"]), + PayloadHex: nullableString(record["payload_hex"]), + Topic: topic, + ChannelID: nullableString(record["channel_id"]), + GatewayID: nullableString(record["gateway_id"]), + PacketID: nullableInt64(record["packet_id"]), + PacketTo: nullableString(record["packet_to"]), + PacketToNum: nullableInt64(record["packet_to_num"]), + Portnum: nullableString(record["portnum"]), + PayloadLen: nullableInt64(record["payload_len"]), + PayloadVariant: nullableString(record["payload_variant"]), + ViaMQTT: nullableBool(record["via_mqtt"]), + PKIEncrypted: nullableBool(record["pki_encrypted"]), + DecryptSuccess: nullableBool(record["decrypt_success"]), + DecryptStatus: nullableString(record["decrypt_status"]), + MQTTClientID: nullableString(clientInfo.ClientID), + MQTTUsername: nullableString(clientInfo.Username), + MQTTListener: nullableString(clientInfo.Listener), + MQTTRemoteAddr: nullableString(clientInfo.RemoteAddr), + MQTTRemoteHost: nullableString(clientInfo.RemoteHost), + MQTTRemotePort: nullableString(clientInfo.RemotePort), + ContentJSON: contentJSON, }, nil } @@ -250,7 +535,52 @@ func nullableString(value any) any { return s } -func boolFromAny(value any) bool { - b, _ := value.(bool) +func nullableBool(value any) any { + b, ok := value.(bool) + if !ok { + return nil + } return b } + +func nullableInt64(value any) any { + if value == nil { + return nil + } + v, err := int64FromAny(value) + if err != nil { + return nil + } + return v +} + +func nullableFloat64(value any) any { + switch v := value.(type) { + case float32: + return float64(v) + case float64: + return v + case int: + return float64(v) + case int8: + return float64(v) + case int16: + return float64(v) + case int32: + return float64(v) + case int64: + return float64(v) + case uint: + return float64(v) + case uint8: + return float64(v) + case uint16: + return float64(v) + case uint32: + return float64(v) + case uint64: + return float64(v) + default: + return nil + } +} diff --git a/db_test.go b/db_test.go index 6034dbf..8e7ef4d 100644 --- a/db_test.go +++ b/db_test.go @@ -7,67 +7,59 @@ import ( "testing" ) -func TestOpenStoreCreatesNodeInfoTable(t *testing.T) { +func TestOpenStoreCreatesTables(t *testing.T) { st := openTestStore(t) defer st.Close() - var name string - if err := st.db.QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'nodeinfo'").Scan(&name); err != nil { - t.Fatalf("nodeinfo table missing: %v", err) + for _, table := range []string{"nodeinfo_map", "text_message"} { + var name string + if err := st.db.QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", table).Scan(&name); err != nil { + t.Fatalf("%s table missing: %v", table, err) + } + if name != table { + t.Fatalf("table name = %q, want %s", name, table) + } } - if name != "nodeinfo" { - t.Fatalf("table name = %q, want nodeinfo", name) + + var oldCount int + if err := st.db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table' AND name = 'nodeinfo'").Scan(&oldCount); err != nil { + t.Fatal(err) + } + if oldCount != 0 { + t.Fatalf("old nodeinfo table count = %d, want 0", oldCount) } } -func TestUpsertNodeInfoInsertsAndUpdatesSameNode(t *testing.T) { +func TestUpsertNodeInfoMapInsertsAndUpdatesSameNode(t *testing.T) { st := openTestStore(t) defer st.Close() - first := map[string]any{ - "type": "nodeinfo", - "from": "!12345678", - "from_num": uint32(0x12345678), - "user_id": "!12345678", - "long_name": "first name", - "short_name": "fst", - "hw_model": "TEST_HW", - "role": "CLIENT", - "is_licensed": true, - "public_key": "abcd", - } - if err := st.UpsertNodeInfo(first); err != nil { - t.Fatalf("first UpsertNodeInfo() error = %v", err) + first := nodeInfoRecord("first name") + if err := st.UpsertNodeInfoMap(first); err != nil { + t.Fatalf("first UpsertNodeInfoMap() error = %v", err) } - second := map[string]any{ - "type": "nodeinfo", - "from": "!12345678", - "from_num": uint32(0x12345678), - "user_id": "!12345678", - "long_name": "second name", - "short_name": "snd", - "hw_model": "TEST_HW_2", - "role": "CLIENT_MUTE", - "is_licensed": false, - "public_key": nil, - } - if err := st.UpsertNodeInfo(second); err != nil { - t.Fatalf("second UpsertNodeInfo() error = %v", err) + second := nodeInfoRecord("second name") + second["short_name"] = "snd" + if err := st.UpsertNodeInfoMap(second); err != nil { + t.Fatalf("second UpsertNodeInfoMap() error = %v", err) } var count int - if err := st.db.QueryRow("SELECT COUNT(*) FROM nodeinfo WHERE node_id = ?", "!12345678").Scan(&count); err != nil { + if err := st.db.QueryRow("SELECT COUNT(*) FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&count); err != nil { t.Fatal(err) } if count != 1 { t.Fatalf("node row count = %d, want 1", count) } - var longName, content string - if err := st.db.QueryRow("SELECT long_name, content_json FROM nodeinfo WHERE node_id = ?", "!12345678").Scan(&longName, &content); err != nil { + var latestType, longName, content string + if err := st.db.QueryRow("SELECT latest_type, long_name, content_json FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&latestType, &longName, &content); err != nil { t.Fatal(err) } + if latestType != "nodeinfo" { + t.Fatalf("latest_type = %q, want nodeinfo", latestType) + } if longName != "second name" { t.Fatalf("long_name = %q, want second name", longName) } @@ -76,18 +68,221 @@ func TestUpsertNodeInfoInsertsAndUpdatesSameNode(t *testing.T) { } } -func TestUpsertNodeInfoRequiresNodeFields(t *testing.T) { +func TestUpsertNodeInfoMapMergesNodeInfoThenMapReport(t *testing.T) { st := openTestStore(t) defer st.Close() - if err := st.UpsertNodeInfo(map[string]any{"type": "nodeinfo", "from_num": 1}); err == nil || !strings.Contains(err.Error(), "from") { + if err := st.UpsertNodeInfoMap(nodeInfoRecord("node name")); err != nil { + t.Fatalf("nodeinfo UpsertNodeInfoMap() error = %v", err) + } + if err := st.UpsertNodeInfoMap(mapReportRecord("map name")); err != nil { + t.Fatalf("map_report UpsertNodeInfoMap() error = %v", err) + } + + var count int + if err := st.db.QueryRow("SELECT COUNT(*) FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&count); err != nil { + t.Fatal(err) + } + if count != 1 { + t.Fatalf("node row count = %d, want 1", count) + } + + var latestType, userID, publicKey, longName, firmware, content string + var latitude float64 + var opted sql.NullBool + if err := st.db.QueryRow("SELECT latest_type, user_id, public_key, long_name, firmware_version, latitude, has_opted_report_location, content_json FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&latestType, &userID, &publicKey, &longName, &firmware, &latitude, &opted, &content); err != nil { + t.Fatal(err) + } + if latestType != "map_report" { + t.Fatalf("latest_type = %q, want map_report", latestType) + } + if userID != "!12345678" || publicKey != "abcd" { + t.Fatalf("nodeinfo fields not preserved: user_id=%q public_key=%q", userID, publicKey) + } + if longName != "map name" { + t.Fatalf("long_name = %q, want map name", longName) + } + if firmware != "1.2.3" { + t.Fatalf("firmware = %q, want 1.2.3", firmware) + } + if latitude != 42.5 { + t.Fatalf("latitude = %v, want 42.5", latitude) + } + if !opted.Valid || opted.Bool { + t.Fatalf("has_opted_report_location = %+v, want valid false", opted) + } + if !strings.Contains(content, "map_report") { + t.Fatalf("content_json = %q, want latest map_report content", content) + } +} + +func TestUpsertNodeInfoMapMergesMapReportThenNodeInfo(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + if err := st.UpsertNodeInfoMap(mapReportRecord("map name")); err != nil { + t.Fatalf("map_report UpsertNodeInfoMap() error = %v", err) + } + if err := st.UpsertNodeInfoMap(nodeInfoRecord("node name")); err != nil { + t.Fatalf("nodeinfo UpsertNodeInfoMap() error = %v", err) + } + + var latestType, userID, longName, firmware string + var latitude float64 + if err := st.db.QueryRow("SELECT latest_type, user_id, long_name, firmware_version, latitude FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&latestType, &userID, &longName, &firmware, &latitude); err != nil { + t.Fatal(err) + } + if latestType != "nodeinfo" { + t.Fatalf("latest_type = %q, want nodeinfo", latestType) + } + if userID != "!12345678" { + t.Fatalf("user_id = %q, want !12345678", userID) + } + if longName != "node name" { + t.Fatalf("long_name = %q, want node name", longName) + } + if firmware != "1.2.3" || latitude != 42.5 { + t.Fatalf("map fields not preserved: firmware=%q latitude=%v", firmware, latitude) + } +} + +func TestUpsertNodeInfoMapRequiresNodeFields(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + if err := st.UpsertNodeInfoMap(map[string]any{"type": "nodeinfo", "from_num": 1}); err == nil || !strings.Contains(err.Error(), "from") { t.Fatalf("missing from error = %v, want from error", err) } - if err := st.UpsertNodeInfo(map[string]any{"type": "nodeinfo", "from": "!00000001"}); err == nil || !strings.Contains(err.Error(), "from_num") { + if err := st.UpsertNodeInfoMap(map[string]any{"type": "nodeinfo", "from": "!00000001"}); err == nil || !strings.Contains(err.Error(), "from_num") { t.Fatalf("missing from_num error = %v, want from_num error", err) } } +func TestNodeInfoMapFromRecordRejectsWrongType(t *testing.T) { + _, err := nodeInfoMapFromRecord(map[string]any{"type": "text_message"}) + if err == nil { + t.Fatalf("nodeInfoMapFromRecord() error = nil, want error") + } +} + +func TestNodeInfoMapNullablePublicKey(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + record := map[string]any{"type": "nodeinfo", "from": "!00000001", "from_num": 1, "public_key": nil} + if err := st.UpsertNodeInfoMap(record); err != nil { + t.Fatalf("UpsertNodeInfoMap() error = %v", err) + } + + var publicKey sql.NullString + if err := st.db.QueryRow("SELECT public_key FROM nodeinfo_map WHERE node_id = ?", "!00000001").Scan(&publicKey); err != nil { + t.Fatal(err) + } + if publicKey.Valid { + t.Fatalf("public_key valid = true, want null") + } +} + +func TestInsertTextMessageAppendsRows(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + clientInfo := mqttClientInfo{ClientID: "client-1", Username: "user-1", Listener: "tcp", RemoteAddr: "127.0.0.1:54321", RemoteHost: "127.0.0.1", RemotePort: "54321"} + if err := st.InsertTextMessage(textMessageTestRecord("hello"), clientInfo); err != nil { + t.Fatalf("first InsertTextMessage() error = %v", err) + } + if err := st.InsertTextMessage(textMessageTestRecord("hello again"), clientInfo); err != nil { + t.Fatalf("second InsertTextMessage() error = %v", err) + } + + var count int + if err := st.db.QueryRow("SELECT COUNT(*) FROM text_message WHERE from_id = ?", "!12345678").Scan(&count); err != nil { + t.Fatal(err) + } + if count != 2 { + t.Fatalf("text_message count = %d, want 2", count) + } + + rows, err := st.db.Query("SELECT id FROM text_message ORDER BY id") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + var ids []int64 + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + t.Fatal(err) + } + ids = append(ids, id) + } + if err := rows.Err(); err != nil { + t.Fatal(err) + } + if len(ids) != 2 || ids[0] <= 0 || ids[1] <= ids[0] { + t.Fatalf("ids = %v, want increasing positive ids", ids) + } +} + +func TestInsertTextMessageStoresClientInfo(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + clientInfo := mqttClientInfo{ClientID: "client-1", Username: "user-1", Listener: "tcp", RemoteAddr: "127.0.0.1:54321", RemoteHost: "127.0.0.1", RemotePort: "54321"} + if err := st.InsertTextMessage(textMessageTestRecord("hello"), clientInfo); err != nil { + t.Fatalf("InsertTextMessage() error = %v", err) + } + + var clientID, username, listener, remoteAddr, remoteHost, remotePort string + if err := st.db.QueryRow("SELECT mqtt_client_id, mqtt_username, mqtt_listener, mqtt_remote_addr, mqtt_remote_host, mqtt_remote_port FROM text_message LIMIT 1").Scan(&clientID, &username, &listener, &remoteAddr, &remoteHost, &remotePort); err != nil { + t.Fatal(err) + } + if clientID != "client-1" || username != "user-1" || listener != "tcp" || remoteAddr != "127.0.0.1:54321" || remoteHost != "127.0.0.1" || remotePort != "54321" { + t.Fatalf("client info = %q %q %q %q %q %q", clientID, username, listener, remoteAddr, remoteHost, remotePort) + } +} + +func TestInsertTextMessageStoresPayloadHex(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + record := textMessageTestRecord(nil) + record["payload_hex"] = "fffefd" + if err := st.InsertTextMessage(record, mqttClientInfo{}); err != nil { + t.Fatalf("InsertTextMessage() error = %v", err) + } + + var text sql.NullString + var payloadHex string + if err := st.db.QueryRow("SELECT text, payload_hex FROM text_message LIMIT 1").Scan(&text, &payloadHex); err != nil { + t.Fatal(err) + } + if text.Valid { + t.Fatalf("text valid = true, want null") + } + if payloadHex != "fffefd" { + t.Fatalf("payload_hex = %q, want fffefd", payloadHex) + } +} + +func TestInsertTextMessageRequiresFields(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + if err := st.InsertTextMessage(map[string]any{"type": "nodeinfo"}, mqttClientInfo{}); err == nil || !strings.Contains(err.Error(), "text_message") { + t.Fatalf("wrong type error = %v, want text_message error", err) + } + if err := st.InsertTextMessage(map[string]any{"type": "text_message", "from_num": 1, "topic": "msh/test"}, mqttClientInfo{}); err == nil || !strings.Contains(err.Error(), "from") { + t.Fatalf("missing from error = %v, want from error", err) + } + if err := st.InsertTextMessage(map[string]any{"type": "text_message", "from": "!00000001", "topic": "msh/test"}, mqttClientInfo{}); err == nil || !strings.Contains(err.Error(), "from_num") { + t.Fatalf("missing from_num error = %v, want from_num error", err) + } + if err := st.InsertTextMessage(map[string]any{"type": "text_message", "from": "!00000001", "from_num": 1}, mqttClientInfo{}); err == nil || !strings.Contains(err.Error(), "topic") { + t.Fatalf("missing topic error = %v, want topic error", err) + } +} + func openTestStore(t *testing.T) *store { t.Helper() st, err := openStore(databaseConfig{ @@ -100,27 +295,58 @@ func openTestStore(t *testing.T) *store { return st } -func TestNodeInfoFromRecordRejectsWrongType(t *testing.T) { - _, err := nodeInfoFromRecord(map[string]any{"type": "text_message"}) - if err == nil { - t.Fatalf("nodeInfoFromRecord() error = nil, want error") +func nodeInfoRecord(longName string) map[string]any { + return map[string]any{ + "type": "nodeinfo", + "from": "!12345678", + "from_num": uint32(0x12345678), + "user_id": "!12345678", + "long_name": longName, + "short_name": "nod", + "hw_model": "TEST_HW", + "role": "CLIENT", + "is_licensed": true, + "public_key": "abcd", } } -func TestNodeInfoNullablePublicKey(t *testing.T) { - st := openTestStore(t) - defer st.Close() - - record := map[string]any{"type": "nodeinfo", "from": "!00000001", "from_num": 1, "public_key": nil} - if err := st.UpsertNodeInfo(record); err != nil { - t.Fatalf("UpsertNodeInfo() error = %v", err) - } - - var publicKey sql.NullString - if err := st.db.QueryRow("SELECT public_key FROM nodeinfo WHERE node_id = ?", "!00000001").Scan(&publicKey); err != nil { - t.Fatal(err) - } - if publicKey.Valid { - t.Fatalf("public_key valid = true, want null") +func mapReportRecord(longName string) map[string]any { + return map[string]any{ + "type": "map_report", + "from": "!12345678", + "from_num": uint32(0x12345678), + "long_name": longName, + "short_name": "map", + "role": "CLIENT_MUTE", + "hw_model": "TEST_HW_2", + "firmware_version": "1.2.3", + "region": "US", + "modem_preset": "LONG_FAST", + "latitude": 42.5, + "longitude": -83.1, + "altitude": int32(200), + "position_precision": uint32(12), + "num_online_local_nodes": uint32(3), + "has_opted_report_location": false, + } +} + +func textMessageTestRecord(text any) map[string]any { + return map[string]any{ + "type": "text_message", + "topic": "msh/US/test", + "channel_id": "LongFast", + "gateway_id": "!gateway", + "from": "!12345678", + "from_num": uint32(0x12345678), + "text": text, + "packet_id": uint32(42), + "packet_to": "!ffffffff", + "packet_to_num": uint32(0xffffffff), + "portnum": "TEXT_MESSAGE_APP", + "payload_len": 5, + "payload_variant": "decoded", + "via_mqtt": true, + "pki_encrypted": false, } } diff --git a/main.go b/main.go index 9546a83..0422e35 100644 --- a/main.go +++ b/main.go @@ -45,15 +45,24 @@ func (h *meshtasticFilterHook) Provides(b byte) bool { } // OnPublish 在 broker 转发消息前校验 payload;无效消息会被拒绝并丢弃。 -func (h *meshtasticFilterHook) OnPublish(_ *mqtt.Client, pk packets.Packet) (packets.Packet, error) { +func (h *meshtasticFilterHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) { valid, _, record := mqtpp.MQTTPP(pk.TopicName, pk.Payload, h.key) if !valid { return pk, packets.ErrRejectPacket } - if record["type"] == "nodeinfo" && h.store != nil { - if err := h.store.UpsertNodeInfo(record); err != nil { - printJSON(map[string]any{"event": "db_error", "type": record["type"], "from": record["from"], "error": err.Error()}) + switch record["type"] { + case "nodeinfo", "map_report": + if h.store != nil { + if err := h.store.UpsertNodeInfoMap(record); err != nil { + printJSON(map[string]any{"event": "db_error", "type": record["type"], "from": record["from"], "error": err.Error()}) + } + } + case "text_message": + if h.store != nil { + if err := h.store.InsertTextMessage(record, mqttClientInfoFromClient(cl)); err != nil { + printJSON(map[string]any{"event": "db_error", "type": record["type"], "from": record["from"], "error": err.Error()}) + } } } if record["type"] != "empty_packet" { @@ -62,6 +71,29 @@ func (h *meshtasticFilterHook) OnPublish(_ *mqtt.Client, pk packets.Packet) (pac return pk, nil } +func mqttClientInfoFromClient(cl *mqtt.Client) mqttClientInfo { + if cl == nil { + return mqttClientInfo{} + } + + info := mqttClientInfo{ + ClientID: cl.ID, + Username: string(cl.Properties.Username), + Listener: cl.Net.Listener, + RemoteAddr: cl.Net.Remote, + } + if info.RemoteAddr == "" && cl.Net.Conn != nil && cl.Net.Conn.RemoteAddr() != nil { + info.RemoteAddr = cl.Net.Conn.RemoteAddr().String() + } + if host, port, err := net.SplitHostPort(info.RemoteAddr); err == nil { + info.RemoteHost = host + info.RemotePort = port + } else { + info.RemoteHost = info.RemoteAddr + } + return info +} + // main 是程序入口,负责解析参数并启动 MQTT broker。 func main() { cfg, err := parseArgs() diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..c603584 --- /dev/null +++ b/main_test.go @@ -0,0 +1,43 @@ +package main + +import ( + "testing" + + mqtt "github.com/mochi-mqtt/server/v2" +) + +func TestMQTTClientInfoFromClientNil(t *testing.T) { + info := mqttClientInfoFromClient(nil) + if info != (mqttClientInfo{}) { + t.Fatalf("info = %#v, want zero value", info) + } +} + +func TestMQTTClientInfoFromClientIPv4(t *testing.T) { + info := mqttClientInfoFromClient(&mqtt.Client{ + ID: "client-1", + Properties: mqtt.ClientProperties{Username: []byte("user-1")}, + Net: mqtt.ClientConnection{Listener: "tcp", Remote: "127.0.0.1:1234"}, + }) + + if info.ClientID != "client-1" || info.Username != "user-1" || info.Listener != "tcp" { + t.Fatalf("client fields = %#v", info) + } + if info.RemoteAddr != "127.0.0.1:1234" || info.RemoteHost != "127.0.0.1" || info.RemotePort != "1234" { + t.Fatalf("remote fields = %#v", info) + } +} + +func TestMQTTClientInfoFromClientIPv6(t *testing.T) { + info := mqttClientInfoFromClient(&mqtt.Client{Net: mqtt.ClientConnection{Remote: "[::1]:1234"}}) + if info.RemoteHost != "::1" || info.RemotePort != "1234" { + t.Fatalf("remote fields = %#v, want host ::1 and port 1234", info) + } +} + +func TestMQTTClientInfoFromClientUnsplitRemote(t *testing.T) { + info := mqttClientInfoFromClient(&mqtt.Client{Net: mqtt.ClientConnection{Remote: "localhost"}}) + if info.RemoteHost != "localhost" || info.RemotePort != "" { + t.Fatalf("remote fields = %#v, want host localhost and empty port", info) + } +}