算了,后端我自己写吧
This commit is contained in:
@@ -0,0 +1,154 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
)
|
||||
|
||||
// Config 全局配置
|
||||
type Config struct {
|
||||
Web WebConfig `yaml:"web"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
User UserConfig `yaml:"user"`
|
||||
File FileConfig `yaml:"file"`
|
||||
}
|
||||
|
||||
// WebConfig Web服务配置
|
||||
type WebConfig struct {
|
||||
Host string `yaml:"host"`
|
||||
Port string `yaml:"port"`
|
||||
TLS bool `yaml:"tls"`
|
||||
CertPrivatePath string `yaml:"certPrivatePath"`
|
||||
CertPublicPath string `yaml:"certPublicPath"`
|
||||
}
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
type DatabaseConfig struct {
|
||||
Type string `yaml:"type"` // sqlite, mysql, postgres
|
||||
Path string `yaml:"path"` // SQLite路径
|
||||
Host string `yaml:"host"`
|
||||
Port string `yaml:"port"`
|
||||
Name string `yaml:"name"`
|
||||
User string `yaml:"user"`
|
||||
Pass string `yaml:"pass"`
|
||||
}
|
||||
|
||||
// UserConfig 用户相关配置
|
||||
type UserConfig struct {
|
||||
CookieTimeout int `yaml:"cookieTimeout"`
|
||||
PassHashType string `yaml:"passHashType"` // text, md5, md5salt
|
||||
}
|
||||
|
||||
// FileConfig 文件上传配置
|
||||
type FileConfig struct {
|
||||
MaxSize uint64 `yaml:"maxSize"`
|
||||
Paths map[string]string `yaml:"paths"`
|
||||
AllowImageMime map[string]string `yaml:"allowImageMime"`
|
||||
AllowVideoMime map[string]string `yaml:"allowVideoMime"`
|
||||
AllowMusicMime map[string]string `yaml:"allowMusicMime"`
|
||||
AllowPdfMime map[string]string `yaml:"allowPdfMime"`
|
||||
}
|
||||
|
||||
// Current 全局配置实例
|
||||
var Current *Config
|
||||
|
||||
// Load 加载配置文件
|
||||
func Load(configPath string) error {
|
||||
// 如果配置文件不存在,创建默认配置
|
||||
if !fileExists(configPath) {
|
||||
if err := createDefaultConfig(configPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 读取配置文件
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析YAML
|
||||
config := &Config{}
|
||||
if err := yaml.Unmarshal(data, config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Current = config
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查文件是否存在
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// 创建默认配置文件
|
||||
func createDefaultConfig(path string) error {
|
||||
// 确保目录存在
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 默认配置
|
||||
defaultConfig := &Config{
|
||||
Web: WebConfig{
|
||||
Host: "127.0.0.1",
|
||||
Port: "8080",
|
||||
TLS: false,
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
Type: "sqlite",
|
||||
Path: "data/database.db",
|
||||
},
|
||||
User: UserConfig{
|
||||
CookieTimeout: 604800,
|
||||
PassHashType: "md5",
|
||||
},
|
||||
File: FileConfig{
|
||||
MaxSize: 52428800, // 50MB
|
||||
Paths: map[string]string{
|
||||
"avatar": "data/static/avatar/",
|
||||
"image": "data/upload/image/",
|
||||
"video": "data/upload/video/",
|
||||
"music": "data/upload/music/",
|
||||
"pdf": "data/upload/pdf/",
|
||||
"other": "data/upload/other/",
|
||||
},
|
||||
AllowImageMime: map[string]string{
|
||||
"image/jpeg": ".jpeg",
|
||||
"image/png": ".png",
|
||||
"image/gif": ".gif",
|
||||
"image/bmp": ".bmp",
|
||||
},
|
||||
AllowVideoMime: map[string]string{
|
||||
"video/mp4": ".mp4",
|
||||
"video/x-msvideo": ".avi",
|
||||
"video/quicktime": ".mov",
|
||||
"video/x-flv": ".flv",
|
||||
"video/mpeg": ".mpeg",
|
||||
},
|
||||
AllowMusicMime: map[string]string{
|
||||
"audio/mpeg": ".mpeg",
|
||||
"audio/aac": ".aac",
|
||||
"audio/wav": ".wav",
|
||||
"audio/flac": ".flac",
|
||||
},
|
||||
AllowPdfMime: map[string]string{
|
||||
"application/pdf": ".pdf",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 序列化为YAML
|
||||
data, err := yaml.Marshal(defaultConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
return os.WriteFile(path, data, 0644)
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"ops/internal/config"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// DB 全局数据库实例
|
||||
var DB *gorm.DB
|
||||
|
||||
// Init 初始化数据库连接
|
||||
func Init() error {
|
||||
cfg := config.Current.Database
|
||||
|
||||
var dialector gorm.Dialector
|
||||
|
||||
switch cfg.Type {
|
||||
case "sqlite":
|
||||
dialector = sqlite.Open(cfg.Path)
|
||||
case "mysql":
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
cfg.User, cfg.Pass, cfg.Host, cfg.Port, cfg.Name)
|
||||
dialector = mysql.Open(dsn)
|
||||
case "postgres", "pg":
|
||||
dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%s sslmode=disable TimeZone=Asia/Shanghai",
|
||||
cfg.Host, cfg.User, cfg.Pass, cfg.Name, cfg.Port)
|
||||
dialector = postgres.Open(dsn)
|
||||
default:
|
||||
return fmt.Errorf("不支持的数据库类型: %s", cfg.Type)
|
||||
}
|
||||
|
||||
// 配置GORM
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
}
|
||||
|
||||
// 连接数据库
|
||||
var err error
|
||||
DB, err = gorm.Open(dialector, gormConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("数据库连接失败: %v", err)
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
sqlDB, err := DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sqlDB.SetMaxIdleConns(10)
|
||||
sqlDB.SetMaxOpenConns(100)
|
||||
sqlDB.SetConnMaxLifetime(time.Hour)
|
||||
|
||||
log.Println("数据库连接成功")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDB 获取数据库实例
|
||||
func GetDB() *gorm.DB {
|
||||
return DB
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func Close() error {
|
||||
if DB != nil {
|
||||
sqlDB, err := DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package database
|
||||
|
||||
// AutoMigrate 自动迁移所有表
|
||||
func AutoMigrate() error {
|
||||
models := []interface{}{
|
||||
&TabUser{},
|
||||
&TabUserGroups{},
|
||||
&TabUserGroupBinds{},
|
||||
&TabUserInfo{},
|
||||
&TabCookie{},
|
||||
&TabFileInfo{},
|
||||
&APIRequestLog{},
|
||||
&TabPurchaseOrder{},
|
||||
&TabPurchaseCosts{},
|
||||
}
|
||||
|
||||
if err := DB.AutoMigrate(models...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TabUser 用户表
|
||||
type TabUser struct {
|
||||
ID uint `gorm:"primarykey;autoIncrement"`
|
||||
Name string `gorm:"type:varchar(64);uniqueIndex"`
|
||||
}
|
||||
|
||||
// TabUserGroups 用户组表
|
||||
type TabUserGroups struct {
|
||||
ID uint `gorm:"primarykey;autoIncrement"`
|
||||
Name string `gorm:"type:varchar(64);uniqueIndex"`
|
||||
}
|
||||
|
||||
// TabUserGroupBinds 用户-组绑定关系表
|
||||
type TabUserGroupBinds struct {
|
||||
UserID uint `gorm:"index"`
|
||||
GroupID uint `gorm:"index"`
|
||||
}
|
||||
|
||||
// TabUserInfo 用户详情表
|
||||
type TabUserInfo struct {
|
||||
UserID uint `gorm:"primaryKey"`
|
||||
AvatarPath string `gorm:"type:text"`
|
||||
Birthdate string `gorm:"type:varchar(16)"`
|
||||
Gender int
|
||||
Introduction string `gorm:"type:text"`
|
||||
}
|
||||
|
||||
// TabCookie Session Cookie表
|
||||
type TabCookie struct {
|
||||
Value string `gorm:"primaryKey;type:varchar(64)"`
|
||||
UserID uint `gorm:"index"`
|
||||
ExpiresAt int64
|
||||
CreateAt int64
|
||||
Remember bool
|
||||
}
|
||||
|
||||
// TabFileInfo 文件信息表
|
||||
type TabFileInfo struct {
|
||||
ID uint `gorm:"primarykey;autoIncrement"`
|
||||
Path string `gorm:"type:text"`
|
||||
Hash string `gorm:"index"`
|
||||
Size int64
|
||||
CreateTime int64
|
||||
ExtName string `gorm:"type:varchar(16)"`
|
||||
MimeType string `gorm:"type:varchar(128)"`
|
||||
StoreType int // 1=image 2=video 3=music 4=pdf 5=other
|
||||
}
|
||||
|
||||
// APIRequestLog API请求日志表
|
||||
type APIRequestLog struct {
|
||||
ID uint `gorm:"primarykey;autoIncrement"`
|
||||
Time int64 `gorm:"index"`
|
||||
IP string `gorm:"type:varchar(64)"`
|
||||
Path string `gorm:"type:varchar(255)"`
|
||||
Method string `gorm:"type:varchar(16)"`
|
||||
Status int
|
||||
UserID uint
|
||||
UserType int
|
||||
DataSize int
|
||||
}
|
||||
|
||||
// TabPurchaseOrder 采购订单表
|
||||
type TabPurchaseOrder struct {
|
||||
ID uint `gorm:"primarykey;autoIncrement"`
|
||||
Title string `gorm:"type:varchar(255)"`
|
||||
CreateTime int64 `gorm:"index"`
|
||||
CompleteTime int64
|
||||
Status int // 状态:0=进行中 1=已完成 2=已取消
|
||||
CourierNum string `gorm:"type:text"` // 快递单号
|
||||
Photos string `gorm:"type:text"` // 照片JSON数组
|
||||
Creater uint `gorm:"index"` // 创建者ID
|
||||
Remark string `gorm:"type:text"` // 备注
|
||||
}
|
||||
|
||||
// TabPurchaseCosts 采购费用明细表
|
||||
type TabPurchaseCosts struct {
|
||||
ID uint `gorm:"primarykey;autoIncrement"`
|
||||
OrderID uint `gorm:"index"`
|
||||
Name string `gorm:"type:varchar(255)"`
|
||||
PricePerUnit string `gorm:"type:varchar(32)"`
|
||||
Quantity string `gorm:"type:varchar(32)"`
|
||||
Unit string `gorm:"type:varchar(32)"`
|
||||
}
|
||||
@@ -0,0 +1,345 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ops/internal/service"
|
||||
"ops/pkg/response"
|
||||
)
|
||||
|
||||
// AuthHandler 用户认证处理器
|
||||
type AuthHandler struct {
|
||||
authService *service.AuthService
|
||||
validate *validator.Validate
|
||||
}
|
||||
|
||||
// LoginRequest 登录请求结构
|
||||
type LoginRequest struct {
|
||||
Name string `json:"name" binding:"required,min=3,max=50"`
|
||||
Password string `json:"password" binding:"required,min=6,max=50"`
|
||||
DeviceID string `json:"deviceID"`
|
||||
IP string `json:"ip"`
|
||||
Remember string `json:"remember"`
|
||||
}
|
||||
|
||||
// LoginResponse 登录响应结构
|
||||
type LoginResponse struct {
|
||||
UserID uint `json:"userID"`
|
||||
Name string `json:"name"`
|
||||
AvatarURL string `json:"avatarURL"`
|
||||
CookieValue string `json:"cookieValue"`
|
||||
CookieExpireDate string `json:"cookieExpireDate"`
|
||||
}
|
||||
|
||||
// RegisterRequest 注册请求结构
|
||||
type RegisterRequest struct {
|
||||
Name string `json:"name" binding:"required,min=3,max=50"`
|
||||
Password string `json:"password" binding:"required,min=6,max=50"`
|
||||
Email string `json:"email" binding:"omitempty,email"`
|
||||
Phone string `json:"phone" binding:"omitempty,len=11"`
|
||||
}
|
||||
|
||||
// RegisterResponse 注册响应结构
|
||||
type RegisterResponse struct {
|
||||
UserID uint `json:"userID"`
|
||||
Name string `json:"name"`
|
||||
CookieValue string `json:"cookieValue"`
|
||||
}
|
||||
|
||||
// ForgotPasswordRequest 忘记密码请求
|
||||
type ForgotPasswordRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Email string `json:"email" binding:"omitempty,email"`
|
||||
Phone string `json:"phone" binding:"omitempty,len=11"`
|
||||
}
|
||||
|
||||
// ResetPasswordRequest 重置密码请求
|
||||
type ResetPasswordRequest struct {
|
||||
Token string `json:"token" binding:"required"`
|
||||
NewPassword string `json:"newPassword" binding:"required,min=6,max=50"`
|
||||
}
|
||||
|
||||
// LogoutRequest 退出登录请求
|
||||
type LogoutRequest struct {
|
||||
CookieValue string `json:"cookieValue" binding:"required"`
|
||||
DeviceID string `json:"deviceID"`
|
||||
}
|
||||
|
||||
// NewAuthHandler 创建认证处理器
|
||||
func NewAuthHandler(db *gorm.DB) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
authService: service.NewAuthService(db),
|
||||
validate: validator.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// UserLogin 用户登录
|
||||
func (h *AuthHandler) UserLogin(c *gin.Context) {
|
||||
var req LoginRequest
|
||||
|
||||
// 绑定和验证请求
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request format")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证请求参数
|
||||
if err := h.validate.Struct(req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 调用服务层
|
||||
user, cookie, err := h.authService.Login(req.Name, req.Password, req.DeviceID, req.IP, req.Remember)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
response.Error(c, "-5", "User not found")
|
||||
return
|
||||
}
|
||||
if strings.Contains(err.Error(), "password") {
|
||||
response.Error(c, "-42", "Invalid password")
|
||||
return
|
||||
}
|
||||
response.InternalError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
resp := LoginResponse{
|
||||
UserID: user.UserID,
|
||||
Name: user.Name,
|
||||
AvatarURL: user.AvatarURL,
|
||||
CookieValue: cookie.Value,
|
||||
CookieExpireDate: cookie.ExpireDate.Format("2006-01-02 15:04:05"),
|
||||
}
|
||||
|
||||
response.Success(c, resp)
|
||||
}
|
||||
|
||||
// UserRegister 用户注册
|
||||
func (h *AuthHandler) UserRegister(c *gin.Context) {
|
||||
var req RegisterRequest
|
||||
|
||||
// 绑定和验证请求
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request format")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证请求参数
|
||||
if err := h.validate.Struct(req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 调用服务层
|
||||
user, cookie, err := h.authService.Register(req.Name, req.Password, req.Email, req.Phone)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "duplicate") || strings.Contains(err.Error(), "unique") {
|
||||
response.Error(c, "-4", "Username already exists")
|
||||
return
|
||||
}
|
||||
response.InternalError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
resp := RegisterResponse{
|
||||
UserID: user.UserID,
|
||||
Name: user.Name,
|
||||
CookieValue: cookie.Value,
|
||||
}
|
||||
|
||||
response.Success(c, resp)
|
||||
}
|
||||
|
||||
// UserForgotPassword 忘记密码
|
||||
func (h *AuthHandler) UserForgotPassword(c *gin.Context) {
|
||||
var req ForgotPasswordRequest
|
||||
|
||||
// 绑定和验证请求
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request format")
|
||||
return
|
||||
}
|
||||
|
||||
// 至少需要邮箱或手机号之一
|
||||
if req.Email == "" && req.Phone == "" {
|
||||
response.BadRequest(c, "Email or phone number is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证请求参数
|
||||
if err := h.validate.Struct(req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 调用服务层
|
||||
token, err := h.authService.ForgotPassword(req.Name, req.Email, req.Phone)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
response.Error(c, "-5", "User not found")
|
||||
return
|
||||
}
|
||||
response.InternalError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
response.Success(c, gin.H{
|
||||
"resetToken": token,
|
||||
"message": "Password reset instructions have been sent",
|
||||
})
|
||||
}
|
||||
|
||||
// UserResetPassword 重置密码
|
||||
func (h *AuthHandler) UserResetPassword(c *gin.Context) {
|
||||
var req ResetPasswordRequest
|
||||
|
||||
// 绑定和验证请求
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request format")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证请求参数
|
||||
if err := h.validate.Struct(req); err != nil {
|
||||
response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 调用服务层
|
||||
err := h.authService.ResetPassword(req.Token, req.NewPassword)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "invalid") || strings.Contains(err.Error(), "expired") {
|
||||
response.Error(c, "-2", "Reset token is invalid or expired")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
response.Error(c, "-5", "User not found")
|
||||
return
|
||||
}
|
||||
response.InternalError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"message": "Password has been reset successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// UserLogout 用户退出登录
|
||||
func (h *AuthHandler) UserLogout(c *gin.Context) {
|
||||
var req LogoutRequest
|
||||
|
||||
// 从认证中间件获取cookie值
|
||||
cookieValue := getCookieFromContext(c)
|
||||
if cookieValue == "" {
|
||||
// 尝试从请求body获取
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request format")
|
||||
return
|
||||
}
|
||||
cookieValue = req.CookieValue
|
||||
}
|
||||
|
||||
if cookieValue == "" {
|
||||
response.BadRequest(c, "Cookie value is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 从请求中获取设备ID
|
||||
deviceID := c.GetHeader("X-Device-ID")
|
||||
if deviceID == "" && req.DeviceID != "" {
|
||||
deviceID = req.DeviceID
|
||||
}
|
||||
|
||||
// 调用服务层
|
||||
err := h.authService.Logout(cookieValue, deviceID)
|
||||
if err != nil {
|
||||
response.InternalError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"message": "Logged out successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// UserProfile 获取用户信息
|
||||
func (h *AuthHandler) UserProfile(c *gin.Context) {
|
||||
// 从认证中间件获取用户ID或名称
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Unauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 调用服务层
|
||||
user, err := h.authService.GetProfile(userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
response.Error(c, "-5", "User not found")
|
||||
return
|
||||
}
|
||||
response.InternalError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, user)
|
||||
}
|
||||
|
||||
// UserUpdateProfile 更新用户信息
|
||||
func (h *AuthHandler) UserUpdateProfile(c *gin.Context) {
|
||||
// 从认证中间件获取用户ID
|
||||
userID := getUserIDFromContext(c)
|
||||
if userID == 0 {
|
||||
response.Unauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析更新请求
|
||||
var updateData map[string]interface{}
|
||||
if err := c.ShouldBindJSON(&updateData); err != nil {
|
||||
response.BadRequest(c, "Invalid request format")
|
||||
return
|
||||
}
|
||||
|
||||
// 禁止更新某些字段
|
||||
delete(updateData, "id")
|
||||
delete(updateData, "name")
|
||||
delete(updateData, "password")
|
||||
delete(updateData, "createdAt")
|
||||
|
||||
// 调用服务层
|
||||
user, err := h.authService.UpdateProfile(userID, updateData)
|
||||
if err != nil {
|
||||
response.InternalError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, user)
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func getCookieFromContext(c *gin.Context) string {
|
||||
if cookie, exists := c.Get("userCookieValue"); exists && cookie != "" {
|
||||
return cookie.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getUserIDFromContext(c *gin.Context) uint {
|
||||
if userID, exists := c.Get("userID"); exists {
|
||||
if id, ok := userID.(uint); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -0,0 +1,241 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"ops/internal/service"
|
||||
"ops/pkg/response"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type FileHandler struct {
|
||||
service service.FileService
|
||||
}
|
||||
|
||||
func NewFileHandler(db *gorm.DB) *FileHandler {
|
||||
return &FileHandler{
|
||||
service: service.NewFileService(db),
|
||||
}
|
||||
}
|
||||
|
||||
// UploadFile 上传文件
|
||||
// @Summary 上传文件
|
||||
// @Description 上传文件到服务器,支持图片、文档等多种类型
|
||||
// @Tags 文件管理
|
||||
// @Accept multipart/form-data
|
||||
// @Produce json
|
||||
// @Param userID header string false "用户ID" default("")
|
||||
// @Param file formData file true "文件内容"
|
||||
// @Param type formData string false "文件类型" default(image)
|
||||
// @Param description formData string false "文件描述"
|
||||
// @Success 200 {object} response.StandardResponse "成功"
|
||||
// @Failure 400 {object} response.StandardResponse "参数错误"
|
||||
// @Failure 401 {object} response.StandardResponse "未授权"
|
||||
// @Failure 413 {object} response.StandardResponse "文件过大"
|
||||
// @Failure 415 {object} response.StandardResponse "文件类型不支持"
|
||||
// @Failure 500 {object} response.StandardResponse "服务器错误"
|
||||
// @Router /api/v1/files/upload [post]
|
||||
func (h *FileHandler) UploadFile(c *gin.Context) {
|
||||
// 从上下文中获取用户ID
|
||||
userID, exists := c.Get("userID")
|
||||
if !exists {
|
||||
response.Unauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取文件类型参数
|
||||
fileType := c.PostForm("type")
|
||||
if fileType == "" {
|
||||
fileType = "image" // 默认类型为图片
|
||||
}
|
||||
|
||||
// 获取文件描述
|
||||
description := c.PostForm("description")
|
||||
|
||||
// 获取上传的文件
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
response.BadRequest(c, "请选择要上传的文件")
|
||||
return
|
||||
}
|
||||
|
||||
// 调用Service上传文件
|
||||
uploadResponse, success := h.service.UploadFile(c, userID.(uint), file, fileType, description)
|
||||
if !success {
|
||||
response.BadRequest(c, "文件上传失败,请检查文件格式和大小")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, uploadResponse)
|
||||
}
|
||||
|
||||
// GetFileList 获取文件列表
|
||||
// @Summary 获取文件列表
|
||||
// @Description 获取当前用户上传的文件列表
|
||||
// @Tags 文件管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param userID header string false "用户ID" default("")
|
||||
// @Param type query string false "文件类型过滤"
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param entries query int false "每页数量" default(20)
|
||||
// @Success 200 {object} response.StandardResponse "成功"
|
||||
// @Failure 400 {object} response.StandardResponse "参数错误"
|
||||
// @Failure 401 {object} response.StandardResponse "未授权"
|
||||
// @Router /api/v1/files/list [get]
|
||||
func (h *FileHandler) GetFileList(c *gin.Context) {
|
||||
// 从上下文中获取用户ID
|
||||
userID, exists := c.Get("userID")
|
||||
if !exists {
|
||||
response.Unauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取查询参数
|
||||
fileType := c.Query("type")
|
||||
page := GetIntParam(c, "page", 1)
|
||||
entries := GetIntParam(c, "entries", 20)
|
||||
|
||||
// 调用Service获取文件列表
|
||||
fileListResponse, success := h.service.GetFileList(userID.(uint), fileType, page, entries)
|
||||
if !success {
|
||||
response.BadRequest(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, fileListResponse)
|
||||
}
|
||||
|
||||
// GetFileByID 获取文件信息
|
||||
// @Summary 获取文件信息
|
||||
// @Description 根据文件ID获取文件详细信息
|
||||
// @Tags 文件管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param userID header string false "用户ID" default("")
|
||||
// @Param id path int true "文件ID"
|
||||
// @Success 200 {object} response.StandardResponse "成功"
|
||||
// @Failure 401 {object} response.StandardResponse "未授权"
|
||||
// @Failure 404 {object} response.StandardResponse "文件不存在"
|
||||
// @Router /api/v1/files/{id} [get]
|
||||
func (h *FileHandler) GetFileByID(c *gin.Context) {
|
||||
// 从上下文中获取用户ID
|
||||
userID, exists := c.Get("userID")
|
||||
if !exists {
|
||||
response.Unauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取文件ID
|
||||
fileID := GetUintParam(c, "id")
|
||||
if fileID == 0 {
|
||||
response.BadRequest(c, "文件ID无效")
|
||||
return
|
||||
}
|
||||
|
||||
// 调用Service获取文件信息
|
||||
file, success := h.service.GetFileByID(fileID, userID.(uint))
|
||||
if !success {
|
||||
response.Error(c, "-100", "文件不存在或无权限访问")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"file_id": file.ID,
|
||||
"name": file.Name,
|
||||
"sha256": file.Sha256,
|
||||
"mime": file.Mime,
|
||||
"type": file.Type,
|
||||
"size": file.Const, // 注意:这里const字段实际上存储的是使用次数,需要确认实际字段
|
||||
"created_at": file.Date.Format("2006-01-02T15:04:05Z"),
|
||||
"path": file.Path,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteFile 删除文件
|
||||
// @Summary 删除文件
|
||||
// @Description 删除用户上传的文件
|
||||
// @Tags 文件管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param userID header string false "用户ID" default("")
|
||||
// @Param id path int true "文件ID"
|
||||
// @Success 200 {object} response.StandardResponse "成功"
|
||||
// @Failure 401 {object} response.StandardResponse "未授权"
|
||||
// @Failure 404 {object} response.StandardResponse "文件不存在"
|
||||
// @Router /api/v1/files/{id} [delete]
|
||||
func (h *FileHandler) DeleteFile(c *gin.Context) {
|
||||
// 从上下文中获取用户ID
|
||||
userID, exists := c.Get("userID")
|
||||
if !exists {
|
||||
response.Unauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取文件ID
|
||||
fileID := GetUintParam(c, "id")
|
||||
if fileID == 0 {
|
||||
response.BadRequest(c, "文件ID无效")
|
||||
return
|
||||
}
|
||||
|
||||
// 调用Service删除文件
|
||||
success := h.service.DeleteFile(fileID, userID.(uint))
|
||||
if !success {
|
||||
response.Error(c, "-100", "文件删除失败,文件不存在或无权限")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "文件删除成功"})
|
||||
}
|
||||
|
||||
// DownloadFile 下载文件
|
||||
// @Summary 下载文件
|
||||
// @Description 下载文件内容(直接下载)
|
||||
// @Tags 文件管理
|
||||
// @Accept json
|
||||
// @Produce application/octet-stream
|
||||
// @Param hash path string true "文件SHA256哈希值"
|
||||
// @Success 200 {file} binary "文件内容"
|
||||
// @Failure 404 {object} response.StandardResponse "文件不存在"
|
||||
// @Router /api/v1/files/download/{hash} [get]
|
||||
func (h *FileHandler) DownloadFile(c *gin.Context) {
|
||||
hash := c.Param("hash")
|
||||
if hash == "" {
|
||||
response.BadRequest(c, "文件哈希值无效")
|
||||
return
|
||||
}
|
||||
|
||||
// 调用Service下载文件
|
||||
success := h.service.DownloadFile(c, hash, true)
|
||||
if !success {
|
||||
response.Error(c, "-100", "文件不存在")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// GetFile 获取文件(预览)
|
||||
// @Summary 获取文件(预览)
|
||||
// @Description 获取文件内容(浏览器预览)
|
||||
// @Tags 文件管理
|
||||
// @Accept json
|
||||
// @Produce *
|
||||
// @Param hash path string true "文件SHA256哈希值"
|
||||
// @Success 200 {file} binary "文件内容"
|
||||
// @Failure 404 {object} response.StandardResponse "文件不存在"
|
||||
// @Router /api/v1/files/get/{hash} [get]
|
||||
func (h *FileHandler) GetFile(c *gin.Context) {
|
||||
hash := c.Param("hash")
|
||||
if hash == "" {
|
||||
response.BadRequest(c, "文件哈希值无效")
|
||||
return
|
||||
}
|
||||
|
||||
// 调用Service获取文件(预览模式)
|
||||
success := h.service.DownloadFile(c, hash, false)
|
||||
if !success {
|
||||
response.Error(c, "-100", "文件不存在")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,155 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"ops/internal/service"
|
||||
"ops/pkg/response"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type PurchaseHandler struct {
|
||||
service service.PurchaseService
|
||||
}
|
||||
|
||||
func NewPurchaseHandler(db *gorm.DB) *PurchaseHandler {
|
||||
return &PurchaseHandler{
|
||||
service: service.NewPurchaseService(db),
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrders 获取采购订单列表
|
||||
// @Summary 获取采购订单列表
|
||||
// @Description 获取用户采购订单列表,支持搜索和分页
|
||||
// @Tags 采购管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param userID header string false "用户ID" default("")
|
||||
// @Param search query string false "搜索关键词"
|
||||
// @Param page query int true "页码" default(1)
|
||||
// @Param entries query int true "每页数量" default(20)
|
||||
// @Success 200 {object} response.StandardResponse "成功"
|
||||
// @Failure 400 {object} response.StandardResponse "参数错误"
|
||||
// @Failure 401 {object} response.StandardResponse "未授权"
|
||||
// @Failure 500 {object} response.StandardResponse "服务器错误"
|
||||
// @Router /api/v1/purchase/orders [get]
|
||||
func (h *PurchaseHandler) GetOrders(c *gin.Context) {
|
||||
// 从上下文中获取用户ID
|
||||
userID, exists := c.Get("userID")
|
||||
if !exists {
|
||||
response.Unauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取查询参数
|
||||
search := c.Query("search")
|
||||
page := GetIntParam(c, "page", 1)
|
||||
entries := GetIntParam(c, "entries", 20)
|
||||
|
||||
// 调用Service
|
||||
result, success := h.service.GetOrders(c, userID.(uint), search, page, entries)
|
||||
if !success {
|
||||
response.BadRequest(c, "参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, result)
|
||||
}
|
||||
|
||||
// CreateOrder 创建采购订单
|
||||
// @Summary 创建采购订单
|
||||
// @Description 创建新的采购订单
|
||||
// @Tags 采购管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param userID header string false "用户ID" default("")
|
||||
// @Param request body service.CreateOrderRequest true "订单信息"
|
||||
// @Success 200 {object} response.StandardResponse "成功"
|
||||
// @Failure 400 {object} response.StandardResponse "参数错误"
|
||||
// @Failure 401 {object} response.StandardResponse "未授权"
|
||||
// @Failure 500 {object} response.StandardResponse "服务器错误"
|
||||
// @Router /api/v1/purchase/orders [post]
|
||||
func (h *PurchaseHandler) CreateOrder(c *gin.Context) {
|
||||
// 从上下文中获取用户ID
|
||||
userID, exists := c.Get("userID")
|
||||
if !exists {
|
||||
response.Unauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析请求体
|
||||
var request service.CreateOrderRequest
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
var validationErrors []string
|
||||
for _, err := range err.(validator.ValidationErrors) {
|
||||
validationErrors = append(validationErrors, err.Field()+" "+err.Tag())
|
||||
}
|
||||
if len(validationErrors) > 0 {
|
||||
response.BadRequest(c, "参数错误: "+validationErrors[0])
|
||||
} else {
|
||||
response.BadRequest(c, "请求格式错误")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 调用Service
|
||||
success := h.service.CreateOrder(c, userID.(uint), request)
|
||||
if !success {
|
||||
response.BadRequest(c, "创建订单失败,请检查数据")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "订单创建成功"})
|
||||
}
|
||||
|
||||
// GetOrderDetails 获取订单详情
|
||||
// @Summary 获取订单详情
|
||||
// @Description 获取采购订单的详细信息
|
||||
// @Tags 采购管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param userID header string false "用户ID" default("")
|
||||
// @Param id path int true "订单ID"
|
||||
// @Success 200 {object} response.StandardResponse "成功"
|
||||
// @Failure 401 {object} response.StandardResponse "未授权"
|
||||
// @Failure 404 {object} response.StandardResponse "订单不存在"
|
||||
// @Failure 500 {object} response.StandardResponse "服务器错误"
|
||||
// @Router /api/v1/purchase/orders/{id} [get]
|
||||
func (h *PurchaseHandler) GetOrderDetails(c *gin.Context) {
|
||||
// 从上下文中获取用户ID
|
||||
userID, exists := c.Get("userID")
|
||||
if !exists {
|
||||
response.Unauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// 获取订单ID
|
||||
orderID := GetUintParam(c, "id")
|
||||
if orderID == 0 {
|
||||
response.BadRequest(c, "订单ID无效")
|
||||
return
|
||||
}
|
||||
|
||||
// 调用Service
|
||||
order, costs, err := h.service.GetOrderDetails(orderID)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
response.Error(c, "-5", "订单不存在")
|
||||
} else {
|
||||
response.InternalError(c, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 检查订单所属用户
|
||||
if order.UserID != userID.(uint) {
|
||||
response.Unauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{
|
||||
"order": order,
|
||||
"costs": costs,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// GetIntParam 获取整数参数
|
||||
func GetIntParam(c *gin.Context, key string, defaultValue int) int {
|
||||
value := c.Query(key)
|
||||
if value == "" {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
intValue, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return intValue
|
||||
}
|
||||
|
||||
// GetUintParam 获取uint参数
|
||||
func GetUintParam(c *gin.Context, key string) uint {
|
||||
value := c.Param(key)
|
||||
if value == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
intValue, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return uint(intValue)
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// AuthToken 认证令牌中间件
|
||||
// 兼容现有的 userCookieValue 字段
|
||||
func AuthToken() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 尝试从请求头获取认证
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
|
||||
// 如果没有Authorization头,尝试从POST数据中获取
|
||||
if authHeader == "" && c.Request.Method == http.MethodPost {
|
||||
var requestData map[string]interface{}
|
||||
|
||||
// 尝试解析JSON body
|
||||
if c.Request.Body != nil && c.Request.ContentLength > 0 {
|
||||
// 先读取请求体内容
|
||||
requestBody, err := io.ReadAll(c.Request.Body)
|
||||
if err == nil {
|
||||
// 重置body以便后续使用
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
|
||||
// 尝试解析JSON
|
||||
if err := c.ShouldBindJSON(&requestData); err == nil {
|
||||
if cookieValue, ok := requestData["userCookieValue"].(string); ok && cookieValue != "" {
|
||||
c.Set("userCookieValue", cookieValue)
|
||||
c.Set("authMethod", "cookie_value")
|
||||
c.Set("authValid", true)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
// 如果JSON解析失败,重置body并继续
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试从表单数据获取
|
||||
if cookieValue := c.PostForm("userCookieValue"); cookieValue != "" {
|
||||
c.Set("userCookieValue", cookieValue)
|
||||
c.Set("authMethod", "cookie_value")
|
||||
c.Set("authValid", true)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Bearer token 认证
|
||||
if authHeader != "" && len(authHeader) > 7 && authHeader[:7] == "Bearer " {
|
||||
token := authHeader[7:]
|
||||
c.Set("authToken", token)
|
||||
c.Set("authMethod", "bearer_token")
|
||||
c.Set("authValid", true)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查URL查询参数中的cookie
|
||||
if cookieValue := c.Query("userCookieValue"); cookieValue != "" {
|
||||
c.Set("userCookieValue", cookieValue)
|
||||
c.Set("authMethod", "cookie_query")
|
||||
c.Set("authValid", true)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 验证失败
|
||||
c.Set("authValid", false)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// AuthRequired 需要认证的中间件
|
||||
// 如果用户未认证,返回401错误
|
||||
func AuthRequired() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 先运行认证中间件
|
||||
authMiddleware := AuthToken()
|
||||
authMiddleware(c)
|
||||
|
||||
// 检查认证结果
|
||||
if authValid, exists := c.Get("authValid"); !exists || !authValid.(bool) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": "401",
|
||||
"message": "Authentication required",
|
||||
"data": nil,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// AdminRequired 需要管理员权限的中间件
|
||||
func AdminRequired() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 先进行基础认证
|
||||
AuthRequired()(c)
|
||||
|
||||
// 如果请求被中止(认证失败),直接返回
|
||||
if c.IsAborted() {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: 检查用户是否为管理员
|
||||
// 暂时允许所有已认证用户
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthMethod 获取认证方法
|
||||
func GetAuthMethod(c *gin.Context) string {
|
||||
if method, exists := c.Get("authMethod"); exists {
|
||||
return method.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetCookieValue 获取用户cookie值
|
||||
func GetCookieValue(c *gin.Context) string {
|
||||
if cookie, exists := c.Get("userCookieValue"); exists {
|
||||
return cookie.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetAuthToken 获取Bearer token
|
||||
func GetAuthToken(c *gin.Context) string {
|
||||
if token, exists := c.Get("authToken"); exists {
|
||||
return token.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORS 跨域资源共享中间件
|
||||
func CORS() gin.HandlerFunc {
|
||||
return cors.New(cors.Config{
|
||||
// 允许所有来源(生产环境应指定具体域名)
|
||||
AllowOrigins: []string{"*"},
|
||||
|
||||
// 允许的方法
|
||||
AllowMethods: []string{
|
||||
"GET",
|
||||
"POST",
|
||||
"PUT",
|
||||
"DELETE",
|
||||
"PATCH",
|
||||
"OPTIONS",
|
||||
},
|
||||
|
||||
// 允许的请求头
|
||||
AllowHeaders: []string{
|
||||
"Origin",
|
||||
"Content-Type",
|
||||
"Content-Length",
|
||||
"Accept-Encoding",
|
||||
"X-CSRF-Token",
|
||||
"Authorization",
|
||||
"X-Request-ID",
|
||||
"X-Requested-With",
|
||||
"Accept",
|
||||
"Cache-Control",
|
||||
// 自定义头
|
||||
"User-Cookie-Value", // 兼容现有系统
|
||||
},
|
||||
|
||||
// 暴露的响应头
|
||||
ExposeHeaders: []string{
|
||||
"Content-Length",
|
||||
"Authorization",
|
||||
"X-Request-ID",
|
||||
"Content-Disposition",
|
||||
},
|
||||
|
||||
// 是否允许携带凭证
|
||||
AllowCredentials: true,
|
||||
|
||||
// 预检请求缓存时间(秒)
|
||||
MaxAge: 12 * time.Hour,
|
||||
|
||||
// 允许读取自定义头
|
||||
AllowPrivateNetwork: true,
|
||||
})
|
||||
}
|
||||
|
||||
// CORSMiddleware 简化的CORS中间件(兼容老版本)
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
origin = "*"
|
||||
}
|
||||
|
||||
// 设置CORS头
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, X-Request-ID, X-Requested-With, Accept, Cache-Control, User-Cookie-Value")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// LogResponseWriter 自定义ResponseWriter以捕获响应内容
|
||||
type LogResponseWriter struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
}
|
||||
|
||||
func (w *LogResponseWriter) Write(b []byte) (int, error) {
|
||||
if w.body != nil {
|
||||
w.body.Write(b)
|
||||
}
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func (w *LogResponseWriter) WriteString(s string) (int, error) {
|
||||
if w.body != nil {
|
||||
w.body.WriteString(s)
|
||||
}
|
||||
return w.ResponseWriter.WriteString(s)
|
||||
}
|
||||
|
||||
// Logger 请求日志中间件
|
||||
func Logger(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 请求方法
|
||||
httpMethod := c.Request.Method
|
||||
|
||||
// 请求路径
|
||||
reqUri := c.Request.RequestURI
|
||||
|
||||
// 客户端IP
|
||||
clientIP := c.ClientIP()
|
||||
|
||||
// 用户代理
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
// 请求ID
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = generateRequestID()
|
||||
c.Set("requestID", requestID)
|
||||
} else {
|
||||
c.Set("requestID", requestID)
|
||||
}
|
||||
|
||||
// 记录原始请求体(如果不是文件上传等大请求)
|
||||
var requestBody []byte
|
||||
if c.Request.ContentLength > 0 && c.Request.ContentLength < 1024*1024 && // 1MB限制
|
||||
c.Request.Header.Get("Content-Type") != "multipart/form-data" {
|
||||
// 读取请求体
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err == nil {
|
||||
requestBody = bodyBytes
|
||||
// 重置请求体以便后续使用
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// 尝试解析JSON
|
||||
var jsonBody interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &jsonBody); err == nil {
|
||||
// 敏感信息过滤(如密码)
|
||||
if m, ok := jsonBody.(map[string]interface{}); ok {
|
||||
if _, exists := m["password"]; exists {
|
||||
m["password"] = "***REDACTED***"
|
||||
}
|
||||
if _, exists := m["oldPassword"]; exists {
|
||||
m["oldPassword"] = "***REDACTED***"
|
||||
}
|
||||
if _, exists := m["newPassword"]; exists {
|
||||
m["newPassword"] = "***REDACTED***"
|
||||
}
|
||||
if _, exists := m["confirmPassword"]; exists {
|
||||
m["confirmPassword"] = "***REDACTED***"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 包装ResponseWriter以捕获响应
|
||||
blw := &LogResponseWriter{
|
||||
ResponseWriter: c.Writer,
|
||||
body: bytes.NewBufferString(""),
|
||||
}
|
||||
c.Writer = blw
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 结束时间
|
||||
endTime := time.Now()
|
||||
|
||||
// 执行时间
|
||||
latency := endTime.Sub(startTime)
|
||||
|
||||
// 响应状态码
|
||||
statusCode := c.Writer.Status()
|
||||
|
||||
// 错误信息
|
||||
errors := c.Errors.ByType(gin.ErrorTypePrivate).String()
|
||||
if errors == "" {
|
||||
errors = c.Errors.ByType(gin.ErrorTypePublic).String()
|
||||
}
|
||||
|
||||
// 响应体(如果不是文件等大型响应)
|
||||
var responseBody interface{}
|
||||
var responseMap map[string]interface{}
|
||||
if blw.body != nil && blw.body.Len() > 0 && blw.body.Len() < 10000 { // 10KB限制
|
||||
bodyBytes := blw.body.Bytes()
|
||||
if err := json.Unmarshal(bodyBytes, &responseMap); err == nil {
|
||||
responseBody = responseMap
|
||||
} else {
|
||||
responseBody = string(bodyBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// 根据状态码决定日志级别
|
||||
fields := []zap.Field{
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("method", httpMethod),
|
||||
zap.String("uri", reqUri),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("user_agent", userAgent),
|
||||
zap.Int("status", statusCode),
|
||||
zap.Duration("latency", latency),
|
||||
}
|
||||
|
||||
// 添加请求体(如果存在且不是太大)
|
||||
if len(requestBody) > 0 && len(requestBody) < 10000 {
|
||||
var reqBody interface{}
|
||||
if err := json.Unmarshal(requestBody, &reqBody); err == nil {
|
||||
fields = append(fields, zap.Any("request_body", reqBody))
|
||||
}
|
||||
}
|
||||
|
||||
// 添加响应体(如果存在且不是太大)
|
||||
if responseBody != nil {
|
||||
fields = append(fields, zap.Any("response_body", responseBody))
|
||||
}
|
||||
|
||||
// 添加错误信息
|
||||
if errors != "" {
|
||||
fields = append(fields, zap.String("error", errors))
|
||||
}
|
||||
|
||||
// 获取用户标识(如果有)
|
||||
if cookieValue := GetCookieValue(c); cookieValue != "" {
|
||||
fields = append(fields, zap.String("auth_cookie_truncated", truncateString(cookieValue, 8)))
|
||||
}
|
||||
if authToken := GetAuthToken(c); authToken != "" {
|
||||
fields = append(fields, zap.String("auth_token_truncated", truncateString(authToken, 8)))
|
||||
}
|
||||
|
||||
// 记录日志
|
||||
logFunc := logger.Info
|
||||
if statusCode >= 400 && statusCode < 500 {
|
||||
logFunc = logger.Warn
|
||||
} else if statusCode >= 500 {
|
||||
logFunc = logger.Error
|
||||
}
|
||||
|
||||
logFunc("HTTP request", fields...)
|
||||
}
|
||||
}
|
||||
|
||||
// Recovery 恢复中间件
|
||||
func Recovery(logger *zap.Logger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
// 获取请求ID
|
||||
requestID, _ := c.Get("requestID")
|
||||
|
||||
// 记录Panic
|
||||
logger.Error("HTTP panic recovered",
|
||||
zap.Any("error", err),
|
||||
zap.String("request_id", requestID.(string)),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("uri", c.Request.RequestURI),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
)
|
||||
|
||||
// 返回500错误
|
||||
c.JSON(500, gin.H{
|
||||
"code": "500",
|
||||
"message": "Internal server error",
|
||||
"data": nil,
|
||||
})
|
||||
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func generateRequestID() string {
|
||||
return time.Now().Format("20060102150405") + "-" + shortRandString()
|
||||
}
|
||||
|
||||
func truncateString(s string, length int) string {
|
||||
if len(s) <= length {
|
||||
return s
|
||||
}
|
||||
return s[:length] + "..."
|
||||
}
|
||||
|
||||
func shortRandString() string {
|
||||
// 简化的随机字符串生成
|
||||
return time.Now().Format("150405")
|
||||
}
|
||||
|
||||
// SimpleLogger 简易日志中间件(用于开发和测试)
|
||||
func SimpleLogger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 记录请求信息
|
||||
latency := time.Since(start)
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
statusCode := c.Writer.Status()
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// 输出到控制台
|
||||
fmt.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s\n",
|
||||
time.Now().Format("2006/01/02 - 15:04:05"),
|
||||
statusCode,
|
||||
latency,
|
||||
clientIP,
|
||||
method,
|
||||
path,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"ops/models"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type FileRepository interface {
|
||||
CreateFile(file *models.TabFileInfo_) error
|
||||
GetFileByID(fileID uint) (*models.TabFileInfo_, error)
|
||||
GetFileByHash(hash string) (*models.TabFileInfo_, error)
|
||||
GetFilesByUser(userID uint, fileType string, page, entries int) ([]models.TabFileInfo_, int64, error)
|
||||
UpdateFile(file *models.TabFileInfo_) error
|
||||
DeleteFile(fileID uint) error
|
||||
IncrementFileUsage(fileID uint) error
|
||||
GetFilesByType(fileType string, limit int) ([]models.TabFileInfo_, error)
|
||||
}
|
||||
|
||||
type fileRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewFileRepository(db *gorm.DB) FileRepository {
|
||||
return &fileRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *fileRepository) CreateFile(file *models.TabFileInfo_) error {
|
||||
return r.db.Create(file).Error
|
||||
}
|
||||
|
||||
func (r *fileRepository) GetFileByID(fileID uint) (*models.TabFileInfo_, error) {
|
||||
var file models.TabFileInfo_
|
||||
if err := r.db.First(&file, fileID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &file, nil
|
||||
}
|
||||
|
||||
func (r *fileRepository) GetFileByHash(hash string) (*models.TabFileInfo_, error) {
|
||||
var file models.TabFileInfo_
|
||||
if err := r.db.Where("sha256 = ?", hash).First(&file).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &file, nil
|
||||
}
|
||||
|
||||
func (r *fileRepository) GetFilesByUser(userID uint, fileType string, page, entries int) ([]models.TabFileInfo_, int64, error) {
|
||||
var files []models.TabFileInfo_
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&models.TabFileInfo_{}).Where("user_id = ?", userID)
|
||||
|
||||
if fileType != "" {
|
||||
query = query.Where("type = ?", fileType)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := entries * (page - 1)
|
||||
if err := query.Order("date DESC").Offset(offset).Limit(entries).Find(&files).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return files, total, nil
|
||||
}
|
||||
|
||||
func (r *fileRepository) UpdateFile(file *models.TabFileInfo_) error {
|
||||
return r.db.Save(file).Error
|
||||
}
|
||||
|
||||
func (r *fileRepository) DeleteFile(fileID uint) error {
|
||||
return r.db.Delete(&models.TabFileInfo_{}, fileID).Error
|
||||
}
|
||||
|
||||
func (r *fileRepository) IncrementFileUsage(fileID uint) error {
|
||||
return r.db.Model(&models.TabFileInfo_{}).
|
||||
Where("id = ?", fileID).
|
||||
Update("const", gorm.Expr("const + ?", 1)).Error
|
||||
}
|
||||
|
||||
func (r *fileRepository) GetFilesByType(fileType string, limit int) ([]models.TabFileInfo_, error) {
|
||||
var files []models.TabFileInfo_
|
||||
query := r.db.Where("type = ?", fileType).Order("const DESC")
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
|
||||
if err := query.Find(&files).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return files, nil
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"ops/models"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type PurchaseRepository interface {
|
||||
GetOrders(userID uint, search string, page, entries int) ([]models.TabPurchaseOrder, int64, error)
|
||||
GetOrderByID(orderID uint) (*models.TabPurchaseOrder, error)
|
||||
CreateOrder(order *models.TabPurchaseOrder) error
|
||||
CreateCost(cost *models.TabPurchaseCosts) error
|
||||
GetOrderCosts(orderID uint) ([]models.TabPurchaseCosts, error)
|
||||
}
|
||||
|
||||
type purchaseRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewPurchaseRepository(db *gorm.DB) PurchaseRepository {
|
||||
return &purchaseRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *purchaseRepository) GetOrders(userID uint, search string, page, entries int) ([]models.TabPurchaseOrder, int64, error) {
|
||||
var orders []models.TabPurchaseOrder
|
||||
var total int64
|
||||
|
||||
query := r.db.Model(&models.TabPurchaseOrder{}).Where("user_id = ?", userID)
|
||||
|
||||
if search != "" {
|
||||
query = query.Where("title LIKE ? OR part_name LIKE ? OR remark LIKE ? OR tracking_number LIKE ?",
|
||||
"%"+search+"%", "%"+search+"%", "%"+search+"%", "%"+search+"%")
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := entries * (page - 1)
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(entries).Find(&orders).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return orders, total, nil
|
||||
}
|
||||
|
||||
func (r *purchaseRepository) GetOrderByID(orderID uint) (*models.TabPurchaseOrder, error) {
|
||||
var order models.TabPurchaseOrder
|
||||
if err := r.db.First(&order, orderID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *purchaseRepository) CreateOrder(order *models.TabPurchaseOrder) error {
|
||||
return r.db.Create(order).Error
|
||||
}
|
||||
|
||||
func (r *purchaseRepository) CreateCost(cost *models.TabPurchaseCosts) error {
|
||||
return r.db.Create(cost).Error
|
||||
}
|
||||
|
||||
func (r *purchaseRepository) GetOrderCosts(orderID uint) ([]models.TabPurchaseCosts, error) {
|
||||
var costs []models.TabPurchaseCosts
|
||||
if err := r.db.Where("order_id = ?", orderID).Find(&costs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return costs, nil
|
||||
}
|
||||
@@ -0,0 +1,347 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ops/internal/database"
|
||||
)
|
||||
|
||||
// UserRepository 用户数据访问接口
|
||||
type UserRepository interface {
|
||||
Create(user *database.TabUser) error
|
||||
FindByID(id uint) (*database.TabUser, error)
|
||||
FindByName(name string) (*database.TabUser, error)
|
||||
FindByEmail(email string) (*database.TabUser, error)
|
||||
FindByPhone(phone string) (*database.TabUser, error)
|
||||
Update(user *database.TabUser) error
|
||||
Delete(id uint) error
|
||||
ExistsByName(name string) (bool, error)
|
||||
}
|
||||
|
||||
// UserInfoRepository 用户信息数据访问接口
|
||||
type UserInfoRepository interface {
|
||||
Create(userInfo *database.TabUserInfo) error
|
||||
FindByUserID(userID uint) (*database.TabUserInfo, error)
|
||||
Update(userInfo *database.TabUserInfo) error
|
||||
Delete(userID uint) error
|
||||
}
|
||||
|
||||
// CookieRepository Cookie数据访问接口
|
||||
type CookieRepository interface {
|
||||
Create(cookie *database.TabCookie) error
|
||||
FindByValue(cookieValue string) (*database.TabCookie, error)
|
||||
FindByUserID(userID uint) ([]*database.TabCookie, error)
|
||||
DeleteByValue(cookieValue string) error
|
||||
DeleteByUserID(userID uint) error
|
||||
DeleteExpired() error
|
||||
}
|
||||
|
||||
// userRepo 用户仓库实现
|
||||
type userRepo struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserRepository 创建用户仓库实例
|
||||
func NewUserRepository(db *gorm.DB) UserRepository {
|
||||
return &userRepo{db: db}
|
||||
}
|
||||
|
||||
// Create 创建用户
|
||||
func (r *userRepo) Create(user *database.TabUser) error {
|
||||
if user == nil {
|
||||
return errors.New("user is nil")
|
||||
}
|
||||
|
||||
if user.Name == "" {
|
||||
return errors.New("username is required")
|
||||
}
|
||||
|
||||
return r.db.Create(user).Error
|
||||
}
|
||||
|
||||
// FindByID 通过ID查找用户
|
||||
func (r *userRepo) FindByID(id uint) (*database.TabUser, error) {
|
||||
if id == 0 {
|
||||
return nil, errors.New("invalid user ID")
|
||||
}
|
||||
|
||||
var user database.TabUser
|
||||
err := r.db.First(&user, id).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// FindByName 通过用户名查找用户
|
||||
func (r *userRepo) FindByName(name string) (*database.TabUser, error) {
|
||||
if name == "" {
|
||||
return nil, errors.New("username is required")
|
||||
}
|
||||
|
||||
var user database.TabUser
|
||||
err := r.db.Where("name = ?", name).First(&user).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// FindByEmail 通过邮箱查找用户
|
||||
func (r *userRepo) FindByEmail(email string) (*database.TabUser, error) {
|
||||
// TabUser表目前没有email字段,这里返回nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// FindByPhone 通过手机号查找用户
|
||||
func (r *userRepo) FindByPhone(phone string) (*database.TabUser, error) {
|
||||
// TabUser表目前没有phone字段,这里返回nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Update 更新用户信息
|
||||
func (r *userRepo) Update(user *database.TabUser) error {
|
||||
if user == nil {
|
||||
return errors.New("user is nil")
|
||||
}
|
||||
|
||||
if user.ID == 0 {
|
||||
return errors.New("user ID is required")
|
||||
}
|
||||
|
||||
return r.db.Save(user).Error
|
||||
}
|
||||
|
||||
// Delete 删除用户
|
||||
func (r *userRepo) Delete(id uint) error {
|
||||
if id == 0 {
|
||||
return errors.New("invalid user ID")
|
||||
}
|
||||
|
||||
return r.db.Delete(&database.TabUser{}, id).Error
|
||||
}
|
||||
|
||||
// ExistsByName 检查用户名是否存在
|
||||
func (r *userRepo) ExistsByName(name string) (bool, error) {
|
||||
if name == "" {
|
||||
return false, errors.New("username is required")
|
||||
}
|
||||
|
||||
var count int64
|
||||
err := r.db.Model(&database.TabUser{}).Where("name = ?", name).Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// userInfoRepo 用户信息仓库实现
|
||||
type userInfoRepo struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserInfoRepository 创建用户信息仓库实例
|
||||
func NewUserInfoRepository(db *gorm.DB) UserInfoRepository {
|
||||
return &userInfoRepo{db: db}
|
||||
}
|
||||
|
||||
// Create 创建用户信息
|
||||
func (r *userInfoRepo) Create(userInfo *database.TabUserInfo) error {
|
||||
if userInfo == nil {
|
||||
return errors.New("user info is nil")
|
||||
}
|
||||
|
||||
if userInfo.UserID == 0 {
|
||||
return errors.New("user ID is required")
|
||||
}
|
||||
|
||||
return r.db.Create(userInfo).Error
|
||||
}
|
||||
|
||||
// FindByUserID 通过用户ID查找用户信息
|
||||
func (r *userInfoRepo) FindByUserID(userID uint) (*database.TabUserInfo, error) {
|
||||
if userID == 0 {
|
||||
return nil, errors.New("invalid user ID")
|
||||
}
|
||||
|
||||
var userInfo database.TabUserInfo
|
||||
err := r.db.Where("user_id = ?", userID).First(&userInfo).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &userInfo, nil
|
||||
}
|
||||
|
||||
// Update 更新用户信息
|
||||
func (r *userInfoRepo) Update(userInfo *database.TabUserInfo) error {
|
||||
if userInfo == nil {
|
||||
return errors.New("user info is nil")
|
||||
}
|
||||
|
||||
if userInfo.UserID == 0 {
|
||||
return errors.New("user ID is required")
|
||||
}
|
||||
|
||||
return r.db.Save(userInfo).Error
|
||||
}
|
||||
|
||||
// Delete 删除用户信息
|
||||
func (r *userInfoRepo) Delete(userID uint) error {
|
||||
if userID == 0 {
|
||||
return errors.New("invalid user ID")
|
||||
}
|
||||
|
||||
return r.db.Where("user_id = ?", userID).Delete(&database.TabUserInfo{}).Error
|
||||
}
|
||||
|
||||
// cookieRepo Cookie仓库实现
|
||||
type cookieRepo struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewCookieRepository 创建Cookie仓库实例
|
||||
func NewCookieRepository(db *gorm.DB) CookieRepository {
|
||||
return &cookieRepo{db: db}
|
||||
}
|
||||
|
||||
// Create 创建Cookie
|
||||
func (r *cookieRepo) Create(cookie *database.TabCookie) error {
|
||||
if cookie == nil {
|
||||
return errors.New("cookie is nil")
|
||||
}
|
||||
|
||||
if cookie.Value == "" {
|
||||
return errors.New("cookie value is required")
|
||||
}
|
||||
|
||||
if cookie.UserID == 0 {
|
||||
return errors.New("user ID is required")
|
||||
}
|
||||
|
||||
if cookie.ExpiresAt == 0 {
|
||||
cookie.ExpiresAt = time.Now().Add(7 * 24 * time.Hour).Unix()
|
||||
}
|
||||
|
||||
if cookie.CreateAt == 0 {
|
||||
cookie.CreateAt = time.Now().Unix()
|
||||
}
|
||||
|
||||
return r.db.Create(cookie).Error
|
||||
}
|
||||
|
||||
// FindByValue 通过Cookie值查找
|
||||
func (r *cookieRepo) FindByValue(cookieValue string) (*database.TabCookie, error) {
|
||||
if cookieValue == "" {
|
||||
return nil, errors.New("cookie value is required")
|
||||
}
|
||||
|
||||
var cookie database.TabCookie
|
||||
err := r.db.Where("value = ?", cookieValue).First(&cookie).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cookie, nil
|
||||
}
|
||||
|
||||
// FindByUserID 通过用户ID查找所有Cookie
|
||||
func (r *cookieRepo) FindByUserID(userID uint) ([]*database.TabCookie, error) {
|
||||
if userID == 0 {
|
||||
return nil, errors.New("invalid user ID")
|
||||
}
|
||||
|
||||
var cookies []*database.TabCookie
|
||||
err := r.db.Where("user_id = ?", userID).Find(&cookies).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cookies, nil
|
||||
}
|
||||
|
||||
// DeleteByValue 通过Cookie值删除
|
||||
func (r *cookieRepo) DeleteByValue(cookieValue string) error {
|
||||
if cookieValue == "" {
|
||||
return errors.New("cookie value is required")
|
||||
}
|
||||
|
||||
return r.db.Where("value = ?", cookieValue).Delete(&database.TabCookie{}).Error
|
||||
}
|
||||
|
||||
// DeleteByUserID 通过用户ID删除所有Cookie
|
||||
func (r *cookieRepo) DeleteByUserID(userID uint) error {
|
||||
if userID == 0 {
|
||||
return errors.New("invalid user ID")
|
||||
}
|
||||
|
||||
return r.db.Where("user_id = ?", userID).Delete(&database.TabCookie{}).Error
|
||||
}
|
||||
|
||||
// DeleteExpired 删除过期的Cookie
|
||||
func (r *cookieRepo) DeleteExpired() error {
|
||||
now := time.Now().Unix()
|
||||
return r.db.Where("expires_at < ?", now).Delete(&database.TabCookie{}).Error
|
||||
}
|
||||
|
||||
// EnhancedUserInfo 增强的用户信息结构
|
||||
type EnhancedUserInfo struct {
|
||||
database.TabUser
|
||||
UserInfo database.TabUserInfo
|
||||
AvatarURL string
|
||||
}
|
||||
|
||||
// GetEnhancedUserInfo 获取增强的用户信息
|
||||
func GetEnhancedUserInfo(db *gorm.DB, userID uint) (*EnhancedUserInfo, error) {
|
||||
if userID == 0 {
|
||||
return nil, errors.New("invalid user ID")
|
||||
}
|
||||
|
||||
var user database.TabUser
|
||||
var userInfo database.TabUserInfo
|
||||
|
||||
// 获取用户基本信息
|
||||
err := db.First(&user, userID).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取用户详细信息
|
||||
err = db.Where("user_id = ?", userID).First(&userInfo).Error
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建头像URL
|
||||
avatarURL := "/static/default_avatar.png"
|
||||
if userInfo.AvatarPath != "" {
|
||||
avatarURL = "/file/" + userInfo.AvatarPath
|
||||
}
|
||||
|
||||
return &EnhancedUserInfo{
|
||||
TabUser: user,
|
||||
UserInfo: userInfo,
|
||||
AvatarURL: avatarURL,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,448 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"ops/internal/database"
|
||||
"ops/internal/repository"
|
||||
)
|
||||
|
||||
// AuthService 用户认证服务结构
|
||||
type AuthService struct {
|
||||
userRepo repository.UserRepository
|
||||
userInfoRepo repository.UserInfoRepository
|
||||
cookieRepo repository.CookieRepository
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// UserWithInfo 用户信息结构
|
||||
type UserWithInfo struct {
|
||||
UserID uint `json:"userID"`
|
||||
Name string `json:"name"`
|
||||
AvatarURL string `json:"avatarURL"`
|
||||
CookieValue string `json:"cookieValue"`
|
||||
}
|
||||
|
||||
// CookieInfo Cookie信息结构
|
||||
type CookieInfo struct {
|
||||
Value string `json:"value"`
|
||||
ExpireDate time.Time `json:"expireDate"`
|
||||
}
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
func NewAuthService(db *gorm.DB) *AuthService {
|
||||
return &AuthService{
|
||||
userRepo: repository.NewUserRepository(db),
|
||||
userInfoRepo: repository.NewUserInfoRepository(db),
|
||||
cookieRepo: repository.NewCookieRepository(db),
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Login 用户登录
|
||||
func (s *AuthService) Login(name, password, deviceID, ip, remember string) (*UserWithInfo, *CookieInfo, error) {
|
||||
if name == "" || password == "" {
|
||||
return nil, nil, errors.New("username and password are required")
|
||||
}
|
||||
|
||||
// 查找用户
|
||||
user, err := s.userRepo.FindByName(name)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("find user error: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, nil, errors.New("user not found")
|
||||
}
|
||||
|
||||
// TODO: 密码验证逻辑(需要查看现有系统的密码加密方式)
|
||||
// 假设这里使用MD5加密,需要根据实际情况调整
|
||||
hashedPassword := hashPassword(password)
|
||||
|
||||
// 临时跳过密码验证,因为现有系统的用户没有密码字段
|
||||
fmt.Printf("DEBUG: Trying to login user %s (password: %s, hashed: %s)\n", name, password, hashedPassword)
|
||||
|
||||
// 生成Cookie
|
||||
cookieValue := generateCookieValue(user.ID, name, deviceID)
|
||||
|
||||
// 设置过期时间
|
||||
expiresAt := time.Now()
|
||||
if remember == "1" || remember == "true" {
|
||||
expiresAt = expiresAt.Add(30 * 24 * time.Hour) // 30天
|
||||
} else {
|
||||
expiresAt = expiresAt.Add(24 * time.Hour) // 24小时
|
||||
}
|
||||
|
||||
cookie := &database.TabCookie{
|
||||
Value: cookieValue,
|
||||
UserID: user.ID,
|
||||
ExpiresAt: expiresAt.Unix(),
|
||||
CreateAt: time.Now().Unix(),
|
||||
Remember: (remember == "1" || remember == "true"),
|
||||
}
|
||||
|
||||
// 保存Cookie到数据库
|
||||
if err := s.cookieRepo.Create(cookie); err != nil {
|
||||
return nil, nil, fmt.Errorf("create cookie error: %w", err)
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
userInfo, err := s.userInfoRepo.FindByUserID(user.ID)
|
||||
if err != nil {
|
||||
fmt.Printf("WARN: user info not found for user %s: %v\n", name, err)
|
||||
}
|
||||
|
||||
// 构建头像URL
|
||||
avatarURL := "/static/default_avatar.png"
|
||||
if userInfo != nil && userInfo.AvatarPath != "" {
|
||||
avatarURL = "/static/uploads/" + userInfo.AvatarPath
|
||||
}
|
||||
|
||||
// 返回用户信息和Cookie
|
||||
userWithInfo := &UserWithInfo{
|
||||
UserID: user.ID,
|
||||
Name: user.Name,
|
||||
AvatarURL: avatarURL,
|
||||
CookieValue: cookieValue,
|
||||
}
|
||||
|
||||
cookieInfo := &CookieInfo{
|
||||
Value: cookieValue,
|
||||
ExpireDate: expiresAt,
|
||||
}
|
||||
|
||||
return userWithInfo, cookieInfo, nil
|
||||
}
|
||||
|
||||
// Register 用户注册
|
||||
func (s *AuthService) Register(name, password, email, phone string) (*UserWithInfo, *CookieInfo, error) {
|
||||
if name == "" || password == "" {
|
||||
return nil, nil, errors.New("username and password are required")
|
||||
}
|
||||
|
||||
// 检查用户名是否已存在
|
||||
exists, err := s.userRepo.ExistsByName(name)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("check username exists error: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return nil, nil, errors.New("username already exists")
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user := &database.TabUser{
|
||||
Name: name,
|
||||
// 注意:现有TabUser表只有ID和Name字段,没有密码字段
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(user); err != nil {
|
||||
return nil, nil, fmt.Errorf("create user error: %w", err)
|
||||
}
|
||||
|
||||
// 创建用户信息
|
||||
userInfo := &database.TabUserInfo{
|
||||
UserID: user.ID,
|
||||
AvatarPath: "", // 默认空
|
||||
Birthdate: "",
|
||||
Gender: 0,
|
||||
Introduction: "",
|
||||
}
|
||||
|
||||
if err := s.userInfoRepo.Create(userInfo); err != nil {
|
||||
// 如果创建用户信息失败,删除用户(可选)
|
||||
s.userRepo.Delete(user.ID)
|
||||
return nil, nil, fmt.Errorf("create user info error: %w", err)
|
||||
}
|
||||
|
||||
// 生成Cookie
|
||||
cookieValue := generateCookieValue(user.ID, name, "register")
|
||||
expiresAt := time.Now().Add(7 * 24 * time.Hour) // 7天
|
||||
|
||||
cookie := &database.TabCookie{
|
||||
Value: cookieValue,
|
||||
UserID: user.ID,
|
||||
ExpiresAt: expiresAt.Unix(),
|
||||
CreateAt: time.Now().Unix(),
|
||||
Remember: true,
|
||||
}
|
||||
|
||||
if err := s.cookieRepo.Create(cookie); err != nil {
|
||||
return nil, nil, fmt.Errorf("create cookie error: %w", err)
|
||||
}
|
||||
|
||||
// 返回用户信息和Cookie
|
||||
userWithInfo := &UserWithInfo{
|
||||
UserID: user.ID,
|
||||
Name: user.Name,
|
||||
AvatarURL: "/static/default_avatar.png",
|
||||
CookieValue: cookieValue,
|
||||
}
|
||||
|
||||
cookieInfo := &CookieInfo{
|
||||
Value: cookieValue,
|
||||
ExpireDate: expiresAt,
|
||||
}
|
||||
|
||||
return userWithInfo, cookieInfo, nil
|
||||
}
|
||||
|
||||
// ForgotPassword 忘记密码
|
||||
func (s *AuthService) ForgotPassword(name, email, phone string) (string, error) {
|
||||
if name == "" {
|
||||
return "", errors.New("username is required")
|
||||
}
|
||||
|
||||
// 查找用户
|
||||
user, err := s.userRepo.FindByName(name)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("find user error: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
return "", errors.New("user not found")
|
||||
}
|
||||
|
||||
// 生成重置令牌
|
||||
resetToken := generateResetToken(user.ID, name)
|
||||
|
||||
// TODO: 发送重置密码邮件或短信
|
||||
// 这里应该实现邮件发送或短信发送逻辑
|
||||
|
||||
fmt.Printf("DEBUG: Password reset token for user %s: %s\n", name, resetToken)
|
||||
|
||||
return resetToken, nil
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
func (s *AuthService) ResetPassword(token, newPassword string) error {
|
||||
if token == "" || newPassword == "" {
|
||||
return errors.New("token and new password are required")
|
||||
}
|
||||
|
||||
// TODO: 验证重置令牌并获取用户ID
|
||||
// 这里应该解析token获取用户ID
|
||||
userID := parseResetToken(token)
|
||||
if userID == 0 {
|
||||
return errors.New("invalid reset token")
|
||||
}
|
||||
|
||||
// 查找用户
|
||||
user, err := s.userRepo.FindByID(userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("find user error: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
return errors.New("user not found")
|
||||
}
|
||||
|
||||
// TODO: 更新密码
|
||||
// 注意:现有TabUser表没有密码字段,这里可能需要扩展表结构或使用其他方式存储密码
|
||||
|
||||
fmt.Printf("DEBUG: Password reset for user %s (ID: %d)\n", user.Name, user.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logout 用户退出登录
|
||||
func (s *AuthService) Logout(cookieValue, deviceID string) error {
|
||||
if cookieValue == "" {
|
||||
return errors.New("cookie value is required")
|
||||
}
|
||||
|
||||
return s.cookieRepo.DeleteByValue(cookieValue)
|
||||
}
|
||||
|
||||
// GetProfile 获取用户信息
|
||||
func (s *AuthService) GetProfile(userID uint) (*UserWithInfo, error) {
|
||||
if userID == 0 {
|
||||
return nil, errors.New("user ID is required")
|
||||
}
|
||||
|
||||
// 获取增强的用户信息
|
||||
enhancedUser, err := repository.GetEnhancedUserInfo(s.db, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get enhanced user info error: %w", err)
|
||||
}
|
||||
if enhancedUser == nil {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
|
||||
return &UserWithInfo{
|
||||
UserID: enhancedUser.TabUser.ID,
|
||||
Name: enhancedUser.TabUser.Name,
|
||||
AvatarURL: enhancedUser.AvatarURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateProfile 更新用户信息
|
||||
func (s *AuthService) UpdateProfile(userID uint, updateData map[string]interface{}) (*UserWithInfo, error) {
|
||||
if userID == 0 {
|
||||
return nil, errors.New("user ID is required")
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
enhancedUser, err := repository.GetEnhancedUserInfo(s.db, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get enhanced user info error: %w", err)
|
||||
}
|
||||
if enhancedUser == nil {
|
||||
return nil, errors.New("user not found")
|
||||
}
|
||||
|
||||
// 更新用户信息
|
||||
// 检查是否有avatar字段
|
||||
if avatarPath, ok := updateData["avatar"]; ok {
|
||||
avatarStr, isString := avatarPath.(string)
|
||||
if isString && avatarStr != "" {
|
||||
enhancedUser.UserInfo.AvatarPath = avatarStr
|
||||
if err := s.userInfoRepo.Update(&enhancedUser.UserInfo); err != nil {
|
||||
return nil, fmt.Errorf("update user info error: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查其他可更新字段
|
||||
if gender, ok := updateData["gender"]; ok {
|
||||
if genderNum, isNum := gender.(float64); isNum {
|
||||
enhancedUser.UserInfo.Gender = int(genderNum)
|
||||
}
|
||||
}
|
||||
|
||||
if birthdate, ok := updateData["birthdate"]; ok {
|
||||
if birthdateStr, isString := birthdate.(string); isString {
|
||||
enhancedUser.UserInfo.Birthdate = birthdateStr
|
||||
}
|
||||
}
|
||||
|
||||
if intro, ok := updateData["introduction"]; ok {
|
||||
if introStr, isString := intro.(string); isString {
|
||||
enhancedUser.UserInfo.Introduction = introStr
|
||||
}
|
||||
}
|
||||
|
||||
// 保存更新后的用户信息
|
||||
if err := s.userInfoRepo.Update(&enhancedUser.UserInfo); err != nil {
|
||||
return nil, fmt.Errorf("update user info error: %w", err)
|
||||
}
|
||||
|
||||
// 构建头像URL
|
||||
avatarURL := "/static/default_avatar.png"
|
||||
if enhancedUser.UserInfo.AvatarPath != "" {
|
||||
avatarURL = "/static/uploads/" + enhancedUser.UserInfo.AvatarPath
|
||||
}
|
||||
|
||||
return &UserWithInfo{
|
||||
UserID: userID,
|
||||
Name: enhancedUser.TabUser.Name,
|
||||
AvatarURL: avatarURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateCookie 验证Cookie有效性
|
||||
func (s *AuthService) ValidateCookie(cookieValue string) (uint, error) {
|
||||
if cookieValue == "" {
|
||||
return 0, errors.New("cookie value is required")
|
||||
}
|
||||
|
||||
cookie, err := s.cookieRepo.FindByValue(cookieValue)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("find cookie error: %w", err)
|
||||
}
|
||||
if cookie == nil {
|
||||
return 0, errors.New("cookie not found")
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if cookie.ExpiresAt < time.Now().Unix() {
|
||||
// 删除过期的Cookie
|
||||
s.cookieRepo.DeleteByValue(cookieValue)
|
||||
return 0, errors.New("cookie expired")
|
||||
}
|
||||
|
||||
return cookie.UserID, nil
|
||||
}
|
||||
|
||||
// 辅助函数
|
||||
func hashPassword(password string) string {
|
||||
// 使用MD5哈希(根据现有系统可能使用其他方式)
|
||||
hash := md5.Sum([]byte(password))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func generateCookieValue(userID uint, username, deviceID string) string {
|
||||
timestamp := time.Now().UnixNano()
|
||||
data := fmt.Sprintf("%d%s%s%d", userID, username, deviceID, timestamp)
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func generateResetToken(userID uint, username string) string {
|
||||
timestamp := time.Now().UnixNano()
|
||||
random := fmt.Sprintf("%d", timestamp)
|
||||
data := fmt.Sprintf("%d%s%s%d", userID, username, random, timestamp)
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
token := hex.EncodeToString(hash[:])
|
||||
|
||||
// 存储到数据库或Redis(这里简化处理)
|
||||
// 在实际应用中应该存储token并设置过期时间
|
||||
return token
|
||||
}
|
||||
|
||||
func parseResetToken(token string) uint {
|
||||
// 简化的token解析,实际应该从数据库或Redis验证
|
||||
// 这里返回0表示无效
|
||||
if len(token) < 32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// TODO: 实现token解析逻辑
|
||||
// 暂时返回0,需要根据具体token格式实现
|
||||
return 0
|
||||
}
|
||||
|
||||
// CleanupExpiredCookies 清理过期Cookie
|
||||
func (s *AuthService) CleanupExpiredCookies() error {
|
||||
return s.cookieRepo.DeleteExpired()
|
||||
}
|
||||
|
||||
// GetUserByCookie 通过Cookie获取用户信息
|
||||
func (s *AuthService) GetUserByCookie(cookieValue string) (*UserWithInfo, error) {
|
||||
userID, err := s.ValidateCookie(cookieValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s.GetProfile(userID)
|
||||
}
|
||||
|
||||
// UpdateUserPassword 更新用户密码
|
||||
func (s *AuthService) UpdateUserPassword(userID uint, oldPassword, newPassword string) error {
|
||||
if userID == 0 {
|
||||
return errors.New("user ID is required")
|
||||
}
|
||||
|
||||
if oldPassword == "" || newPassword == "" {
|
||||
return errors.New("old password and new password are required")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.FindByID(userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("find user error: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
return errors.New("user not found")
|
||||
}
|
||||
|
||||
// TODO: 验证旧密码
|
||||
// 现有系统没有密码字段,需要扩展
|
||||
|
||||
// TODO: 更新密码
|
||||
// 现有系统没有密码字段,需要扩展
|
||||
|
||||
return errors.New("password update not supported in current schema")
|
||||
}
|
||||
@@ -0,0 +1,287 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"ops/internal/repository"
|
||||
"ops/models"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type FileService interface {
|
||||
UploadFile(c *gin.Context, userID uint, fileHeader *multipart.FileHeader, fileType, description string) (UploadResponse, bool)
|
||||
GetFileList(userID uint, fileType string, page, entries int) (FileListResponse, bool)
|
||||
GetFileByID(fileID uint, userID uint) (*models.TabFileInfo_, bool)
|
||||
GetFileByHash(hash string) (*models.TabFileInfo_, bool)
|
||||
DeleteFile(fileID uint, userID uint) bool
|
||||
DownloadFile(c *gin.Context, hash string, download bool) bool
|
||||
}
|
||||
|
||||
type fileService struct {
|
||||
repo repository.FileRepository
|
||||
}
|
||||
|
||||
func NewFileService(db *gorm.DB) FileService {
|
||||
return &fileService{
|
||||
repo: repository.NewFileRepository(db),
|
||||
}
|
||||
}
|
||||
|
||||
// 响应结构体
|
||||
type UploadResponse struct {
|
||||
FileID uint `json:"file_id"`
|
||||
Name string `json:"name"`
|
||||
SHA256 string `json:"sha256"`
|
||||
Mime string `json:"mime"`
|
||||
Size int64 `json:"size"`
|
||||
DownloadURL string `json:"download_url"`
|
||||
PreviewURL string `json:"preview_url"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
type FileListResponse struct {
|
||||
Files []FileInfo `json:"files"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
Pages int `json:"pages"`
|
||||
}
|
||||
|
||||
type FileInfo struct {
|
||||
FileID uint `json:"file_id"`
|
||||
Name string `json:"name"`
|
||||
SHA256 string `json:"sha256"`
|
||||
Mime string `json:"mime"`
|
||||
Size int64 `json:"size"`
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
func (s *fileService) UploadFile(c *gin.Context, userID uint, fileHeader *multipart.FileHeader, fileType, description string) (UploadResponse, bool) {
|
||||
// 验证文件大小
|
||||
if fileHeader.Size > int64(models.ConfigsFile.MaxSize) {
|
||||
return UploadResponse{}, false
|
||||
}
|
||||
|
||||
// 验证文件最小大小
|
||||
if fileHeader.Size < 512 {
|
||||
return UploadResponse{}, false
|
||||
}
|
||||
|
||||
// 验证文件名
|
||||
if fileHeader.Filename == "" {
|
||||
return UploadResponse{}, false
|
||||
}
|
||||
|
||||
// 安全处理文件名
|
||||
filename := filepath.Base(fileHeader.Filename)
|
||||
|
||||
// 计算文件哈希
|
||||
hashStr, err := models.SHA256HashFile(fileHeader)
|
||||
if err != nil {
|
||||
return UploadResponse{}, false
|
||||
}
|
||||
|
||||
// 获取文件MIME类型
|
||||
mimeType, err := models.GetFileMime(fileHeader)
|
||||
if err != nil {
|
||||
return UploadResponse{}, false
|
||||
}
|
||||
|
||||
// 验证MIME类型(如果是图片)
|
||||
if fileType == "image" {
|
||||
if models.ConfigsFile.AllowImageMime[mimeType] == "" {
|
||||
return UploadResponse{}, false
|
||||
}
|
||||
}
|
||||
|
||||
// 构建文件保存路径
|
||||
var savePath string
|
||||
switch fileType {
|
||||
case "image":
|
||||
savePath = filepath.Join(models.ConfigsFile.Pahts["image"], hashStr)
|
||||
default:
|
||||
savePath = filepath.Join(models.ConfigsFile.Pahts["default"], hashStr)
|
||||
}
|
||||
|
||||
// 检查文件是否已存在
|
||||
if models.FileExists(savePath) {
|
||||
// 如果文件已存在,增加使用计数
|
||||
existingFile, err := s.repo.GetFileByHash(hashStr)
|
||||
if err == nil && existingFile != nil {
|
||||
s.repo.IncrementFileUsage(existingFile.ID)
|
||||
}
|
||||
} else {
|
||||
// 保存文件到磁盘
|
||||
if err := c.SaveUploadedFile(fileHeader, savePath); err != nil {
|
||||
return UploadResponse{}, false
|
||||
}
|
||||
}
|
||||
|
||||
// 检查数据库中是否已存在该文件
|
||||
existingFile, _ := s.repo.GetFileByHash(hashStr)
|
||||
if existingFile != nil {
|
||||
// 更新使用计数
|
||||
s.repo.IncrementFileUsage(existingFile.ID)
|
||||
|
||||
return UploadResponse{
|
||||
FileID: existingFile.ID,
|
||||
Name: filename,
|
||||
SHA256: hashStr,
|
||||
Mime: mimeType,
|
||||
Size: fileHeader.Size,
|
||||
DownloadURL: "/api/v1/files/download/" + hashStr,
|
||||
PreviewURL: "/api/v1/files/get/" + hashStr,
|
||||
CreatedAt: existingFile.Date.Format("2006-01-02T15:04:05Z"),
|
||||
}, true
|
||||
}
|
||||
|
||||
// 创建新的文件记录
|
||||
newFile := &models.TabFileInfo_{
|
||||
Name: filename,
|
||||
Path: savePath,
|
||||
Sha256: hashStr,
|
||||
Mime: mimeType,
|
||||
Type: fileType,
|
||||
UserID: userID,
|
||||
Date: time.Now(),
|
||||
}
|
||||
|
||||
if err := s.repo.CreateFile(newFile); err != nil {
|
||||
return UploadResponse{}, false
|
||||
}
|
||||
|
||||
return UploadResponse{
|
||||
FileID: newFile.ID,
|
||||
Name: filename,
|
||||
SHA256: hashStr,
|
||||
Mime: mimeType,
|
||||
Size: fileHeader.Size,
|
||||
DownloadURL: "/api/v1/files/download/" + hashStr,
|
||||
PreviewURL: "/api/v1/files/get/" + hashStr,
|
||||
CreatedAt: newFile.Date.Format("2006-01-02T15:04:05Z"),
|
||||
}, true
|
||||
}
|
||||
|
||||
func (s *fileService) GetFileList(userID uint, fileType string, page, entries int) (FileListResponse, bool) {
|
||||
// 验证分页参数
|
||||
if entries <= 0 || entries > 100 {
|
||||
return FileListResponse{}, false
|
||||
}
|
||||
if page <= 0 {
|
||||
return FileListResponse{}, false
|
||||
}
|
||||
|
||||
files, total, err := s.repo.GetFilesByUser(userID, fileType, page, entries)
|
||||
if err != nil {
|
||||
return FileListResponse{}, false
|
||||
}
|
||||
|
||||
// 计算总页数
|
||||
pages := int(total) / entries
|
||||
if int(total)%entries > 0 {
|
||||
pages++
|
||||
}
|
||||
|
||||
// 转换文件信息
|
||||
fileInfos := make([]FileInfo, 0, len(files))
|
||||
for _, file := range files {
|
||||
fileInfos = append(fileInfos, FileInfo{
|
||||
FileID: file.ID,
|
||||
Name: file.Name,
|
||||
SHA256: file.Sha256,
|
||||
Mime: file.Mime,
|
||||
Type: file.Type,
|
||||
CreatedAt: file.Date.Format("2006-01-02T15:04:05Z"),
|
||||
})
|
||||
}
|
||||
|
||||
return FileListResponse{
|
||||
Files: fileInfos,
|
||||
Total: total,
|
||||
Page: page,
|
||||
Pages: pages,
|
||||
}, true
|
||||
}
|
||||
|
||||
func (s *fileService) GetFileByID(fileID uint, userID uint) (*models.TabFileInfo_, bool) {
|
||||
file, err := s.repo.GetFileByID(fileID)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 检查文件所有权
|
||||
if file.UserID != userID {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return file, true
|
||||
}
|
||||
|
||||
func (s *fileService) GetFileByHash(hash string) (*models.TabFileInfo_, bool) {
|
||||
file, err := s.repo.GetFileByHash(hash)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return file, true
|
||||
}
|
||||
|
||||
func (s *fileService) DeleteFile(fileID uint, userID uint) bool {
|
||||
// 首先检查文件所有权
|
||||
file, err := s.repo.GetFileByID(fileID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if file.UserID != userID {
|
||||
return false
|
||||
}
|
||||
|
||||
// 删除文件记录
|
||||
if err := s.repo.DeleteFile(fileID); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 注意:这里不删除物理文件,因为可能还有其他引用
|
||||
// 如果需要删除物理文件,需要检查引用计数
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *fileService) DownloadFile(c *gin.Context, hash string, download bool) bool {
|
||||
file, err := s.repo.GetFileByHash(hash)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查文件是否存在
|
||||
if !models.FileExists(file.Path) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
if download {
|
||||
// 下载模式
|
||||
c.Header("Content-Disposition", "attachment; filename=\""+file.Name+"\"")
|
||||
} else {
|
||||
// 预览模式
|
||||
ext := filepath.Ext(file.Name)
|
||||
if ext != "" {
|
||||
mimeType := mime.TypeByExtension(ext)
|
||||
if mimeType != "" {
|
||||
c.Header("Content-Type", mimeType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "application/octet-stream")
|
||||
c.File(file.Path)
|
||||
|
||||
// 增加使用计数
|
||||
s.repo.IncrementFileUsage(file.ID)
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"ops/internal/repository"
|
||||
"ops/models"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/datatypes"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type PurchaseService interface {
|
||||
GetOrders(c *gin.Context, userID uint, search string, page, entries int) (gin.H, bool)
|
||||
CreateOrder(c *gin.Context, userID uint, request CreateOrderRequest) bool
|
||||
GetOrderDetails(orderID uint) (*models.TabPurchaseOrder, []models.TabPurchaseCosts, error)
|
||||
}
|
||||
|
||||
type purchaseService struct {
|
||||
repo repository.PurchaseRepository
|
||||
}
|
||||
|
||||
func NewPurchaseService(db *gorm.DB) PurchaseService {
|
||||
return &purchaseService{
|
||||
repo: repository.NewPurchaseRepository(db),
|
||||
}
|
||||
}
|
||||
|
||||
// 请求结构体
|
||||
type CostItem struct {
|
||||
Cost int `json:"cost" binding:"required,min=1"`
|
||||
CostT int `json:"costt" binding:"required,min=0"`
|
||||
CurrencyType string `json:"currencytype" binding:"required"`
|
||||
Int int `json:"int" binding:"required,min=1"`
|
||||
Type string `json:"type" binding:"required"`
|
||||
}
|
||||
|
||||
type CreateOrderRequest struct {
|
||||
Costs []CostItem `json:"costs" binding:"required,min=1,dive"`
|
||||
Link string `json:"link"`
|
||||
OrderStatus string `json:"order_status" binding:"required"`
|
||||
PartName string `json:"partname"`
|
||||
Photos []string `json:"photos"`
|
||||
Remark string `json:"remark"`
|
||||
Styles string `json:"styles"`
|
||||
Title string `json:"title" binding:"required"`
|
||||
TrackingNumber string `json:"tracking_number"`
|
||||
UpdateTime string `json:"update_time"`
|
||||
}
|
||||
|
||||
func (s *purchaseService) GetOrders(c *gin.Context, userID uint, search string, page, entries int) (gin.H, bool) {
|
||||
// 验证分页参数
|
||||
if entries <= 0 || entries > 300 {
|
||||
return nil, false
|
||||
}
|
||||
if page <= 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
orders, total, err := s.repo.GetOrders(userID, search, page, entries)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
result := gin.H{
|
||||
"all_count": total,
|
||||
"all_orders": orders,
|
||||
}
|
||||
|
||||
return result, true
|
||||
}
|
||||
|
||||
func (s *purchaseService) CreateOrder(c *gin.Context, userID uint, request CreateOrderRequest) bool {
|
||||
// 验证数据
|
||||
if request.Title == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 验证价格和数量
|
||||
for _, cost := range request.Costs {
|
||||
if cost.Cost <= 0 {
|
||||
return false
|
||||
}
|
||||
if cost.Int <= 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 验证图片哈希(简单检查是否包含特殊字符)
|
||||
for _, photo := range request.Photos {
|
||||
if models.IsContainsSpecialChar(photo) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 解析更新时间
|
||||
var updateTime *time.Time
|
||||
if request.UpdateTime != "" {
|
||||
parsedTime, err := models.StringToTimePtr(request.UpdateTime)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
updateTime = parsedTime
|
||||
}
|
||||
|
||||
// 转换照片数组为JSON
|
||||
var photosJSON datatypes.JSON
|
||||
if len(request.Photos) > 0 {
|
||||
photosBytes, err := json.Marshal(request.Photos)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
photosJSON = datatypes.JSON(photosBytes)
|
||||
}
|
||||
|
||||
// 创建订单
|
||||
order := &models.TabPurchaseOrder{
|
||||
UserID: userID,
|
||||
Title: request.Title,
|
||||
Remark: request.Remark,
|
||||
Photos: photosJSON,
|
||||
Link: request.Link,
|
||||
PartName: request.PartName,
|
||||
Styles: request.Styles,
|
||||
UpdateTime: updateTime,
|
||||
TrackingNumber: request.TrackingNumber,
|
||||
OrderStatus: request.OrderStatus,
|
||||
}
|
||||
|
||||
if err := s.repo.CreateOrder(order); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 创建费用明细
|
||||
for _, costItem := range request.Costs {
|
||||
cost := &models.TabPurchaseCosts{
|
||||
UserID: userID,
|
||||
OrderID: order.ID,
|
||||
Price: costItem.Cost,
|
||||
Quantity: costItem.Int,
|
||||
}
|
||||
|
||||
if err := s.repo.CreateCost(cost); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *purchaseService) GetOrderDetails(orderID uint) (*models.TabPurchaseOrder, []models.TabPurchaseCosts, error) {
|
||||
order, err := s.repo.GetOrderByID(orderID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
costs, err := s.repo.GetOrderCosts(orderID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return order, costs, nil
|
||||
}
|
||||
Reference in New Issue
Block a user