124 lines
4.5 KiB
Go
124 lines
4.5 KiB
Go
package search
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestNormalizeConfigDefaults(t *testing.T) {
|
|
cfg := &Config{Enabled: true, Profiles: ProfileConfigs{{Provider: "duckduckgo"}}}
|
|
if err := normalizeConfig(cfg); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if strings.TrimSpace(cfg.ActivationPrompt) == "" {
|
|
t.Fatal("activation prompt should be defaulted")
|
|
}
|
|
profile := cfg.Profiles[0]
|
|
if profile.Name != "duckduckgo" || !profile.Active || profile.BaseURL != defaultBaseURL || profile.Count != defaultCount || profile.Timeout != defaultTimeout {
|
|
t.Fatalf("unexpected profile: %#v", profile)
|
|
}
|
|
}
|
|
|
|
func TestNormalizeConfigDuplicateProfile(t *testing.T) {
|
|
cfg := &Config{Enabled: true, Profiles: ProfileConfigs{{Name: "duck", Provider: "duckduckgo"}, {Name: " duck ", Provider: "duckduckgo"}}}
|
|
if err := normalizeConfig(cfg); err == nil {
|
|
t.Fatal("expected duplicate profile error")
|
|
}
|
|
}
|
|
|
|
func TestNormalizeConfigUnsupportedProvider(t *testing.T) {
|
|
cfg := &Config{Enabled: true, Profiles: ProfileConfigs{{Provider: "unknown"}}}
|
|
if err := normalizeConfig(cfg); err == nil {
|
|
t.Fatal("expected unsupported provider error")
|
|
}
|
|
}
|
|
|
|
func TestApplyEnvBraveAPIKey(t *testing.T) {
|
|
t.Setenv("BRAVE_SEARCH_API_KEY", "from-env")
|
|
cfg := &Config{Profiles: ProfileConfigs{{Name: "brave", Provider: "brave", APIKey: "from-config"}}}
|
|
applyEnv(cfg)
|
|
if cfg.Profiles[0].APIKey != "from-env" {
|
|
t.Fatalf("api key = %q", cfg.Profiles[0].APIKey)
|
|
}
|
|
}
|
|
|
|
func TestListProfilesRedactsAPIKey(t *testing.T) {
|
|
state, err := NewState(&Config{Enabled: true, Profiles: ProfileConfigs{{Name: "brave", Active: true, Enabled: true, Provider: "brave", APIKey: "secret", BaseURL: "https://example.com", Count: 1, Timeout: 1}}})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
list := state.ListProfiles()
|
|
if len(list.Profiles) != 1 || list.Profiles[0].APIKey != "" || !list.Profiles[0].Active {
|
|
t.Fatalf("unexpected list: %#v", list)
|
|
}
|
|
}
|
|
|
|
func TestBuildResultContext(t *testing.T) {
|
|
text := BuildResultContext(ProfileConfig{Name: "duckduckgo", Provider: "duckduckgo"}, "最新消息", []Result{{Title: "标题", URL: "https://example.com", Description: "摘要"}}, "需要最新信息")
|
|
for _, want := range []string{"联网搜索", "最新消息", "标题", "https://example.com", "需要最新信息", "不要编造"} {
|
|
if !strings.Contains(text, want) {
|
|
t.Fatalf("context missing %q:\n%s", want, text)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestBuildFallbackContext(t *testing.T) {
|
|
text := BuildFallbackContext(ProfileConfig{Name: "duckduckgo", Provider: "duckduckgo"}, "历史上的今天都发生了什么?", "需要查询当天历史事件", errors.New("未搜索到相关网页结果"))
|
|
for _, want := range []string{"没有可用的搜索结果", "历史上的今天", "需要查询当天历史事件", "模型训练数据/内置知识", "不要伪造网页链接"} {
|
|
if !strings.Contains(text, want) {
|
|
t.Fatalf("fallback context missing %q:\n%s", want, text)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDuckDuckGoSearchParsesResults(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Query().Get("q") != "golang" {
|
|
t.Fatalf("query = %s", r.URL.RawQuery)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{
|
|
"Heading":"Go",
|
|
"Abstract":"Go is a language",
|
|
"AbstractURL":"https://go.dev",
|
|
"RelatedTopics":[{"Text":"Gopher - mascot","FirstURL":"https://go.dev/blog/gopher"}]
|
|
}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
results, err := duckDuckGoSearch(context.Background(), ProfileConfig{Provider: "duckduckgo", BaseURL: server.URL, Count: 2, Timeout: 1}, "golang")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(results) != 2 || results[0].Title != "Go" || results[1].Title != "Gopher" {
|
|
t.Fatalf("unexpected results: %#v", results)
|
|
}
|
|
}
|
|
|
|
func TestBraveMissingAPIKey(t *testing.T) {
|
|
_, err := braveWebSearch(context.Background(), ProfileConfig{Provider: "brave", BaseURL: "https://example.com", Timeout: 1, Count: 1}, "query")
|
|
if err == nil || !strings.Contains(err.Error(), "API Key") {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestLoadConfigWritesLegacyProfiles(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := dir + string(os.PathSeparator) + "config.yaml"
|
|
cfg, err := LoadConfig(path, []ProfileConfig{{Name: "legacy", Active: true, Enabled: true, Provider: "duckduckgo", Count: 3, Timeout: 2}})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(cfg.Profiles) != 1 || cfg.Profiles[0].Name != "legacy" {
|
|
t.Fatalf("unexpected cfg: %#v", cfg)
|
|
}
|
|
if _, err := os.Stat(path); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|