diff --git a/admin_mqtt_forward_routes.go b/admin_mqtt_forward_routes.go new file mode 100644 index 0000000..bf946ea --- /dev/null +++ b/admin_mqtt_forward_routes.go @@ -0,0 +1,281 @@ +package main + +import ( + "errors" + "net/http" + "strconv" + + "github.com/gin-gonic/gin" + "gorm.io/gorm" +) + +type mqttForwarderRequest struct { + Name string `json:"name"` + Enabled bool `json:"enabled"` + SourceHost string `json:"source_host"` + SourcePort int `json:"source_port"` + SourceUsername string `json:"source_username"` + SourcePassword *string `json:"source_password"` + SourcePasswordClear bool `json:"source_password_clear"` + SourceClientID string `json:"source_client_id"` + SourceTLS bool `json:"source_tls"` + TargetHost string `json:"target_host"` + TargetPort int `json:"target_port"` + TargetUsername string `json:"target_username"` + TargetPassword *string `json:"target_password"` + TargetPasswordClear bool `json:"target_password_clear"` + TargetClientID string `json:"target_client_id"` + TargetTLS bool `json:"target_tls"` +} + +type mqttForwardTopicRequest struct { + Topic string `json:"topic"` + Enabled bool `json:"enabled"` + Direction string `json:"direction"` + SourcePrefix string `json:"source_prefix"` + TargetPrefix string `json:"target_prefix"` + QoS int `json:"qos"` + Retain bool `json:"retain"` +} + +func registerAdminMQTTForwardRoutes(r gin.IRouter, store *store, forwarder mqttForwardReloader) { + r.GET("/mqtt-forward/forwarders", func(c *gin.Context) { + opts, ok := parseListOptions(c) + if !ok { + return + } + rows, err := store.ListMQTTForwarders(opts) + if err != nil { + writeListResponse(c, rows, opts, err, mqttForwarderDTO) + return + } + total, err := store.CountMQTTForwarders(opts) + writeListResponseWithTotal(c, rows, opts, total, err, mqttForwarderDTO) + }) + r.POST("/mqtt-forward/forwarders", func(c *gin.Context) { + var req mqttForwarderRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid mqtt forwarder request"}) + return + } + input := mqttForwarderInputFromRequest(req) + row, err := store.CreateMQTTForwarder(input) + writeMQTTForwardMutationResponse(c, http.StatusCreated, row, err, func() error { + return reloadMQTTForwarder(forwarder, row.ID) + }) + }) + r.PUT("/mqtt-forward/forwarders/:id", func(c *gin.Context) { + id, ok := parseMQTTForwardID(c, "invalid mqtt forwarder id") + if !ok { + return + } + var req mqttForwarderRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid mqtt forwarder request"}) + return + } + input := mqttForwarderInputFromRequest(req) + row, err := store.UpdateMQTTForwarder(id, input) + writeMQTTForwardMutationResponse(c, http.StatusOK, row, err, func() error { + return reloadMQTTForwarder(forwarder, id) + }) + }) + r.DELETE("/mqtt-forward/forwarders/:id", func(c *gin.Context) { + id, ok := parseMQTTForwardID(c, "invalid mqtt forwarder id") + if !ok { + return + } + if forwarder != nil { + forwarder.StopForwarder(id) + } + writeMQTTForwardDeleteResponse(c, store.DeleteMQTTForwarder(id), nil) + }) + r.POST("/mqtt-forward/forwarders/:id/restart", func(c *gin.Context) { + id, ok := parseMQTTForwardID(c, "invalid mqtt forwarder id") + if !ok { + return + } + if err := reloadMQTTForwarder(forwarder, id); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + }) + r.GET("/mqtt-forward/forwarders/:id/topics", func(c *gin.Context) { + id, ok := parseMQTTForwardID(c, "invalid mqtt forwarder id") + if !ok { + return + } + opts, ok := parseListOptions(c) + if !ok { + return + } + rows, err := store.ListMQTTForwardTopics(id, opts) + if err != nil { + writeListResponse(c, rows, opts, err, mqttForwardTopicDTO) + return + } + total, err := store.CountMQTTForwardTopics(id) + writeListResponseWithTotal(c, rows, opts, total, err, mqttForwardTopicDTO) + }) + r.POST("/mqtt-forward/forwarders/:id/topics", func(c *gin.Context) { + id, ok := parseMQTTForwardID(c, "invalid mqtt forwarder id") + if !ok { + return + } + var req mqttForwardTopicRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid mqtt forward topic request"}) + return + } + row, err := store.CreateMQTTForwardTopic(id, mqttForwardTopicInputFromRequest(req)) + writeMQTTForwardTopicMutationResponse(c, http.StatusCreated, row, err, func() error { + return reloadMQTTForwarder(forwarder, id) + }) + }) + r.PUT("/mqtt-forward/topics/:id", func(c *gin.Context) { + id, ok := parseMQTTForwardID(c, "invalid mqtt forward topic id") + if !ok { + return + } + var req mqttForwardTopicRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid mqtt forward topic request"}) + return + } + row, err := store.UpdateMQTTForwardTopic(id, mqttForwardTopicInputFromRequest(req)) + writeMQTTForwardTopicMutationResponse(c, http.StatusOK, row, err, func() error { + return reloadMQTTForwarder(forwarder, row.ForwarderID) + }) + }) + r.DELETE("/mqtt-forward/topics/:id", func(c *gin.Context) { + id, ok := parseMQTTForwardID(c, "invalid mqtt forward topic id") + if !ok { + return + } + row, err := store.GetMQTTForwardTopic(id) + if errors.Is(err, gorm.ErrRecordNotFound) { + c.JSON(http.StatusNotFound, gin.H{"error": "mqtt forward topic not found"}) + return + } + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + parentID := row.ForwarderID + writeMQTTForwardDeleteResponse(c, store.DeleteMQTTForwardTopic(id), func() error { + return reloadMQTTForwarder(forwarder, parentID) + }) + }) + r.GET("/mqtt-forward/status", func(c *gin.Context) { + items := []mqttForwardRuntimeStatus{} + if forwarder != nil { + items = forwarder.Status() + } + c.JSON(http.StatusOK, gin.H{"items": items}) + }) +} + +func mqttForwarderInputFromRequest(req mqttForwarderRequest) mqttForwarderInput { + sourcePassword := req.SourcePassword + if req.SourcePasswordClear { + empty := "" + sourcePassword = &empty + } + targetPassword := req.TargetPassword + if req.TargetPasswordClear { + empty := "" + targetPassword = &empty + } + return mqttForwarderInput{Name: req.Name, Enabled: req.Enabled, SourceHost: req.SourceHost, SourcePort: req.SourcePort, SourceUsername: req.SourceUsername, SourcePassword: sourcePassword, SourceClientID: req.SourceClientID, SourceTLS: req.SourceTLS, TargetHost: req.TargetHost, TargetPort: req.TargetPort, TargetUsername: req.TargetUsername, TargetPassword: targetPassword, TargetClientID: req.TargetClientID, TargetTLS: req.TargetTLS} +} + +func mqttForwardTopicInputFromRequest(req mqttForwardTopicRequest) mqttForwardTopicInput { + return mqttForwardTopicInput{Topic: req.Topic, Enabled: req.Enabled, Direction: req.Direction, SourcePrefix: req.SourcePrefix, TargetPrefix: req.TargetPrefix, QoS: req.QoS, Retain: req.Retain} +} + +func parseMQTTForwardID(c *gin.Context, message string) (uint64, bool) { + id, err := strconv.ParseUint(c.Param("id"), 10, 64) + if err != nil || id == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": message}) + return 0, false + } + return id, true +} + +func reloadMQTTForwarder(forwarder mqttForwardReloader, id uint64) error { + if forwarder == nil { + return nil + } + return forwarder.ReloadForwarder(id) +} + +func writeMQTTForwardMutationResponse(c *gin.Context, status int, row *mqttForwarderRecord, err error, afterSuccess func() error) { + if errors.Is(err, errMQTTForwarderAlreadyExists) { + c.JSON(http.StatusConflict, gin.H{"error": "mqtt forwarder already exists"}) + return + } + if errors.Is(err, gorm.ErrRecordNotFound) { + c.JSON(http.StatusNotFound, gin.H{"error": "mqtt forwarder not found"}) + return + } + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if afterSuccess != nil { + if err := afterSuccess(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "mqtt forwarder saved but reload failed: " + err.Error()}) + return + } + } + c.JSON(status, gin.H{"item": mqttForwarderDTO(*row)}) +} + +func writeMQTTForwardTopicMutationResponse(c *gin.Context, status int, row *mqttForwardTopicRecord, err error, afterSuccess func() error) { + if errors.Is(err, errMQTTForwardTopicAlreadyExists) { + c.JSON(http.StatusConflict, gin.H{"error": "mqtt forward topic already exists"}) + return + } + if errors.Is(err, gorm.ErrRecordNotFound) { + c.JSON(http.StatusNotFound, gin.H{"error": "mqtt forward topic not found"}) + return + } + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if afterSuccess != nil { + if err := afterSuccess(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "mqtt forward topic saved but reload failed: " + err.Error()}) + return + } + } + c.JSON(status, gin.H{"item": mqttForwardTopicDTO(*row)}) +} + +func writeMQTTForwardDeleteResponse(c *gin.Context, err error, afterSuccess func() error) { + if errors.Is(err, gorm.ErrRecordNotFound) { + c.JSON(http.StatusNotFound, gin.H{"error": "mqtt forward item not found"}) + return + } + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + if afterSuccess != nil { + if err := afterSuccess(); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "mqtt forward item deleted but reload failed: " + err.Error()}) + return + } + } + c.JSON(http.StatusOK, gin.H{"status": "ok"}) +} + +func mqttForwarderDTO(row mqttForwarderRecord) gin.H { + return gin.H{"id": row.ID, "name": row.Name, "enabled": row.Enabled, "source_host": row.SourceHost, "source_port": row.SourcePort, "source_username": row.SourceUsername, "source_password_set": row.SourcePassword != "", "source_client_id": row.SourceClientID, "source_tls": row.SourceTLS, "target_host": row.TargetHost, "target_port": row.TargetPort, "target_username": row.TargetUsername, "target_password_set": row.TargetPassword != "", "target_client_id": row.TargetClientID, "target_tls": row.TargetTLS, "created_at": row.CreatedAt, "updated_at": row.UpdatedAt} +} + +func mqttForwardTopicDTO(row mqttForwardTopicRecord) gin.H { + return gin.H{"id": row.ID, "forwarder_id": row.ForwarderID, "topic": row.Topic, "enabled": row.Enabled, "direction": row.Direction, "source_prefix": row.SourcePrefix, "target_prefix": row.TargetPrefix, "qos": row.QoS, "retain": row.Retain, "created_at": row.CreatedAt, "updated_at": row.UpdatedAt} +} diff --git a/db.go b/db.go index 5d14548..9453890 100644 --- a/db.go +++ b/db.go @@ -153,6 +153,48 @@ func (forbiddenWordBlockingRecord) TableName() string { return "forbidden_word_blocking" } +type mqttForwarderRecord struct { + ID uint64 `gorm:"column:id;primaryKey;autoIncrement"` + Name string `gorm:"column:name;not null;uniqueIndex"` + Enabled bool `gorm:"column:enabled;not null;index"` + SourceHost string `gorm:"column:source_host;not null"` + SourcePort int `gorm:"column:source_port;not null"` + SourceUsername string `gorm:"column:source_username"` + SourcePassword string `gorm:"column:source_password"` + SourceClientID string `gorm:"column:source_client_id"` + SourceTLS bool `gorm:"column:source_tls;not null"` + TargetHost string `gorm:"column:target_host;not null"` + TargetPort int `gorm:"column:target_port;not null"` + TargetUsername string `gorm:"column:target_username"` + TargetPassword string `gorm:"column:target_password"` + TargetClientID string `gorm:"column:target_client_id"` + TargetTLS bool `gorm:"column:target_tls;not null"` + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"` + UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime;index"` +} + +func (mqttForwarderRecord) TableName() string { + return "mqtt_forwarders" +} + +type mqttForwardTopicRecord struct { + ID uint64 `gorm:"column:id;primaryKey;autoIncrement"` + ForwarderID uint64 `gorm:"column:forwarder_id;not null;index;uniqueIndex:idx_mqtt_forward_topic_unique,priority:1"` + Topic string `gorm:"column:topic;not null;uniqueIndex:idx_mqtt_forward_topic_unique,priority:2"` + Enabled bool `gorm:"column:enabled;not null;index"` + Direction string `gorm:"column:direction;not null;index"` + SourcePrefix string `gorm:"column:source_prefix"` + TargetPrefix string `gorm:"column:target_prefix"` + QoS int `gorm:"column:qos;not null"` + Retain bool `gorm:"column:retain;not null"` + CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"` + UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime;index"` +} + +func (mqttForwardTopicRecord) TableName() string { + return "mqtt_forward_topics" +} + type nodeInfoRecord struct { NodeID string `gorm:"column:node_id;primaryKey;not null"` NodeNum int64 `gorm:"column:node_num;not null;index"` @@ -351,6 +393,8 @@ func (s *store) migrate() error { {label: "node_blocking", model: &nodeBlockingRecord{}}, {label: "ip_blocking", model: &ipBlockingRecord{}}, {label: "forbidden_word_blocking", model: &forbiddenWordBlockingRecord{}}, + {label: "mqtt_forwarders", model: &mqttForwarderRecord{}}, + {label: "mqtt_forward_topics", model: &mqttForwardTopicRecord{}}, {label: "nodeinfo", model: &nodeInfoRecord{}}, {label: "map_report", model: &mapReportRecord{}}, {label: "text_message", model: &textMessageRecord{}}, diff --git a/go.mod b/go.mod index 04491df..4fa472b 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/eclipse/paho.mqtt.golang v1.5.1 // indirect github.com/gabriel-vasile/mimetype v1.4.12 // indirect github.com/gin-contrib/sse v1.1.0 // indirect github.com/glebarez/go-sqlite v1.21.2 // indirect @@ -50,6 +51,7 @@ require ( golang.org/x/arch v0.22.0 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/net v0.51.0 // indirect + golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.34.0 // indirect modernc.org/libc v1.72.3 // indirect diff --git a/go.sum b/go.sum index 934b335..2fb93ad 100644 --- a/go.sum +++ b/go.sum @@ -13,6 +13,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/eclipse/paho.mqtt.golang v1.5.1 h1:/VSOv3oDLlpqR2Epjn1Q7b2bSTplJIeV2ISgCl2W7nE= +github.com/eclipse/paho.mqtt.golang v1.5.1/go.mod h1:1/yJCneuyOoCOzKSsOTUc0AJfpsItBGWvYpBLimhArU= github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw= github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= diff --git a/main.go b/main.go index a236b60..c093c52 100644 --- a/main.go +++ b/main.go @@ -219,6 +219,12 @@ func run(cfg *config) error { if err != nil { return err } + forwardManager := newMQTTForwardManager(store) + if err := forwardManager.StartFromStore(); err != nil { + server.Close() + return err + } + defer forwardManager.StopAll() var httpServer *http.Server errCh := make(chan error, 1) @@ -228,7 +234,7 @@ func run(cfg *config) error { return err } mqttStatus := mqttRuntimeStatus{server: server, address: mqttAddr, tls: cfg.MQTT.TLS.Enabled, stats: messageStats, dbQueue: dbQueue} - httpServer = newHTTPServer(cfg.Web, store, sessions, mqttStatus, blocking) + httpServer = newHTTPServer(cfg.Web, store, sessions, mqttStatus, blocking, forwardManager) webAddress := httpServer.Addr go func() { if cfg.Web.SocketPath != "" { diff --git a/meshmap_frontend/src/App.vue b/meshmap_frontend/src/App.vue index e2d11d4..f08297c 100644 --- a/meshmap_frontend/src/App.vue +++ b/meshmap_frontend/src/App.vue @@ -6,6 +6,7 @@ import AdminDashboard from './components/AdminDashboard.vue' import AdminDiscardDetails from './components/AdminDiscardDetails.vue' import AdminLogin from './components/AdminLogin.vue' import AdminLoginLogs from './components/AdminLoginLogs.vue' +import AdminMqttForward from './components/AdminMqttForward.vue' import AdminUsers from './components/AdminUsers.vue' import ChatPanel from './components/ChatPanel.vue' import ConfirmDeleteModal from './components/ConfirmDeleteModal.vue' @@ -18,6 +19,7 @@ import type { AdminUser, HealthStatus, MapBoundsChangePayload, MapBoundsQuery, M const currentPath = window.location.pathname const adminPath = currentPath const isAdminPage = adminPath.startsWith('/admin') +const isMqttForwardAdminPage = adminPath === '/admin/mqtt_forward' || adminPath === '/admin/mqtt_forward/' const detailMatch = currentPath.match(/^\/detailed\/(.+)$/) const detailedNodeId = detailMatch ? decodeURIComponent(detailMatch[1]) : '' const isDetailedPage = !!detailedNodeId @@ -454,6 +456,7 @@ onBeforeUnmount(() => { 服务状态 用户管理 屏蔽管理 + MQTT转发 登录日志 丢弃数据 @@ -491,6 +494,7 @@ onBeforeUnmount(() => { + diff --git a/meshmap_frontend/src/api.ts b/meshmap_frontend/src/api.ts index ff8cecd..68c9d5b 100644 --- a/meshmap_frontend/src/api.ts +++ b/meshmap_frontend/src/api.ts @@ -15,6 +15,12 @@ import type { MapBoundsQuery, MapReport, MapViewportResponse, + MQTTForwarder, + MQTTForwarderPayload, + MQTTForwardMutationResponse, + MQTTForwardStatusResponse, + MQTTForwardTopic, + MQTTForwardTopicPayload, NodeBlockingRule, NodeBlockingRulePayload, NodeInfo, @@ -214,3 +220,43 @@ export function updateForbiddenWordBlockingRule(id: number, payload: ForbiddenWo export function deleteForbiddenWordBlockingRule(id: number): Promise<{ status: string }> { return deleteJSON<{ status: string }>(`/api/admin/blocking/words/${id}`) } + +export function getMQTTForwarders(limit = 100, offset = 0): Promise> { + return getJSON>(listPath('/api/admin/mqtt-forward/forwarders', limit, offset)) +} + +export function createMQTTForwarder(payload: MQTTForwarderPayload): Promise> { + return postJSON>('/api/admin/mqtt-forward/forwarders', payload) +} + +export function updateMQTTForwarder(id: number, payload: MQTTForwarderPayload): Promise> { + return putJSON>(`/api/admin/mqtt-forward/forwarders/${id}`, payload) +} + +export function deleteMQTTForwarder(id: number): Promise<{ status: string }> { + return deleteJSON<{ status: string }>(`/api/admin/mqtt-forward/forwarders/${id}`) +} + +export function restartMQTTForwarder(id: number): Promise<{ status: string }> { + return postJSON<{ status: string }>(`/api/admin/mqtt-forward/forwarders/${id}/restart`) +} + +export function getMQTTForwardTopics(forwarderId: number, limit = 100, offset = 0): Promise> { + return getJSON>(listPath(`/api/admin/mqtt-forward/forwarders/${forwarderId}/topics`, limit, offset)) +} + +export function createMQTTForwardTopic(forwarderId: number, payload: MQTTForwardTopicPayload): Promise> { + return postJSON>(`/api/admin/mqtt-forward/forwarders/${forwarderId}/topics`, payload) +} + +export function updateMQTTForwardTopic(id: number, payload: MQTTForwardTopicPayload): Promise> { + return putJSON>(`/api/admin/mqtt-forward/topics/${id}`, payload) +} + +export function deleteMQTTForwardTopic(id: number): Promise<{ status: string }> { + return deleteJSON<{ status: string }>(`/api/admin/mqtt-forward/topics/${id}`) +} + +export function getMQTTForwardStatus(): Promise { + return getJSON('/api/admin/mqtt-forward/status') +} diff --git a/meshmap_frontend/src/components/AdminMqttForward.vue b/meshmap_frontend/src/components/AdminMqttForward.vue new file mode 100644 index 0000000..6b86acf --- /dev/null +++ b/meshmap_frontend/src/components/AdminMqttForward.vue @@ -0,0 +1,923 @@ + + + + + diff --git a/meshmap_frontend/src/types.ts b/meshmap_frontend/src/types.ts index 83a1a02..63e275a 100644 --- a/meshmap_frontend/src/types.ts +++ b/meshmap_frontend/src/types.ts @@ -285,3 +285,87 @@ export interface ForbiddenWordBlockingRulePayload { export interface BlockingRuleResponse { item: T } + +export type MQTTForwardDirection = 'source_to_target' | 'bidirectional' + +export interface MQTTForwarder { + id: number + name: string + enabled: boolean + source_host: string + source_port: number + source_username: string + source_password_set: boolean + source_client_id: string + source_tls: boolean + target_host: string + target_port: number + target_username: string + target_password_set: boolean + target_client_id: string + target_tls: boolean + created_at: string + updated_at: string +} + +export interface MQTTForwarderPayload { + name: string + enabled: boolean + source_host: string + source_port: number + source_username: string + source_password?: string + source_password_clear?: boolean + source_client_id: string + source_tls: boolean + target_host: string + target_port: number + target_username: string + target_password?: string + target_password_clear?: boolean + target_client_id: string + target_tls: boolean +} + +export interface MQTTForwardTopic { + id: number + forwarder_id: number + topic: string + enabled: boolean + direction: MQTTForwardDirection + source_prefix: string + target_prefix: string + qos: number + retain: boolean + created_at: string + updated_at: string +} + +export interface MQTTForwardTopicPayload { + topic: string + enabled: boolean + direction: MQTTForwardDirection + source_prefix: string + target_prefix: string + qos: number + retain: boolean +} + +export interface MQTTForwardRuntimeStatus { + forwarder_id: number + running: boolean + source_connected: boolean + target_connected: boolean + last_error: string + started_at: string | null + messages_forwarded: number + messages_dropped: number +} + +export interface MQTTForwardMutationResponse { + item: T +} + +export interface MQTTForwardStatusResponse { + items: MQTTForwardRuntimeStatus[] +} diff --git a/mqtt_forward_manager.go b/mqtt_forward_manager.go new file mode 100644 index 0000000..83e31fa --- /dev/null +++ b/mqtt_forward_manager.go @@ -0,0 +1,408 @@ +package main + +import ( + "context" + "crypto/sha256" + "crypto/tls" + "encoding/hex" + "errors" + "fmt" + "net" + "sort" + "strings" + "sync" + "time" + + pahomqtt "github.com/eclipse/paho.mqtt.golang" + "gorm.io/gorm" +) + +const ( + mqttForwardDirectionTargetToSource = "target_to_source" + mqttForwardLoopTTL = 15 * time.Second + mqttForwardLoopMaxEntries = 10000 +) + +type mqttForwardReloader interface { + ReloadForwarder(id uint64) error + StopForwarder(id uint64) + Status() []mqttForwardRuntimeStatus +} + +type mqttForwardManager struct { + store *store + mu sync.Mutex + runners map[uint64]*mqttForwardRunner +} + +type mqttForwardRuntimeStatus struct { + ForwarderID uint64 `json:"forwarder_id"` + Running bool `json:"running"` + SourceConnected bool `json:"source_connected"` + TargetConnected bool `json:"target_connected"` + LastError string `json:"last_error"` + StartedAt *time.Time `json:"started_at"` + MessagesForwarded uint64 `json:"messages_forwarded"` + MessagesDropped uint64 `json:"messages_dropped"` +} + +type mqttForwardRunner struct { + config mqttForwarderConfig + ctx context.Context + cancel context.CancelFunc + source pahomqtt.Client + target pahomqtt.Client + + mu sync.Mutex + lastError string + startedAt time.Time + sourceConnected bool + targetConnected bool + messagesForwarded uint64 + messagesDropped uint64 + loopCache map[string]time.Time +} + +func newMQTTForwardManager(store *store) *mqttForwardManager { + return &mqttForwardManager{store: store, runners: make(map[uint64]*mqttForwardRunner)} +} + +func (m *mqttForwardManager) StartFromStore() error { + configs, err := m.store.ListEnabledMQTTForwarderConfigs() + if err != nil { + return err + } + for _, cfg := range configs { + if len(cfg.Topics) == 0 { + continue + } + runner := newMQTTForwardRunner(cfg) + runner.Start() + m.mu.Lock() + m.runners[cfg.Forwarder.ID] = runner + m.mu.Unlock() + } + return nil +} + +func (m *mqttForwardManager) ReloadForwarder(id uint64) error { + m.StopForwarder(id) + cfg, err := m.store.GetMQTTForwarderConfig(id) + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } + if err != nil { + return err + } + if !cfg.Forwarder.Enabled || len(cfg.Topics) == 0 { + return nil + } + runner := newMQTTForwardRunner(*cfg) + runner.Start() + m.mu.Lock() + m.runners[id] = runner + m.mu.Unlock() + return nil +} + +func (m *mqttForwardManager) StopForwarder(id uint64) { + m.mu.Lock() + runner := m.runners[id] + delete(m.runners, id) + m.mu.Unlock() + if runner != nil { + runner.Stop() + } +} + +func (m *mqttForwardManager) StopAll() { + m.mu.Lock() + runners := make([]*mqttForwardRunner, 0, len(m.runners)) + for id, runner := range m.runners { + runners = append(runners, runner) + delete(m.runners, id) + } + m.mu.Unlock() + for _, runner := range runners { + runner.Stop() + } +} + +func (m *mqttForwardManager) Status() []mqttForwardRuntimeStatus { + m.mu.Lock() + runners := make([]*mqttForwardRunner, 0, len(m.runners)) + for _, runner := range m.runners { + runners = append(runners, runner) + } + m.mu.Unlock() + items := make([]mqttForwardRuntimeStatus, 0, len(runners)) + for _, runner := range runners { + items = append(items, runner.Status()) + } + sort.Slice(items, func(i, j int) bool { return items[i].ForwarderID < items[j].ForwarderID }) + return items +} + +func newMQTTForwardRunner(config mqttForwarderConfig) *mqttForwardRunner { + ctx, cancel := context.WithCancel(context.Background()) + return &mqttForwardRunner{config: config, ctx: ctx, cancel: cancel, startedAt: time.Now(), loopCache: make(map[string]time.Time)} +} + +func (r *mqttForwardRunner) Start() { + r.source = r.newClient(true) + r.target = r.newClient(false) + r.connectClient(r.target, "target") + r.connectClient(r.source, "source") +} + +func (r *mqttForwardRunner) Stop() { + r.cancel() + if r.source != nil && r.source.IsConnected() { + r.source.Disconnect(250) + } + if r.target != nil && r.target.IsConnected() { + r.target.Disconnect(250) + } +} + +func (r *mqttForwardRunner) Status() mqttForwardRuntimeStatus { + r.mu.Lock() + defer r.mu.Unlock() + started := r.startedAt + return mqttForwardRuntimeStatus{ + ForwarderID: r.config.Forwarder.ID, + Running: true, + SourceConnected: r.sourceConnected, + TargetConnected: r.targetConnected, + LastError: r.lastError, + StartedAt: &started, + MessagesForwarded: r.messagesForwarded, + MessagesDropped: r.messagesDropped, + } +} + +func (r *mqttForwardRunner) newClient(source bool) pahomqtt.Client { + forwarder := r.config.Forwarder + host, port, username, password, clientID, useTLS := forwarder.SourceHost, forwarder.SourcePort, forwarder.SourceUsername, forwarder.SourcePassword, forwarder.SourceClientID, forwarder.SourceTLS + role := "source" + if !source { + host, port, username, password, clientID, useTLS = forwarder.TargetHost, forwarder.TargetPort, forwarder.TargetUsername, forwarder.TargetPassword, forwarder.TargetClientID, forwarder.TargetTLS + role = "target" + } + if clientID == "" { + clientID = fmt.Sprintf("mesh-forward-%d-%s", forwarder.ID, role) + } + scheme := "tcp" + if useTLS { + scheme = "ssl" + } + opts := pahomqtt.NewClientOptions(). + AddBroker(fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(host, fmt.Sprint(port)))). + SetClientID(clientID). + SetAutoReconnect(true). + SetConnectRetry(true). + SetKeepAlive(60 * time.Second). + SetConnectionLostHandler(func(_ pahomqtt.Client, err error) { + r.setConnected(source, false) + r.setError(fmt.Sprintf("%s connection lost: %v", role, err)) + }). + SetOnConnectHandler(func(client pahomqtt.Client) { + r.setConnected(source, true) + r.subscribe(client, source) + }) + if username != "" { + opts.SetUsername(username) + } + if password != "" { + opts.SetPassword(password) + } + if useTLS { + opts.SetTLSConfig(&tls.Config{MinVersion: tls.VersionTLS12}) + } + return pahomqtt.NewClient(opts) +} + +func (r *mqttForwardRunner) connectClient(client pahomqtt.Client, label string) { + token := client.Connect() + if !token.WaitTimeout(2 * time.Second) { + r.setError(label + " connect pending") + return + } + if err := token.Error(); err != nil { + r.setError(fmt.Sprintf("%s connect failed: %v", label, err)) + } +} + +func (r *mqttForwardRunner) subscribe(client pahomqtt.Client, source bool) { + for _, topic := range r.config.Topics { + filter := topic.Topic + if !source { + if topic.Direction != mqttForwardDirectionBidirectional { + continue + } + filter = mapMQTTForwardTopic(topic.Topic, topic.SourcePrefix, topic.TargetPrefix) + } + topicRule := topic + token := client.Subscribe(filter, byte(topic.QoS), func(_ pahomqtt.Client, msg pahomqtt.Message) { + r.forwardMessage(source, topicRule, msg) + }) + if !token.WaitTimeout(2 * time.Second) { + r.setError("subscribe pending: " + filter) + continue + } + if err := token.Error(); err != nil { + r.setError(fmt.Sprintf("subscribe %s failed: %v", filter, err)) + } + } +} + +func (r *mqttForwardRunner) forwardMessage(fromSource bool, rule mqttForwardTopicRecord, msg pahomqtt.Message) { + if r.ctx.Err() != nil { + return + } + fromTopic := msg.Topic() + if fromSource { + if !mqttTopicFilterMatches(rule.Topic, fromTopic) { + return + } + } else if !mqttTopicFilterMatches(mapMQTTForwardTopic(rule.Topic, rule.SourcePrefix, rule.TargetPrefix), fromTopic) { + return + } + toTopic := fromTopic + forwardDirection := mqttForwardDirectionSourceToTarget + if fromSource { + toTopic = mapMQTTForwardTopic(fromTopic, rule.SourcePrefix, rule.TargetPrefix) + } else { + forwardDirection = mqttForwardDirectionTargetToSource + toTopic = mapMQTTForwardTopic(fromTopic, rule.TargetPrefix, rule.SourcePrefix) + } + if r.isSuppressed(forwardDirection, fromTopic, toTopic, msg.Payload(), rule.QoS, rule.Retain) { + r.incDropped() + return + } + target := r.target + reverseDirection := mqttForwardDirectionTargetToSource + if !fromSource { + target = r.source + reverseDirection = mqttForwardDirectionSourceToTarget + } + r.markSuppressed(reverseDirection, toTopic, fromTopic, msg.Payload(), rule.QoS, rule.Retain) + token := target.Publish(toTopic, byte(rule.QoS), rule.Retain, msg.Payload()) + if !token.WaitTimeout(2 * time.Second) { + r.setError("publish pending: " + toTopic) + r.incDropped() + return + } + if err := token.Error(); err != nil { + r.setError(fmt.Sprintf("publish %s failed: %v", toTopic, err)) + r.incDropped() + return + } + r.incForwarded() +} + +func (r *mqttForwardRunner) setConnected(source bool, connected bool) { + r.mu.Lock() + defer r.mu.Unlock() + if source { + r.sourceConnected = connected + } else { + r.targetConnected = connected + } +} + +func (r *mqttForwardRunner) setError(message string) { + r.mu.Lock() + r.lastError = message + r.mu.Unlock() +} + +func (r *mqttForwardRunner) incForwarded() { + r.mu.Lock() + r.messagesForwarded++ + r.mu.Unlock() +} + +func (r *mqttForwardRunner) incDropped() { + r.mu.Lock() + r.messagesDropped++ + r.mu.Unlock() +} + +func (r *mqttForwardRunner) isSuppressed(direction, fromTopic, toTopic string, payload []byte, qos int, retain bool) bool { + key := mqttForwardLoopKey(direction, fromTopic, toTopic, payload, qos, retain) + now := time.Now() + r.mu.Lock() + defer r.mu.Unlock() + expires, ok := r.loopCache[key] + if !ok { + return false + } + if now.After(expires) { + delete(r.loopCache, key) + return false + } + delete(r.loopCache, key) + return true +} + +func (r *mqttForwardRunner) markSuppressed(direction, fromTopic, toTopic string, payload []byte, qos int, retain bool) { + key := mqttForwardLoopKey(direction, fromTopic, toTopic, payload, qos, retain) + now := time.Now() + r.mu.Lock() + defer r.mu.Unlock() + if len(r.loopCache) >= mqttForwardLoopMaxEntries { + for existing, expires := range r.loopCache { + if now.After(expires) || len(r.loopCache) >= mqttForwardLoopMaxEntries { + delete(r.loopCache, existing) + } + if len(r.loopCache) < mqttForwardLoopMaxEntries { + break + } + } + } + r.loopCache[key] = now.Add(mqttForwardLoopTTL) +} + +func mqttForwardLoopKey(direction, fromTopic, toTopic string, payload []byte, qos int, retain bool) string { + sum := sha256.Sum256(payload) + return fmt.Sprintf("%s\x00%s\x00%s\x00%d\x00%t\x00%s", direction, fromTopic, toTopic, qos, retain, hex.EncodeToString(sum[:])) +} + +func mapMQTTForwardTopic(topic, fromPrefix, toPrefix string) string { + fromPrefix = strings.Trim(fromPrefix, "/") + toPrefix = strings.Trim(toPrefix, "/") + if fromPrefix == "" { + return topic + } + if topic == fromPrefix { + return toPrefix + } + if strings.HasPrefix(topic, fromPrefix+"/") { + if toPrefix == "" { + return strings.TrimPrefix(topic, fromPrefix+"/") + } + return toPrefix + strings.TrimPrefix(topic, fromPrefix) + } + return topic +} + +func mqttTopicFilterMatches(filter, topic string) bool { + filterParts := strings.Split(filter, "/") + topicParts := strings.Split(topic, "/") + for i, filterPart := range filterParts { + if filterPart == "#" { + return i == len(filterParts)-1 + } + if i >= len(topicParts) { + return false + } + if filterPart == "+" { + continue + } + if filterPart != topicParts[i] { + return false + } + } + return len(filterParts) == len(topicParts) +} diff --git a/mqtt_forward_store.go b/mqtt_forward_store.go new file mode 100644 index 0000000..37854f5 --- /dev/null +++ b/mqtt_forward_store.go @@ -0,0 +1,378 @@ +package main + +import ( + "errors" + "fmt" + "strings" + "time" + + "gorm.io/gorm" +) + +const ( + mqttForwardDirectionSourceToTarget = "source_to_target" + mqttForwardDirectionBidirectional = "bidirectional" +) + +var ( + errMQTTForwarderAlreadyExists = errors.New("mqtt forwarder already exists") + errMQTTForwardTopicAlreadyExists = errors.New("mqtt forward topic already exists") +) + +type mqttForwarderInput struct { + Name string + Enabled bool + SourceHost string + SourcePort int + SourceUsername string + SourcePassword *string + SourceClientID string + SourceTLS bool + TargetHost string + TargetPort int + TargetUsername string + TargetPassword *string + TargetClientID string + TargetTLS bool +} + +type mqttForwardTopicInput struct { + Topic string + Enabled bool + Direction string + SourcePrefix string + TargetPrefix string + QoS int + Retain bool +} + +type mqttForwarderConfig struct { + Forwarder mqttForwarderRecord + Topics []mqttForwardTopicRecord +} + +func (s *store) ListMQTTForwarders(opts listOptions) ([]mqttForwarderRecord, error) { + opts = normalizeListOptions(opts) + var rows []mqttForwarderRecord + q := s.db.Model(&mqttForwarderRecord{}). + Order("updated_at DESC"). + Order("id DESC"). + Limit(opts.Limit). + Offset(opts.Offset) + return rows, q.Find(&rows).Error +} + +func (s *store) CountMQTTForwarders(opts listOptions) (int64, error) { + var total int64 + return total, s.db.Model(&mqttForwarderRecord{}).Count(&total).Error +} + +func (s *store) GetMQTTForwarder(id uint64) (*mqttForwarderRecord, error) { + var row mqttForwarderRecord + if err := s.db.Where("id = ?", id).Take(&row).Error; err != nil { + return nil, err + } + return &row, nil +} + +func (s *store) CreateMQTTForwarder(input mqttForwarderInput) (*mqttForwarderRecord, error) { + row, err := mqttForwarderFromInput(input, nil) + if err != nil { + return nil, err + } + if err := s.ensureMQTTForwarderNameUnique(0, row.Name); err != nil { + return nil, err + } + if err := s.db.Create(row).Error; err != nil { + return nil, err + } + return row, nil +} + +func (s *store) UpdateMQTTForwarder(id uint64, input mqttForwarderInput) (*mqttForwarderRecord, error) { + if id == 0 { + return nil, fmt.Errorf("mqtt forwarder id is required") + } + existing, err := s.GetMQTTForwarder(id) + if err != nil { + return nil, err + } + row, err := mqttForwarderFromInput(input, existing) + if err != nil { + return nil, err + } + if err := s.ensureMQTTForwarderNameUnique(id, row.Name); err != nil { + return nil, err + } + updates := map[string]any{ + "name": row.Name, "enabled": row.Enabled, + "source_host": row.SourceHost, "source_port": row.SourcePort, "source_username": row.SourceUsername, + "source_password": row.SourcePassword, "source_client_id": row.SourceClientID, "source_tls": row.SourceTLS, + "target_host": row.TargetHost, "target_port": row.TargetPort, "target_username": row.TargetUsername, + "target_password": row.TargetPassword, "target_client_id": row.TargetClientID, "target_tls": row.TargetTLS, + "updated_at": time.Now(), + } + if err := s.db.Model(&mqttForwarderRecord{}).Where("id = ?", id).Updates(updates).Error; err != nil { + return nil, err + } + return s.GetMQTTForwarder(id) +} + +func (s *store) DeleteMQTTForwarder(id uint64) error { + if id == 0 { + return fmt.Errorf("mqtt forwarder id is required") + } + return s.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Where("forwarder_id = ?", id).Delete(&mqttForwardTopicRecord{}).Error; err != nil { + return err + } + result := tx.Where("id = ?", id).Delete(&mqttForwarderRecord{}) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return gorm.ErrRecordNotFound + } + return nil + }) +} + +func (s *store) ListMQTTForwardTopics(forwarderID uint64, opts listOptions) ([]mqttForwardTopicRecord, error) { + opts = normalizeListOptions(opts) + var rows []mqttForwardTopicRecord + q := s.db.Model(&mqttForwardTopicRecord{}). + Where("forwarder_id = ?", forwarderID). + Order("updated_at DESC"). + Order("id DESC"). + Limit(opts.Limit). + Offset(opts.Offset) + return rows, q.Find(&rows).Error +} + +func (s *store) CountMQTTForwardTopics(forwarderID uint64) (int64, error) { + var total int64 + return total, s.db.Model(&mqttForwardTopicRecord{}).Where("forwarder_id = ?", forwarderID).Count(&total).Error +} + +func (s *store) GetMQTTForwardTopic(id uint64) (*mqttForwardTopicRecord, error) { + var row mqttForwardTopicRecord + if err := s.db.Where("id = ?", id).Take(&row).Error; err != nil { + return nil, err + } + return &row, nil +} + +func (s *store) CreateMQTTForwardTopic(forwarderID uint64, input mqttForwardTopicInput) (*mqttForwardTopicRecord, error) { + if _, err := s.GetMQTTForwarder(forwarderID); err != nil { + return nil, err + } + row, err := mqttForwardTopicFromInput(forwarderID, input) + if err != nil { + return nil, err + } + if err := s.ensureMQTTForwardTopicUnique(0, forwarderID, row.Topic); err != nil { + return nil, err + } + if err := s.db.Create(row).Error; err != nil { + return nil, err + } + return row, nil +} + +func (s *store) UpdateMQTTForwardTopic(id uint64, input mqttForwardTopicInput) (*mqttForwardTopicRecord, error) { + if id == 0 { + return nil, fmt.Errorf("mqtt forward topic id is required") + } + existing, err := s.GetMQTTForwardTopic(id) + if err != nil { + return nil, err + } + row, err := mqttForwardTopicFromInput(existing.ForwarderID, input) + if err != nil { + return nil, err + } + if err := s.ensureMQTTForwardTopicUnique(id, existing.ForwarderID, row.Topic); err != nil { + return nil, err + } + updates := map[string]any{ + "topic": row.Topic, "enabled": row.Enabled, "direction": row.Direction, + "source_prefix": row.SourcePrefix, "target_prefix": row.TargetPrefix, + "qos": row.QoS, "retain": row.Retain, "updated_at": time.Now(), + } + if err := s.db.Model(&mqttForwardTopicRecord{}).Where("id = ?", id).Updates(updates).Error; err != nil { + return nil, err + } + return s.GetMQTTForwardTopic(id) +} + +func (s *store) DeleteMQTTForwardTopic(id uint64) error { + result := s.db.Where("id = ?", id).Delete(&mqttForwardTopicRecord{}) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return gorm.ErrRecordNotFound + } + return nil +} + +func (s *store) GetMQTTForwarderConfig(id uint64) (*mqttForwarderConfig, error) { + forwarder, err := s.GetMQTTForwarder(id) + if err != nil { + return nil, err + } + var topics []mqttForwardTopicRecord + if err := s.db.Where("forwarder_id = ? AND enabled = ?", id, true).Order("id ASC").Find(&topics).Error; err != nil { + return nil, err + } + return &mqttForwarderConfig{Forwarder: *forwarder, Topics: topics}, nil +} + +func (s *store) ListEnabledMQTTForwarderConfigs() ([]mqttForwarderConfig, error) { + var forwarders []mqttForwarderRecord + if err := s.db.Where("enabled = ?", true).Order("id ASC").Find(&forwarders).Error; err != nil { + return nil, err + } + configs := make([]mqttForwarderConfig, 0, len(forwarders)) + for _, forwarder := range forwarders { + var topics []mqttForwardTopicRecord + if err := s.db.Where("forwarder_id = ? AND enabled = ?", forwarder.ID, true).Order("id ASC").Find(&topics).Error; err != nil { + return nil, err + } + if len(topics) == 0 { + continue + } + configs = append(configs, mqttForwarderConfig{Forwarder: forwarder, Topics: topics}) + } + return configs, nil +} + +func (s *store) ensureMQTTForwarderNameUnique(id uint64, name string) error { + var existing mqttForwarderRecord + q := s.db.Where("name = ?", name) + if id != 0 { + q = q.Where("id <> ?", id) + } + err := q.Take(&existing).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } + if err != nil { + return err + } + return errMQTTForwarderAlreadyExists +} + +func (s *store) ensureMQTTForwardTopicUnique(id, forwarderID uint64, topic string) error { + var existing mqttForwardTopicRecord + q := s.db.Where("forwarder_id = ? AND topic = ?", forwarderID, topic) + if id != 0 { + q = q.Where("id <> ?", id) + } + err := q.Take(&existing).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } + if err != nil { + return err + } + return errMQTTForwardTopicAlreadyExists +} + +func mqttForwarderFromInput(input mqttForwarderInput, existing *mqttForwarderRecord) (*mqttForwarderRecord, error) { + name := strings.TrimSpace(input.Name) + if name == "" { + return nil, fmt.Errorf("mqtt forwarder name is required") + } + sourceHost := strings.TrimSpace(input.SourceHost) + if sourceHost == "" { + return nil, fmt.Errorf("source host is required") + } + if err := validateMQTTForwardPort(input.SourcePort, "source port"); err != nil { + return nil, err + } + targetHost := strings.TrimSpace(input.TargetHost) + if targetHost == "" { + return nil, fmt.Errorf("target host is required") + } + if err := validateMQTTForwardPort(input.TargetPort, "target port"); err != nil { + return nil, err + } + row := &mqttForwarderRecord{ + Name: name, Enabled: input.Enabled, + SourceHost: sourceHost, SourcePort: input.SourcePort, SourceUsername: strings.TrimSpace(input.SourceUsername), SourceClientID: strings.TrimSpace(input.SourceClientID), SourceTLS: input.SourceTLS, + TargetHost: targetHost, TargetPort: input.TargetPort, TargetUsername: strings.TrimSpace(input.TargetUsername), TargetClientID: strings.TrimSpace(input.TargetClientID), TargetTLS: input.TargetTLS, + } + if input.SourcePassword != nil { + row.SourcePassword = *input.SourcePassword + } else if existing != nil { + row.SourcePassword = existing.SourcePassword + } + if input.TargetPassword != nil { + row.TargetPassword = *input.TargetPassword + } else if existing != nil { + row.TargetPassword = existing.TargetPassword + } + return row, nil +} + +func mqttForwardTopicFromInput(forwarderID uint64, input mqttForwardTopicInput) (*mqttForwardTopicRecord, error) { + if forwarderID == 0 { + return nil, fmt.Errorf("mqtt forwarder id is required") + } + topic := strings.TrimSpace(input.Topic) + if err := validateMQTTTopicFilter(topic); err != nil { + return nil, err + } + direction, err := normalizeMQTTForwardDirection(input.Direction) + if err != nil { + return nil, err + } + if input.QoS < 0 || input.QoS > 2 { + return nil, fmt.Errorf("qos must be 0, 1, or 2") + } + return &mqttForwardTopicRecord{ + ForwarderID: forwarderID, Topic: topic, Enabled: input.Enabled, Direction: direction, + SourcePrefix: strings.Trim(strings.TrimSpace(input.SourcePrefix), "/"), + TargetPrefix: strings.Trim(strings.TrimSpace(input.TargetPrefix), "/"), + QoS: input.QoS, Retain: input.Retain, + }, nil +} + +func validateMQTTForwardPort(port int, label string) error { + if port <= 0 || port > 65535 { + return fmt.Errorf("%s must be between 1 and 65535", label) + } + return nil +} + +func normalizeMQTTForwardDirection(direction string) (string, error) { + direction = strings.TrimSpace(direction) + if direction == "" { + direction = mqttForwardDirectionSourceToTarget + } + switch direction { + case mqttForwardDirectionSourceToTarget, mqttForwardDirectionBidirectional: + return direction, nil + default: + return "", fmt.Errorf("invalid mqtt forward direction") + } +} + +func validateMQTTTopicFilter(topic string) error { + if topic == "" { + return fmt.Errorf("topic is required") + } + parts := strings.Split(topic, "/") + for i, part := range parts { + if strings.Contains(part, "#") { + if part != "#" || i != len(parts)-1 { + return fmt.Errorf("invalid topic filter: # must be the last level") + } + } + if strings.Contains(part, "+") && part != "+" { + return fmt.Errorf("invalid topic filter: + must occupy an entire level") + } + } + return nil +} diff --git a/web.go b/web.go index 788785e..4510da7 100644 --- a/web.go +++ b/web.go @@ -14,10 +14,10 @@ import ( "gorm.io/gorm" ) -func newHTTPServer(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache) *http.Server { +func newHTTPServer(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache, forwarder mqttForwardReloader) *http.Server { return &http.Server{ Addr: net.JoinHostPort(cfg.Host, strconv.Itoa(cfg.Port)), - Handler: newRouter(cfg, store, sessions, mqttStatus, blocking), + Handler: newRouter(cfg, store, sessions, mqttStatus, blocking, forwarder), } } @@ -47,12 +47,12 @@ func serveHTTPUnixSocket(server *http.Server, socketPath string) error { return server.Serve(listener) } -func newRouter(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache) *gin.Engine { +func newRouter(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache, forwarder mqttForwardReloader) *gin.Engine { r := gin.New() r.Use(gin.Logger(), gin.Recovery()) api := r.Group("/api") registerAPIRoutes(api, store) - registerAdminRoutes(api.Group("/admin"), store, sessions, mqttStatus, blocking) + registerAdminRoutes(api.Group("/admin"), store, sessions, mqttStatus, blocking, forwarder) registerStaticRoutes(r, cfg.StaticDir) return r } @@ -122,7 +122,7 @@ func registerAPIRoutes(r gin.IRouter, store *store) { }) } -func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache) { +func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache, forwarder mqttForwardReloader) { type loginRequest struct { Username string `json:"username"` Password string `json:"password"` @@ -183,6 +183,7 @@ func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager, protected := r.Group("") protected.Use(requireAdmin(sessions)) registerAdminBlockingRoutes(protected, store, blocking) + registerAdminMQTTForwardRoutes(protected, store, forwarder) protected.GET("/me", func(c *gin.Context) { claims := c.MustGet("admin_claims").(*sessionClaims) c.JSON(http.StatusOK, gin.H{"user": adminUserDTO{Username: claims.Username, Role: claims.Role}})