Files
aichat/agents/sql/sql_query_test.go
T
2026-06-11 18:04:47 +08:00

86 lines
2.8 KiB
Go

package sqlquery
import (
"context"
"strings"
"testing"
)
func TestDefaultActivationPromptMentionsCalendarEvents(t *testing.T) {
if !strings.Contains(defaultActivationPrompt, "tab_calendar_events") {
t.Fatal("default activation prompt should mention tab_calendar_events for calendar queries")
}
}
func TestValidateReadOnlySQLAllowsSelectAndWith(t *testing.T) {
queries := []string{
"SELECT * FROM events LIMIT 10",
"select id, created_at from events where content = 'delete keyword in text' limit 5;",
"WITH recent AS (SELECT * FROM events LIMIT 10) SELECT * FROM recent",
}
for _, query := range queries {
if err := ValidateReadOnlySQL(query); err != nil {
t.Fatalf("ValidateReadOnlySQL(%q) returned error: %v", query, err)
}
}
}
func TestValidateReadOnlySQLRejectsUnsafeStatements(t *testing.T) {
queries := []string{
"",
"DELETE FROM events",
"UPDATE events SET content='x'",
"DROP TABLE events",
"SELECT * FROM events; DELETE FROM events",
"SELECT * INTO OUTFILE '/tmp/x' FROM events",
"SELECT SLEEP(10)",
"ATTACH DATABASE 'x' AS y",
"VACUUM",
"SELECT * FROM events -- comment",
}
for _, query := range queries {
if err := ValidateReadOnlySQL(query); err == nil {
t.Fatalf("ValidateReadOnlySQL(%q) returned nil, want error", query)
}
}
}
func TestToolDefinition(t *testing.T) {
state := &State{}
definition := state.ToolDefinition("custom sql")
if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom sql" {
t.Fatalf("unexpected definition: %#v", definition)
}
}
func TestBuildSQLGenerationPromptIncludesSafetyRules(t *testing.T) {
prompt := BuildSQLGenerationPrompt("本月有什么日程安排", "schema: tab_calendar_events(start_time)")
for _, want := range []string{"只读 SQL", "SELECT", "WITH", "半开区间", "tab_calendar_events", "max_rows", "本月有什么日程安排"} {
if !strings.Contains(prompt, want) {
t.Fatalf("prompt missing %q:\n%s", want, prompt)
}
}
}
func TestGenerateSQLParsesJSONAndRejectsMalformed(t *testing.T) {
generated, err := GenerateSQL(context.Background(), func(ctx context.Context, prompt string, maxTokens int) (string, error) {
if !strings.Contains(prompt, "schema") || maxTokens != 1024 {
t.Fatalf("unexpected prompt/maxTokens: %s / %d", prompt, maxTokens)
}
return `{"database":"default","sql":"SELECT * FROM events LIMIT 1","reason":"ok"}`, nil
}, "查事件", "schema")
if err != nil {
t.Fatal(err)
}
if generated.Database != "default" || !strings.Contains(generated.SQL, "SELECT") || generated.Reason != "ok" {
t.Fatalf("unexpected generated SQL: %#v", generated)
}
_, err = GenerateSQL(context.Background(), func(context.Context, string, int) (string, error) {
return "not json", nil
}, "查事件", "schema")
if err == nil {
t.Fatal("expected malformed JSON error")
}
}