// Package mysql 提供 MySQL 数据库连接和管理功能。 // 支持 Unix Socket 和 TCP 两种连接方式,自动初始化数据表和恢复数据。 package mysql import ( "context" "database/sql" "fmt" "log" "os" "path/filepath" "strings" "time" goredis "github.com/redis/go-redis/v9" _ "github.com/go-sql-driver/mysql" "sese-engine/config" ) // DB 是 MySQL 数据库连接池 var DB *sql.DB // Open 初始化 MySQL 连接 // 根据配置自动选择 Unix Socket 或 TCP 连接 func Open() error { dsn := config.MySQLDSN() db, err := sql.Open("mysql", dsn) if err != nil { return fmt.Errorf("mysql.Open: %w", err) } // 配置连接池 db.SetConnMaxLifetime(time.Duration(config.MySQLConnMaxLifetime()) * time.Second) db.SetMaxIdleConns(config.MySQLMaxIdleConns()) db.SetMaxOpenConns(config.MySQLMaxOpenConns()) // 验证连接 if err := db.Ping(); err != nil { return fmt.Errorf("mysql.Ping: %w", err) } DB = db log.Printf("[mysql] connected via %s", formatDSN(dsn)) // 自动初始化数据表 if err := initSchema(); err != nil { return fmt.Errorf("mysql init schema: %w", err) } return nil } // initSchema 自动执行 init_db.sql 初始化数据表 func initSchema() error { // 查找 init_db.sql 文件 execPath, err := os.Executable() if err != nil { execPath = os.Args[0] } sqlFile := filepath.Join(filepath.Dir(execPath), "mysql", "init_db.sql") if _, err := os.Stat(sqlFile); os.IsNotExist(err) { // 尝试从当前工作目录查找 cwd, _ := os.Getwd() sqlFile = filepath.Join(cwd, "mysql", "init_db.sql") } data, err := os.ReadFile(sqlFile) if err != nil { return fmt.Errorf("read init_db.sql: %w", err) } // 获取配置的数据库名 dbName := config.Global.MySQL.Database if dbName == "" { dbName = "sese_engine" } log.Printf("[mysql] init schema: database=%s", dbName) // 先切换到目标数据库 if _, err := DB.Exec("USE " + dbName); err != nil { return fmt.Errorf("mysql USE database: %w", err) } // 分割 SQL 语句(按分号分割) statements := splitStatements(string(data)) log.Printf("[mysql] found %d SQL statements to execute", len(statements)) execed := 0 for i, stmt := range statements { trimmed := strings.TrimSpace(stmt) // 跳过空行和注释 if trimmed == "" || strings.HasPrefix(trimmed, "--") || strings.HasPrefix(trimmed, "/*") { log.Printf("[mysql] [%d/%d] SKIP (empty/comment): %s", i+1, len(statements), truncate(trimmed, 60)) continue } if _, err := DB.Exec(trimmed); err != nil { log.Printf("[mysql] [%d/%d] FAILED: %v\n SQL: %s", i+1, len(statements), err, truncate(trimmed, 200)) continue } execed++ log.Printf("[mysql] [%d/%d] OK: %s", i+1, len(statements), truncate(trimmed, 60)) } log.Printf("[mysql] init schema done, executed=%d statements", execed) return nil } // splitStatements 按分号分割 SQL 语句(处理多行 CREATE TABLE) func splitStatements(sql string) []string { var statements []string var buf strings.Builder inComment := false for _, line := range strings.Split(sql, "\n") { trimmed := strings.TrimSpace(line) // 单行注释 if strings.HasPrefix(trimmed, "--") || strings.HasPrefix(trimmed, "//") { continue } // 多行注释开始/结束 if strings.Contains(trimmed, "/*") { inComment = true } if inComment { if strings.Contains(trimmed, "*/") { inComment = false } continue } // 空行跳过 if trimmed == "" { continue } buf.WriteString(line) buf.WriteString("\n") // 检查是否以分号结尾 trimmed = strings.TrimSpace(buf.String()) if strings.HasSuffix(trimmed, ";") { statements = append(statements, trimmed) buf.Reset() } } // 处理最后一条(可能没有分号) if buf.Len() > 0 { trimmed := strings.TrimSpace(buf.String()) if trimmed != "" { statements = append(statements, trimmed) } } return statements } // truncate 截断字符串 func truncate(s string, maxLen int) string { if len(s) <= maxLen { return s } return s[:maxLen] + "..." } // Close 关闭 MySQL 连接 func Close() error { if DB != nil { return DB.Close() } return nil } // Ping 检查 MySQL 连接是否正常 func Ping() error { if DB == nil { return fmt.Errorf("mysql not initialized") } return DB.Ping() } // formatDSN 格式化 DSN 用于日志(隐藏密码) func formatDSN(dsn string) string { // 简化日志输出 cfg := config.Global.MySQL if cfg.UnixSocket != "" { return fmt.Sprintf("unix_socket=%s database=%s", cfg.UnixSocket, cfg.Database) } return fmt.Sprintf("tcp=%s:%d database=%s", cfg.Host, cfg.Port, cfg.Database) } // RestoreFromMySQLToRedis 从 MySQL 恢复数据到 Redis // 用于 Redis 数据丢失后重建索引 func RestoreFromMySQLToRedis(redisDB *goredis.Client) error { if DB == nil { return fmt.Errorf("mysql not initialized") } start := time.Now() log.Printf("[mysql-restore] starting restoration from MySQL to Redis...") ctx := context.Background() // 1. 恢复 index_entries → Redis idx:* ZSet if err := restoreIndexEntries(ctx, redisDB); err != nil { return fmt.Errorf("restore index_entries: %w", err) } // 2. 恢复 url_snippets → Redis gate:* + url2hash:* if err := restoreUrlSnippets(ctx, redisDB); err != nil { return fmt.Errorf("restore url_snippets: %w", err) } // 3. 恢复 site_info → Redis site:* if err := restoreSiteInfo(ctx, redisDB); err != nil { return fmt.Errorf("restore site_info: %w", err) } // 4. 恢复 priority_urls → Redis priority:* if err := restorePriorityURLs(ctx, redisDB); err != nil { return fmt.Errorf("restore priority_urls: %w", err) } log.Printf("[mysql-restore] restoration completed in %v", time.Since(start)) return nil } // restoreIndexEntries 恢复倒排索引 func restoreIndexEntries(ctx context.Context, redisDB *goredis.Client) error { rows, err := DB.Query("SELECT keyword, url, weight FROM index_entries") if err != nil { // 表不存在时跳过 log.Printf("[mysql-restore][index] skip: %v", err) return nil } defer rows.Close() // 按 keyword 分组 type indexRow struct { URL string Weight float32 } keywordMap := make(map[string][]indexRow) count := 0 for rows.Next() { var keyword, url string var weight float32 if err := rows.Scan(&keyword, &url, &weight); err != nil { continue } keywordMap[keyword] = append(keywordMap[keyword], indexRow{URL: url, Weight: weight}) count++ } // 批量写入 Redis for keyword, entries := range keywordMap { if len(entries) == 0 { continue } zSlice := make([]goredis.Z, len(entries)) for i, e := range entries { zSlice[i] = goredis.Z{Score: float64(e.Weight), Member: e.URL} } if err := redisDB.ZAdd(ctx, "idx:"+keyword, zSlice...).Err(); err != nil { log.Printf("[mysql-restore][index] failed to restore %s: %v", keyword, err) } } log.Printf("[mysql-restore][index] restored %d entries (%d keywords)", count, len(keywordMap)) return nil } // restoreUrlSnippets 恢复 URL 摘要 func restoreUrlSnippets(ctx context.Context, redisDB *goredis.Client) error { rows, err := DB.Query("SELECT url, url_hash, title, description, text, timestamp, content_hash FROM url_snippets") if err != nil { // 表不存在时跳过 log.Printf("[mysql-restore][snippets] skip: %v", err) return nil } defer rows.Close() count := 0 for rows.Next() { var url, urlHash, title, description, text, contentHash sql.NullString var timestamp sql.NullInt64 if err := rows.Scan(&url, &urlHash, &title, &description, &text, ×tamp, &contentHash); err != nil { continue } if !url.Valid || urlHash.Valid == false { continue } fields := map[string]interface{}{ "url": url.String, "title": nullString(title), "desc": nullString(description), "text": nullString(text), "ts": nullInt64(timestamp), "hash": nullString(contentHash), } if err := redisDB.HMSet(ctx, "gate:"+urlHash.String, fields).Err(); err != nil { continue } // 同时写入 URL→hash 映射 redisDB.Set(ctx, "url2hash:"+url.String, urlHash.String, 0) count++ } log.Printf("[mysql-restore][snippets] restored %d entries", count) return nil } // restoreSiteInfo 恢复网站信息 func restoreSiteInfo(ctx context.Context, redisDB *goredis.Client) error { rows, err := DB.Query("SELECT host, visit_count, last_visit_time, success_rate, https_available FROM site_info") if err != nil { // 表不存在时跳过 log.Printf("[mysql-restore][site] skip: %v", err) return nil } defer rows.Close() count := 0 for rows.Next() { var host string var visitCount sql.NullInt64 var lastVisitTime sql.NullInt64 var successRate sql.NullFloat64 var httpsAvailable sql.NullInt64 if err := rows.Scan(&host, &visitCount, &lastVisitTime, &successRate, &httpsAvailable); err != nil { continue } if host == "" { continue } fields := map[string]interface{}{ "visit_count": nullInt64(visitCount), "last_visit_time": nullInt64(lastVisitTime), } if successRate.Valid { fields["success_rate"] = successRate.Float64 } if httpsAvailable.Valid { fields["https_available"] = httpsAvailable.Int64 } if err := redisDB.HMSet(ctx, "site:"+host, fields).Err(); err != nil { continue } count++ } log.Printf("[mysql-restore][site] restored %d entries", count) return nil } // restorePriorityURLs 恢复优先 URL func restorePriorityURLs(ctx context.Context, redisDB *goredis.Client) error { rows, err := DB.Query("SELECT url FROM priority_urls") if err != nil { // 表不存在时跳过 log.Printf("[mysql-restore][priority] skip: %v", err) return nil } defer rows.Close() count := 0 for rows.Next() { var url string if err := rows.Scan(&url); err != nil { continue } if url == "" { continue } fields := map[string]interface{}{ "url": url, "is_domain": "0", "added_at": time.Now().Unix(), "visited": "0", } if err := redisDB.HMSet(ctx, "priority:"+url, fields).Err(); err != nil { continue } count++ } log.Printf("[mysql-restore][priority] restored %d entries", count) return nil } // ---- 辅助函数 ---- func nullString(v sql.NullString) string { if v.Valid { return v.String } return "" } func nullInt64(v sql.NullInt64) int64 { if v.Valid { return v.Int64 } return 0 }