Files
2026-04-20 18:26:54 +08:00

417 lines
10 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package mysql 提供 MySQL 数据库连接和管理功能。
// 支持 Unix Socket 和 TCP 两种连接方式,自动初始化数据表和恢复数据。
package mysql
import (
"context"
"database/sql"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"time"
goredis "github.com/redis/go-redis/v9"
_ "github.com/go-sql-driver/mysql"
"sese-engine/config"
)
// DB 是 MySQL 数据库连接池
var DB *sql.DB
// Open 初始化 MySQL 连接
// 根据配置自动选择 Unix Socket 或 TCP 连接
func Open() error {
dsn := config.MySQLDSN()
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("mysql.Open: %w", err)
}
// 配置连接池
db.SetConnMaxLifetime(time.Duration(config.MySQLConnMaxLifetime()) * time.Second)
db.SetMaxIdleConns(config.MySQLMaxIdleConns())
db.SetMaxOpenConns(config.MySQLMaxOpenConns())
// 验证连接
if err := db.Ping(); err != nil {
return fmt.Errorf("mysql.Ping: %w", err)
}
DB = db
log.Printf("[mysql] connected via %s", formatDSN(dsn))
// 自动初始化数据表
if err := initSchema(); err != nil {
return fmt.Errorf("mysql init schema: %w", err)
}
return nil
}
// initSchema 自动执行 init_db.sql 初始化数据表
func initSchema() error {
// 查找 init_db.sql 文件
execPath, err := os.Executable()
if err != nil {
execPath = os.Args[0]
}
sqlFile := filepath.Join(filepath.Dir(execPath), "mysql", "init_db.sql")
if _, err := os.Stat(sqlFile); os.IsNotExist(err) {
// 尝试从当前工作目录查找
cwd, _ := os.Getwd()
sqlFile = filepath.Join(cwd, "mysql", "init_db.sql")
}
data, err := os.ReadFile(sqlFile)
if err != nil {
return fmt.Errorf("read init_db.sql: %w", err)
}
// 获取配置的数据库名
dbName := config.Global.MySQL.Database
if dbName == "" {
dbName = "sese_engine"
}
log.Printf("[mysql] init schema: database=%s", dbName)
// 先切换到目标数据库
if _, err := DB.Exec("USE " + dbName); err != nil {
return fmt.Errorf("mysql USE database: %w", err)
}
// 分割 SQL 语句(按分号分割)
statements := splitStatements(string(data))
log.Printf("[mysql] found %d SQL statements to execute", len(statements))
execed := 0
for i, stmt := range statements {
trimmed := strings.TrimSpace(stmt)
// 跳过空行和注释
if trimmed == "" || strings.HasPrefix(trimmed, "--") || strings.HasPrefix(trimmed, "/*") {
log.Printf("[mysql] [%d/%d] SKIP (empty/comment): %s", i+1, len(statements), truncate(trimmed, 60))
continue
}
if _, err := DB.Exec(trimmed); err != nil {
log.Printf("[mysql] [%d/%d] FAILED: %v\n SQL: %s", i+1, len(statements), err, truncate(trimmed, 200))
continue
}
execed++
log.Printf("[mysql] [%d/%d] OK: %s", i+1, len(statements), truncate(trimmed, 60))
}
log.Printf("[mysql] init schema done, executed=%d statements", execed)
return nil
}
// splitStatements 按分号分割 SQL 语句(处理多行 CREATE TABLE
func splitStatements(sql string) []string {
var statements []string
var buf strings.Builder
inComment := false
for _, line := range strings.Split(sql, "\n") {
trimmed := strings.TrimSpace(line)
// 单行注释
if strings.HasPrefix(trimmed, "--") || strings.HasPrefix(trimmed, "//") {
continue
}
// 多行注释开始/结束
if strings.Contains(trimmed, "/*") {
inComment = true
}
if inComment {
if strings.Contains(trimmed, "*/") {
inComment = false
}
continue
}
// 空行跳过
if trimmed == "" {
continue
}
buf.WriteString(line)
buf.WriteString("\n")
// 检查是否以分号结尾
trimmed = strings.TrimSpace(buf.String())
if strings.HasSuffix(trimmed, ";") {
statements = append(statements, trimmed)
buf.Reset()
}
}
// 处理最后一条(可能没有分号)
if buf.Len() > 0 {
trimmed := strings.TrimSpace(buf.String())
if trimmed != "" {
statements = append(statements, trimmed)
}
}
return statements
}
// truncate 截断字符串
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
// Close 关闭 MySQL 连接
func Close() error {
if DB != nil {
return DB.Close()
}
return nil
}
// Ping 检查 MySQL 连接是否正常
func Ping() error {
if DB == nil {
return fmt.Errorf("mysql not initialized")
}
return DB.Ping()
}
// formatDSN 格式化 DSN 用于日志(隐藏密码)
func formatDSN(dsn string) string {
// 简化日志输出
cfg := config.Global.MySQL
if cfg.UnixSocket != "" {
return fmt.Sprintf("unix_socket=%s database=%s", cfg.UnixSocket, cfg.Database)
}
return fmt.Sprintf("tcp=%s:%d database=%s", cfg.Host, cfg.Port, cfg.Database)
}
// RestoreFromMySQLToRedis 从 MySQL 恢复数据到 Redis
// 用于 Redis 数据丢失后重建索引
func RestoreFromMySQLToRedis(redisDB *goredis.Client) error {
if DB == nil {
return fmt.Errorf("mysql not initialized")
}
start := time.Now()
log.Printf("[mysql-restore] starting restoration from MySQL to Redis...")
ctx := context.Background()
// 1. 恢复 index_entries → Redis idx:* ZSet
if err := restoreIndexEntries(ctx, redisDB); err != nil {
return fmt.Errorf("restore index_entries: %w", err)
}
// 2. 恢复 url_snippets → Redis gate:* + url2hash:*
if err := restoreUrlSnippets(ctx, redisDB); err != nil {
return fmt.Errorf("restore url_snippets: %w", err)
}
// 3. 恢复 site_info → Redis site:*
if err := restoreSiteInfo(ctx, redisDB); err != nil {
return fmt.Errorf("restore site_info: %w", err)
}
// 4. 恢复 priority_urls → Redis priority:*
if err := restorePriorityURLs(ctx, redisDB); err != nil {
return fmt.Errorf("restore priority_urls: %w", err)
}
log.Printf("[mysql-restore] restoration completed in %v", time.Since(start))
return nil
}
// restoreIndexEntries 恢复倒排索引
func restoreIndexEntries(ctx context.Context, redisDB *goredis.Client) error {
rows, err := DB.Query("SELECT keyword, url, weight FROM index_entries")
if err != nil {
// 表不存在时跳过
log.Printf("[mysql-restore][index] skip: %v", err)
return nil
}
defer rows.Close()
// 按 keyword 分组
type indexRow struct {
URL string
Weight float32
}
keywordMap := make(map[string][]indexRow)
count := 0
for rows.Next() {
var keyword, url string
var weight float32
if err := rows.Scan(&keyword, &url, &weight); err != nil {
continue
}
keywordMap[keyword] = append(keywordMap[keyword], indexRow{URL: url, Weight: weight})
count++
}
// 批量写入 Redis
for keyword, entries := range keywordMap {
if len(entries) == 0 {
continue
}
zSlice := make([]goredis.Z, len(entries))
for i, e := range entries {
zSlice[i] = goredis.Z{Score: float64(e.Weight), Member: e.URL}
}
if err := redisDB.ZAdd(ctx, "idx:"+keyword, zSlice...).Err(); err != nil {
log.Printf("[mysql-restore][index] failed to restore %s: %v", keyword, err)
}
}
log.Printf("[mysql-restore][index] restored %d entries (%d keywords)", count, len(keywordMap))
return nil
}
// restoreUrlSnippets 恢复 URL 摘要
func restoreUrlSnippets(ctx context.Context, redisDB *goredis.Client) error {
rows, err := DB.Query("SELECT url, url_hash, title, description, text, timestamp, content_hash FROM url_snippets")
if err != nil {
// 表不存在时跳过
log.Printf("[mysql-restore][snippets] skip: %v", err)
return nil
}
defer rows.Close()
count := 0
for rows.Next() {
var url, urlHash, title, description, text, contentHash sql.NullString
var timestamp sql.NullInt64
if err := rows.Scan(&url, &urlHash, &title, &description, &text, &timestamp, &contentHash); err != nil {
continue
}
if !url.Valid || urlHash.Valid == false {
continue
}
fields := map[string]interface{}{
"url": url.String,
"title": nullString(title),
"desc": nullString(description),
"text": nullString(text),
"ts": nullInt64(timestamp),
"hash": nullString(contentHash),
}
if err := redisDB.HMSet(ctx, "gate:"+urlHash.String, fields).Err(); err != nil {
continue
}
// 同时写入 URL→hash 映射
redisDB.Set(ctx, "url2hash:"+url.String, urlHash.String, 0)
count++
}
log.Printf("[mysql-restore][snippets] restored %d entries", count)
return nil
}
// restoreSiteInfo 恢复网站信息
func restoreSiteInfo(ctx context.Context, redisDB *goredis.Client) error {
rows, err := DB.Query("SELECT host, visit_count, last_visit_time, success_rate, https_available FROM site_info")
if err != nil {
// 表不存在时跳过
log.Printf("[mysql-restore][site] skip: %v", err)
return nil
}
defer rows.Close()
count := 0
for rows.Next() {
var host string
var visitCount sql.NullInt64
var lastVisitTime sql.NullInt64
var successRate sql.NullFloat64
var httpsAvailable sql.NullInt64
if err := rows.Scan(&host, &visitCount, &lastVisitTime, &successRate, &httpsAvailable); err != nil {
continue
}
if host == "" {
continue
}
fields := map[string]interface{}{
"visit_count": nullInt64(visitCount),
"last_visit_time": nullInt64(lastVisitTime),
}
if successRate.Valid {
fields["success_rate"] = successRate.Float64
}
if httpsAvailable.Valid {
fields["https_available"] = httpsAvailable.Int64
}
if err := redisDB.HMSet(ctx, "site:"+host, fields).Err(); err != nil {
continue
}
count++
}
log.Printf("[mysql-restore][site] restored %d entries", count)
return nil
}
// restorePriorityURLs 恢复优先 URL
func restorePriorityURLs(ctx context.Context, redisDB *goredis.Client) error {
rows, err := DB.Query("SELECT url FROM priority_urls")
if err != nil {
// 表不存在时跳过
log.Printf("[mysql-restore][priority] skip: %v", err)
return nil
}
defer rows.Close()
count := 0
for rows.Next() {
var url string
if err := rows.Scan(&url); err != nil {
continue
}
if url == "" {
continue
}
fields := map[string]interface{}{
"url": url,
"is_domain": "0",
"added_at": time.Now().Unix(),
"visited": "0",
}
if err := redisDB.HMSet(ctx, "priority:"+url, fields).Err(); err != nil {
continue
}
count++
}
log.Printf("[mysql-restore][priority] restored %d entries", count)
return nil
}
// ---- 辅助函数 ----
func nullString(v sql.NullString) string {
if v.Valid {
return v.String
}
return ""
}
func nullInt64(v sql.NullInt64) int64 {
if v.Valid {
return v.Int64
}
return 0
}