From 3715b03fabce9deb5660c5b537f0638587b9e7f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=B4=E6=96=87=E5=B3=B0?= Date: Thu, 9 Apr 2026 17:03:10 +0800 Subject: [PATCH] =?UTF-8?q?=E9=98=B2=E5=BE=A1=E4=B8=80=E4=BA=9B=E7=88=AC?= =?UTF-8?q?=E8=99=AB=E9=99=B7=E9=98=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- crawler/crawler.go | 50 +++++++++++++++++++++---- crawler/fetcher.go | 75 +++++++++++++++++++++++++++++++++++++ parser/parser.go | 93 ++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 206 insertions(+), 12 deletions(-) diff --git a/crawler/crawler.go b/crawler/crawler.go index 85e70d6..b4bece8 100644 --- a/crawler/crawler.go +++ b/crawler/crawler.go @@ -50,6 +50,10 @@ type Crawler struct { prosperMap map[string]float64 // 域名 → 反向链接繁荣值(来自 info 模块,越大越"有价值") stats Stats // 原子计数器 + // visited 记录已访问的 URL 集合(跨 epoch 持久,启动时从 DB 预热) + visited map[string]bool + visitedMu sync.RWMutex // 保护 visited 的并发读写 + // 熔断器(全用 atomic,无 mutex,无慢 I/O 时持有锁的风险) circuitState int32 // circuitClosed | circuitOpen | circuitHalfOpen circuitFailures int32 // 连续失败计数(atomic) @@ -76,18 +80,48 @@ func GlobalActiveWorkers() int64 { // New 创建一个 Crawler 实例。 // prosperMap 由 info 模块加载,传入域名繁荣值用于调度优先级计算。 func New(db *storage.DB, a *analyzer.Analyzer, prosperMap map[string]float64) *Crawler { - return &Crawler{ + c := &Crawler{ fetcher: NewFetcher(config.SpiderName(), time.Duration(config.CrawlerCooldown())*time.Second), db: db, analyzer: a, prosperMap: prosperMap, + visited: make(map[string]bool), } + // 启动时从 gate bucket 预热已爬取的 URL 集合(程序重启后不会重复爬取) + c.warmVisited() + return c +} + +// warmVisited 从 DB 的 gate bucket 加载所有已缓存的 URL 到 visited set。 +func (c *Crawler) warmVisited() { + count := 0 + _ = c.db.ForEachSnippet(func(u string, entry *storage.SnippetEntry) error { + c.visited[u] = true + count++ + return nil + }) + log.Printf("[crawler] visited set warmed: %d URLs loaded", count) +} + +// markVisited 将 URL 标记为已访问(线程安全)。 +func (c *Crawler) markVisited(url string) { + c.visitedMu.Lock() + c.visited[url] = true + c.visitedMu.Unlock() +} + +// isVisited 检查 URL 是否已访问(线程安全)。 +func (c *Crawler) isVisited(url string) bool { + c.visitedMu.RLock() + v := c.visited[url] + c.visitedMu.RUnlock() + return v } // fetchAndApplyPriorityURLs 从数据库读取用户插入的 priority URLs, // 将未访问的插入队列前端(prepend),已爬取的条目从存储中清除。 // 返回本次插入队列的 URL 数量。 -func (c *Crawler) fetchAndApplyPriorityURLs(visited map[string]bool, queue *[]string) int { +func (c *Crawler) fetchAndApplyPriorityURLs(queue *[]string) int { entries, err := c.db.GetPriorityURLs() if err != nil || len(entries) == 0 { return 0 @@ -95,7 +129,7 @@ func (c *Crawler) fetchAndApplyPriorityURLs(visited map[string]bool, queue *[]st added := 0 for _, e := range entries { - if visited[e.URL] { + if c.isVisited(e.URL) { _ = c.db.RemovePriorityURL(e.URL) continue } @@ -117,15 +151,15 @@ type URLWeight struct { // 各轮之间是串行的,每轮内并发抓取,按调度算法选择下一轮 URL。 // 每轮开始前会检查 priority 队列,优先爬取用户插入的 URL。 func (c *Crawler) Run(entryURL string, maxEpoch int) { - visited := make(map[string]bool) // 已访问 URL 集合(防止重复抓取) - queue := []string{entryURL} // 当前轮次的待抓取队列 + c.markVisited(entryURL) + queue := []string{entryURL} for ep := 0; ep < maxEpoch; ep++ { // 每轮 epoch 从 config 读取最新 workers 值,支持运行时动态调整 workers := config.CrawlerWorkers() // 每轮开始前:拉取 priority URLs,插入队列前端 - priorityAdded := c.fetchAndApplyPriorityURLs(visited, &queue) + priorityAdded := c.fetchAndApplyPriorityURLs(&queue) if priorityAdded > 0 { log.Printf("[crawler] epoch %d/%d queue=%d (+%d priority) workers=%d", ep+1, maxEpoch, len(queue), priorityAdded, workers) } else { @@ -133,7 +167,7 @@ func (c *Crawler) Run(entryURL string, maxEpoch int) { } // 将本轮所有 URL 标记为已访问(防止下一轮重复入队) for _, u := range queue { - visited[u] = true + c.markVisited(u) } // 并发抓取本轮所有 URL @@ -164,7 +198,7 @@ func (c *Crawler) Run(entryURL string, maxEpoch int) { w := 1.0 / float64(n) mu.Lock() for _, h := range hrefs { - if !visited[h] { + if !c.isVisited(h) { newLinks = append(newLinks, URLWeight{URL: h, Weight: w}) } } diff --git a/crawler/fetcher.go b/crawler/fetcher.go index 26ad92a..d72a204 100644 --- a/crawler/fetcher.go +++ b/crawler/fetcher.go @@ -6,6 +6,7 @@ package crawler import ( "fmt" // 字符串格式化(构建 robots.txt URL、错误信息) "io" // IO 接口(读取响应体) + "net" // IP 地址解析(SSRF 防护) "net/http" // HTTP 客户端 "net/url" // URL 解析 "strings" // 字符串操作 @@ -120,6 +121,10 @@ func (f *Fetcher) fetchWithHistory(rawURL string, polite bool, timeout time.Dura if len(via) >= 10 { return fmt.Errorf("too many redirects") } + // SSRF 防护:拒绝重定向到内网 IP 或非标端口 + if err := isSafeRedirect(req.URL); err != nil { + return err + } // 记录永久重定向 if req.Response != nil && (req.Response.StatusCode == 301 || req.Response.StatusCode == 308) { from := via[len(via)-1].URL.String() @@ -130,6 +135,11 @@ func (f *Fetcher) fetchWithHistory(rawURL string, polite bool, timeout time.Dura }, } + // 对初始 URL 也做 SSRF 检查 + if err := isSafeRedirect(parsed); err != nil { + return nil, &ErrCrawl{Msg: err.Error()} + } + // 构造 GET 请求 req, _ := http.NewRequest("GET", rawURL, nil) req.Header.Set("User-Agent", f.userAgent) @@ -344,3 +354,68 @@ func decodeBody(r io.Reader, contentType string, sizeLimit int) (string, error) } return string(data), nil } + +// isPrivateIP 检查 IP 是否为私有/回环/链路本地地址。 +func isPrivateIP(ip net.IP) bool { + privateRanges := []struct { + network *net.IPNet + }{ + // 10.0.0.0/8 — RFC 1918 私有网络 + {mustParseCIDR("10.0.0.0/8")}, + // 172.16.0.0/12 — RFC 1918 私有网络 + {mustParseCIDR("172.16.0.0/12")}, + // 192.168.0.0/16 — RFC 1918 私有网络 + {mustParseCIDR("192.168.0.0/16")}, + // 127.0.0.0/8 — 回环地址 + {mustParseCIDR("127.0.0.0/8")}, + // 169.254.0.0/16 — 链路本地(AWS/GCP 元数据服务) + {mustParseCIDR("169.254.0.0/16")}, + // ::1/128 — IPv6 回环 + {mustParseCIDR("::1/128")}, + // fe80::/10 — IPv6 链路本地 + {mustParseCIDR("fe80::/10")}, + // fc00::/7 — IPv6 唯一本地地址 + {mustParseCIDR("fc00::/7")}, + } + for _, r := range privateRanges { + if r.network.Contains(ip) { + return true + } + } + return false +} + +// mustParseCIDR 解析 CIDR,失败时 panic(仅用于编译期常量)。 +func mustParseCIDR(s string) *net.IPNet { + _, network, err := net.ParseCIDR(s) + if err != nil { + panic("invalid CIDR: " + s) + } + return network +} + +// isSafeRedirect 检查重定向目标是否安全(非内网 IP、非非标端口)。 +// 用于防止 SSRF 攻击:恶意服务器将爬虫重定向到内网服务。 +func isSafeRedirect(u *url.URL) error { + host := u.Hostname() + port := u.Port() + // 解析 IP 地址 + ip := net.ParseIP(host) + if ip == nil { + // 域名(非 IP),允许(DNS 解析由系统处理) + // 但非标端口仍需检查 + if port != "" && port != "80" && port != "443" { + return fmt.Errorf("blocked: non-standard port %s", port) + } + return nil + } + // IP 直连:检查是否为私有地址 + if isPrivateIP(ip) { + return fmt.Errorf("blocked: private IP %s", ip) + } + // 非标端口检查 + if port != "" && port != "80" && port != "443" { + return fmt.Errorf("blocked: non-standard port %s", port) + } + return nil +} diff --git a/parser/parser.go b/parser/parser.go index 27e4aa3..743102b 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -3,8 +3,10 @@ package parser import ( + "net/url" // URL 解析(规范化) "path" // 路径处理(提取目录、规范化相对路径) "regexp" // 正则表达式(空白字符替换) + "sort" // query 参数排序(URL 规范化去重) "strings" // 字符串操作 "golang.org/x/net/html" // 标准 HTML 解析器(将 HTML 解析为 DOM 树) @@ -61,14 +63,15 @@ func ParseHTML(body, baseURL string) (title, description, text string, hrefs []s // 提取 链接 if tag == "a" { href := attrVal(n, "href") - if href != "" { - // 去除 URL 中的锚点(#fragment) + if href != "" && isSafeURL(href) { href = strings.SplitN(href, "#", 2)[0] if href != "" { - // 解析为绝对 URL(处理相对路径、协议相对路径等) href = resolveURL(base, basePath, href) if href != "" { - hrefs = append(hrefs, href) + href = NormalizeURL(href) + if href != "" { + hrefs = append(hrefs, href) + } } } } @@ -177,3 +180,85 @@ func resolveURL(base, basePath, href string) string { dir := path.Dir(basePath) // 提取当前页面的目录部分 return base + path.Clean(dir+"/"+href) // path.Clean 规范化,去除多余的 ../ 等 } + +// isSafeURL 检查 href 是否为安全的 HTTP(S) 链接。 +// 过滤 javascript:、data:、mailto:、tel:、ftp: 等伪协议, +// 以及空 href 和仅包含锚点的 href。 +func isSafeURL(href string) bool { + if href == "" || href == "#" { + return false + } + // 检查是否包含冒号(伪协议特征) + colon := strings.Index(href, ":") + if colon < 0 { + return true // 无冒号:相对路径、绝对路径、协议相对路径,都是安全的 + } + scheme := strings.ToLower(href[:colon]) + switch scheme { + case "http", "https": + return true + default: + // javascript:, data:, mailto:, tel:, ftp:, vbscript: 等全部拦截 + return false + } +} + +// NormalizeURL 将 URL 规范化为用于去重的标准形式。 +// 1. 统一为小写 scheme 和 host +// 2. path.Clean 规范化路径(去除 ./、../) +// 3. 按 key 字典序排列 query 参数(消除 ?a=1&b=2 vs ?b=2&a=1 的差异) +// 4. 去除 fragment +// 5. 去除末尾斜杠(根路径 / 除外) +// 返回空字符串表示 URL 无效。 +func NormalizeURL(rawURL string) string { + u, err := url.Parse(rawURL) + if err != nil { + return "" + } + if u.Scheme != "http" && u.Scheme != "https" { + return "" + } + if u.Host == "" { + return "" + } + + // 统一 scheme 和 host 为小写 + u.Scheme = strings.ToLower(u.Scheme) + u.Host = strings.ToLower(u.Host) + + // 规范化路径 + if u.Path == "" { + u.Path = "/" + } + u.Path = path.Clean(u.Path) + + // query 参数按 key 字典序排列 + if u.RawQuery != "" { + u.RawQuery = sortQuery(u.RawQuery) + } + + // 去除 fragment + u.Fragment = "" + + // 去除末尾斜杠(根路径 / 除外) + if len(u.Path) > 1 && strings.HasSuffix(u.Path, "/") { + u.Path = strings.TrimRight(u.Path, "/") + } + + return u.String() +} + +// sortQuery 将 query 字符串的参数按 key 字典序排列,用于 URL 去重。 +func sortQuery(query string) string { + params, err := url.ParseQuery(query) + if err != nil { + return query + } + keys := make([]string, 0, len(params)) + for k := range params { + keys = append(keys, k) + } + sort.Strings(keys) + // url.Values 编码后参数已排序且值已去重 + return params.Encode() +}