package main import ( "crypto/ecdh" "crypto/rand" "encoding/base64" "encoding/binary" "errors" "fmt" "strings" "time" "unicode/utf8" "meshtastic_mqtt_server/mqtpp" "gorm.io/gorm" ) const ( botDefaultTopicPrefix = "msh/CN" botDefaultPSK = "AQ==" botDefaultNodeInfoBroadcastSeconds = int64(3600) botMessageTypeChannel = "channel" botMessageTypeDirect = "direct" botMessageStatusPending = "pending" botMessageStatusPublished = "published" botMessageStatusFailed = "failed" ) var errBotNodeAlreadyExists = errors.New("bot node already exists") type botNodeInput struct { NodeNum *int64 LongName string ShortName string Enabled bool DefaultChannelID string TopicPrefix string PSK string NodeInfoBroadcastEnabled bool NodeInfoBroadcastIntervalSeconds int64 } type botMessageListOptions struct { listOptions BotID uint64 MessageType string ChannelID string } func (s *store) ListBotNodes(opts listOptions) ([]botNodeRecord, error) { opts = normalizeListOptions(opts) var rows []botNodeRecord q := s.db.Model(&botNodeRecord{}). Order("updated_at DESC"). Order("id DESC"). Limit(opts.Limit). Offset(opts.Offset) return rows, q.Find(&rows).Error } func (s *store) CountBotNodes(opts listOptions) (int64, error) { var total int64 return total, s.db.Model(&botNodeRecord{}).Count(&total).Error } func (s *store) GetBotNode(id uint64) (*botNodeRecord, error) { var row botNodeRecord if err := s.db.Where("id = ?", id).Take(&row).Error; err != nil { return nil, err } return &row, nil } func (s *store) CreateBotNode(input botNodeInput) (*botNodeRecord, error) { row, err := s.normalizedBotNodeRecord(input) if err != nil { return nil, err } if err := s.ensureBotNodeUnique(0, row.NodeID, row.NodeNum); err != nil { return nil, err } if err := s.ensureBotNodeDoesNotConflictWithNodeInfo(row.NodeNum); err != nil { return nil, err } if err := populateBotNodeKeys(row); err != nil { return nil, err } if err := s.db.Create(row).Error; err != nil { return nil, err } return row, nil } func (s *store) UpdateBotNode(id uint64, input botNodeInput) (*botNodeRecord, error) { if id == 0 { return nil, fmt.Errorf("bot node id is required") } if _, err := s.GetBotNode(id); err != nil { return nil, err } row, err := s.normalizedBotNodeRecord(input) if err != nil { return nil, err } if err := s.ensureBotNodeUnique(id, row.NodeID, row.NodeNum); err != nil { return nil, err } if err := s.ensureBotNodeDoesNotConflictWithNodeInfo(row.NodeNum); err != nil { return nil, err } updates := map[string]any{ "node_id": row.NodeID, "node_num": row.NodeNum, "long_name": row.LongName, "short_name": row.ShortName, "enabled": row.Enabled, "default_channel_id": row.DefaultChannelID, "topic_prefix": row.TopicPrefix, "psk": row.PSK, "nodeinfo_broadcast_enabled": row.NodeInfoBroadcastEnabled, "nodeinfo_broadcast_interval_seconds": row.NodeInfoBroadcastIntervalSeconds, "updated_at": time.Now(), } if err := s.db.Model(&botNodeRecord{}).Where("id = ?", id).Updates(updates).Error; err != nil { return nil, err } return s.GetBotNode(id) } func (s *store) DeleteBotNode(id uint64) error { result := s.db.Where("id = ?", id).Delete(&botNodeRecord{}) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return gorm.ErrRecordNotFound } return nil } func (s *store) InsertBotMessage(row *botMessageRecord) error { return s.db.Create(row).Error } func (s *store) UpdateBotMessageStatus(id uint64, status, errText string, publishedAt *time.Time) error { updates := map[string]any{"status": status, "error": strings.TrimSpace(errText), "published_at": publishedAt} result := s.db.Model(&botMessageRecord{}).Where("id = ?", id).Updates(updates) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return gorm.ErrRecordNotFound } return nil } func (s *store) UpdateBotNodeInfoBroadcastAt(id uint64, t time.Time) error { result := s.db.Model(&botNodeRecord{}).Where("id = ?", id).Updates(map[string]any{"last_nodeinfo_broadcast_at": &t, "updated_at": time.Now()}) if result.Error != nil { return result.Error } if result.RowsAffected == 0 { return gorm.ErrRecordNotFound } return nil } func (s *store) RegenerateBotNodeKeys(id uint64) (*botNodeRecord, error) { if id == 0 { return nil, fmt.Errorf("bot node id is required") } row, err := s.GetBotNode(id) if err != nil { return nil, err } if err := populateBotNodeKeys(row); err != nil { return nil, err } updates := map[string]any{"public_key": row.PublicKey, "private_key": row.PrivateKey, "updated_at": time.Now()} if err := s.db.Model(&botNodeRecord{}).Where("id = ?", id).Updates(updates).Error; err != nil { return nil, err } return s.GetBotNode(id) } func (s *store) ListBotMessages(opts botMessageListOptions) ([]botMessageRecord, error) { opts.listOptions = normalizeListOptions(opts.listOptions) var rows []botMessageRecord q := applyBotMessageFilters(s.db.Model(&botMessageRecord{}), opts). Order("created_at DESC"). Order("id DESC"). Limit(opts.Limit). Offset(opts.Offset) return rows, q.Find(&rows).Error } func (s *store) CountBotMessages(opts botMessageListOptions) (int64, error) { var total int64 q := applyBotMessageFilters(s.db.Model(&botMessageRecord{}), opts) return total, q.Count(&total).Error } func applyBotMessageFilters(q *gorm.DB, opts botMessageListOptions) *gorm.DB { if opts.BotID != 0 { q = q.Where("bot_id = ?", opts.BotID) } if opts.MessageType != "" { q = q.Where("message_type = ?", opts.MessageType) } if opts.ChannelID != "" { q = q.Where("channel_id = ?", opts.ChannelID) } if opts.Since != nil { q = q.Where("created_at >= ?", *opts.Since) } if opts.Until != nil { q = q.Where("created_at <= ?", *opts.Until) } return q } func (s *store) normalizedBotNodeRecord(input botNodeInput) (*botNodeRecord, error) { longName := strings.TrimSpace(input.LongName) shortName := strings.TrimSpace(input.ShortName) channelID := strings.TrimSpace(input.DefaultChannelID) psk := strings.TrimSpace(input.PSK) if psk == "" { psk = botDefaultPSK } if _, err := mqtpp.ExpandPSK(psk); err != nil { return nil, err } topicPrefix := strings.Trim(strings.TrimSpace(input.TopicPrefix), "/") if topicPrefix == "" { topicPrefix = botDefaultTopicPrefix } if longName == "" { return nil, fmt.Errorf("long name is required") } if !utf8.ValidString(longName) { return nil, fmt.Errorf("long name must be valid utf-8") } if shortName == "" { return nil, fmt.Errorf("short name is required") } if !utf8.ValidString(shortName) { return nil, fmt.Errorf("short name must be valid utf-8") } if channelID == "" { return nil, fmt.Errorf("default channel id is required") } interval := input.NodeInfoBroadcastIntervalSeconds if interval <= 0 { interval = botDefaultNodeInfoBroadcastSeconds } if interval < 60 { return nil, fmt.Errorf("nodeinfo broadcast interval must be at least 60 seconds") } var nodeNum int64 if input.NodeNum == nil || *input.NodeNum == 0 { generated, err := s.generateBotNodeNum() if err != nil { return nil, err } nodeNum = generated } else { nodeNum = *input.NodeNum } if err := validateBotNodeNum(nodeNum); err != nil { return nil, err } return &botNodeRecord{NodeID: mqtpp.NodeNumToID(uint32(nodeNum)), NodeNum: nodeNum, LongName: longName, ShortName: shortName, Enabled: input.Enabled, DefaultChannelID: channelID, TopicPrefix: topicPrefix, PSK: psk, NodeInfoBroadcastEnabled: input.NodeInfoBroadcastEnabled, NodeInfoBroadcastIntervalSeconds: interval}, nil } func populateBotNodeKeys(row *botNodeRecord) error { privateKey, err := ecdh.X25519().GenerateKey(rand.Reader) if err != nil { return err } row.PrivateKey = base64.StdEncoding.EncodeToString(privateKey.Bytes()) row.PublicKey = base64.StdEncoding.EncodeToString(privateKey.PublicKey().Bytes()) return nil } func decodeBotPublicKey(row botNodeRecord) ([]byte, error) { if strings.TrimSpace(row.PublicKey) == "" { return nil, nil } key, err := base64.StdEncoding.DecodeString(row.PublicKey) if err != nil { return nil, fmt.Errorf("invalid bot public key: %w", err) } return key, nil } func validateBotNodeNum(nodeNum int64) error { if nodeNum <= 0 || nodeNum >= int64(mqtpp.NodeNumBroadcast) { return fmt.Errorf("node num must be between 1 and 4294967294") } return nil } func (s *store) generateBotNodeNum() (int64, error) { for i := 0; i < 32; i++ { var buf [4]byte if _, err := rand.Read(buf[:]); err != nil { return 0, err } nodeNum := int64(binary.LittleEndian.Uint32(buf[:]) & 0x7fffffff) if err := validateBotNodeNum(nodeNum); err != nil { continue } if err := s.ensureBotNodeUnique(0, mqtpp.NodeNumToID(uint32(nodeNum)), nodeNum); err != nil { if errors.Is(err, errBotNodeAlreadyExists) { continue } return 0, err } if err := s.ensureBotNodeDoesNotConflictWithNodeInfo(nodeNum); err != nil { if errors.Is(err, errBotNodeAlreadyExists) { continue } return 0, err } return nodeNum, nil } return 0, fmt.Errorf("generate bot node num failed") } func (s *store) ensureBotNodeUnique(id uint64, nodeID string, nodeNum int64) error { var existing botNodeRecord q := s.db.Where("node_id = ? OR node_num = ?", nodeID, nodeNum) if id != 0 { q = q.Where("id <> ?", id) } err := q.Take(&existing).Error if err == nil { return errBotNodeAlreadyExists } if errors.Is(err, gorm.ErrRecordNotFound) { return nil } return err } func (s *store) ensureBotNodeDoesNotConflictWithNodeInfo(nodeNum int64) error { var existing nodeInfoRecord err := s.db.Where("node_num = ?", nodeNum).Take(&existing).Error if err == nil { return errBotNodeAlreadyExists } if errors.Is(err, gorm.ErrRecordNotFound) { return nil } return err }