package sqlquery import ( "context" "strings" "testing" ) func TestDefaultActivationPromptMentionsCalendarEvents(t *testing.T) { if !strings.Contains(defaultActivationPrompt, "tab_calendar_events") { t.Fatal("default activation prompt should mention tab_calendar_events for calendar queries") } } func TestValidateReadOnlySQLAllowsSelectAndWith(t *testing.T) { queries := []string{ "SELECT * FROM events LIMIT 10", "select id, created_at from events where content = 'delete keyword in text' limit 5;", "WITH recent AS (SELECT * FROM events LIMIT 10) SELECT * FROM recent", } for _, query := range queries { if err := ValidateReadOnlySQL(query); err != nil { t.Fatalf("ValidateReadOnlySQL(%q) returned error: %v", query, err) } } } func TestValidateReadOnlySQLRejectsUnsafeStatements(t *testing.T) { queries := []string{ "", "DELETE FROM events", "UPDATE events SET content='x'", "DROP TABLE events", "SELECT * FROM events; DELETE FROM events", "SELECT * INTO OUTFILE '/tmp/x' FROM events", "SELECT SLEEP(10)", "ATTACH DATABASE 'x' AS y", "VACUUM", "SELECT * FROM events -- comment", } for _, query := range queries { if err := ValidateReadOnlySQL(query); err == nil { t.Fatalf("ValidateReadOnlySQL(%q) returned nil, want error", query) } } } func TestToolDefinition(t *testing.T) { state := &State{} definition := state.ToolDefinition("custom sql") if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom sql" { t.Fatalf("unexpected definition: %#v", definition) } } func TestBuildSQLGenerationPromptIncludesSafetyRules(t *testing.T) { prompt := BuildSQLGenerationPrompt("本月有什么日程安排", "schema: tab_calendar_events(start_time)") for _, want := range []string{"只读 SQL", "SELECT", "WITH", "半开区间", "tab_calendar_events", "max_rows", "本月有什么日程安排"} { if !strings.Contains(prompt, want) { t.Fatalf("prompt missing %q:\n%s", want, prompt) } } } func TestGenerateSQLParsesJSONAndRejectsMalformed(t *testing.T) { generated, err := GenerateSQL(context.Background(), func(ctx context.Context, prompt string, maxTokens int) (string, error) { if !strings.Contains(prompt, "schema") || maxTokens != 1024 { t.Fatalf("unexpected prompt/maxTokens: %s / %d", prompt, maxTokens) } return `{"database":"default","sql":"SELECT * FROM events LIMIT 1","reason":"ok"}`, nil }, "查事件", "schema") if err != nil { t.Fatal(err) } if generated.Database != "default" || !strings.Contains(generated.SQL, "SELECT") || generated.Reason != "ok" { t.Fatalf("unexpected generated SQL: %#v", generated) } _, err = GenerateSQL(context.Background(), func(context.Context, string, int) (string, error) { return "not json", nil }, "查事件", "schema") if err == nil { t.Fatal("expected malformed JSON error") } }