增加转发开关

This commit is contained in:
2026-06-05 20:03:52 +08:00
parent bfdf57cd72
commit fb1971da72
15 changed files with 615 additions and 16 deletions
+1 -1
View File
@@ -5,7 +5,7 @@
每条传入的 `PUBLISH` 都会先进入: 每条传入的 `PUBLISH` 都会先进入:
```go ```go
valid, _, record := mqtpp.MQTTPP(topic, payload, key) valid, _, record := mqtpp.MQTTPP(topic, payload, key, mqtpp.Options{})
``` ```
- `valid == true`:保留原始 topic、payload、QoS、retain 等字段,正常转发给订阅匹配 topic 的客户端 - `valid == true`:保留原始 topic、payload、QoS、retain 等字段,正常转发给订阅匹配 topic 的客户端
+52
View File
@@ -0,0 +1,52 @@
package main
import (
"net/http"
"github.com/gin-gonic/gin"
)
const allowEncryptedForwardingLabel = "Allow encrypted MQTT packets to be forwarded when they cannot be decrypted"
type runtimeSettingsRequest struct {
AllowEncryptedForwarding bool `json:"allow_encrypted_forwarding"`
}
func registerAdminRuntimeSettingsRoutes(r gin.IRouter, store *store, settings *runtimeSettingsCache) {
r.GET("/runtime-settings", func(c *gin.Context) {
snapshot, err := store.GetRuntimeSettings()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"item": runtimeSettingsDTO(snapshot)})
})
r.PUT("/runtime-settings", func(c *gin.Context) {
var req runtimeSettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid runtime settings request"})
return
}
if _, err := store.SetBoolRuntimeSetting(runtimeSettingAllowEncryptedForwarding, req.AllowEncryptedForwarding, allowEncryptedForwardingLabel); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if settings != nil {
if err := settings.Reload(store); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
snapshot, err := store.GetRuntimeSettings()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"item": runtimeSettingsDTO(snapshot)})
})
}
func runtimeSettingsDTO(settings runtimeSettingsSnapshot) gin.H {
return gin.H{"allow_encrypted_forwarding": settings.AllowEncryptedForwarding}
}
+14
View File
@@ -102,6 +102,19 @@ func (helpContentRecord) TableName() string {
return "help_content" return "help_content"
} }
type runtimeSettingRecord struct {
Key string `gorm:"column:key;primaryKey;size:128;not null"`
Value string `gorm:"column:value;type:text;not null"`
ValueType string `gorm:"column:value_type;size:32;not null;index"`
Label string `gorm:"column:label"`
CreatedAt time.Time `gorm:"column:created_at;autoCreateTime"`
UpdatedAt time.Time `gorm:"column:updated_at;autoUpdateTime;index"`
}
func (runtimeSettingRecord) TableName() string {
return "runtime_settings"
}
type discardDetailsRecord struct { type discardDetailsRecord struct {
ID uint64 `gorm:"column:id;primaryKey;autoIncrement"` ID uint64 `gorm:"column:id;primaryKey;autoIncrement"`
Topic string `gorm:"column:topic"` Topic string `gorm:"column:topic"`
@@ -401,6 +414,7 @@ func (s *store) migrate() error {
{label: "users", model: &userRecord{}}, {label: "users", model: &userRecord{}},
{label: "login_log", model: &loginLogRecord{}}, {label: "login_log", model: &loginLogRecord{}},
{label: "help_content", model: &helpContentRecord{}}, {label: "help_content", model: &helpContentRecord{}},
{label: "runtime_settings", model: &runtimeSettingRecord{}},
{label: "discard_details", model: &discardDetailsRecord{}}, {label: "discard_details", model: &discardDetailsRecord{}},
{label: "node_blocking", model: &nodeBlockingRecord{}}, {label: "node_blocking", model: &nodeBlockingRecord{}},
{label: "ip_blocking", model: &ipBlockingRecord{}}, {label: "ip_blocking", model: &ipBlockingRecord{}},
+1 -1
View File
@@ -15,7 +15,7 @@ func TestOpenStoreCreatesTables(t *testing.T) {
st := openTestStore(t) st := openTestStore(t)
defer st.Close() defer st.Close()
for _, table := range []string{"users", "login_log", "discard_details", "node_blocking", "ip_blocking", "forbidden_word_blocking", "nodeinfo", "map_report", "text_message", "position", "telemetry", "routing", "traceroute"} { for _, table := range []string{"users", "login_log", "runtime_settings", "discard_details", "node_blocking", "ip_blocking", "forbidden_word_blocking", "nodeinfo", "map_report", "text_message", "position", "telemetry", "routing", "traceroute"} {
var name string var name string
if err := rawTestDB(t, st).QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", table).Scan(&name); err != nil { if err := rawTestDB(t, st).QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?", table).Scan(&name); err != nil {
t.Fatalf("%s table missing: %v", table, err) t.Fatalf("%s table missing: %v", table, err)
+10 -5
View File
@@ -39,6 +39,7 @@ type meshtasticFilterHook struct {
dbQueue *dbWriteQueue dbQueue *dbWriteQueue
stats *meshtasticMessageStats stats *meshtasticMessageStats
blocking *blockingCache blocking *blockingCache
settings *runtimeSettingsCache
} }
// ID 返回用于识别 Meshtastic payload 过滤器的 hook 名称。 // ID 返回用于识别 Meshtastic payload 过滤器的 hook 名称。
@@ -63,7 +64,7 @@ func (h *meshtasticFilterHook) OnConnect(cl *mqtt.Client, pk packets.Packet) err
// OnPublish 在 broker 转发消息前校验 payload;无效消息会被拒绝并丢弃。 // OnPublish 在 broker 转发消息前校验 payload;无效消息会被拒绝并丢弃。
func (h *meshtasticFilterHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) { func (h *meshtasticFilterHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) {
valid, _, record := mqtpp.MQTTPP(pk.TopicName, pk.Payload, h.key) valid, _, record := mqtpp.MQTTPP(pk.TopicName, pk.Payload, h.key, mqtpp.Options{AllowEncryptedForwarding: h.settings.AllowEncryptedForwarding()})
if !valid { if !valid {
h.rejectPublish(cl, pk, record) h.rejectPublish(cl, pk, record)
return pk, packets.ErrRejectPacket return pk, packets.ErrRejectPacket
@@ -213,9 +214,13 @@ func run(cfg *config) error {
if err != nil { if err != nil {
return err return err
} }
settings, err := newRuntimeSettingsCache(store)
if err != nil {
return err
}
messageStats := &meshtasticMessageStats{} messageStats := &meshtasticMessageStats{}
server, mqttAddr, err := startMQTTServer(cfg, dbQueue, messageStats, blocking) server, mqttAddr, err := startMQTTServer(cfg, dbQueue, messageStats, blocking, settings)
if err != nil { if err != nil {
return err return err
} }
@@ -234,7 +239,7 @@ func run(cfg *config) error {
return err return err
} }
mqttStatus := mqttRuntimeStatus{server: server, address: mqttAddr, tls: cfg.MQTT.TLS.Enabled, stats: messageStats, dbQueue: dbQueue} mqttStatus := mqttRuntimeStatus{server: server, address: mqttAddr, tls: cfg.MQTT.TLS.Enabled, stats: messageStats, dbQueue: dbQueue}
httpServer = newHTTPServer(cfg.Web, store, sessions, mqttStatus, blocking, forwardManager) httpServer = newHTTPServer(cfg.Web, store, sessions, mqttStatus, blocking, forwardManager, settings)
webAddress := httpServer.Addr webAddress := httpServer.Addr
go func() { go func() {
if cfg.Web.SocketPath != "" { if cfg.Web.SocketPath != "" {
@@ -275,12 +280,12 @@ func run(cfg *config) error {
return runErr return runErr
} }
func startMQTTServer(cfg *config, dbQueue *dbWriteQueue, stats *meshtasticMessageStats, blocking *blockingCache) (*mqtt.Server, string, error) { func startMQTTServer(cfg *config, dbQueue *dbWriteQueue, stats *meshtasticMessageStats, blocking *blockingCache, settings *runtimeSettingsCache) (*mqtt.Server, string, error) {
server := mqtt.New(nil) server := mqtt.New(nil)
if err := server.AddHook(new(auth.AllowHook), nil); err != nil { if err := server.AddHook(new(auth.AllowHook), nil); err != nil {
return nil, "", err return nil, "", err
} }
if err := server.AddHook(&meshtasticFilterHook{key: cfg.key, dbQueue: dbQueue, stats: stats, blocking: blocking}, nil); err != nil { if err := server.AddHook(&meshtasticFilterHook{key: cfg.key, dbQueue: dbQueue, stats: stats, blocking: blocking, settings: settings}, nil); err != nil {
return nil, "", err return nil, "", err
} }
+10
View File
@@ -3,6 +3,8 @@ import type {
AdminLoginResponse, AdminLoginResponse,
AdminManagedUserResponse, AdminManagedUserResponse,
AdminMqttStatus, AdminMqttStatus,
AdminRuntimeSettingsPayload,
AdminRuntimeSettingsResponse,
AdminUsersResponse, AdminUsersResponse,
BlockingRuleResponse, BlockingRuleResponse,
DiscardDetails, DiscardDetails,
@@ -163,6 +165,14 @@ export function getAdminMqttStatus(): Promise<AdminMqttStatus> {
return getJSON<AdminMqttStatus>('/api/admin/mqtt/status') return getJSON<AdminMqttStatus>('/api/admin/mqtt/status')
} }
export function getAdminRuntimeSettings(): Promise<AdminRuntimeSettingsResponse> {
return getJSON<AdminRuntimeSettingsResponse>('/api/admin/runtime-settings')
}
export function updateAdminRuntimeSettings(payload: AdminRuntimeSettingsPayload): Promise<AdminRuntimeSettingsResponse> {
return putJSON<AdminRuntimeSettingsResponse>('/api/admin/runtime-settings', payload)
}
export function getAdminHelpContent(): Promise<HelpContentResponse> { export function getAdminHelpContent(): Promise<HelpContentResponse> {
return getJSON<HelpContentResponse>('/api/admin/help') return getJSON<HelpContentResponse>('/api/admin/help')
} }
@@ -1,11 +1,15 @@
<script setup lang="ts"> <script setup lang="ts">
import { onBeforeUnmount, onMounted, ref } from 'vue' import { onBeforeUnmount, onMounted, ref } from 'vue'
import { getAdminMqttStatus } from '../api' import { getAdminMqttStatus, getAdminRuntimeSettings, updateAdminRuntimeSettings } from '../api'
import type { AdminMqttStatus } from '../types' import type { AdminMqttStatus, AdminRuntimeSettings } from '../types'
const status = ref<AdminMqttStatus | null>(null) const status = ref<AdminMqttStatus | null>(null)
const runtimeSettings = ref<AdminRuntimeSettings | null>(null)
const loading = ref(false) const loading = ref(false)
const settingsLoading = ref(false)
const error = ref('') const error = ref('')
const settingsError = ref('')
const settingsMessage = ref('')
let timer: number | undefined let timer: number | undefined
function formatUptime(seconds: number): string { function formatUptime(seconds: number): string {
@@ -27,8 +31,43 @@ async function refreshStatus() {
} }
} }
async function refreshRuntimeSettings() {
settingsLoading.value = true
settingsError.value = ''
try {
const response = await getAdminRuntimeSettings()
runtimeSettings.value = response.item
} catch (err) {
settingsError.value = err instanceof Error ? err.message : String(err)
} finally {
settingsLoading.value = false
}
}
async function saveEncryptedForwarding(value: boolean) {
if (!runtimeSettings.value) {
return
}
const previous = runtimeSettings.value.allow_encrypted_forwarding
runtimeSettings.value.allow_encrypted_forwarding = value
settingsLoading.value = true
settingsError.value = ''
settingsMessage.value = ''
try {
const response = await updateAdminRuntimeSettings({ allow_encrypted_forwarding: value })
runtimeSettings.value = response.item
settingsMessage.value = '设置已保存'
} catch (err) {
runtimeSettings.value.allow_encrypted_forwarding = previous
settingsError.value = err instanceof Error ? err.message : String(err)
} finally {
settingsLoading.value = false
}
}
onMounted(() => { onMounted(() => {
refreshStatus() refreshStatus()
refreshRuntimeSettings()
timer = window.setInterval(refreshStatus, 5000) timer = window.setInterval(refreshStatus, 5000)
}) })
@@ -69,6 +108,44 @@ onBeforeUnmount(() => {
</div> </div>
</div> </div>
<div class="panel admin-status-panel mqtt-control-panel">
<div class="panel-header control-header">
<div class="control-title">
<div>
<p class="eyebrow">MQTT Forwarding</p>
<h2>MQTT 转发控制</h2>
</div>
</div>
<span class="control-badge" :class="{ active: runtimeSettings?.allow_encrypted_forwarding }">
{{ runtimeSettings?.allow_encrypted_forwarding ? '加密包放行' : '默认拦截' }}
</span>
</div>
<div class="control-body">
<div class="control-copy">
<h3>加密转发</h3>
<p>
控制 Broker 在无法解密 Meshtastic 加密包时是否仍允许转发关闭时保持当前行为无法解密的加密包会被丢弃并记录到丢弃详情
</p>
</div>
<div v-if="!runtimeSettings" class="empty control-empty">正在加载转发设置...</div>
<label v-else class="switch-card" :class="{ enabled: runtimeSettings.allow_encrypted_forwarding, saving: settingsLoading }">
<span class="switch-text">
<strong>允许无法解密的加密包继续转发</strong>
<small>{{ runtimeSettings.allow_encrypted_forwarding ? '已开启,原始 payload 将继续转发' : '已关闭,无法解密时会拒绝转发' }}</small>
</span>
<input
type="checkbox"
:checked="runtimeSettings.allow_encrypted_forwarding"
:disabled="settingsLoading"
@change="saveEncryptedForwarding(($event.target as HTMLInputElement).checked)"
/>
<span class="switch-toggle" aria-hidden="true"></span>
</label>
</div>
<p v-if="settingsError" class="error">{{ settingsError }}</p>
<p v-if="settingsMessage" class="success">{{ settingsMessage }}</p>
</div>
<div class="panel admin-status-panel"> <div class="panel admin-status-panel">
<div class="panel-header"> <div class="panel-header">
<div> <div>
@@ -105,3 +182,176 @@ onBeforeUnmount(() => {
</div> </div>
</section> </section>
</template> </template>
<style scoped>
.mqtt-control-panel {
position: relative;
overflow: hidden;
display: flex;
flex-direction: column;
gap: 1rem;
border: 1px solid rgba(37, 99, 235, 0.14);
background:
radial-gradient(circle at top right, rgba(59, 130, 246, 0.16), transparent 32%),
linear-gradient(135deg, #ffffff 0%, #f8fbff 52%, #eef6ff 100%);
}
.control-header {
position: relative;
align-items: flex-start;
}
.control-title {
display: flex;
align-items: center;
gap: 0.85rem;
}
.control-badge {
display: inline-flex;
align-items: center;
border: 1px solid #cbd5e1;
border-radius: 999px;
padding: 6px 12px;
color: #475569;
font-size: 12px;
font-weight: 800;
background: rgba(255, 255, 255, 0.8);
}
.control-badge.active {
border-color: rgba(22, 163, 74, 0.32);
color: #15803d;
background: #dcfce7;
}
.control-body {
position: relative;
display: grid;
grid-template-columns: minmax(0, 1fr) minmax(320px, 0.85fr);
gap: 1rem;
align-items: stretch;
}
.control-copy,
.switch-card {
border: 1px solid rgba(203, 213, 225, 0.78);
border-radius: 18px;
background: rgba(255, 255, 255, 0.86);
box-shadow: 0 14px 36px rgba(15, 23, 42, 0.06);
}
.control-copy {
padding: 1rem;
}
.control-copy h3 {
margin: 0 0 0.45rem;
color: #0f172a;
font-size: 18px;
}
.control-copy p {
margin: 0;
color: #64748b;
line-height: 1.7;
}
.control-empty {
align-self: center;
}
.switch-card {
position: relative;
display: flex;
align-items: center;
justify-content: space-between;
gap: 1rem;
min-height: 108px;
padding: 1rem;
color: #334155;
cursor: pointer;
transition: transform 0.15s ease, border-color 0.15s ease, box-shadow 0.15s ease, background 0.15s ease;
}
.switch-card:hover {
transform: translateY(-1px);
border-color: rgba(37, 99, 235, 0.35);
box-shadow: 0 18px 44px rgba(15, 23, 42, 0.09);
}
.switch-card.enabled {
border-color: rgba(22, 163, 74, 0.35);
background: linear-gradient(135deg, #ffffff 0%, #f0fdf4 100%);
}
.switch-card.saving {
cursor: wait;
opacity: 0.76;
}
.switch-card input {
position: absolute;
opacity: 0;
pointer-events: none;
}
.switch-text {
display: flex;
flex-direction: column;
gap: 0.35rem;
}
.switch-text strong {
color: #0f172a;
font-size: 15px;
}
.switch-text small {
color: #64748b;
font-size: 12px;
line-height: 1.45;
}
.switch-toggle {
position: relative;
flex: 0 0 auto;
width: 54px;
height: 30px;
border-radius: 999px;
background: #cbd5e1;
box-shadow: inset 0 2px 4px rgba(15, 23, 42, 0.14);
transition: background 0.15s ease;
}
.switch-toggle::after {
content: '';
position: absolute;
top: 4px;
left: 4px;
width: 22px;
height: 22px;
border-radius: 999px;
background: #fff;
box-shadow: 0 4px 10px rgba(15, 23, 42, 0.24);
transition: transform 0.15s ease;
}
.switch-card.enabled .switch-toggle {
background: linear-gradient(135deg, #16a34a, #22c55e);
}
.switch-card.enabled .switch-toggle::after {
transform: translateX(24px);
}
@media (max-width: 820px) {
.control-body {
grid-template-columns: 1fr;
}
.control-header {
gap: 0.75rem;
}
}
</style>
+12
View File
@@ -225,6 +225,18 @@ export interface AdminMqttClient {
remote_port: string remote_port: string
} }
export interface AdminRuntimeSettings {
allow_encrypted_forwarding: boolean
}
export interface AdminRuntimeSettingsPayload {
allow_encrypted_forwarding: boolean
}
export interface AdminRuntimeSettingsResponse {
item: AdminRuntimeSettings
}
export interface AdminMqttStatus { export interface AdminMqttStatus {
running: boolean running: boolean
address: string address: string
+6 -2
View File
@@ -32,6 +32,10 @@ var defaultMeshtasticPSK = []byte{
0xCF, 0x4E, 0x69, 0x01, 0xCF, 0x4E, 0x69, 0x01,
} }
type Options struct {
AllowEncryptedForwarding bool
}
type serviceEnvelope struct { type serviceEnvelope struct {
Packet *meshPacket Packet *meshPacket
ChannelID string ChannelID string
@@ -115,7 +119,7 @@ type telemetryInfo struct {
// MQTTPP 处理一个 MQTT 原始 payload,返回合规状态、原始数据和解码后的记录。 // MQTTPP 处理一个 MQTT 原始 payload,返回合规状态、原始数据和解码后的记录。
// 第一个返回值表示数据是否合规;第二个返回值在不合规时为 nil;第三个返回值是解码结果记录。 // 第一个返回值表示数据是否合规;第二个返回值在不合规时为 nil;第三个返回值是解码结果记录。
func MQTTPP(topic string, raw []byte, key []byte) (bool, []byte, map[string]any) { func MQTTPP(topic string, raw []byte, key []byte, opts Options) (bool, []byte, map[string]any) {
env, err := parseServiceEnvelope(raw) env, err := parseServiceEnvelope(raw)
if err != nil { if err != nil {
@@ -127,7 +131,7 @@ func MQTTPP(topic string, raw []byte, key []byte) (bool, []byte, map[string]any)
//解码失败 //解码失败
return false, nil, map[string]any{"topic": topic, "error": err.Error(), "payload_len": len(raw)} return false, nil, map[string]any{"topic": topic, "error": err.Error(), "payload_len": len(raw)}
} }
if record["type"] == "encrypted_packet" { if record["type"] == "encrypted_packet" && !opts.AllowEncryptedForwarding {
record["error"] = "cannot be decrypted" record["error"] = "cannot be decrypted"
return false, nil, record return false, nil, record
} }
+48
View File
@@ -0,0 +1,48 @@
package mqtpp
import (
"testing"
"google.golang.org/protobuf/encoding/protowire"
)
func TestMQTTPPEncryptedPacketDefaultRejected(t *testing.T) {
raw := encryptedServiceEnvelopeTestPayload()
valid, payload, record := MQTTPP("msh/test", raw, nil, Options{})
if valid {
t.Fatalf("valid = true, want false")
}
if payload != nil {
t.Fatalf("payload = %v, want nil", payload)
}
if record["type"] != "encrypted_packet" {
t.Fatalf("type = %v, want encrypted_packet", record["type"])
}
if record["error"] != "cannot be decrypted" {
t.Fatalf("error = %v, want cannot be decrypted", record["error"])
}
}
func TestMQTTPPEncryptedPacketAllowed(t *testing.T) {
raw := encryptedServiceEnvelopeTestPayload()
valid, payload, record := MQTTPP("msh/test", raw, nil, Options{AllowEncryptedForwarding: true})
if !valid {
t.Fatalf("valid = false, want true: %+v", record)
}
if string(payload) != string(raw) {
t.Fatalf("payload = %v, want raw payload", payload)
}
if record["type"] != "encrypted_packet" {
t.Fatalf("type = %v, want encrypted_packet", record["type"])
}
if record["error"] != nil {
t.Fatalf("error = %v, want nil", record["error"])
}
}
func encryptedServiceEnvelopeTestPayload() []byte {
packet := protowire.AppendTag(nil, 5, protowire.BytesType)
packet = protowire.AppendBytes(packet, []byte{1, 2, 3, 4})
envelope := protowire.AppendTag(nil, 1, protowire.BytesType)
return protowire.AppendBytes(envelope, packet)
}
+47
View File
@@ -0,0 +1,47 @@
package main
import (
"fmt"
"sync"
)
type runtimeSettingsCache struct {
mu sync.RWMutex
settings runtimeSettingsSnapshot
}
func newRuntimeSettingsCache(store *store) (*runtimeSettingsCache, error) {
cache := &runtimeSettingsCache{}
if err := cache.Reload(store); err != nil {
return nil, err
}
return cache, nil
}
func (c *runtimeSettingsCache) Reload(store *store) error {
if store == nil {
return fmt.Errorf("store is required")
}
settings, err := store.GetRuntimeSettings()
if err != nil {
return err
}
c.mu.Lock()
c.settings = settings
c.mu.Unlock()
return nil
}
func (c *runtimeSettingsCache) Snapshot() runtimeSettingsSnapshot {
if c == nil {
return runtimeSettingsSnapshot{}
}
c.mu.RLock()
defer c.mu.RUnlock()
return c.settings
}
func (c *runtimeSettingsCache) AllowEncryptedForwarding() bool {
return c.Snapshot().AllowEncryptedForwarding
}
+36
View File
@@ -0,0 +1,36 @@
package main
import "testing"
func TestRuntimeSettingsCacheReload(t *testing.T) {
st := openTestStore(t)
defer st.Close()
cache, err := newRuntimeSettingsCache(st)
if err != nil {
t.Fatalf("newRuntimeSettingsCache() error = %v", err)
}
if cache.AllowEncryptedForwarding() {
t.Fatalf("AllowEncryptedForwarding() = true, want false")
}
if _, err := st.SetBoolRuntimeSetting(runtimeSettingAllowEncryptedForwarding, true, "test setting"); err != nil {
t.Fatalf("SetBoolRuntimeSetting(true) error = %v", err)
}
if err := cache.Reload(st); err != nil {
t.Fatalf("Reload() after true error = %v", err)
}
if !cache.AllowEncryptedForwarding() {
t.Fatalf("AllowEncryptedForwarding() = false, want true")
}
if _, err := st.SetBoolRuntimeSetting(runtimeSettingAllowEncryptedForwarding, false, "test setting"); err != nil {
t.Fatalf("SetBoolRuntimeSetting(false) error = %v", err)
}
if err := cache.Reload(st); err != nil {
t.Fatalf("Reload() after false error = %v", err)
}
if cache.AllowEncryptedForwarding() {
t.Fatalf("AllowEncryptedForwarding() = true, want false")
}
}
+82
View File
@@ -0,0 +1,82 @@
package main
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
const (
runtimeSettingAllowEncryptedForwarding = "mqtt.allow_encrypted_forwarding"
runtimeSettingTypeBool = "bool"
)
type runtimeSettingsSnapshot struct {
AllowEncryptedForwarding bool
}
func (s *store) GetRuntimeSettings() (runtimeSettingsSnapshot, error) {
allowEncrypted, err := s.GetBoolRuntimeSetting(runtimeSettingAllowEncryptedForwarding, false)
if err != nil {
return runtimeSettingsSnapshot{}, err
}
return runtimeSettingsSnapshot{AllowEncryptedForwarding: allowEncrypted}, nil
}
func (s *store) GetBoolRuntimeSetting(key string, defaultValue bool) (bool, error) {
key = strings.TrimSpace(key)
if key == "" {
return false, fmt.Errorf("runtime setting key is required")
}
var row runtimeSettingRecord
err := s.db.Where("key = ?", key).Take(&row).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return defaultValue, nil
}
if err != nil {
return false, err
}
if row.ValueType != "" && row.ValueType != runtimeSettingTypeBool {
return false, fmt.Errorf("runtime setting %s has type %s, want %s", key, row.ValueType, runtimeSettingTypeBool)
}
value, err := strconv.ParseBool(strings.TrimSpace(row.Value))
if err != nil {
return false, fmt.Errorf("parse runtime setting %s: %w", key, err)
}
return value, nil
}
func (s *store) SetBoolRuntimeSetting(key string, value bool, label string) (*runtimeSettingRecord, error) {
key = strings.TrimSpace(key)
if key == "" {
return nil, fmt.Errorf("runtime setting key is required")
}
row := runtimeSettingRecord{
Key: key,
Value: strconv.FormatBool(value),
ValueType: runtimeSettingTypeBool,
Label: strings.TrimSpace(label),
}
if err := s.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "key"}},
DoUpdates: clause.Assignments(map[string]any{
"value": row.Value,
"value_type": row.ValueType,
"label": row.Label,
"updated_at": time.Now(),
}),
}).Create(&row).Error; err != nil {
return nil, err
}
if err := s.db.Where("key = ?", key).Take(&row).Error; err != nil {
return nil, err
}
return &row, nil
}
+38
View File
@@ -0,0 +1,38 @@
package main
import "testing"
func TestRuntimeSettingsDefaultAndUpdates(t *testing.T) {
st := openTestStore(t)
defer st.Close()
settings, err := st.GetRuntimeSettings()
if err != nil {
t.Fatalf("GetRuntimeSettings() error = %v", err)
}
if settings.AllowEncryptedForwarding {
t.Fatalf("AllowEncryptedForwarding = true, want false")
}
if _, err := st.SetBoolRuntimeSetting(runtimeSettingAllowEncryptedForwarding, true, "test setting"); err != nil {
t.Fatalf("SetBoolRuntimeSetting(true) error = %v", err)
}
settings, err = st.GetRuntimeSettings()
if err != nil {
t.Fatalf("GetRuntimeSettings() after true error = %v", err)
}
if !settings.AllowEncryptedForwarding {
t.Fatalf("AllowEncryptedForwarding = false, want true")
}
if _, err := st.SetBoolRuntimeSetting(runtimeSettingAllowEncryptedForwarding, false, "test setting"); err != nil {
t.Fatalf("SetBoolRuntimeSetting(false) error = %v", err)
}
settings, err = st.GetRuntimeSettings()
if err != nil {
t.Fatalf("GetRuntimeSettings() after false error = %v", err)
}
if settings.AllowEncryptedForwarding {
t.Fatalf("AllowEncryptedForwarding = true, want false")
}
}
+6 -5
View File
@@ -14,10 +14,10 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
func newHTTPServer(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache, forwarder mqttForwardReloader) *http.Server { func newHTTPServer(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache, forwarder mqttForwardReloader, settings *runtimeSettingsCache) *http.Server {
return &http.Server{ return &http.Server{
Addr: net.JoinHostPort(cfg.Host, strconv.Itoa(cfg.Port)), Addr: net.JoinHostPort(cfg.Host, strconv.Itoa(cfg.Port)),
Handler: newRouter(cfg, store, sessions, mqttStatus, blocking, forwarder), Handler: newRouter(cfg, store, sessions, mqttStatus, blocking, forwarder, settings),
} }
} }
@@ -47,12 +47,12 @@ func serveHTTPUnixSocket(server *http.Server, socketPath string) error {
return server.Serve(listener) return server.Serve(listener)
} }
func newRouter(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache, forwarder mqttForwardReloader) *gin.Engine { func newRouter(cfg webConfig, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache, forwarder mqttForwardReloader, settings *runtimeSettingsCache) *gin.Engine {
r := gin.New() r := gin.New()
r.Use(gin.Logger(), gin.Recovery()) r.Use(gin.Logger(), gin.Recovery())
api := r.Group("/api") api := r.Group("/api")
registerAPIRoutes(api, store) registerAPIRoutes(api, store)
registerAdminRoutes(api.Group("/admin"), store, sessions, mqttStatus, blocking, forwarder) registerAdminRoutes(api.Group("/admin"), store, sessions, mqttStatus, blocking, forwarder, settings)
registerStaticRoutes(r, cfg.StaticDir) registerStaticRoutes(r, cfg.StaticDir)
return r return r
} }
@@ -123,7 +123,7 @@ func registerAPIRoutes(r gin.IRouter, store *store) {
}) })
} }
func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache, forwarder mqttForwardReloader) { func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager, mqttStatus mqttStatusProvider, blocking *blockingCache, forwarder mqttForwardReloader, settings *runtimeSettingsCache) {
type loginRequest struct { type loginRequest struct {
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
@@ -185,6 +185,7 @@ func registerAdminRoutes(r gin.IRouter, store *store, sessions *sessionManager,
protected.Use(requireAdmin(sessions)) protected.Use(requireAdmin(sessions))
registerAdminBlockingRoutes(protected, store, blocking) registerAdminBlockingRoutes(protected, store, blocking)
registerAdminMQTTForwardRoutes(protected, store, forwarder) registerAdminMQTTForwardRoutes(protected, store, forwarder)
registerAdminRuntimeSettingsRoutes(protected, store, settings)
registerAdminHelpRoutes(protected, store) registerAdminHelpRoutes(protected, store)
protected.GET("/me", func(c *gin.Context) { protected.GET("/me", func(c *gin.Context) {
claims := c.MustGet("admin_claims").(*sessionClaims) claims := c.MustGet("admin_claims").(*sessionClaims)