package main
import (
"html/template"
"log"
"os"
"path/filepath"
"strings"
"simple_portal/database"
"simple_portal/handlers"
"simple_portal/middleware"
"simple_portal/session"
"github.com/gin-gonic/gin"
)
// loadTemplates loads HTML templates from templates/ directory recursively.
// Custom implementation because Go's ParseGlob has issues with directories on Windows.
func loadTemplates() *template.Template {
funcMap := template.FuncMap{
"hasPrefix": strings.HasPrefix,
"sub": func(a, b int) int { return a - b },
"add": func(a, b int) int { return a + b },
}
t := template.New("").Funcs(funcMap)
// 收集所有 .html 模板文件路径
var files []string
filepath.Walk("templates", func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && filepath.Ext(path) == ".html" {
files = append(files, path)
}
return nil
})
if len(files) == 0 {
log.Fatal("No template files found in templates/")
}
// 将 Windows 反斜杠路径转为正斜杠,避免模板名问题
for i, f := range files {
files[i] = filepath.ToSlash(f)
}
var terr error
t, terr = t.ParseFiles(files...)
if terr != nil {
log.Fatalf("Failed to parse templates: %v", terr)
}
return t
}
func main() {
// Initialize database
if err := database.InitDB(); err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
defer database.CloseDB()
// Create uploads directory
if err := os.MkdirAll(filepath.Join(".", "data", "uploads"), 0755); err != nil {
log.Fatalf("Failed to create uploads directory: %v", err)
}
// Create session store
sessionStore := session.NewSessionStore()
// Create IP ban guard (in-memory fail counter)
ipBanGuard := middleware.NewIPBanGuard()
// Set Gin mode
ginMode := os.Getenv("GIN_MODE")
if ginMode == "" {
gin.SetMode(gin.DebugMode)
}
r := gin.Default()
// Load HTML templates (custom loader for nested directories)
r.SetHTMLTemplate(loadTemplates())
// Serve static files
r.Static("/static", "./static")
// Inject session store and IP ban guard into context for handlers
r.Use(func(c *gin.Context) {
c.Set("sessionStore", sessionStore)
c.Set("ipBanGuard", ipBanGuard)
c.Next()
})
// Public routes (home page and uploads — no IP restriction)
r.GET("/", handlers.HomeHandler)
r.GET("/click/:id", handlers.CardClickHandler)
r.GET("/search", handlers.SearchHandler)
r.GET("/uploads/:filename", handlers.ServeUploadHandler)
// Admin routes with IP whitelist check applied to all /admin/* routes
adminGroup := r.Group("/admin")
adminGroup.Use(middleware.IPWhitelistRequired(func(sessionID string) bool {
return sessionStore.Get(sessionID) != nil
}))
{
// Public admin routes (login — no auth required, but IP whitelist applies)
adminGroup.GET("/login", handlers.LoginGet)
adminGroup.POST("/login", handlers.LoginPost)
// Protected admin routes (auth required)
protected := adminGroup.Group("")
protected.Use(middleware.AuthRequired(sessionStore))
{
protected.POST("/logout", handlers.Logout)
protected.GET("/", handlers.AdminIndex)
// Cards management
protected.GET("/cards", handlers.CardsList)
protected.GET("/cards/new", handlers.CardCreateGet)
protected.POST("/cards", handlers.CardCreatePost)
protected.GET("/cards/:id/edit", handlers.CardEditGet)
protected.POST("/cards/:id", handlers.CardEditPost)
protected.POST("/cards/:id/delete", handlers.CardDelete)
protected.POST("/cards/:id/toggle", handlers.CardToggle)
protected.POST("/cards/:id/move-up", handlers.CardMoveUp)
protected.POST("/cards/:id/move-down", handlers.CardMoveDown)
// Image upload
protected.POST("/upload", handlers.UploadHandler)
// Settings
protected.GET("/settings", handlers.SettingsGet)
protected.POST("/settings", handlers.SettingsPost)
// Security: login logs
protected.GET("/logs", handlers.LoginLogsGet)
protected.POST("/logs/unban/:id", handlers.UnbanIP)
// Security: change password
protected.GET("/password", handlers.ChangePasswordGet)
protected.POST("/password", handlers.ChangePasswordPost)
// Security: IP whitelist management
protected.GET("/ip-whitelist", handlers.IPWhitelistGet)
protected.POST("/ip-whitelist/add", handlers.IPWhitelistAdd)
protected.POST("/ip-whitelist/:id/delete", handlers.IPWhitelistDelete)
// Analytics: access logs
protected.GET("/access-logs", handlers.AccessLogsGet)
}
}
// Determine port
port := os.Getenv("PORT")
if port == "" {
port = "8080"
}
log.Printf("Starting Portal server on :%s", port)
if err := r.Run(":" + port); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
}