This commit is contained in:
2026-06-09 19:37:26 +08:00
parent 5150a23256
commit 721caccc58
5 changed files with 1025 additions and 6 deletions
+163 -1
View File
@@ -19,6 +19,8 @@ import (
"sync"
"time"
sqlquery "aichat/agents/SQL_query"
"github.com/gin-gonic/gin"
ark "github.com/volcengine/volcengine-go-sdk/service/arkruntime"
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
@@ -624,6 +626,7 @@ var (
cfg *Config
aiState *OpenAIState
searchState *SearchState
sqlState *sqlquery.State
store *ConvStore
)
@@ -737,8 +740,16 @@ func chatHandler(c *gin.Context) {
}
chatMessages := req.Messages
if sqlState != nil && sqlState.Enabled() {
withSQL, err := enrichMessagesWithSQL(c.Request.Context(), profile, chatMessages)
if err != nil {
fmt.Fprintln(os.Stderr, "SQL 查询插件调用失败:", err)
} else {
chatMessages = withSQL
}
}
if req.WebSearch {
withSearch, err := enrichMessagesWithSearch(c.Request.Context(), req.Messages)
withSearch, err := enrichMessagesWithSearch(c.Request.Context(), chatMessages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -882,6 +893,146 @@ func latestUserQuery(messages []ChatMessage) string {
return ""
}
type sqlActivationDecision struct {
Activate bool `json:"activate"`
Reason string `json:"reason"`
}
type sqlGenerationResult struct {
Database string `json:"database"`
SQL string `json:"sql"`
Reason string `json:"reason"`
}
func enrichMessagesWithSQL(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage) ([]ChatMessage, error) {
query := latestUserQuery(messages)
if query == "" {
return messages, nil
}
activate, reason, err := classifySQLActivation(ctx, profile, messages)
if err != nil {
return messages, err
}
if !activate {
return messages, nil
}
schemaContext, err := sqlState.SchemaContext(ctx)
if err != nil {
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
generated, err := generateSQLForUserQuery(ctx, profile, query, schemaContext)
if err != nil {
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
generated.Database = strings.TrimSpace(generated.Database)
generated.SQL = strings.TrimSpace(generated.SQL)
if generated.SQL == "" {
err := fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason)
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
result, err := sqlState.ExecuteReadOnly(ctx, generated.Database, generated.SQL)
if err != nil {
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
contextText := sqlquery.BuildResultContext(query, generated.SQL, result)
if strings.TrimSpace(reason) != "" {
contextText += "\n激活原因:" + reason
}
return prependSQLContext(messages, contextText), nil
}
func prependSQLContext(messages []ChatMessage, content string) []ChatMessage {
withSQL := make([]ChatMessage, 0, len(messages)+1)
withSQL = append(withSQL, ChatMessage{Role: "system", Content: content, Hidden: true})
withSQL = append(withSQL, messages...)
return withSQL
}
func classifySQLActivation(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage) (bool, string, error) {
query := latestUserQuery(messages)
prompt := fmt.Sprintf("%s\n\n最新用户问题:%s", sqlState.ActivationPrompt(), query)
text, err := completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, 512)
if err != nil {
return false, "", err
}
var decision sqlActivationDecision
if err := json.Unmarshal([]byte(extractJSONObject(text)), &decision); err != nil {
return false, "", fmt.Errorf("解析 SQL 查询激活结果失败: %w", err)
}
return decision.Activate, decision.Reason, nil
}
func generateSQLForUserQuery(ctx context.Context, profile *OpenAIProfile, userQuery string, schemaContext string) (*sqlGenerationResult, error) {
prompt := fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题和数据库 schema 生成一条只读 SQL。
要求:
- 只能返回 JSON,不要使用 Markdown。
- JSON 格式:{"database":"数据库名称","sql":"SELECT ... LIMIT N","reason":"生成原因"}
- 只能生成 SELECT 或 WITH 查询,禁止 INSERT/UPDATE/DELETE/DROP/ALTER/CREATE 等任何修改语句。
- 必须只使用 schema 中出现的数据库、表和字段。
- 必须添加 LIMIT,且 LIMIT 不超过插件配置的 max_rows。
- 如果无法根据 schema 回答,返回 {"database":"","sql":"","reason":"无法根据已知表结构生成查询"}。
%s
用户问题:%s`, schemaContext, userQuery)
text, err := completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, 1024)
if err != nil {
return nil, err
}
var generated sqlGenerationResult
if err := json.Unmarshal([]byte(extractJSONObject(text)), &generated); err != nil {
return nil, fmt.Errorf("解析 SQL 生成结果失败: %w", err)
}
return &generated, nil
}
func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) {
messages, err := buildArkMessages(chatMessages)
if err != nil {
return "", err
}
timeout := time.Duration(profile.Config.Timeout) * time.Second
completionCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
stream, err := profile.Client.CreateChatCompletionStream(completionCtx, model.CreateChatCompletionRequest{
Model: profile.Config.Model,
Messages: messages,
MaxTokens: intPtr(maxTokens),
}.WithStream(true))
if err != nil {
return "", err
}
defer stream.Close()
var b strings.Builder
for {
resp, err := stream.Recv()
if errors.Is(err, io.EOF) {
return b.String(), nil
}
if err != nil {
return "", err
}
if len(resp.Choices) > 0 {
b.WriteString(resp.Choices[0].Delta.Content)
}
}
}
func extractJSONObject(text string) string {
text = strings.TrimSpace(text)
start := strings.Index(text, "{")
end := strings.LastIndex(text, "}")
if start >= 0 && end > start {
return text[start : end+1]
}
return text
}
func webSearch(ctx context.Context, config SearchConfig, query string) ([]searchResult, error) {
switch strings.ToLower(config.Provider) {
case "duckduckgo", "ddg":
@@ -1363,6 +1514,17 @@ func main() {
fmt.Fprintln(os.Stderr, "搜索配置初始化失败:", err)
os.Exit(1)
}
sqlConfig, err := sqlquery.LoadConfig("agents/SQL_query/config.yaml")
if err != nil {
fmt.Fprintln(os.Stderr, "SQL 查询插件配置加载失败:", err)
os.Exit(1)
}
sqlState, err = sqlquery.NewState(sqlConfig)
if err != nil {
fmt.Fprintln(os.Stderr, "SQL 查询插件初始化失败:", err)
os.Exit(1)
}
defer sqlState.Close()
store = NewConvStore("conversations")
// Gin 路由