116 lines
3.4 KiB
Go
116 lines
3.4 KiB
Go
package store
|
|
|
|
import (
|
|
"time"
|
|
|
|
"mail_go/internal/db"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// BanStore defines the interface for IP ban operations.
|
|
type BanStore interface {
|
|
Create(entry *db.BanEntry) error
|
|
GetByIP(ip string) (*db.BanEntry, error)
|
|
Delete(id uint) error
|
|
List(page, size int) ([]db.BanEntry, int64, error)
|
|
IsBanned(ip string) (bool, *db.BanEntry)
|
|
IncrementFail(ip string) (int, error)
|
|
ResetFail(ip string) error
|
|
Cleanup() error
|
|
}
|
|
|
|
// banStoreGorm implements BanStore using GORM.
|
|
type banStoreGorm struct {
|
|
db *gorm.DB
|
|
}
|
|
|
|
// newBanStore creates a new GORM-backed BanStore.
|
|
func newBanStore(database *gorm.DB) BanStore {
|
|
return &banStoreGorm{db: database}
|
|
}
|
|
|
|
// Create inserts a new ban entry record.
|
|
func (s *banStoreGorm) Create(entry *db.BanEntry) error {
|
|
return s.db.Create(entry).Error
|
|
}
|
|
|
|
// GetByIP retrieves the most recent ban entry for a given IP address.
|
|
func (s *banStoreGorm) GetByIP(ip string) (*db.BanEntry, error) {
|
|
var entry db.BanEntry
|
|
if err := s.db.Where("ip_address = ?", ip).Order("id DESC").First(&entry).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
return &entry, nil
|
|
}
|
|
|
|
// Delete removes a ban entry by ID.
|
|
func (s *banStoreGorm) Delete(id uint) error {
|
|
return s.db.Delete(&db.BanEntry{}, id).Error
|
|
}
|
|
|
|
// List retrieves a paginated list of ban entries.
|
|
func (s *banStoreGorm) List(page, size int) ([]db.BanEntry, int64, error) {
|
|
var entries []db.BanEntry
|
|
var total int64
|
|
|
|
if err := s.db.Model(&db.BanEntry{}).Count(&total).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
offset := (page - 1) * size
|
|
if err := s.db.Order("id DESC").Offset(offset).Limit(size).Find(&entries).Error; err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return entries, total, nil
|
|
}
|
|
|
|
// IsBanned checks whether an IP address is currently banned.
|
|
// An IP is considered banned if there is a record with expires_at in the future.
|
|
func (s *banStoreGorm) IsBanned(ip string) (bool, *db.BanEntry) {
|
|
var entry db.BanEntry
|
|
if err := s.db.Where("ip_address = ? AND expires_at > ?", ip, time.Now()).First(&entry).Error; err != nil {
|
|
return false, nil
|
|
}
|
|
return true, &entry
|
|
}
|
|
|
|
// IncrementFail increments the fail count for an IP address.
|
|
// If no record exists, it creates one with fail_count=1 and a zero expires_at.
|
|
// Returns the updated fail count.
|
|
func (s *banStoreGorm) IncrementFail(ip string) (int, error) {
|
|
var entry db.BanEntry
|
|
err := s.db.Where("ip_address = ?", ip).First(&entry).Error
|
|
if err != nil {
|
|
// No record exists, create a new one
|
|
entry = db.BanEntry{
|
|
IPAddress: ip,
|
|
FailCount: 1,
|
|
ExpiresAt: time.Time{}, // Zero time, not yet banned
|
|
}
|
|
if createErr := s.db.Create(&entry).Error; createErr != nil {
|
|
return 0, createErr
|
|
}
|
|
return 1, nil
|
|
}
|
|
|
|
// Record exists, increment fail count
|
|
newCount := entry.FailCount + 1
|
|
if updateErr := s.db.Model(&entry).Update("fail_count", newCount).Error; updateErr != nil {
|
|
return 0, updateErr
|
|
}
|
|
return newCount, nil
|
|
}
|
|
|
|
// ResetFail resets the fail count for an IP address by deleting its record.
|
|
func (s *banStoreGorm) ResetFail(ip string) error {
|
|
return s.db.Where("ip_address = ?", ip).Delete(&db.BanEntry{}).Error
|
|
}
|
|
|
|
// Cleanup removes expired ban entries.
|
|
// It deletes records where expires_at is in the past and is not zero
|
|
// (preserving records that have fail counts but are not yet banned).
|
|
func (s *banStoreGorm) Cleanup() error {
|
|
return s.db.Where("expires_at < ? AND expires_at > ?", time.Now(), time.Time{}).Delete(&db.BanEntry{}).Error
|
|
}
|