package main import ( "html/template" "log" "os" "path/filepath" "strings" "simple_portal/database" "simple_portal/handlers" "simple_portal/middleware" "simple_portal/session" "github.com/gin-gonic/gin" ) // loadTemplates loads HTML templates from templates/ directory recursively. // Custom implementation because Go's ParseGlob has issues with directories on Windows. func loadTemplates() *template.Template { funcMap := template.FuncMap{ "hasPrefix": strings.HasPrefix, "sub": func(a, b int) int { return a - b }, "add": func(a, b int) int { return a + b }, } t := template.New("").Funcs(funcMap) // 收集所有 .html 模板文件路径 var files []string filepath.Walk("templates", func(path string, info os.FileInfo, err error) error { if err != nil { return err } if !info.IsDir() && filepath.Ext(path) == ".html" { files = append(files, path) } return nil }) if len(files) == 0 { log.Fatal("No template files found in templates/") } // 将 Windows 反斜杠路径转为正斜杠,避免模板名问题 for i, f := range files { files[i] = filepath.ToSlash(f) } var terr error t, terr = t.ParseFiles(files...) if terr != nil { log.Fatalf("Failed to parse templates: %v", terr) } return t } func main() { // Initialize database if err := database.InitDB(); err != nil { log.Fatalf("Failed to initialize database: %v", err) } defer database.CloseDB() // Create uploads directory if err := os.MkdirAll(filepath.Join(".", "data", "uploads"), 0755); err != nil { log.Fatalf("Failed to create uploads directory: %v", err) } // Create session store sessionStore := session.NewSessionStore() // Create IP ban guard (in-memory fail counter) ipBanGuard := middleware.NewIPBanGuard() // Set Gin mode ginMode := os.Getenv("GIN_MODE") if ginMode == "" { gin.SetMode(gin.DebugMode) } r := gin.Default() // Load HTML templates (custom loader for nested directories) r.SetHTMLTemplate(loadTemplates()) // Serve static files r.Static("/static", "./static") // Inject session store and IP ban guard into context for handlers r.Use(func(c *gin.Context) { c.Set("sessionStore", sessionStore) c.Set("ipBanGuard", ipBanGuard) c.Next() }) // Public routes (home page and uploads — no IP restriction) r.GET("/", handlers.HomeHandler) r.GET("/click/:id", handlers.CardClickHandler) r.GET("/search", handlers.SearchHandler) r.GET("/uploads/:filename", handlers.ServeUploadHandler) // Admin routes with IP whitelist check applied to all /admin/* routes adminGroup := r.Group("/admin") adminGroup.Use(middleware.IPWhitelistRequired(func(sessionID string) bool { return sessionStore.Get(sessionID) != nil })) { // Public admin routes (login — no auth required, but IP whitelist applies) adminGroup.GET("/login", handlers.LoginGet) adminGroup.POST("/login", handlers.LoginPost) // Protected admin routes (auth required) protected := adminGroup.Group("") protected.Use(middleware.AuthRequired(sessionStore)) { protected.POST("/logout", handlers.Logout) protected.GET("/", handlers.AdminIndex) // Cards management protected.GET("/cards", handlers.CardsList) protected.GET("/cards/new", handlers.CardCreateGet) protected.POST("/cards", handlers.CardCreatePost) protected.GET("/cards/:id/edit", handlers.CardEditGet) protected.POST("/cards/:id", handlers.CardEditPost) protected.POST("/cards/:id/delete", handlers.CardDelete) protected.POST("/cards/:id/toggle", handlers.CardToggle) protected.POST("/cards/:id/move-up", handlers.CardMoveUp) protected.POST("/cards/:id/move-down", handlers.CardMoveDown) // Image upload protected.POST("/upload", handlers.UploadHandler) // Settings protected.GET("/settings", handlers.SettingsGet) protected.POST("/settings", handlers.SettingsPost) // Security: login logs protected.GET("/logs", handlers.LoginLogsGet) protected.POST("/logs/unban/:id", handlers.UnbanIP) // Security: change password protected.GET("/password", handlers.ChangePasswordGet) protected.POST("/password", handlers.ChangePasswordPost) // Security: IP whitelist management protected.GET("/ip-whitelist", handlers.IPWhitelistGet) protected.POST("/ip-whitelist/add", handlers.IPWhitelistAdd) protected.POST("/ip-whitelist/:id/delete", handlers.IPWhitelistDelete) // Analytics: access logs protected.GET("/access-logs", handlers.AccessLogsGet) } } // Determine port port := os.Getenv("PORT") if port == "" { port = "8080" } log.Printf("Starting Portal server on :%s", port) if err := r.Run(":" + port); err != nil { log.Fatalf("Failed to start server: %v", err) } }