Files
2026-04-20 18:26:54 +08:00

652 lines
14 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// Package storage provides unified storage interface backed by Redis + MySQL.
// Redis is used as the primary high-performance store, MySQL is used for persistence.
package storage
import (
"context"
"crypto/md5"
"fmt"
"sync"
"sync/atomic"
goredis "github.com/redis/go-redis/v9"
extredis "sese-engine/redis"
)
// PriorityEntry 记录一条待优先爬取的 URL 或域名。
type PriorityEntry struct {
URL string `json:"url"` // 用户提交的 URL 或域名(会自动规范化为带 scheme 的 URL)
IsDomain bool `json:"domain"` // 是否为纯域名(true=仅域名,false=完整 URL
AddedAt int64 `json:"added_at"` // 添加时的 Unix 时间戳
Visited bool `json:"visited"` // 是否已爬取(crawler 爬完后标记)
}
// Store 统一存储接口,所有数据操作都通过此接口。
type Store interface {
// 倒排索引操作
GetIndex(keyword string) ([]IndexEntry, error)
BatchSetIndex(batch map[string][]IndexEntry) error
ForEachIndex(fn func(keyword string, entries []IndexEntry) error) error
// URL 摘要操作
GetSnippet(url string) (*SnippetEntry, error)
SetSnippet(url string, entry *SnippetEntry) error
ForEachSnippet(fn func(url string, entry *SnippetEntry) error) error
// 网站信息操作
GetSiteInfo(host string) (*SiteInfo, error)
SetSiteInfo(host string, info *SiteInfo) error
UpdateSiteInfo(host string, fn func(*SiteInfo)) error
ForEachSite(fn func(host string, info *SiteInfo) error) error
// Priority URL 操作
GetPriorityURLs() ([]PriorityEntry, error)
AddPriorityURL(entry PriorityEntry) error
RemovePriorityURL(url string) error
MarkPriorityURLVisited(url string) error
ClearVisitedPriorityURLs() error
// 生命周期
Close() error
}
// RedisStore 实现了 Store 接口,提供 Redis 存储能力。
type RedisStoreV2 struct {
client *goredis.Client
// 内存索引聚合(用于写入)
mem map[string][]IndexEntry
memMu sync.RWMutex
rowCount int64
// 站点信息内存缓存
siteCache map[string]*SiteInfo
siteCacheMu sync.RWMutex
}
// NewRedisStoreV2 创建新的 Redis 存储实例
func NewRedisStoreV2() *RedisStoreV2 {
return &RedisStoreV2{
mem: make(map[string][]IndexEntry),
siteCache: make(map[string]*SiteInfo),
}
}
// Init 初始化 Redis 存储
func (r *RedisStoreV2) Init() error {
if extredis.Client == nil {
return fmt.Errorf("redis not initialized, call redis.Open() first")
}
r.client = extredis.Client
return nil
}
// Close 关闭存储
func (r *RedisStoreV2) Close() error {
// 将内存数据刷到 Redis
if err := r.FlushMemToRedis(); err != nil {
return err
}
return nil
}
// ---- 倒排索引操作 ----
// GetIndex 获取关键词的倒排索引
func (r *RedisStoreV2) GetIndex(keyword string) ([]IndexEntry, error) {
ctx := context.Background()
entries, err := r.client.ZRevRangeWithScores(ctx, "idx:"+keyword, 0, -1).Result()
if err != nil {
if err == goredis.Nil {
return nil, nil
}
return nil, err
}
result := make([]IndexEntry, 0, len(entries))
for _, e := range entries {
result = append(result, IndexEntry{
Weight: float32(e.Score),
URL: e.Member.(string),
})
}
return result, nil
}
// BatchSetIndex 批量设置倒排索引
func (r *RedisStoreV2) BatchSetIndex(batch map[string][]IndexEntry) error {
ctx := context.Background()
for keyword, entries := range batch {
if len(entries) == 0 {
continue
}
// 先删除旧的
r.client.Del(ctx, "idx:"+keyword)
// 添加新的
if len(entries) > 0 {
zSlice := make([]goredis.Z, len(entries))
for i, e := range entries {
zSlice[i] = goredis.Z{
Score: float64(e.Weight),
Member: e.URL,
}
}
if err := r.client.ZAdd(ctx, "idx:"+keyword, zSlice...).Err(); err != nil {
return err
}
}
}
return nil
}
// ForEachIndex 遍历所有倒排索引
func (r *RedisStoreV2) ForEachIndex(fn func(keyword string, entries []IndexEntry) error) error {
ctx := context.Background()
var cursor uint64
for {
keys, nextCursor, err := r.client.Scan(ctx, cursor, "idx:*", 1000).Result()
if err != nil {
return err
}
for _, key := range keys {
keyword := key[4:] // 去掉 "idx:" 前缀
entries, err := r.GetIndex(keyword)
if err != nil {
continue
}
if err := fn(keyword, entries); err != nil {
return err
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}
// ---- URL 摘要操作 ----
// GetSnippet 获取 URL 摘要
func (r *RedisStoreV2) GetSnippet(url string) (*SnippetEntry, error) {
ctx := context.Background()
hash := urlHash(url)
data, err := r.client.HGetAll(ctx, "gate:"+hash).Result()
if err != nil {
return nil, err
}
if len(data) == 0 {
return nil, nil
}
return &SnippetEntry{
Title: data["title"],
Description: data["desc"],
Text: data["text"],
Timestamp: parseInt64(data["ts"]),
ContentHash: data["hash"],
}, nil
}
// SetSnippet 设置 URL 摘要
func (r *RedisStoreV2) SetSnippet(url string, entry *SnippetEntry) error {
ctx := context.Background()
hash := urlHash(url)
fields := map[string]interface{}{
"url": url,
"title": entry.Title,
"desc": entry.Description,
"text": entry.Text,
"ts": entry.Timestamp,
"hash": entry.ContentHash,
}
err := r.client.HMSet(ctx, "gate:"+hash, fields).Err()
if err != nil {
return err
}
// 同时存储 URL→hash 的映射
r.client.Set(ctx, "url2hash:"+url, hash, 0)
return nil
}
// ForEachSnippet 遍历所有 URL 摘要
func (r *RedisStoreV2) ForEachSnippet(fn func(url string, entry *SnippetEntry) error) error {
ctx := context.Background()
var cursor uint64
for {
keys, nextCursor, err := r.client.Scan(ctx, cursor, "gate:*", 1000).Result()
if err != nil {
return err
}
for _, key := range keys {
data, err := r.client.HGetAll(ctx, key).Result()
if err != nil || len(data) == 0 {
continue
}
entry := &SnippetEntry{
Title: data["title"],
Description: data["desc"],
Text: data["text"],
Timestamp: parseInt64(data["ts"]),
ContentHash: data["hash"],
}
if err := fn(data["url"], entry); err != nil {
return err
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}
// ---- 网站信息操作 ----
// GetSiteInfo 获取网站信息
func (r *RedisStoreV2) GetSiteInfo(host string) (*SiteInfo, error) {
// 先从内存缓存读取
r.siteCacheMu.RLock()
if info, ok := r.siteCache[host]; ok {
r.siteCacheMu.RUnlock()
return info, nil
}
r.siteCacheMu.RUnlock()
// 从 Redis 读取
ctx := context.Background()
data, err := r.client.HGetAll(ctx, "site:"+host).Result()
if err != nil {
return nil, err
}
if len(data) == 0 {
return nil, nil
}
info := &SiteInfo{
VisitCount: int(parseInt64(data["visit_count"])),
LastVisitTime: parseInt64(data["last_visit_time"]),
}
if v, ok := data["success_rate"]; ok {
f := parseFloat64(v)
info.SuccessRate = &f
}
if v, ok := data["https_available"]; ok {
b := parseInt64(v) == 1
info.HTTPSAvailable = &b
}
// 回填缓存
r.siteCacheMu.Lock()
r.siteCache[host] = info
r.siteCacheMu.Unlock()
return info, nil
}
// SetSiteInfo 设置网站信息
func (r *RedisStoreV2) SetSiteInfo(host string, info *SiteInfo) error {
ctx := context.Background()
fields := map[string]interface{}{
"visit_count": info.VisitCount,
"last_visit_time": info.LastVisitTime,
}
if info.SuccessRate != nil {
fields["success_rate"] = *info.SuccessRate
}
if info.HTTPSAvailable != nil {
if *info.HTTPSAvailable {
fields["https_available"] = 1
} else {
fields["https_available"] = 0
}
}
err := r.client.HMSet(ctx, "site:"+host, fields).Err()
if err != nil {
return err
}
// 更新缓存
r.siteCacheMu.Lock()
r.siteCache[host] = info
r.siteCacheMu.Unlock()
return nil
}
// UpdateSiteInfo 更新网站信息
func (r *RedisStoreV2) UpdateSiteInfo(host string, fn func(*SiteInfo)) error {
info, err := r.GetSiteInfo(host)
if err != nil {
return err
}
if info == nil {
info = &SiteInfo{}
}
// 调用更新函数
fn(info)
return r.SetSiteInfo(host, info)
}
// ForEachSite 遍历所有网站信息
func (r *RedisStoreV2) ForEachSite(fn func(host string, info *SiteInfo) error) error {
ctx := context.Background()
var cursor uint64
for {
keys, nextCursor, err := r.client.Scan(ctx, cursor, "site:*", 1000).Result()
if err != nil {
return err
}
for _, key := range keys {
host := key[5:] // 去掉 "site:" 前缀
info, err := r.GetSiteInfo(host)
if err != nil || info == nil {
continue
}
if err := fn(host, info); err != nil {
return err
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}
// ---- Priority URL 操作 ----
// GetPriorityURLs 获取所有未访问的 Priority URL
func (r *RedisStoreV2) GetPriorityURLs() ([]PriorityEntry, error) {
ctx := context.Background()
keys, err := r.client.Keys(ctx, "priority:*").Result()
if err != nil {
return nil, err
}
var entries []PriorityEntry
for _, key := range keys {
data, err := r.client.HGetAll(ctx, key).Result()
if err != nil || len(data) == 0 {
continue
}
visited := data["visited"] == "1"
if visited {
continue
}
entry := PriorityEntry{
URL: data["url"],
IsDomain: data["is_domain"] == "1",
AddedAt: parseInt64(data["added_at"]),
Visited: visited,
}
entries = append(entries, entry)
}
return entries, nil
}
// AddPriorityURL 添加 Priority URL
func (r *RedisStoreV2) AddPriorityURL(entry PriorityEntry) error {
ctx := context.Background()
fields := map[string]interface{}{
"url": entry.URL,
"is_domain": boolToStr(entry.IsDomain),
"added_at": entry.AddedAt,
"visited": boolToStr(entry.Visited),
}
key := "priority:" + entry.URL
return r.client.HMSet(ctx, key, fields).Err()
}
// RemovePriorityURL 删除 Priority URL
func (r *RedisStoreV2) RemovePriorityURL(url string) error {
ctx := context.Background()
return r.client.Del(ctx, "priority:"+url).Err()
}
// MarkPriorityURLVisited 标记 Priority URL 为已访问
func (r *RedisStoreV2) MarkPriorityURLVisited(url string) error {
ctx := context.Background()
return r.client.HSet(ctx, "priority:"+url, "visited", "1").Err()
}
// ClearVisitedPriorityURLs 清除已访问的 Priority URL
func (r *RedisStoreV2) ClearVisitedPriorityURLs() error {
ctx := context.Background()
keys, err := r.client.Keys(ctx, "priority:*").Result()
if err != nil {
return err
}
for _, key := range keys {
visited, _ := r.client.HGet(ctx, key, "visited").Result()
if visited == "1" {
r.client.Del(ctx, key)
}
}
return nil
}
// ---- 内存索引操作(用于写入)----
// GetMemIndex 获取内存中的索引条目
func (r *RedisStoreV2) GetMemIndex(keyword string) []IndexEntry {
r.memMu.RLock()
defer r.memMu.RUnlock()
return r.mem[keyword]
}
// SetMemIndex 设置内存中的索引条目
func (r *RedisStoreV2) SetMemIndex(keyword string, entries []IndexEntry) {
r.memMu.Lock()
r.mem[keyword] = entries
r.memMu.Unlock()
}
// GetAllMemIndexes 获取所有内存索引
func (r *RedisStoreV2) GetAllMemIndexes() map[string][]IndexEntry {
r.memMu.RLock()
defer r.memMu.RUnlock()
result := make(map[string][]IndexEntry, len(r.mem))
for k, v := range r.mem {
result[k] = v
}
return result
}
// GetRowCount 获取未刷盘的索引条目数
func (r *RedisStoreV2) GetRowCount() int64 {
return atomic.LoadInt64(&r.rowCount)
}
// AddRowCount 增加未刷盘的索引条目计数
func (r *RedisStoreV2) AddRowCount(delta int64) {
atomic.AddInt64(&r.rowCount, delta)
}
// SetRowCount 设置未刷盘的索引条目计数
func (r *RedisStoreV2) SetRowCount(v int64) {
atomic.StoreInt64(&r.rowCount, v)
}
// FlushMemToRedis 将内存索引刷到 Redis
func (r *RedisStoreV2) FlushMemToRedis() error {
r.memMu.Lock()
snapshot := r.mem
r.mem = make(map[string][]IndexEntry)
atomic.StoreInt64(&r.rowCount, 0)
r.memMu.Unlock()
if len(snapshot) == 0 {
return nil
}
// 合并内存和 Redis 数据,然后写回
for keyword, memEntries := range snapshot {
// 从 Redis 读取已有数据
diskEntries, _ := r.GetIndex(keyword)
// 合并
merged := mergeEntries(memEntries, diskEntries)
// 写回 Redis
if err := r.setIndexEntries(keyword, merged); err != nil {
return err
}
}
return nil
}
func (r *RedisStoreV2) setIndexEntries(keyword string, entries []IndexEntry) error {
ctx := context.Background()
// 删除旧的
r.client.Del(ctx, "idx:"+keyword)
if len(entries) > 0 {
zSlice := make([]goredis.Z, len(entries))
for i, e := range entries {
zSlice[i] = goredis.Z{
Score: float64(e.Weight),
Member: e.URL,
}
}
return r.client.ZAdd(ctx, "idx:"+keyword, zSlice...).Err()
}
return nil
}
func mergeEntries(newEntries, existingEntries []IndexEntry) []IndexEntry {
seen := make(map[string]bool)
var result []IndexEntry
// 先添加新条目
for _, e := range newEntries {
if !seen[e.URL] {
result = append(result, e)
seen[e.URL] = true
}
}
// 添加已有条目中不在新条目里的
for _, e := range existingEntries {
if !seen[e.URL] {
result = append(result, e)
seen[e.URL] = true
}
}
return result
}
// SiteCacheRefresh 从 Redis 刷新站点缓存
func (r *RedisStoreV2) SiteCacheRefresh() error {
ctx := context.Background()
var cursor uint64
r.siteCacheMu.Lock()
defer r.siteCacheMu.Unlock()
for {
keys, nextCursor, err := r.client.Scan(ctx, cursor, "site:*", 1000).Result()
if err != nil {
return err
}
for _, key := range keys {
host := key[5:] // 去掉 "site:" 前缀
data, err := r.client.HGetAll(ctx, key).Result()
if err != nil || len(data) == 0 {
continue
}
info := &SiteInfo{
VisitCount: int(parseInt64(data["visit_count"])),
LastVisitTime: parseInt64(data["last_visit_time"]),
}
if v, ok := data["success_rate"]; ok {
f := parseFloat64(v)
info.SuccessRate = &f
}
if v, ok := data["https_available"]; ok {
b := parseInt64(v) == 1
info.HTTPSAvailable = &b
}
r.siteCache[host] = info
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return nil
}
// ---- 辅助函数 ----
func urlHash(url string) string {
return fmt.Sprintf("%x", md5.Sum([]byte(url)))
}
func parseInt64(s string) int64 {
var v int64
fmt.Sscanf(s, "%d", &v)
return v
}
func parseFloat64(s string) float64 {
var v float64
fmt.Sscanf(s, "%f", &v)
return v
}
func boolToStr(b bool) string {
if b {
return "1"
}
return "0"
}