更新工具链

This commit is contained in:
2026-06-11 18:04:47 +08:00
parent 440f83f6a7
commit d1324dc2f2
9 changed files with 718 additions and 570 deletions
+40
View File
@@ -1,6 +1,7 @@
package sqlquery
import (
"context"
"strings"
"testing"
)
@@ -43,3 +44,42 @@ func TestValidateReadOnlySQLRejectsUnsafeStatements(t *testing.T) {
}
}
}
func TestToolDefinition(t *testing.T) {
state := &State{}
definition := state.ToolDefinition("custom sql")
if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom sql" {
t.Fatalf("unexpected definition: %#v", definition)
}
}
func TestBuildSQLGenerationPromptIncludesSafetyRules(t *testing.T) {
prompt := BuildSQLGenerationPrompt("本月有什么日程安排", "schema: tab_calendar_events(start_time)")
for _, want := range []string{"只读 SQL", "SELECT", "WITH", "半开区间", "tab_calendar_events", "max_rows", "本月有什么日程安排"} {
if !strings.Contains(prompt, want) {
t.Fatalf("prompt missing %q:\n%s", want, prompt)
}
}
}
func TestGenerateSQLParsesJSONAndRejectsMalformed(t *testing.T) {
generated, err := GenerateSQL(context.Background(), func(ctx context.Context, prompt string, maxTokens int) (string, error) {
if !strings.Contains(prompt, "schema") || maxTokens != 1024 {
t.Fatalf("unexpected prompt/maxTokens: %s / %d", prompt, maxTokens)
}
return `{"database":"default","sql":"SELECT * FROM events LIMIT 1","reason":"ok"}`, nil
}, "查事件", "schema")
if err != nil {
t.Fatal(err)
}
if generated.Database != "default" || !strings.Contains(generated.SQL, "SELECT") || generated.Reason != "ok" {
t.Fatalf("unexpected generated SQL: %#v", generated)
}
_, err = GenerateSQL(context.Background(), func(context.Context, string, int) (string, error) {
return "not json", nil
}, "查事件", "schema")
if err == nil {
t.Fatal("expected malformed JSON error")
}
}