up
This commit is contained in:
@@ -19,6 +19,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
sqlquery "aichat/agents/SQL_query"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
ark "github.com/volcengine/volcengine-go-sdk/service/arkruntime"
|
||||
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||
@@ -624,6 +626,7 @@ var (
|
||||
cfg *Config
|
||||
aiState *OpenAIState
|
||||
searchState *SearchState
|
||||
sqlState *sqlquery.State
|
||||
store *ConvStore
|
||||
)
|
||||
|
||||
@@ -737,8 +740,16 @@ func chatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
chatMessages := req.Messages
|
||||
if sqlState != nil && sqlState.Enabled() {
|
||||
withSQL, err := enrichMessagesWithSQL(c.Request.Context(), profile, chatMessages)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "SQL 查询插件调用失败:", err)
|
||||
} else {
|
||||
chatMessages = withSQL
|
||||
}
|
||||
}
|
||||
if req.WebSearch {
|
||||
withSearch, err := enrichMessagesWithSearch(c.Request.Context(), req.Messages)
|
||||
withSearch, err := enrichMessagesWithSearch(c.Request.Context(), chatMessages)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -882,6 +893,146 @@ func latestUserQuery(messages []ChatMessage) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type sqlActivationDecision struct {
|
||||
Activate bool `json:"activate"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type sqlGenerationResult struct {
|
||||
Database string `json:"database"`
|
||||
SQL string `json:"sql"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
func enrichMessagesWithSQL(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage) ([]ChatMessage, error) {
|
||||
query := latestUserQuery(messages)
|
||||
if query == "" {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
activate, reason, err := classifySQLActivation(ctx, profile, messages)
|
||||
if err != nil {
|
||||
return messages, err
|
||||
}
|
||||
if !activate {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
schemaContext, err := sqlState.SchemaContext(ctx)
|
||||
if err != nil {
|
||||
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
||||
}
|
||||
|
||||
generated, err := generateSQLForUserQuery(ctx, profile, query, schemaContext)
|
||||
if err != nil {
|
||||
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
||||
}
|
||||
generated.Database = strings.TrimSpace(generated.Database)
|
||||
generated.SQL = strings.TrimSpace(generated.SQL)
|
||||
if generated.SQL == "" {
|
||||
err := fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason)
|
||||
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
||||
}
|
||||
|
||||
result, err := sqlState.ExecuteReadOnly(ctx, generated.Database, generated.SQL)
|
||||
if err != nil {
|
||||
return prependSQLContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
||||
}
|
||||
contextText := sqlquery.BuildResultContext(query, generated.SQL, result)
|
||||
if strings.TrimSpace(reason) != "" {
|
||||
contextText += "\n激活原因:" + reason
|
||||
}
|
||||
return prependSQLContext(messages, contextText), nil
|
||||
}
|
||||
|
||||
func prependSQLContext(messages []ChatMessage, content string) []ChatMessage {
|
||||
withSQL := make([]ChatMessage, 0, len(messages)+1)
|
||||
withSQL = append(withSQL, ChatMessage{Role: "system", Content: content, Hidden: true})
|
||||
withSQL = append(withSQL, messages...)
|
||||
return withSQL
|
||||
}
|
||||
|
||||
func classifySQLActivation(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage) (bool, string, error) {
|
||||
query := latestUserQuery(messages)
|
||||
prompt := fmt.Sprintf("%s\n\n最新用户问题:%s", sqlState.ActivationPrompt(), query)
|
||||
text, err := completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, 512)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
var decision sqlActivationDecision
|
||||
if err := json.Unmarshal([]byte(extractJSONObject(text)), &decision); err != nil {
|
||||
return false, "", fmt.Errorf("解析 SQL 查询激活结果失败: %w", err)
|
||||
}
|
||||
return decision.Activate, decision.Reason, nil
|
||||
}
|
||||
|
||||
func generateSQLForUserQuery(ctx context.Context, profile *OpenAIProfile, userQuery string, schemaContext string) (*sqlGenerationResult, error) {
|
||||
prompt := fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题和数据库 schema 生成一条只读 SQL。
|
||||
要求:
|
||||
- 只能返回 JSON,不要使用 Markdown。
|
||||
- JSON 格式:{"database":"数据库名称","sql":"SELECT ... LIMIT N","reason":"生成原因"}
|
||||
- 只能生成 SELECT 或 WITH 查询,禁止 INSERT/UPDATE/DELETE/DROP/ALTER/CREATE 等任何修改语句。
|
||||
- 必须只使用 schema 中出现的数据库、表和字段。
|
||||
- 必须添加 LIMIT,且 LIMIT 不超过插件配置的 max_rows。
|
||||
- 如果无法根据 schema 回答,返回 {"database":"","sql":"","reason":"无法根据已知表结构生成查询"}。
|
||||
|
||||
%s
|
||||
|
||||
用户问题:%s`, schemaContext, userQuery)
|
||||
text, err := completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, 1024)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var generated sqlGenerationResult
|
||||
if err := json.Unmarshal([]byte(extractJSONObject(text)), &generated); err != nil {
|
||||
return nil, fmt.Errorf("解析 SQL 生成结果失败: %w", err)
|
||||
}
|
||||
return &generated, nil
|
||||
}
|
||||
|
||||
func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) {
|
||||
messages, err := buildArkMessages(chatMessages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
timeout := time.Duration(profile.Config.Timeout) * time.Second
|
||||
completionCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
stream, err := profile.Client.CreateChatCompletionStream(completionCtx, model.CreateChatCompletionRequest{
|
||||
Model: profile.Config.Model,
|
||||
Messages: messages,
|
||||
MaxTokens: intPtr(maxTokens),
|
||||
}.WithStream(true))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
var b strings.Builder
|
||||
for {
|
||||
resp, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
return b.String(), nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(resp.Choices) > 0 {
|
||||
b.WriteString(resp.Choices[0].Delta.Content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func extractJSONObject(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
start := strings.Index(text, "{")
|
||||
end := strings.LastIndex(text, "}")
|
||||
if start >= 0 && end > start {
|
||||
return text[start : end+1]
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
func webSearch(ctx context.Context, config SearchConfig, query string) ([]searchResult, error) {
|
||||
switch strings.ToLower(config.Provider) {
|
||||
case "duckduckgo", "ddg":
|
||||
@@ -1363,6 +1514,17 @@ func main() {
|
||||
fmt.Fprintln(os.Stderr, "搜索配置初始化失败:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
sqlConfig, err := sqlquery.LoadConfig("agents/SQL_query/config.yaml")
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "SQL 查询插件配置加载失败:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
sqlState, err = sqlquery.NewState(sqlConfig)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "SQL 查询插件初始化失败:", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer sqlState.Close()
|
||||
store = NewConvStore("conversations")
|
||||
|
||||
// Gin 路由
|
||||
|
||||
Reference in New Issue
Block a user