409 lines
11 KiB
Go
409 lines
11 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
pahomqtt "github.com/eclipse/paho.mqtt.golang"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
const (
|
|
mqttForwardDirectionTargetToSource = "target_to_source"
|
|
mqttForwardLoopTTL = 15 * time.Second
|
|
mqttForwardLoopMaxEntries = 10000
|
|
)
|
|
|
|
type mqttForwardReloader interface {
|
|
ReloadForwarder(id uint64) error
|
|
StopForwarder(id uint64)
|
|
Status() []mqttForwardRuntimeStatus
|
|
}
|
|
|
|
type mqttForwardManager struct {
|
|
store *store
|
|
mu sync.Mutex
|
|
runners map[uint64]*mqttForwardRunner
|
|
}
|
|
|
|
type mqttForwardRuntimeStatus struct {
|
|
ForwarderID uint64 `json:"forwarder_id"`
|
|
Running bool `json:"running"`
|
|
SourceConnected bool `json:"source_connected"`
|
|
TargetConnected bool `json:"target_connected"`
|
|
LastError string `json:"last_error"`
|
|
StartedAt *time.Time `json:"started_at"`
|
|
MessagesForwarded uint64 `json:"messages_forwarded"`
|
|
MessagesDropped uint64 `json:"messages_dropped"`
|
|
}
|
|
|
|
type mqttForwardRunner struct {
|
|
config mqttForwarderConfig
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
source pahomqtt.Client
|
|
target pahomqtt.Client
|
|
|
|
mu sync.Mutex
|
|
lastError string
|
|
startedAt time.Time
|
|
sourceConnected bool
|
|
targetConnected bool
|
|
messagesForwarded uint64
|
|
messagesDropped uint64
|
|
loopCache map[string]time.Time
|
|
}
|
|
|
|
func newMQTTForwardManager(store *store) *mqttForwardManager {
|
|
return &mqttForwardManager{store: store, runners: make(map[uint64]*mqttForwardRunner)}
|
|
}
|
|
|
|
func (m *mqttForwardManager) StartFromStore() error {
|
|
configs, err := m.store.ListEnabledMQTTForwarderConfigs()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, cfg := range configs {
|
|
if len(cfg.Topics) == 0 {
|
|
continue
|
|
}
|
|
runner := newMQTTForwardRunner(cfg)
|
|
runner.Start()
|
|
m.mu.Lock()
|
|
m.runners[cfg.Forwarder.ID] = runner
|
|
m.mu.Unlock()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *mqttForwardManager) ReloadForwarder(id uint64) error {
|
|
m.StopForwarder(id)
|
|
cfg, err := m.store.GetMQTTForwarderConfig(id)
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !cfg.Forwarder.Enabled || len(cfg.Topics) == 0 {
|
|
return nil
|
|
}
|
|
runner := newMQTTForwardRunner(*cfg)
|
|
runner.Start()
|
|
m.mu.Lock()
|
|
m.runners[id] = runner
|
|
m.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (m *mqttForwardManager) StopForwarder(id uint64) {
|
|
m.mu.Lock()
|
|
runner := m.runners[id]
|
|
delete(m.runners, id)
|
|
m.mu.Unlock()
|
|
if runner != nil {
|
|
runner.Stop()
|
|
}
|
|
}
|
|
|
|
func (m *mqttForwardManager) StopAll() {
|
|
m.mu.Lock()
|
|
runners := make([]*mqttForwardRunner, 0, len(m.runners))
|
|
for id, runner := range m.runners {
|
|
runners = append(runners, runner)
|
|
delete(m.runners, id)
|
|
}
|
|
m.mu.Unlock()
|
|
for _, runner := range runners {
|
|
runner.Stop()
|
|
}
|
|
}
|
|
|
|
func (m *mqttForwardManager) Status() []mqttForwardRuntimeStatus {
|
|
m.mu.Lock()
|
|
runners := make([]*mqttForwardRunner, 0, len(m.runners))
|
|
for _, runner := range m.runners {
|
|
runners = append(runners, runner)
|
|
}
|
|
m.mu.Unlock()
|
|
items := make([]mqttForwardRuntimeStatus, 0, len(runners))
|
|
for _, runner := range runners {
|
|
items = append(items, runner.Status())
|
|
}
|
|
sort.Slice(items, func(i, j int) bool { return items[i].ForwarderID < items[j].ForwarderID })
|
|
return items
|
|
}
|
|
|
|
func newMQTTForwardRunner(config mqttForwarderConfig) *mqttForwardRunner {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
return &mqttForwardRunner{config: config, ctx: ctx, cancel: cancel, startedAt: time.Now(), loopCache: make(map[string]time.Time)}
|
|
}
|
|
|
|
func (r *mqttForwardRunner) Start() {
|
|
r.source = r.newClient(true)
|
|
r.target = r.newClient(false)
|
|
r.connectClient(r.target, "target")
|
|
r.connectClient(r.source, "source")
|
|
}
|
|
|
|
func (r *mqttForwardRunner) Stop() {
|
|
r.cancel()
|
|
if r.source != nil && r.source.IsConnected() {
|
|
r.source.Disconnect(250)
|
|
}
|
|
if r.target != nil && r.target.IsConnected() {
|
|
r.target.Disconnect(250)
|
|
}
|
|
}
|
|
|
|
func (r *mqttForwardRunner) Status() mqttForwardRuntimeStatus {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
started := r.startedAt
|
|
return mqttForwardRuntimeStatus{
|
|
ForwarderID: r.config.Forwarder.ID,
|
|
Running: true,
|
|
SourceConnected: r.sourceConnected,
|
|
TargetConnected: r.targetConnected,
|
|
LastError: r.lastError,
|
|
StartedAt: &started,
|
|
MessagesForwarded: r.messagesForwarded,
|
|
MessagesDropped: r.messagesDropped,
|
|
}
|
|
}
|
|
|
|
func (r *mqttForwardRunner) newClient(source bool) pahomqtt.Client {
|
|
forwarder := r.config.Forwarder
|
|
host, port, username, password, clientID, useTLS := forwarder.SourceHost, forwarder.SourcePort, forwarder.SourceUsername, forwarder.SourcePassword, forwarder.SourceClientID, forwarder.SourceTLS
|
|
role := "source"
|
|
if !source {
|
|
host, port, username, password, clientID, useTLS = forwarder.TargetHost, forwarder.TargetPort, forwarder.TargetUsername, forwarder.TargetPassword, forwarder.TargetClientID, forwarder.TargetTLS
|
|
role = "target"
|
|
}
|
|
if clientID == "" {
|
|
clientID = fmt.Sprintf("mesh-forward-%d-%s", forwarder.ID, role)
|
|
}
|
|
scheme := "tcp"
|
|
if useTLS {
|
|
scheme = "ssl"
|
|
}
|
|
opts := pahomqtt.NewClientOptions().
|
|
AddBroker(fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(host, fmt.Sprint(port)))).
|
|
SetClientID(clientID).
|
|
SetAutoReconnect(true).
|
|
SetConnectRetry(true).
|
|
SetKeepAlive(60 * time.Second).
|
|
SetConnectionLostHandler(func(_ pahomqtt.Client, err error) {
|
|
r.setConnected(source, false)
|
|
r.setError(fmt.Sprintf("%s connection lost: %v", role, err))
|
|
}).
|
|
SetOnConnectHandler(func(client pahomqtt.Client) {
|
|
r.setConnected(source, true)
|
|
r.subscribe(client, source)
|
|
})
|
|
if username != "" {
|
|
opts.SetUsername(username)
|
|
}
|
|
if password != "" {
|
|
opts.SetPassword(password)
|
|
}
|
|
if useTLS {
|
|
opts.SetTLSConfig(&tls.Config{MinVersion: tls.VersionTLS12})
|
|
}
|
|
return pahomqtt.NewClient(opts)
|
|
}
|
|
|
|
func (r *mqttForwardRunner) connectClient(client pahomqtt.Client, label string) {
|
|
token := client.Connect()
|
|
if !token.WaitTimeout(2 * time.Second) {
|
|
r.setError(label + " connect pending")
|
|
return
|
|
}
|
|
if err := token.Error(); err != nil {
|
|
r.setError(fmt.Sprintf("%s connect failed: %v", label, err))
|
|
}
|
|
}
|
|
|
|
func (r *mqttForwardRunner) subscribe(client pahomqtt.Client, source bool) {
|
|
for _, topic := range r.config.Topics {
|
|
filter := topic.Topic
|
|
if !source {
|
|
if topic.Direction != mqttForwardDirectionBidirectional {
|
|
continue
|
|
}
|
|
filter = mapMQTTForwardTopic(topic.Topic, topic.SourcePrefix, topic.TargetPrefix)
|
|
}
|
|
topicRule := topic
|
|
token := client.Subscribe(filter, byte(topic.QoS), func(_ pahomqtt.Client, msg pahomqtt.Message) {
|
|
r.forwardMessage(source, topicRule, msg)
|
|
})
|
|
if !token.WaitTimeout(2 * time.Second) {
|
|
r.setError("subscribe pending: " + filter)
|
|
continue
|
|
}
|
|
if err := token.Error(); err != nil {
|
|
r.setError(fmt.Sprintf("subscribe %s failed: %v", filter, err))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *mqttForwardRunner) forwardMessage(fromSource bool, rule mqttForwardTopicRecord, msg pahomqtt.Message) {
|
|
if r.ctx.Err() != nil {
|
|
return
|
|
}
|
|
fromTopic := msg.Topic()
|
|
if fromSource {
|
|
if !mqttTopicFilterMatches(rule.Topic, fromTopic) {
|
|
return
|
|
}
|
|
} else if !mqttTopicFilterMatches(mapMQTTForwardTopic(rule.Topic, rule.SourcePrefix, rule.TargetPrefix), fromTopic) {
|
|
return
|
|
}
|
|
toTopic := fromTopic
|
|
forwardDirection := mqttForwardDirectionSourceToTarget
|
|
if fromSource {
|
|
toTopic = mapMQTTForwardTopic(fromTopic, rule.SourcePrefix, rule.TargetPrefix)
|
|
} else {
|
|
forwardDirection = mqttForwardDirectionTargetToSource
|
|
toTopic = mapMQTTForwardTopic(fromTopic, rule.TargetPrefix, rule.SourcePrefix)
|
|
}
|
|
if r.isSuppressed(forwardDirection, fromTopic, toTopic, msg.Payload(), rule.QoS, rule.Retain) {
|
|
r.incDropped()
|
|
return
|
|
}
|
|
target := r.target
|
|
reverseDirection := mqttForwardDirectionTargetToSource
|
|
if !fromSource {
|
|
target = r.source
|
|
reverseDirection = mqttForwardDirectionSourceToTarget
|
|
}
|
|
r.markSuppressed(reverseDirection, toTopic, fromTopic, msg.Payload(), rule.QoS, rule.Retain)
|
|
token := target.Publish(toTopic, byte(rule.QoS), rule.Retain, msg.Payload())
|
|
if !token.WaitTimeout(2 * time.Second) {
|
|
r.setError("publish pending: " + toTopic)
|
|
r.incDropped()
|
|
return
|
|
}
|
|
if err := token.Error(); err != nil {
|
|
r.setError(fmt.Sprintf("publish %s failed: %v", toTopic, err))
|
|
r.incDropped()
|
|
return
|
|
}
|
|
r.incForwarded()
|
|
}
|
|
|
|
func (r *mqttForwardRunner) setConnected(source bool, connected bool) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if source {
|
|
r.sourceConnected = connected
|
|
} else {
|
|
r.targetConnected = connected
|
|
}
|
|
}
|
|
|
|
func (r *mqttForwardRunner) setError(message string) {
|
|
r.mu.Lock()
|
|
r.lastError = message
|
|
r.mu.Unlock()
|
|
}
|
|
|
|
func (r *mqttForwardRunner) incForwarded() {
|
|
r.mu.Lock()
|
|
r.messagesForwarded++
|
|
r.mu.Unlock()
|
|
}
|
|
|
|
func (r *mqttForwardRunner) incDropped() {
|
|
r.mu.Lock()
|
|
r.messagesDropped++
|
|
r.mu.Unlock()
|
|
}
|
|
|
|
func (r *mqttForwardRunner) isSuppressed(direction, fromTopic, toTopic string, payload []byte, qos int, retain bool) bool {
|
|
key := mqttForwardLoopKey(direction, fromTopic, toTopic, payload, qos, retain)
|
|
now := time.Now()
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
expires, ok := r.loopCache[key]
|
|
if !ok {
|
|
return false
|
|
}
|
|
if now.After(expires) {
|
|
delete(r.loopCache, key)
|
|
return false
|
|
}
|
|
delete(r.loopCache, key)
|
|
return true
|
|
}
|
|
|
|
func (r *mqttForwardRunner) markSuppressed(direction, fromTopic, toTopic string, payload []byte, qos int, retain bool) {
|
|
key := mqttForwardLoopKey(direction, fromTopic, toTopic, payload, qos, retain)
|
|
now := time.Now()
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if len(r.loopCache) >= mqttForwardLoopMaxEntries {
|
|
for existing, expires := range r.loopCache {
|
|
if now.After(expires) || len(r.loopCache) >= mqttForwardLoopMaxEntries {
|
|
delete(r.loopCache, existing)
|
|
}
|
|
if len(r.loopCache) < mqttForwardLoopMaxEntries {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
r.loopCache[key] = now.Add(mqttForwardLoopTTL)
|
|
}
|
|
|
|
func mqttForwardLoopKey(direction, fromTopic, toTopic string, payload []byte, qos int, retain bool) string {
|
|
sum := sha256.Sum256(payload)
|
|
return fmt.Sprintf("%s\x00%s\x00%s\x00%d\x00%t\x00%s", direction, fromTopic, toTopic, qos, retain, hex.EncodeToString(sum[:]))
|
|
}
|
|
|
|
func mapMQTTForwardTopic(topic, fromPrefix, toPrefix string) string {
|
|
fromPrefix = strings.Trim(fromPrefix, "/")
|
|
toPrefix = strings.Trim(toPrefix, "/")
|
|
if fromPrefix == "" {
|
|
return topic
|
|
}
|
|
if topic == fromPrefix {
|
|
return toPrefix
|
|
}
|
|
if strings.HasPrefix(topic, fromPrefix+"/") {
|
|
if toPrefix == "" {
|
|
return strings.TrimPrefix(topic, fromPrefix+"/")
|
|
}
|
|
return toPrefix + strings.TrimPrefix(topic, fromPrefix)
|
|
}
|
|
return topic
|
|
}
|
|
|
|
func mqttTopicFilterMatches(filter, topic string) bool {
|
|
filterParts := strings.Split(filter, "/")
|
|
topicParts := strings.Split(topic, "/")
|
|
for i, filterPart := range filterParts {
|
|
if filterPart == "#" {
|
|
return i == len(filterParts)-1
|
|
}
|
|
if i >= len(topicParts) {
|
|
return false
|
|
}
|
|
if filterPart == "+" {
|
|
continue
|
|
}
|
|
if filterPart != topicParts[i] {
|
|
return false
|
|
}
|
|
}
|
|
return len(filterParts) == len(topicParts)
|
|
}
|