86 lines
2.8 KiB
Go
86 lines
2.8 KiB
Go
package sqlquery
|
|
|
|
import (
|
|
"context"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestDefaultActivationPromptMentionsCalendarEvents(t *testing.T) {
|
|
if !strings.Contains(defaultActivationPrompt, "tab_calendar_events") {
|
|
t.Fatal("default activation prompt should mention tab_calendar_events for calendar queries")
|
|
}
|
|
}
|
|
|
|
func TestValidateReadOnlySQLAllowsSelectAndWith(t *testing.T) {
|
|
queries := []string{
|
|
"SELECT * FROM events LIMIT 10",
|
|
"select id, created_at from events where content = 'delete keyword in text' limit 5;",
|
|
"WITH recent AS (SELECT * FROM events LIMIT 10) SELECT * FROM recent",
|
|
}
|
|
for _, query := range queries {
|
|
if err := ValidateReadOnlySQL(query); err != nil {
|
|
t.Fatalf("ValidateReadOnlySQL(%q) returned error: %v", query, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestValidateReadOnlySQLRejectsUnsafeStatements(t *testing.T) {
|
|
queries := []string{
|
|
"",
|
|
"DELETE FROM events",
|
|
"UPDATE events SET content='x'",
|
|
"DROP TABLE events",
|
|
"SELECT * FROM events; DELETE FROM events",
|
|
"SELECT * INTO OUTFILE '/tmp/x' FROM events",
|
|
"SELECT SLEEP(10)",
|
|
"ATTACH DATABASE 'x' AS y",
|
|
"VACUUM",
|
|
"SELECT * FROM events -- comment",
|
|
}
|
|
for _, query := range queries {
|
|
if err := ValidateReadOnlySQL(query); err == nil {
|
|
t.Fatalf("ValidateReadOnlySQL(%q) returned nil, want error", query)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestToolDefinition(t *testing.T) {
|
|
state := &State{}
|
|
definition := state.ToolDefinition("custom sql")
|
|
if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom sql" {
|
|
t.Fatalf("unexpected definition: %#v", definition)
|
|
}
|
|
}
|
|
|
|
func TestBuildSQLGenerationPromptIncludesSafetyRules(t *testing.T) {
|
|
prompt := BuildSQLGenerationPrompt("本月有什么日程安排", "schema: tab_calendar_events(start_time)")
|
|
for _, want := range []string{"只读 SQL", "SELECT", "WITH", "半开区间", "tab_calendar_events", "max_rows", "本月有什么日程安排"} {
|
|
if !strings.Contains(prompt, want) {
|
|
t.Fatalf("prompt missing %q:\n%s", want, prompt)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGenerateSQLParsesJSONAndRejectsMalformed(t *testing.T) {
|
|
generated, err := GenerateSQL(context.Background(), func(ctx context.Context, prompt string, maxTokens int) (string, error) {
|
|
if !strings.Contains(prompt, "schema") || maxTokens != 1024 {
|
|
t.Fatalf("unexpected prompt/maxTokens: %s / %d", prompt, maxTokens)
|
|
}
|
|
return `{"database":"default","sql":"SELECT * FROM events LIMIT 1","reason":"ok"}`, nil
|
|
}, "查事件", "schema")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if generated.Database != "default" || !strings.Contains(generated.SQL, "SELECT") || generated.Reason != "ok" {
|
|
t.Fatalf("unexpected generated SQL: %#v", generated)
|
|
}
|
|
|
|
_, err = GenerateSQL(context.Background(), func(context.Context, string, int) (string, error) {
|
|
return "not json", nil
|
|
}, "查事件", "schema")
|
|
if err == nil {
|
|
t.Fatal("expected malformed JSON error")
|
|
}
|
|
}
|