From f0d6c5a96a24af129065cbd11b9716b7b84f5e60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E6=96=87=E5=B3=B0?= Date: Wed, 3 Jun 2026 15:33:56 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E6=94=B9GORM=20?= =?UTF-8?q?=E6=93=8D=E4=BD=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- db.go | 535 ++++++++++++++++++++--------------------------------- db_test.go | 33 ++-- go.mod | 11 +- go.sum | 14 ++ 5 files changed, 243 insertions(+), 352 deletions(-) diff --git a/README.md b/README.md index e6250e4..8cc9be6 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,7 @@ meshtastic: ## 数据库持久化 -程序默认启用 SQLite,并持久化以下数据: +程序默认启用 SQLite,数据库表迁移和操作由 GORM 执行,并持久化以下数据: - `nodeinfo_map`:融合 `type == "nodeinfo"` 和 `type == "map_report"` 的节点信息 - `text_message`:追加保存 `type == "text_message"` 的聊天消息 diff --git a/db.go b/db.go index 56b4dc6..9933cf7 100644 --- a/db.go +++ b/db.go @@ -1,14 +1,16 @@ package main import ( - "database/sql" "encoding/json" + "errors" "fmt" "os" "path/filepath" + "time" - _ "github.com/go-sql-driver/mysql" - _ "modernc.org/sqlite" + "github.com/glebarez/sqlite" + "gorm.io/driver/mysql" + "gorm.io/gorm" ) const ( @@ -17,15 +19,10 @@ const ( ) type store struct { - db *sql.DB + db *gorm.DB driver string } -type migrationQuery struct { - name string - query string -} - type mqttClientInfo struct { ClientID string Username string @@ -36,81 +33,97 @@ type mqttClientInfo struct { } type nodeInfoMapRecord struct { - NodeID string - NodeNum int64 - LatestType string - UserID any - LongName any - ShortName any - HWModel any - Role any - IsLicensed any - PublicKey any - FirmwareVersion any - Region any - ModemPreset any - Latitude any - Longitude any - Altitude any - PositionPrecision any - NumOnlineLocalNodes any - HasOptedReportLocation any - ContentJSON []byte + NodeID string `gorm:"column:node_id;primaryKey;not null"` + NodeNum int64 `gorm:"column:node_num;not null"` + LatestType string `gorm:"column:latest_type;not null"` + UserID *string `gorm:"column:user_id"` + LongName *string `gorm:"column:long_name"` + ShortName *string `gorm:"column:short_name"` + HWModel *string `gorm:"column:hw_model"` + Role *string `gorm:"column:role"` + IsLicensed *bool `gorm:"column:is_licensed"` + PublicKey *string `gorm:"column:public_key"` + FirmwareVersion *string `gorm:"column:firmware_version"` + Region *string `gorm:"column:region"` + ModemPreset *string `gorm:"column:modem_preset"` + Latitude *float64 `gorm:"column:latitude"` + Longitude *float64 `gorm:"column:longitude"` + Altitude *int64 `gorm:"column:altitude"` + PositionPrecision *int64 `gorm:"column:position_precision"` + NumOnlineLocalNodes *int64 `gorm:"column:num_online_local_nodes"` + HasOptedReportLocation *bool `gorm:"column:has_opted_report_location"` + ContentJSON string `gorm:"column:content_json;not null"` + FirstSeenAt time.Time `gorm:"column:first_seen_at;autoCreateTime"` + UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime"` +} + +func (nodeInfoMapRecord) TableName() string { + return "nodeinfo_map" } type textMessageRecord struct { - FromID string - FromNum int64 - Text any - PayloadHex any - Topic string - ChannelID any - GatewayID any - PacketID any - PacketTo any - PacketToNum any - Portnum any - PayloadLen any - PayloadVariant any - ViaMQTT any - PKIEncrypted any - DecryptSuccess any - DecryptStatus any - MQTTClientID any - MQTTUsername any - MQTTListener any - MQTTRemoteAddr any - MQTTRemoteHost any - MQTTRemotePort any - ContentJSON []byte + ID uint64 `gorm:"column:id;primaryKey;autoIncrement"` + FromID string `gorm:"column:from_id;not null"` + FromNum int64 `gorm:"column:from_num;not null;index:idx_text_message_from_num_created_at,priority:1"` + Text *string `gorm:"column:text"` + PayloadHex *string `gorm:"column:payload_hex"` + Topic string `gorm:"column:topic;not null"` + ChannelID *string `gorm:"column:channel_id"` + GatewayID *string `gorm:"column:gateway_id"` + PacketID *int64 `gorm:"column:packet_id;index:idx_text_message_packet_id"` + PacketTo *string `gorm:"column:packet_to"` + PacketToNum *int64 `gorm:"column:packet_to_num"` + Portnum *string `gorm:"column:portnum"` + PayloadLen *int64 `gorm:"column:payload_len"` + PayloadVariant *string `gorm:"column:payload_variant"` + ViaMQTT *bool `gorm:"column:via_mqtt"` + PKIEncrypted *bool `gorm:"column:pki_encrypted"` + DecryptSuccess *bool `gorm:"column:decrypt_success"` + DecryptStatus *string `gorm:"column:decrypt_status"` + MQTTClientID *string `gorm:"column:mqtt_client_id"` + MQTTUsername *string `gorm:"column:mqtt_username"` + MQTTListener *string `gorm:"column:mqtt_listener"` + MQTTRemoteAddr *string `gorm:"column:mqtt_remote_addr"` + MQTTRemoteHost *string `gorm:"column:mqtt_remote_host"` + MQTTRemotePort *string `gorm:"column:mqtt_remote_port"` + ContentJSON string `gorm:"column:content_json;not null"` + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime;index:idx_text_message_from_num_created_at,priority:2;index:idx_text_message_created_at"` +} + +func (textMessageRecord) TableName() string { + return "text_message" } func openStore(cfg databaseConfig) (*store, error) { - var dsn string + var dialector gorm.Dialector switch cfg.Driver { case databaseDriverSQLite: if err := os.MkdirAll(filepath.Dir(cfg.SQLite.Path), 0755); err != nil { return nil, fmt.Errorf("create sqlite directory %s: %w", filepath.Dir(cfg.SQLite.Path), err) } - dsn = cfg.SQLite.Path + dialector = sqlite.Open(cfg.SQLite.Path) case databaseDriverMySQL: - dsn = cfg.MySQL.DSN + dialector = mysql.Open(cfg.MySQL.DSN) default: return nil, fmt.Errorf("unsupported database driver %q", cfg.Driver) } - db, err := sql.Open(cfg.Driver, dsn) + db, err := gorm.Open(dialector, &gorm.Config{}) if err != nil { return nil, fmt.Errorf("open %s database: %w", cfg.Driver, err) } - if err := db.Ping(); err != nil { - db.Close() + sqlDB, err := db.DB() + if err != nil { + return nil, fmt.Errorf("get %s database handle: %w", cfg.Driver, err) + } + if err := sqlDB.Ping(); err != nil { + sqlDB.Close() return nil, fmt.Errorf("ping %s database: %w", cfg.Driver, err) } s := &store{db: db, driver: cfg.Driver} if err := s.migrate(); err != nil { - db.Close() + sqlDB.Close() return nil, err } return s, nil @@ -120,143 +133,39 @@ func (s *store) Close() error { if s == nil || s.db == nil { return nil } - return s.db.Close() -} - -func (s *store) migrate() error { - queries, err := s.migrationQueries() + sqlDB, err := s.db.DB() if err != nil { return err } - for _, q := range queries { - if _, err := s.db.Exec(q.query); err != nil { - return fmt.Errorf("migrate %s: %w", q.name, err) - } - } - return nil + return sqlDB.Close() } -func (s *store) migrationQueries() ([]migrationQuery, error) { - switch s.driver { - case databaseDriverSQLite: - return []migrationQuery{ - {name: "nodeinfo_map table", query: `CREATE TABLE IF NOT EXISTS nodeinfo_map ( - node_id TEXT PRIMARY KEY, - node_num INTEGER NOT NULL, - latest_type TEXT NOT NULL, - user_id TEXT, - long_name TEXT, - short_name TEXT, - hw_model TEXT, - role TEXT, - is_licensed BOOLEAN, - public_key TEXT, - firmware_version TEXT, - region TEXT, - modem_preset TEXT, - latitude REAL, - longitude REAL, - altitude INTEGER, - position_precision INTEGER, - num_online_local_nodes INTEGER, - has_opted_report_location BOOLEAN, - content_json TEXT NOT NULL, - first_seen_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -);`}, - {name: "text_message table", query: `CREATE TABLE IF NOT EXISTS text_message ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - from_id TEXT NOT NULL, - from_num INTEGER NOT NULL, - text TEXT, - payload_hex TEXT, - topic TEXT NOT NULL, - channel_id TEXT, - gateway_id TEXT, - packet_id INTEGER, - packet_to TEXT, - packet_to_num INTEGER, - portnum TEXT, - payload_len INTEGER, - payload_variant TEXT, - via_mqtt BOOLEAN, - pki_encrypted BOOLEAN, - decrypt_success BOOLEAN, - decrypt_status TEXT, - mqtt_client_id TEXT, - mqtt_username TEXT, - mqtt_listener TEXT, - mqtt_remote_addr TEXT, - mqtt_remote_host TEXT, - mqtt_remote_port TEXT, - content_json TEXT NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP -);`}, - {name: "text_message from_num index", query: `CREATE INDEX IF NOT EXISTS idx_text_message_from_num_created_at ON text_message (from_num, created_at);`}, - {name: "text_message created_at index", query: `CREATE INDEX IF NOT EXISTS idx_text_message_created_at ON text_message (created_at);`}, - {name: "text_message packet_id index", query: `CREATE INDEX IF NOT EXISTS idx_text_message_packet_id ON text_message (packet_id);`}, - }, nil - case databaseDriverMySQL: - return []migrationQuery{ - {name: "nodeinfo_map table", query: `CREATE TABLE IF NOT EXISTS nodeinfo_map ( - node_id VARCHAR(32) NOT NULL PRIMARY KEY, - node_num BIGINT UNSIGNED NOT NULL, - latest_type VARCHAR(32) NOT NULL, - user_id VARCHAR(128), - long_name TEXT, - short_name VARCHAR(64), - hw_model VARCHAR(128), - role VARCHAR(128), - is_licensed BOOLEAN, - public_key TEXT, - firmware_version VARCHAR(128), - region VARCHAR(128), - modem_preset VARCHAR(128), - latitude DOUBLE, - longitude DOUBLE, - altitude INT, - position_precision INT UNSIGNED, - num_online_local_nodes INT UNSIGNED, - has_opted_report_location BOOLEAN, - content_json JSON NOT NULL, - first_seen_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP -);`}, - {name: "text_message table", query: `CREATE TABLE IF NOT EXISTS text_message ( - id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, - from_id VARCHAR(32) NOT NULL, - from_num BIGINT UNSIGNED NOT NULL, - text TEXT, - payload_hex TEXT, - topic TEXT NOT NULL, - channel_id VARCHAR(128), - gateway_id VARCHAR(128), - packet_id BIGINT UNSIGNED, - packet_to VARCHAR(32), - packet_to_num BIGINT UNSIGNED, - portnum VARCHAR(64), - payload_len INT UNSIGNED, - payload_variant VARCHAR(32), - via_mqtt BOOLEAN, - pki_encrypted BOOLEAN, - decrypt_success BOOLEAN, - decrypt_status VARCHAR(255), - mqtt_client_id VARCHAR(255), - mqtt_username VARCHAR(255), - mqtt_listener VARCHAR(128), - mqtt_remote_addr VARCHAR(255), - mqtt_remote_host VARCHAR(255), - mqtt_remote_port VARCHAR(16), - content_json JSON NOT NULL, - created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - INDEX idx_text_message_from_num_created_at (from_num, created_at), - INDEX idx_text_message_created_at (created_at), - INDEX idx_text_message_packet_id (packet_id) -);`}, - }, nil - default: - return nil, fmt.Errorf("unsupported database driver %q", s.driver) - } +func (s *store) migrate() error { + return s.db.Transaction(func(tx *gorm.DB) error { + migrator := tx.Migrator() + if !migrator.HasTable(&nodeInfoMapRecord{}) { + if err := migrator.CreateTable(&nodeInfoMapRecord{}); err != nil { + return fmt.Errorf("migrate nodeinfo_map table: %w", err) + } + } + if !migrator.HasTable(&textMessageRecord{}) { + if err := migrator.CreateTable(&textMessageRecord{}); err != nil { + return fmt.Errorf("migrate text_message table: %w", err) + } + } + for _, indexName := range []string{ + "idx_text_message_from_num_created_at", + "idx_text_message_created_at", + "idx_text_message_packet_id", + } { + if !migrator.HasIndex(&textMessageRecord{}, indexName) { + if err := migrator.CreateIndex(&textMessageRecord{}, indexName); err != nil { + return fmt.Errorf("migrate text_message index %s: %w", indexName, err) + } + } + } + return nil + }) } func (s *store) UpsertNodeInfoMap(record map[string]any) error { @@ -264,140 +173,66 @@ func (s *store) UpsertNodeInfoMap(record map[string]any) error { if err != nil { return err } - - var query string - switch s.driver { - case databaseDriverSQLite: - query = `INSERT INTO nodeinfo_map ( - node_id, node_num, latest_type, user_id, long_name, short_name, - hw_model, role, is_licensed, public_key, firmware_version, - region, modem_preset, latitude, longitude, altitude, - position_precision, num_online_local_nodes, has_opted_report_location, - content_json, first_seen_at, updated_at -) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) -ON CONFLICT(node_id) DO UPDATE SET - node_num = excluded.node_num, - latest_type = excluded.latest_type, - user_id = COALESCE(excluded.user_id, nodeinfo_map.user_id), - long_name = COALESCE(excluded.long_name, nodeinfo_map.long_name), - short_name = COALESCE(excluded.short_name, nodeinfo_map.short_name), - hw_model = COALESCE(excluded.hw_model, nodeinfo_map.hw_model), - role = COALESCE(excluded.role, nodeinfo_map.role), - is_licensed = COALESCE(excluded.is_licensed, nodeinfo_map.is_licensed), - public_key = COALESCE(excluded.public_key, nodeinfo_map.public_key), - firmware_version = COALESCE(excluded.firmware_version, nodeinfo_map.firmware_version), - region = COALESCE(excluded.region, nodeinfo_map.region), - modem_preset = COALESCE(excluded.modem_preset, nodeinfo_map.modem_preset), - latitude = COALESCE(excluded.latitude, nodeinfo_map.latitude), - longitude = COALESCE(excluded.longitude, nodeinfo_map.longitude), - altitude = COALESCE(excluded.altitude, nodeinfo_map.altitude), - position_precision = COALESCE(excluded.position_precision, nodeinfo_map.position_precision), - num_online_local_nodes = COALESCE(excluded.num_online_local_nodes, nodeinfo_map.num_online_local_nodes), - has_opted_report_location = COALESCE(excluded.has_opted_report_location, nodeinfo_map.has_opted_report_location), - content_json = excluded.content_json, - updated_at = CURRENT_TIMESTAMP;` - case databaseDriverMySQL: - query = `INSERT INTO nodeinfo_map ( - node_id, node_num, latest_type, user_id, long_name, short_name, - hw_model, role, is_licensed, public_key, firmware_version, - region, modem_preset, latitude, longitude, altitude, - position_precision, num_online_local_nodes, has_opted_report_location, - content_json, first_seen_at, updated_at -) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) -ON DUPLICATE KEY UPDATE - node_num = VALUES(node_num), - latest_type = VALUES(latest_type), - user_id = COALESCE(VALUES(user_id), user_id), - long_name = COALESCE(VALUES(long_name), long_name), - short_name = COALESCE(VALUES(short_name), short_name), - hw_model = COALESCE(VALUES(hw_model), hw_model), - role = COALESCE(VALUES(role), role), - is_licensed = COALESCE(VALUES(is_licensed), is_licensed), - public_key = COALESCE(VALUES(public_key), public_key), - firmware_version = COALESCE(VALUES(firmware_version), firmware_version), - region = COALESCE(VALUES(region), region), - modem_preset = COALESCE(VALUES(modem_preset), modem_preset), - latitude = COALESCE(VALUES(latitude), latitude), - longitude = COALESCE(VALUES(longitude), longitude), - altitude = COALESCE(VALUES(altitude), altitude), - position_precision = COALESCE(VALUES(position_precision), position_precision), - num_online_local_nodes = COALESCE(VALUES(num_online_local_nodes), num_online_local_nodes), - has_opted_report_location = COALESCE(VALUES(has_opted_report_location), has_opted_report_location), - content_json = VALUES(content_json), - updated_at = CURRENT_TIMESTAMP;` - default: - return fmt.Errorf("unsupported database driver %q", s.driver) - } - - _, err = s.db.Exec(query, - node.NodeID, - node.NodeNum, - node.LatestType, - node.UserID, - node.LongName, - node.ShortName, - node.HWModel, - node.Role, - node.IsLicensed, - node.PublicKey, - node.FirmwareVersion, - node.Region, - node.ModemPreset, - node.Latitude, - node.Longitude, - node.Altitude, - node.PositionPrecision, - node.NumOnlineLocalNodes, - node.HasOptedReportLocation, - string(node.ContentJSON), - ) - if err != nil { + if err := s.upsertNodeInfoMapRecord(node); err != nil { return fmt.Errorf("upsert nodeinfo_map %s: %w", node.NodeID, err) } return nil } +func (s *store) upsertNodeInfoMapRecord(node *nodeInfoMapRecord) error { + return s.db.Transaction(func(tx *gorm.DB) error { + var existing nodeInfoMapRecord + err := tx.Where("node_id = ?", node.NodeID).Take(&existing).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + if err := tx.Create(node).Error; err != nil { + return s.updateNodeInfoMapRecord(tx, node) + } + return nil + } + if err != nil { + return err + } + return s.updateNodeInfoMapRecord(tx, node) + }) +} + +func (s *store) updateNodeInfoMapRecord(tx *gorm.DB, node *nodeInfoMapRecord) error { + updates := nodeInfoMapUpdates(node) + return tx.Model(&nodeInfoMapRecord{}).Where("node_id = ?", node.NodeID).Updates(updates).Error +} + +func nodeInfoMapUpdates(node *nodeInfoMapRecord) map[string]any { + updates := map[string]any{ + "node_num": node.NodeNum, + "latest_type": node.LatestType, + "content_json": node.ContentJSON, + "updated_at": time.Now(), + } + addStringUpdate(updates, "user_id", node.UserID) + addStringUpdate(updates, "long_name", node.LongName) + addStringUpdate(updates, "short_name", node.ShortName) + addStringUpdate(updates, "hw_model", node.HWModel) + addStringUpdate(updates, "role", node.Role) + addBoolUpdate(updates, "is_licensed", node.IsLicensed) + addStringUpdate(updates, "public_key", node.PublicKey) + addStringUpdate(updates, "firmware_version", node.FirmwareVersion) + addStringUpdate(updates, "region", node.Region) + addStringUpdate(updates, "modem_preset", node.ModemPreset) + addFloat64Update(updates, "latitude", node.Latitude) + addFloat64Update(updates, "longitude", node.Longitude) + addInt64Update(updates, "altitude", node.Altitude) + addInt64Update(updates, "position_precision", node.PositionPrecision) + addInt64Update(updates, "num_online_local_nodes", node.NumOnlineLocalNodes) + addBoolUpdate(updates, "has_opted_report_location", node.HasOptedReportLocation) + return updates +} + func (s *store) InsertTextMessage(record map[string]any, clientInfo mqttClientInfo) error { message, err := textMessageFromRecord(record, clientInfo) if err != nil { return err } - - query := `INSERT INTO text_message ( - from_id, from_num, text, payload_hex, topic, channel_id, gateway_id, - packet_id, packet_to, packet_to_num, portnum, payload_len, - payload_variant, via_mqtt, pki_encrypted, decrypt_success, decrypt_status, - mqtt_client_id, mqtt_username, mqtt_listener, mqtt_remote_addr, - mqtt_remote_host, mqtt_remote_port, content_json -) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);` - - _, err = s.db.Exec(query, - message.FromID, - message.FromNum, - message.Text, - message.PayloadHex, - message.Topic, - message.ChannelID, - message.GatewayID, - message.PacketID, - message.PacketTo, - message.PacketToNum, - message.Portnum, - message.PayloadLen, - message.PayloadVariant, - message.ViaMQTT, - message.PKIEncrypted, - message.DecryptSuccess, - message.DecryptStatus, - message.MQTTClientID, - message.MQTTUsername, - message.MQTTListener, - message.MQTTRemoteAddr, - message.MQTTRemoteHost, - message.MQTTRemotePort, - string(message.ContentJSON), - ) - if err != nil { + if err := s.db.Create(message).Error; err != nil { return fmt.Errorf("insert text_message from %s: %w", message.FromID, err) } return nil @@ -441,7 +276,7 @@ func nodeInfoMapFromRecord(record map[string]any) (*nodeInfoMapRecord, error) { PositionPrecision: nullableInt64(record["position_precision"]), NumOnlineLocalNodes: nullableInt64(record["num_online_local_nodes"]), HasOptedReportLocation: nullableBool(record["has_opted_report_location"]), - ContentJSON: contentJSON, + ContentJSON: string(contentJSON), }, nil } @@ -491,7 +326,7 @@ func textMessageFromRecord(record map[string]any, clientInfo mqttClientInfo) (*t MQTTRemoteAddr: nullableString(clientInfo.RemoteAddr), MQTTRemoteHost: nullableString(clientInfo.RemoteHost), MQTTRemotePort: nullableString(clientInfo.RemotePort), - ContentJSON: contentJSON, + ContentJSON: string(contentJSON), }, nil } @@ -524,7 +359,7 @@ func int64FromAny(value any) (int64, error) { } } -func nullableString(value any) any { +func nullableString(value any) *string { if value == nil { return nil } @@ -532,18 +367,18 @@ func nullableString(value any) any { if !ok || s == "" { return nil } - return s + return &s } -func nullableBool(value any) any { +func nullableBool(value any) *bool { b, ok := value.(bool) if !ok { return nil } - return b + return &b } -func nullableInt64(value any) any { +func nullableInt64(value any) *int64 { if value == nil { return nil } @@ -551,36 +386,62 @@ func nullableInt64(value any) any { if err != nil { return nil } - return v + return &v } -func nullableFloat64(value any) any { +func nullableFloat64(value any) *float64 { + var out float64 switch v := value.(type) { case float32: - return float64(v) + out = float64(v) case float64: - return v + out = v case int: - return float64(v) + out = float64(v) case int8: - return float64(v) + out = float64(v) case int16: - return float64(v) + out = float64(v) case int32: - return float64(v) + out = float64(v) case int64: - return float64(v) + out = float64(v) case uint: - return float64(v) + out = float64(v) case uint8: - return float64(v) + out = float64(v) case uint16: - return float64(v) + out = float64(v) case uint32: - return float64(v) + out = float64(v) case uint64: - return float64(v) + out = float64(v) default: return nil } + return &out +} + +func addStringUpdate(updates map[string]any, column string, value *string) { + if value != nil { + updates[column] = *value + } +} + +func addBoolUpdate(updates map[string]any, column string, value *bool) { + if value != nil { + updates[column] = *value + } +} + +func addInt64Update(updates map[string]any, column string, value *int64) { + if value != nil { + updates[column] = *value + } +} + +func addFloat64Update(updates map[string]any, column string, value *float64) { + if value != nil { + updates[column] = *value + } } diff --git a/db_test.go b/db_test.go index 8e7ef4d..b9f5030 100644 --- a/db_test.go +++ b/db_test.go @@ -13,7 +13,7 @@ func TestOpenStoreCreatesTables(t *testing.T) { for _, table := range []string{"nodeinfo_map", "text_message"} { var name string - if err := st.db.QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", table).Scan(&name); err != nil { + if err := rawTestDB(t, st).QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", table).Scan(&name); err != nil { t.Fatalf("%s table missing: %v", table, err) } if name != table { @@ -22,7 +22,7 @@ func TestOpenStoreCreatesTables(t *testing.T) { } var oldCount int - if err := st.db.QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table' AND name = 'nodeinfo'").Scan(&oldCount); err != nil { + if err := rawTestDB(t, st).QueryRow("SELECT COUNT(*) FROM sqlite_master WHERE type = 'table' AND name = 'nodeinfo'").Scan(&oldCount); err != nil { t.Fatal(err) } if oldCount != 0 { @@ -46,7 +46,7 @@ func TestUpsertNodeInfoMapInsertsAndUpdatesSameNode(t *testing.T) { } var count int - if err := st.db.QueryRow("SELECT COUNT(*) FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&count); err != nil { + if err := rawTestDB(t, st).QueryRow("SELECT COUNT(*) FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&count); err != nil { t.Fatal(err) } if count != 1 { @@ -54,7 +54,7 @@ func TestUpsertNodeInfoMapInsertsAndUpdatesSameNode(t *testing.T) { } var latestType, longName, content string - if err := st.db.QueryRow("SELECT latest_type, long_name, content_json FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&latestType, &longName, &content); err != nil { + if err := rawTestDB(t, st).QueryRow("SELECT latest_type, long_name, content_json FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&latestType, &longName, &content); err != nil { t.Fatal(err) } if latestType != "nodeinfo" { @@ -80,7 +80,7 @@ func TestUpsertNodeInfoMapMergesNodeInfoThenMapReport(t *testing.T) { } var count int - if err := st.db.QueryRow("SELECT COUNT(*) FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&count); err != nil { + if err := rawTestDB(t, st).QueryRow("SELECT COUNT(*) FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&count); err != nil { t.Fatal(err) } if count != 1 { @@ -90,7 +90,7 @@ func TestUpsertNodeInfoMapMergesNodeInfoThenMapReport(t *testing.T) { var latestType, userID, publicKey, longName, firmware, content string var latitude float64 var opted sql.NullBool - if err := st.db.QueryRow("SELECT latest_type, user_id, public_key, long_name, firmware_version, latitude, has_opted_report_location, content_json FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&latestType, &userID, &publicKey, &longName, &firmware, &latitude, &opted, &content); err != nil { + if err := rawTestDB(t, st).QueryRow("SELECT latest_type, user_id, public_key, long_name, firmware_version, latitude, has_opted_report_location, content_json FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&latestType, &userID, &publicKey, &longName, &firmware, &latitude, &opted, &content); err != nil { t.Fatal(err) } if latestType != "map_report" { @@ -129,7 +129,7 @@ func TestUpsertNodeInfoMapMergesMapReportThenNodeInfo(t *testing.T) { var latestType, userID, longName, firmware string var latitude float64 - if err := st.db.QueryRow("SELECT latest_type, user_id, long_name, firmware_version, latitude FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&latestType, &userID, &longName, &firmware, &latitude); err != nil { + if err := rawTestDB(t, st).QueryRow("SELECT latest_type, user_id, long_name, firmware_version, latitude FROM nodeinfo_map WHERE node_id = ?", "!12345678").Scan(&latestType, &userID, &longName, &firmware, &latitude); err != nil { t.Fatal(err) } if latestType != "nodeinfo" { @@ -175,7 +175,7 @@ func TestNodeInfoMapNullablePublicKey(t *testing.T) { } var publicKey sql.NullString - if err := st.db.QueryRow("SELECT public_key FROM nodeinfo_map WHERE node_id = ?", "!00000001").Scan(&publicKey); err != nil { + if err := rawTestDB(t, st).QueryRow("SELECT public_key FROM nodeinfo_map WHERE node_id = ?", "!00000001").Scan(&publicKey); err != nil { t.Fatal(err) } if publicKey.Valid { @@ -196,14 +196,14 @@ func TestInsertTextMessageAppendsRows(t *testing.T) { } var count int - if err := st.db.QueryRow("SELECT COUNT(*) FROM text_message WHERE from_id = ?", "!12345678").Scan(&count); err != nil { + if err := rawTestDB(t, st).QueryRow("SELECT COUNT(*) FROM text_message WHERE from_id = ?", "!12345678").Scan(&count); err != nil { t.Fatal(err) } if count != 2 { t.Fatalf("text_message count = %d, want 2", count) } - rows, err := st.db.Query("SELECT id FROM text_message ORDER BY id") + rows, err := rawTestDB(t, st).Query("SELECT id FROM text_message ORDER BY id") if err != nil { t.Fatal(err) } @@ -234,7 +234,7 @@ func TestInsertTextMessageStoresClientInfo(t *testing.T) { } var clientID, username, listener, remoteAddr, remoteHost, remotePort string - if err := st.db.QueryRow("SELECT mqtt_client_id, mqtt_username, mqtt_listener, mqtt_remote_addr, mqtt_remote_host, mqtt_remote_port FROM text_message LIMIT 1").Scan(&clientID, &username, &listener, &remoteAddr, &remoteHost, &remotePort); err != nil { + if err := rawTestDB(t, st).QueryRow("SELECT mqtt_client_id, mqtt_username, mqtt_listener, mqtt_remote_addr, mqtt_remote_host, mqtt_remote_port FROM text_message LIMIT 1").Scan(&clientID, &username, &listener, &remoteAddr, &remoteHost, &remotePort); err != nil { t.Fatal(err) } if clientID != "client-1" || username != "user-1" || listener != "tcp" || remoteAddr != "127.0.0.1:54321" || remoteHost != "127.0.0.1" || remotePort != "54321" { @@ -254,7 +254,7 @@ func TestInsertTextMessageStoresPayloadHex(t *testing.T) { var text sql.NullString var payloadHex string - if err := st.db.QueryRow("SELECT text, payload_hex FROM text_message LIMIT 1").Scan(&text, &payloadHex); err != nil { + if err := rawTestDB(t, st).QueryRow("SELECT text, payload_hex FROM text_message LIMIT 1").Scan(&text, &payloadHex); err != nil { t.Fatal(err) } if text.Valid { @@ -295,6 +295,15 @@ func openTestStore(t *testing.T) *store { return st } +func rawTestDB(t *testing.T, st *store) *sql.DB { + t.Helper() + db, err := st.db.DB() + if err != nil { + t.Fatalf("st.db.DB() error = %v", err) + } + return db +} + func nodeInfoRecord(longName string) map[string]any { return map[string]any{ "type": "nodeinfo", diff --git a/go.mod b/go.mod index fc60410..b645c61 100644 --- a/go.mod +++ b/go.mod @@ -3,24 +3,31 @@ module meshtastic_mqtt_server go 1.25.0 require ( - github.com/go-sql-driver/mysql v1.10.0 + github.com/glebarez/sqlite v1.11.0 github.com/mochi-mqtt/server/v2 v2.7.9 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 - modernc.org/sqlite v1.51.0 + gorm.io/driver/mysql v1.6.0 + gorm.io/gorm v1.31.1 ) require ( filippo.io/edwards25519 v1.2.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/glebarez/go-sqlite v1.21.2 // indirect + github.com/go-sql-driver/mysql v1.10.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rs/xid v1.4.0 // indirect golang.org/x/sys v0.42.0 // indirect + golang.org/x/text v0.21.0 // indirect modernc.org/libc v1.72.3 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect + modernc.org/sqlite v1.51.0 // indirect ) diff --git a/go.sum b/go.sum index 96ea3b3..5f9e36e 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo= +github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k= +github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= +github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= github.com/go-sql-driver/mysql v1.10.0 h1:Q+1LV8DkHJvSYAdR83XzuhDaTykuDx0l6fkXxoWCWfw= github.com/go-sql-driver/mysql v1.10.0/go.mod h1:M+cqaI7+xxXGG9swrdeUIoPG3Y3KCkF0pZej+SK+nWk= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= @@ -16,6 +20,10 @@ github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/jinzhu/copier v0.3.5 h1:GlvfUwHk62RokgqVNvYsku0TATCF7bAHVwEXoBh3iJg= github.com/jinzhu/copier v0.3.5/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mochi-mqtt/server/v2 v2.7.9 h1:y0g4vrSLAag7T07l2oCzOa/+nKVLoazKEWAArwqBNYI= @@ -37,6 +45,8 @@ golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= @@ -45,6 +55,10 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg= +gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY= modernc.org/cc/v4 v4.28.2/go.mod h1:OnovgIhbbMXMu1aISnJ0wvVD1KnW+cAUJkIrAWh+kVI= modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ=