diff --git a/db_write_queue.go b/db_write_queue.go new file mode 100644 index 0000000..5fc76c1 --- /dev/null +++ b/db_write_queue.go @@ -0,0 +1,123 @@ +package main + +import "sync" + +type dbWriteQueue struct { + store *store + jobs chan dbWriteJob + wg sync.WaitGroup +} + +type dbWriteJob struct { + typeName string + from any + run func() error + errorEvent map[string]any +} + +func newDBWriteQueue(store *store) *dbWriteQueue { + if store == nil { + return nil + } + q := &dbWriteQueue{ + store: store, + jobs: make(chan dbWriteJob, 1024), + } + q.wg.Add(1) + go q.run() + return q +} + +func (q *dbWriteQueue) EnqueueRecord(record map[string]any, clientInfo mqttClientInfo) { + if q == nil { + return + } + record = cloneDBWriteRecord(record) + switch record["type"] { + case "nodeinfo": + q.enqueue(dbWriteJob{typeName: "nodeinfo", from: record["from"], run: func() error { + return q.store.UpsertNodeInfo(record) + }}) + case "map_report": + q.enqueue(dbWriteJob{typeName: "map_report", from: record["from"], run: func() error { + return q.store.UpsertMapReport(record) + }}) + case "text_message": + q.enqueue(dbWriteJob{typeName: "text_message", from: record["from"], run: func() error { + return q.store.InsertTextMessage(record, clientInfo) + }}) + case "position": + q.enqueue(dbWriteJob{typeName: "position", from: record["from"], run: func() error { + return q.store.InsertPosition(record, clientInfo) + }}) + case "telemetry": + q.enqueue(dbWriteJob{typeName: "telemetry", from: record["from"], run: func() error { + return q.store.InsertTelemetry(record, clientInfo) + }}) + case "routing": + q.enqueue(dbWriteJob{typeName: "routing", from: record["from"], run: func() error { + return q.store.InsertRouting(record, clientInfo) + }}) + case "traceroute": + q.enqueue(dbWriteJob{typeName: "traceroute", from: record["from"], run: func() error { + return q.store.InsertTraceroute(record, clientInfo) + }}) + } +} + +func (q *dbWriteQueue) EnqueueDiscard(record map[string]any, raw []byte, clientInfo mqttClientInfo) { + if q == nil { + return + } + record = cloneDBWriteRecord(record) + raw = append([]byte(nil), raw...) + q.enqueue(dbWriteJob{typeName: "discard_details", from: record["from"], errorEvent: map[string]any{"event": "db_error", "type": "discard_details", "topic": record["topic"]}, run: func() error { + return q.store.InsertDiscardDetails(record, raw, clientInfo) + }}) +} + +func (q *dbWriteQueue) Close() { + if q == nil { + return + } + close(q.jobs) + q.wg.Wait() +} + +func (q *dbWriteQueue) Len() int { + if q == nil { + return 0 + } + return len(q.jobs) +} + +func (q *dbWriteQueue) enqueue(job dbWriteJob) { + q.jobs <- job +} + +func (q *dbWriteQueue) run() { + defer q.wg.Done() + for job := range q.jobs { + if err := job.run(); err != nil { + event := job.errorEvent + if event == nil { + event = map[string]any{"event": "db_error", "type": job.typeName, "from": job.from} + } else { + event = cloneDBWriteRecord(event) + } + event["error"] = err.Error() + printJSON(event) + } + } +} + +func cloneDBWriteRecord(record map[string]any) map[string]any { + if record == nil { + return nil + } + cloned := make(map[string]any, len(record)) + for key, value := range record { + cloned[key] = value + } + return cloned +} diff --git a/db_write_queue_test.go b/db_write_queue_test.go new file mode 100644 index 0000000..a553b57 --- /dev/null +++ b/db_write_queue_test.go @@ -0,0 +1,104 @@ +package main + +import ( + "database/sql" + "testing" +) + +func TestDBWriteQueueWritesRecordsAsync(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + queue := newDBWriteQueue(st) + record := textMessageTestRecord("queued") + queue.EnqueueRecord(record, mqttClientInfo{ClientID: "client-1"}) + record["text"] = "mutated after enqueue" + queue.Close() + + var text, clientID string + if err := rawTestDB(t, st).QueryRow("SELECT text, mqtt_client_id FROM text_message WHERE from_id = ?", "!12345678").Scan(&text, &clientID); err != nil { + t.Fatal(err) + } + if text != "queued" || clientID != "client-1" { + t.Fatalf("queued row = text %q client %q, want queued/client-1", text, clientID) + } +} + +func TestDBWriteQueueWritesDiscardAsync(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + queue := newDBWriteQueue(st) + record := map[string]any{"topic": "msh/test", "error": "bad packet"} + queue.EnqueueDiscard(record, []byte{1, 2, 3}, mqttClientInfo{RemoteAddr: "127.0.0.1:1883"}) + record["error"] = "mutated after enqueue" + queue.Close() + + var topic, reason, rawBase64, remoteAddr string + if err := rawTestDB(t, st).QueryRow("SELECT topic, error, raw_base64, mqtt_remote_addr FROM discard_details").Scan(&topic, &reason, &rawBase64, &remoteAddr); err != nil { + t.Fatal(err) + } + if topic != "msh/test" || reason != "bad packet" || rawBase64 != "AQID" || remoteAddr != "127.0.0.1:1883" { + t.Fatalf("discard row = %q/%q/%q/%q, want queued values", topic, reason, rawBase64, remoteAddr) + } +} + +func TestDBWriteQueueLen(t *testing.T) { + queue := &dbWriteQueue{jobs: make(chan dbWriteJob, 1)} + queue.enqueue(dbWriteJob{run: func() error { return nil }}) + if queue.Len() != 1 { + t.Fatalf("queue.Len() = %d, want 1", queue.Len()) + } +} + +func TestDBWriteQueueIgnoresUnsupportedRecordType(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + queue := newDBWriteQueue(st) + queue.EnqueueRecord(map[string]any{"type": "empty_packet", "from": "!12345678"}, mqttClientInfo{}) + queue.Close() + + var count int + if err := rawTestDB(t, st).QueryRow("SELECT COUNT(*) FROM text_message").Scan(&count); err != nil { + t.Fatal(err) + } + if count != 0 { + t.Fatalf("text_message count = %d, want 0", count) + } +} + +func TestDBWriteQueueNilStore(t *testing.T) { + if queue := newDBWriteQueue(nil); queue != nil { + t.Fatalf("newDBWriteQueue(nil) = %#v, want nil", queue) + } + var queue *dbWriteQueue + queue.EnqueueRecord(textMessageTestRecord("ignored"), mqttClientInfo{}) + queue.EnqueueDiscard(map[string]any{"topic": "ignored"}, []byte{1}, mqttClientInfo{}) + queue.Close() +} + +func TestDBWriteQueueRecordValidationErrorDoesNotStopWorker(t *testing.T) { + st := openTestStore(t) + defer st.Close() + + queue := newDBWriteQueue(st) + badRecord := textMessageTestRecord("bad") + delete(badRecord, "from") + queue.EnqueueRecord(badRecord, mqttClientInfo{}) + queue.EnqueueRecord(textMessageTestRecord("good"), mqttClientInfo{}) + queue.Close() + + var text string + if err := rawTestDB(t, st).QueryRow("SELECT text FROM text_message").Scan(&text); err != nil { + t.Fatal(err) + } + if text != "good" { + t.Fatalf("text = %q, want good", text) + } + + var missing sql.NullString + if err := rawTestDB(t, st).QueryRow("SELECT text FROM text_message WHERE text = ?", "bad").Scan(&missing); err != sql.ErrNoRows { + t.Fatalf("bad row error = %v, want sql.ErrNoRows", err) + } +} diff --git a/main.go b/main.go index c34d45e..a236b60 100644 --- a/main.go +++ b/main.go @@ -36,7 +36,7 @@ const ( type meshtasticFilterHook struct { mqtt.HookBase key []byte - store *store + dbQueue *dbWriteQueue stats *meshtasticMessageStats blocking *blockingCache } @@ -77,50 +77,7 @@ func (h *meshtasticFilterHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (pa } h.stats.IncForwarded() - switch record["type"] { - case "nodeinfo": - if h.store != nil { - if err := h.store.UpsertNodeInfo(record); err != nil { - printJSON(map[string]any{"event": "db_error", "type": record["type"], "from": record["from"], "error": err.Error()}) - } - } - case "map_report": - if h.store != nil { - if err := h.store.UpsertMapReport(record); err != nil { - printJSON(map[string]any{"event": "db_error", "type": record["type"], "from": record["from"], "error": err.Error()}) - } - } - case "text_message": - if h.store != nil { - if err := h.store.InsertTextMessage(record, mqttClientInfoFromClient(cl)); err != nil { - printJSON(map[string]any{"event": "db_error", "type": record["type"], "from": record["from"], "error": err.Error()}) - } - } - case "position": - if h.store != nil { - if err := h.store.InsertPosition(record, mqttClientInfoFromClient(cl)); err != nil { - printJSON(map[string]any{"event": "db_error", "type": record["type"], "from": record["from"], "error": err.Error()}) - } - } - case "telemetry": - if h.store != nil { - if err := h.store.InsertTelemetry(record, mqttClientInfoFromClient(cl)); err != nil { - printJSON(map[string]any{"event": "db_error", "type": record["type"], "from": record["from"], "error": err.Error()}) - } - } - case "routing": - if h.store != nil { - if err := h.store.InsertRouting(record, mqttClientInfoFromClient(cl)); err != nil { - printJSON(map[string]any{"event": "db_error", "type": record["type"], "from": record["from"], "error": err.Error()}) - } - } - case "traceroute": - if h.store != nil { - if err := h.store.InsertTraceroute(record, mqttClientInfoFromClient(cl)); err != nil { - printJSON(map[string]any{"event": "db_error", "type": record["type"], "from": record["from"], "error": err.Error()}) - } - } - } + h.dbQueue.EnqueueRecord(record, mqttClientInfoFromClient(cl)) if record["type"] != "empty_packet" { printJSON(record) } @@ -131,11 +88,11 @@ func (h *meshtasticFilterHook) rejectPublish(cl *mqtt.Client, pk packets.Packet, if h.stats != nil { h.stats.IncDropped() } - if h.store != nil { - if err := h.store.InsertDiscardDetails(record, pk.Payload, mqttClientInfoFromClient(cl)); err != nil { - printJSON(map[string]any{"event": "db_error", "type": "discard_details", "topic": pk.TopicName, "error": err.Error()}) - } + if record == nil { + record = map[string]any{} } + record["topic"] = pk.TopicName + h.dbQueue.EnqueueDiscard(record, pk.Payload, mqttClientInfoFromClient(cl)) } func blockingViolationForRecord(blocking *blockingCache, record map[string]any) map[string]any { @@ -246,6 +203,8 @@ func run(cfg *config) error { return err } defer store.Close() + dbQueue := newDBWriteQueue(store) + defer dbQueue.Close() if err := store.EnsureDefaultAdmin(cfg.Web.Admin.Username, cfg.Web.Admin.Password); err != nil { return err } @@ -256,7 +215,7 @@ func run(cfg *config) error { } messageStats := &meshtasticMessageStats{} - server, mqttAddr, err := startMQTTServer(cfg, store, messageStats, blocking) + server, mqttAddr, err := startMQTTServer(cfg, dbQueue, messageStats, blocking) if err != nil { return err } @@ -268,7 +227,7 @@ func run(cfg *config) error { if err != nil { return err } - mqttStatus := mqttRuntimeStatus{server: server, address: mqttAddr, tls: cfg.MQTT.TLS.Enabled, stats: messageStats} + mqttStatus := mqttRuntimeStatus{server: server, address: mqttAddr, tls: cfg.MQTT.TLS.Enabled, stats: messageStats, dbQueue: dbQueue} httpServer = newHTTPServer(cfg.Web, store, sessions, mqttStatus, blocking) webAddress := httpServer.Addr go func() { @@ -310,12 +269,12 @@ func run(cfg *config) error { return runErr } -func startMQTTServer(cfg *config, store *store, stats *meshtasticMessageStats, blocking *blockingCache) (*mqtt.Server, string, error) { +func startMQTTServer(cfg *config, dbQueue *dbWriteQueue, stats *meshtasticMessageStats, blocking *blockingCache) (*mqtt.Server, string, error) { server := mqtt.New(nil) if err := server.AddHook(new(auth.AllowHook), nil); err != nil { return nil, "", err } - if err := server.AddHook(&meshtasticFilterHook{key: cfg.key, store: store, stats: stats, blocking: blocking}, nil); err != nil { + if err := server.AddHook(&meshtasticFilterHook{key: cfg.key, dbQueue: dbQueue, stats: stats, blocking: blocking}, nil); err != nil { return nil, "", err } diff --git a/meshmap_frontend/src/components/AdminDashboard.vue b/meshmap_frontend/src/components/AdminDashboard.vue index e4804cb..46c3f12 100644 --- a/meshmap_frontend/src/components/AdminDashboard.vue +++ b/meshmap_frontend/src/components/AdminDashboard.vue @@ -62,6 +62,7 @@ onBeforeUnmount(() => {
当前连接{{ status.clients_connected }}
订阅数{{ status.subscriptions }}
转发消息{{ status.messages_sent }}
+
数据库队列{{ status.db_write_queue_length }}
丢弃消息{{ status.messages_dropped }}
收到包{{ status.packets_received }}
发送包{{ status.packets_sent }}
diff --git a/meshmap_frontend/src/types.ts b/meshmap_frontend/src/types.ts index 335864b..dc8605b 100644 --- a/meshmap_frontend/src/types.ts +++ b/meshmap_frontend/src/types.ts @@ -220,6 +220,7 @@ export interface AdminMqttStatus { messages_received: number messages_sent: number messages_dropped: number + db_write_queue_length: number retained: number inflight: number inflight_dropped: number diff --git a/mqtt_status.go b/mqtt_status.go index 642a399..4b3a53c 100644 --- a/mqtt_status.go +++ b/mqtt_status.go @@ -13,6 +13,7 @@ type mqttRuntimeStatus struct { address string tls bool stats *meshtasticMessageStats + dbQueue *dbWriteQueue } type adminMqttStatus struct { @@ -31,6 +32,7 @@ type adminMqttStatus struct { MessagesReceived int64 `json:"messages_received"` MessagesSent int64 `json:"messages_sent"` MessagesDropped int64 `json:"messages_dropped"` + DBWriteQueueLength int `json:"db_write_queue_length"` Retained int64 `json:"retained"` Inflight int64 `json:"inflight"` InflightDropped int64 `json:"inflight_dropped"` @@ -51,7 +53,7 @@ type adminMqttClient struct { func (m mqttRuntimeStatus) Status() adminMqttStatus { if m.server == nil || m.server.Info == nil { - return adminMqttStatus{Running: false, Address: m.address, TLS: m.tls} + return adminMqttStatus{Running: false, Address: m.address, TLS: m.tls, DBWriteQueueLength: m.dbQueue.Len()} } info := m.server.Info.Clone() status := adminMqttStatus{ @@ -70,6 +72,7 @@ func (m mqttRuntimeStatus) Status() adminMqttStatus { MessagesReceived: info.MessagesReceived, MessagesSent: m.stats.Forwarded(), MessagesDropped: m.stats.Dropped(), + DBWriteQueueLength: m.dbQueue.Len(), Retained: info.Retained, Inflight: info.Inflight, InflightDropped: info.InflightDropped,