更新工具链

This commit is contained in:
2026-06-11 18:04:47 +08:00
parent 440f83f6a7
commit d1324dc2f2
9 changed files with 718 additions and 570 deletions
+247 -402
View File
@@ -37,16 +37,12 @@ const (
defaultOpenAITimeout = 120
defaultToolRouterTimeout = 30
defaultToolRouterMaxTokens = 512
defaultToolRouterSystemText = `是工具路由器。根据用户最新问题和可用工具列表,判断本轮是否需要调用一个或多个工具
只能返回 JSON,不要使用 Markdown
JSON 格式:{"tools":[{"name":"工具名称","reason":"..."}],"reason":"..."}
工具名称必须来自“可用工具”列表
可以选择多个工具,工具会按配置顺序依次执行;后面的工具可以使用前面工具写入的上下文
如果用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近等相对时间,且还需要调用 search 或 sql,必须同时选择 time,并让 time 排在这些工具之前。
例如“历史上的今天都发生了什么”应选择 time 和 search:先获取今天的绝对日期,再搜索当天历史事件;如果联网无结果,主模型会回退到自身知识库回答并说明来源。
例如“本月有什么日程安排”应选择 time 和 sql:先获取本月绝对日期范围,再查询日程表。
如果无需工具,返回 {"tools":[],"reason":"..."}。
只选择确实必要的工具。`
defaultToolRouterSystemText = `可以按需直接调用可用工具来回答用户问题
如果用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近等相对时间,且后续需要搜索或查询数据库,应先调用 time 获取绝对日期范围
需要实时网页资料、新闻、当前版本、近期事件、网页核验或用户明确要求联网时,调用 search。
需要查询本地业务数据、日程、会议、待办、记录、统计或时间范围内数据时,调用 sql
工具结果优先于模型内置知识;工具失败时必须如实说明,不要编造结果
只调用确实必要的工具。`
)
type OpenAIConfig struct {
@@ -286,6 +282,11 @@ func normalizeOpenAIConfigs(cfg *Config) (bool, error) {
return changed, nil
}
func isLegacyToolRouterPrompt(prompt string) bool {
prompt = strings.TrimSpace(prompt)
return strings.Contains(prompt, "工具路由器") || strings.Contains(prompt, "route_tools") || strings.Contains(prompt, `"tools":[`)
}
func normalizeToolRouterConfig(cfg *Config) (bool, error) {
changed := false
defaults := defaultToolRouterConfig()
@@ -298,11 +299,12 @@ func normalizeToolRouterConfig(cfg *Config) (bool, error) {
cfg.ToolRouter.MaxTokens = defaultToolRouterMaxTokens
changed = true
}
if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) == "" {
systemPrompt := strings.TrimSpace(cfg.ToolRouter.SystemPrompt)
if systemPrompt == "" || isLegacyToolRouterPrompt(systemPrompt) {
cfg.ToolRouter.SystemPrompt = defaultToolRouterSystemText
changed = true
} else if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) != cfg.ToolRouter.SystemPrompt {
cfg.ToolRouter.SystemPrompt = strings.TrimSpace(cfg.ToolRouter.SystemPrompt)
} else if systemPrompt != cfg.ToolRouter.SystemPrompt {
cfg.ToolRouter.SystemPrompt = systemPrompt
changed = true
}
if len(cfg.ToolRouter.Tools) == 0 {
@@ -432,12 +434,12 @@ type openAIListResponse struct {
Profiles []OpenAIConfig `json:"profiles"`
}
type toolTextCompleter func(context.Context, *OpenAIProfile, []ChatMessage, int, time.Duration) (string, error)
type chatCompleter func(context.Context, *OpenAIProfile, model.CreateChatCompletionRequest, time.Duration) (model.ChatCompletionResponse, error)
type ToolRouterState struct {
cfg *ToolRouterConfig
ai *OpenAIState
complete toolTextCompleter
complete chatCompleter
}
func NewToolRouterState(config *ToolRouterConfig, ai *OpenAIState) (*ToolRouterState, error) {
@@ -453,7 +455,7 @@ func NewToolRouterState(config *ToolRouterConfig, ai *OpenAIState) (*ToolRouterS
return nil, fmt.Errorf("tool_router.openai_name 配置无效: %w", err)
}
}
return &ToolRouterState{cfg: config, ai: ai, complete: completeTextWithTimeout}, nil
return &ToolRouterState{cfg: config, ai: ai, complete: completeChatWithTimeout}, nil
}
func NewOpenAIState(configs []OpenAIConfig) (*OpenAIState, error) {
@@ -796,20 +798,17 @@ func chatHandler(c *gin.Context) {
usage := newTokenUsageTracker()
ctx = contextWithTokenUsage(ctx, usage)
chatMessages := req.Messages
withTools, err := enrichMessagesWithRoutedTools(ctx, profile, chatMessages, emit)
// 用 Function Calling 工具循环替代旧的路由+隐藏上下文机制
messages, err := runAgentToolLoop(ctx, profile, req.Messages, emit)
if err != nil {
fmt.Fprintln(os.Stderr, "工具路由调用失败:", err)
} else {
chatMessages = withTools
fmt.Fprintln(os.Stderr, "Agent 工具循环失败:", err)
messages, err = buildArkMessages(req.Messages)
if err != nil {
emitError(err)
return
}
}
// 构建 ark 消息列表
messages, err := buildArkMessages(chatMessages)
if err != nil {
emitError(err)
return
}
promptTokens := estimateChatMessagesTokens(chatMessages)
promptTokens := estimateChatMessagesTokens(req.Messages)
emitTrace("model", "request", "running", "正在调用模型生成回答", nil)
stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
@@ -885,21 +884,16 @@ func chatHandler(c *gin.Context) {
stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats})
}
// 思考过程 reasoning_content 单独事件推送
if resp.Choices[0].Delta.ReasoningContent != nil && *resp.Choices[0].Delta.ReasoningContent != "" {
emit(chatSSEFrame{Type: "reasoning", Text: *resp.Choices[0].Delta.ReasoningContent})
}
}
}
}
// ─── 辅助函数 ─────────────────────────────────────────────
func latestUserQuery(messages []ChatMessage) string {
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
return strings.TrimSpace(messages[i].Content)
}
}
return ""
}
func estimateChatMessagesTokens(messages []ChatMessage) int {
total := 0
for _, msg := range messages {
@@ -951,380 +945,214 @@ func tokensPerSecond(tokens int, start time.Time) float64 {
return float64(tokens) / elapsed
}
type ToolSelection struct {
Name string `json:"name"`
Reason string `json:"reason"`
type agentTool struct {
name string
definition *model.Tool
execute func(context.Context, string) (string, error)
}
type ToolRoutingDecision struct {
Tools []ToolSelection `json:"tools"`
Reason string `json:"reason"`
}
func (t agentTool) Name() string { return t.name }
type ChatTool interface {
Name() string
Description() string
Enabled() bool
Enrich(context.Context, *OpenAIProfile, []ChatMessage, string, func(chatSSEFrame)) ([]ChatMessage, error)
}
const maxAgentToolIterations = 6
type TimeChatTool struct{}
func (t TimeChatTool) Name() string { return "time" }
func (t TimeChatTool) Description() string {
return timeagent.ActivationPrompt
}
func (t TimeChatTool) Enabled() bool { return true }
func (t TimeChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
return runTimeTool(ctx, messages, routeReason, emit)
}
type SQLChatTool struct {
state *sqlquery.State
}
func (t SQLChatTool) Name() string { return "sql" }
func (t SQLChatTool) Description() string {
if t.state == nil {
return ""
}
return t.state.ActivationPrompt()
}
func (t SQLChatTool) Enabled() bool { return t.state != nil && t.state.Enabled() }
func (t SQLChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
return runSQLTool(ctx, t.state, profile, messages, routeReason, emit)
}
type SearchChatTool struct {
state *searchagent.State
}
func (t SearchChatTool) Name() string { return "search" }
func (t SearchChatTool) Description() string {
if t.state == nil {
return ""
}
return t.state.ActivationPrompt()
}
func (t SearchChatTool) Enabled() bool { return t.state != nil && t.state.Enabled() }
func (t SearchChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
return runSearchTool(ctx, t.state, messages, routeReason, emit)
}
type sqlGenerationResult struct {
Database string `json:"database"`
SQL string `json:"sql"`
Reason string `json:"reason"`
}
func runSQLTool(ctx context.Context, state *sqlquery.State, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
query := latestUserQuery(messages)
if query == "" {
return messages, nil
}
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "running", Message: "正在读取数据库结构"})
schemaContext, err := state.SchemaContext(ctx)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "error", Message: "数据库结构读取失败", Data: map[string]any{"error": err.Error()}})
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "success", Message: "数据库结构读取完成"})
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "running", Message: "正在生成只读 SQL"})
generated, err := generateSQLForUserQuery(ctx, profile, query, schemaContext)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "SQL 生成失败", Data: map[string]any{"error": err.Error()}})
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
generated.Database = strings.TrimSpace(generated.Database)
generated.SQL = strings.TrimSpace(generated.SQL)
if generated.SQL == "" {
err := fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason)
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "模型未生成可执行 SQL", Data: map[string]any{"reason": generated.Reason}})
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "running", Message: "正在执行数据库查询", Data: map[string]any{"database": generated.Database}})
result, err := state.ExecuteReadOnly(ctx, generated.Database, generated.SQL)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "error", Message: "数据库查询失败", Data: map[string]any{"error": err.Error()}})
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
}
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "success", Message: "已生成只读 SQL", Data: map[string]any{"database": generated.Database, "sql": generated.SQL, "reason": generated.Reason}})
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "success", Message: fmt.Sprintf("数据库查询完成,返回 %d 行", len(result.Rows)), Data: map[string]any{"database": result.Database, "rows": len(result.Rows), "columns": len(result.Columns), "truncated": result.Truncated, "max_rows": result.MaxRows}})
contextText := sqlquery.BuildResultContext(query, generated.SQL, result)
if strings.TrimSpace(routeReason) != "" {
contextText += "\n激活原因:" + routeReason
}
return prependHiddenContext(messages, contextText), nil
}
func prependHiddenContext(messages []ChatMessage, content string) []ChatMessage {
withContext := make([]ChatMessage, 0, len(messages)+1)
withContext = append(withContext, ChatMessage{Role: "system", Content: content, Hidden: true})
withContext = append(withContext, messages...)
return withContext
}
func runTimeTool(ctx context.Context, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
_ = ctx
resolved := timeagent.Resolve(time.Now())
emit(chatSSEFrame{Type: "trace", Tool: "time", Stage: "resolve", Status: "success", Message: "已获取当前时间上下文", Data: map[string]any{
"today": timeagent.FormatDate(resolved.Now),
"this_month": fmt.Sprintf("%s 至 %s", timeagent.FormatDate(resolved.ThisMonth.Start), timeagent.FormatDate(resolved.ThisMonth.End.AddDate(0, 0, -1))),
}})
return prependHiddenContext(messages, timeagent.BuildContext(resolved, routeReason)), nil
}
func runSearchTool(ctx context.Context, state *searchagent.State, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
query := latestUserQuery(messages)
if query == "" {
return messages, nil
}
if state == nil || !state.Enabled() {
err := errors.New("联网搜索未启用")
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索未启用", Data: map[string]any{"error": err.Error()}})
return prependHiddenContext(messages, searchagent.BuildErrorContext(query, err)), nil
}
active := state.ActiveProfile()
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "running", Message: "正在联网搜索", Data: map[string]any{"provider": active.Provider}})
results, profile, err := state.Search(ctx, query)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索失败", Data: map[string]any{"error": err.Error()}})
return prependHiddenContext(messages, searchagent.BuildErrorContext(query, err)), nil
}
if len(results) == 0 {
err := errors.New("未搜索到相关网页结果")
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "warning", Message: "未搜索到相关网页结果,将使用模型知识库回答"})
return prependHiddenContext(messages, searchagent.BuildFallbackContext(profile, query, routeReason, err)), nil
}
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "success", Message: fmt.Sprintf("联网搜索完成,找到 %d 条结果", len(results)), Data: map[string]any{"provider": profile.Provider, "count": len(results)}})
return prependHiddenContext(messages, searchagent.BuildResultContext(profile, query, results, routeReason)), nil
}
func enrichMessagesWithRoutedTools(ctx context.Context, chatProfile *OpenAIProfile, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) {
func availableAgentTools(profile *OpenAIProfile, emit func(chatSSEFrame)) []agentTool {
if toolRouterState == nil || toolRouterState.cfg == nil || !toolRouterState.cfg.Enabled {
return messages, nil
return nil
}
if latestUserQuery(messages) == "" {
return messages, nil
}
tools := availableChatTools(toolRouterState.cfg)
if len(tools) == 0 {
return messages, nil
}
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "running", Message: "正在进行工具路由"})
decision, err := routeTools(ctx, toolRouterState, chatProfile, messages, tools)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "error", Message: "工具路由失败,将继续普通回答", Data: map[string]any{"error": err.Error()}})
return messages, err
}
selected := filterToolSelections(decision, tools, toolRouterState.cfg.Tools)
selected = ensureTimeSelectionForRelativeQuery(selected, tools, toolRouterState.cfg.Tools, latestUserQuery(messages))
if len(selected) == 0 {
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "success", Message: "工具路由结果:无需调用工具", Data: map[string]any{"reason": decision.Reason}})
return messages, nil
}
names := make([]string, 0, len(selected))
for _, item := range selected {
names = append(names, item.Name)
}
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "success", Message: "工具路由结果:将调用 " + strings.Join(names, ", "), Data: map[string]any{"tools": names, "reason": decision.Reason}})
current := messages
for _, item := range selected {
tool := tools[item.Name]
next, err := tool.Enrich(ctx, chatProfile, current, firstNonEmpty(item.Reason, decision.Reason), emit)
if err != nil {
emit(chatSSEFrame{Type: "trace", Tool: item.Name, Stage: "error", Status: "error", Message: "工具调用失败,将继续普通回答", Data: map[string]any{"error": err.Error()}})
continue
}
current = next
}
return current, nil
}
func availableChatTools(config *ToolRouterConfig) map[string]ChatTool {
configured := map[string]ToolRouteConfig{}
for _, item := range config.Tools {
configured[item.Name] = item
}
registered := []ChatTool{
TimeChatTool{},
SearchChatTool{state: searchState},
SQLChatTool{state: sqlState},
}
available := map[string]ChatTool{}
for _, tool := range registered {
name := tool.Name()
item, ok := configured[name]
if !ok || !item.Enabled || !tool.Enabled() {
continue
}
available[name] = tool
}
return available
}
func routeTools(ctx context.Context, state *ToolRouterState, chatProfile *OpenAIProfile, messages []ChatMessage, tools map[string]ChatTool) (ToolRoutingDecision, error) {
routerProfile := chatProfile
if strings.TrimSpace(state.cfg.OpenAIName) != "" {
profile, err := state.ai.GetProfile(state.cfg.OpenAIName)
if err != nil {
return ToolRoutingDecision{}, err
}
routerProfile = profile
}
prompt := buildToolRouterPrompt(state.cfg, messages, tools)
text, err := state.complete(ctx, routerProfile, []ChatMessage{{Role: "system", Content: prompt}}, state.cfg.MaxTokens, time.Duration(state.cfg.Timeout)*time.Second)
if err != nil {
return ToolRoutingDecision{}, err
}
return parseToolRoutingDecision(text)
}
func buildToolRouterPrompt(config *ToolRouterConfig, messages []ChatMessage, tools map[string]ChatTool) string {
query := latestUserQuery(messages)
var b strings.Builder
b.WriteString(strings.TrimSpace(config.SystemPrompt))
b.WriteString("\n\n可用工具:\n")
for _, item := range config.Tools {
tool, ok := tools[item.Name]
if !ok {
tools := make([]agentTool, 0, len(toolRouterState.cfg.Tools))
for _, item := range toolRouterState.cfg.Tools {
if !item.Enabled {
continue
}
description := strings.TrimSpace(item.Description)
if description == "" {
description = tool.Description()
}
fmt.Fprintf(&b, "- name: %s\n description: %s\n", item.Name, description)
}
fmt.Fprintf(&b, "\n最新用户问题:%s", query)
return b.String()
}
func parseToolRoutingDecision(text string) (ToolRoutingDecision, error) {
var decision ToolRoutingDecision
if err := json.Unmarshal([]byte(extractJSONObject(text)), &decision); err != nil {
return decision, fmt.Errorf("解析工具路由结果失败: %w", err)
}
for i := range decision.Tools {
decision.Tools[i].Name = strings.ToLower(strings.TrimSpace(decision.Tools[i].Name))
decision.Tools[i].Reason = strings.TrimSpace(decision.Tools[i].Reason)
}
decision.Reason = strings.TrimSpace(decision.Reason)
return decision, nil
}
func filterToolSelections(decision ToolRoutingDecision, tools map[string]ChatTool, order []ToolRouteConfig) []ToolSelection {
selected := map[string]ToolSelection{}
for _, item := range decision.Tools {
if item.Name == "" {
continue
}
if _, ok := tools[item.Name]; !ok {
continue
}
if _, ok := selected[item.Name]; !ok {
selected[item.Name] = item
switch item.Name {
case timeagent.ToolName:
tools = append(tools, agentTool{
name: timeagent.ToolName,
definition: timeagent.ToolDefinition(description),
execute: func(ctx context.Context, args string) (string, error) {
result, err := timeagent.ExecuteTool(args, time.Now())
if err == nil && emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: timeagent.ToolName, Stage: "resolve", Status: "success", Message: "已获取当前时间上下文"})
}
return result, err
},
})
case searchagent.ToolName:
if searchState == nil || !searchState.Enabled() {
continue
}
tools = append(tools, agentTool{
name: searchagent.ToolName,
definition: searchState.ToolDefinition(description),
execute: func(ctx context.Context, args string) (string, error) {
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: searchagent.ToolName, Stage: "request", Status: "running", Message: "正在联网搜索"})
}
result, err := searchState.ExecuteTool(ctx, args)
if emit != nil {
status := "success"
message := "联网搜索完成"
if err != nil {
status = "error"
message = "联网搜索失败"
}
emit(chatSSEFrame{Type: "trace", Tool: searchagent.ToolName, Stage: "results", Status: status, Message: message})
}
return result, err
},
})
case sqlquery.ToolName:
if sqlState == nil || !sqlState.Enabled() {
continue
}
tools = append(tools, agentTool{
name: sqlquery.ToolName,
definition: sqlState.ToolDefinition(description),
execute: func(ctx context.Context, args string) (string, error) {
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: sqlquery.ToolName, Stage: "execute", Status: "running", Message: "正在查询数据库"})
}
generator := func(ctx context.Context, prompt string, maxTokens int) (string, error) {
return completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, maxTokens)
}
result, err := sqlState.ExecuteTool(ctx, args, generator)
if emit != nil {
status := "success"
message := "数据库查询完成"
if err != nil {
status = "error"
message = "数据库查询失败"
}
emit(chatSSEFrame{Type: "trace", Tool: sqlquery.ToolName, Stage: "execute", Status: status, Message: message})
}
return result, err
},
})
}
}
return orderToolSelections(selected, order)
return tools
}
func ensureTimeSelectionForRelativeQuery(selected []ToolSelection, tools map[string]ChatTool, order []ToolRouteConfig, query string) []ToolSelection {
if !containsRelativeTime(query) || hasToolSelection(selected, "time") || (!hasToolSelection(selected, "search") && !hasToolSelection(selected, "sql")) {
return selected
}
if _, ok := tools["time"]; !ok {
return selected
}
withTime := make(map[string]ToolSelection, len(selected)+1)
for _, item := range selected {
withTime[item.Name] = item
}
withTime["time"] = ToolSelection{Name: "time", Reason: "问题包含相对日期,需要先获取当前日期"}
return orderToolSelections(withTime, order)
}
func containsRelativeTime(query string) bool {
query = strings.TrimSpace(query)
if query == "" {
return false
}
for _, keyword := range []string{"今天", "今日", "明天", "昨天", "本周", "这周", "本月", "这个月", "本年", "今年", "最近", "历史上的今天"} {
if strings.Contains(query, keyword) {
return true
}
}
return false
}
func hasToolSelection(selected []ToolSelection, name string) bool {
for _, item := range selected {
if item.Name == name {
return true
}
}
return false
}
func orderToolSelections(selected map[string]ToolSelection, order []ToolRouteConfig) []ToolSelection {
result := make([]ToolSelection, 0, len(selected))
for _, item := range order {
if selection, ok := selected[item.Name]; ok {
result = append(result, selection)
}
}
return result
}
func firstNonEmpty(items ...string) string {
for _, item := range items {
if strings.TrimSpace(item) != "" {
return strings.TrimSpace(item)
}
}
return ""
}
func generateSQLForUserQuery(ctx context.Context, profile *OpenAIProfile, userQuery string, schemaContext string) (*sqlGenerationResult, error) {
prompt := fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题、隐藏上下文和数据库 schema 生成一条只读 SQL。
要求:
- 只能返回 JSON,不要使用 Markdown。
- JSON 格式:{"database":"数据库名称","sql":"SELECT ... LIMIT N","reason":"生成原因"}
- 只能生成 SELECT 或 WITH 查询,禁止 INSERT/UPDATE/DELETE/DROP/ALTER/CREATE 等任何修改语句。
- 必须只使用 schema 中出现的数据库、表和字段。
- 如果隐藏上下文中包含“时间工具结果”,必须使用其中的绝对日期范围解释用户问题里的今天、明天、昨天、本周、本月、本年、最近等相对时间。
- 用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,优先查询 tab_calendar_events 表;如果 schema 中没有该表,再返回无法根据已知表结构生成查询。
- 查询日程表时,涉及日期范围必须使用半开区间:时间字段 >= start AND 时间字段 < end_exclusive;时间字段必须从 schema 中选择真实存在的字段。
- 必须添加 LIMIT,且 LIMIT 不超过插件配置的 max_rows。
- 如果无法根据 schema 回答,返回 {"database":"","sql":"","reason":"无法根据已知表结构生成查询"}。
%s
用户问题:%s`, schemaContext, userQuery)
text, err := completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, 1024)
func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, emit func(chatSSEFrame)) ([]*model.ChatCompletionMessage, error) {
messages, err := buildArkMessages(chatMessages)
if err != nil {
return nil, err
}
var generated sqlGenerationResult
if err := json.Unmarshal([]byte(extractJSONObject(text)), &generated); err != nil {
return nil, fmt.Errorf("解析 SQL 生成结果失败: %w", err)
tools := availableAgentTools(profile, emit)
if len(tools) == 0 {
return messages, nil
}
return &generated, nil
toolByName := make(map[string]agentTool, len(tools))
definitions := make([]*model.Tool, 0, len(tools))
availableNames := make([]string, 0, len(tools))
toolDescriptions := make([]string, 0, len(tools))
for _, tool := range tools {
toolByName[tool.name] = tool
definitions = append(definitions, tool.definition)
availableNames = append(availableNames, tool.name)
if tool.definition != nil && tool.definition.Function != nil {
toolDescriptions = append(toolDescriptions, fmt.Sprintf("%s: %s", tool.name, tool.definition.Function.Description))
}
}
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "prepare", Status: "success", Message: "已准备可用工具", Data: map[string]any{"tools": availableNames, "tool_descriptions": toolDescriptions}})
}
if prompt := strings.TrimSpace(toolRouterState.cfg.SystemPrompt); prompt != "" {
messages = append([]*model.ChatCompletionMessage{{Role: model.ChatMessageRoleSystem, Content: stringContent(prompt)}}, messages...)
}
for i := 0; i < maxAgentToolIterations; i++ {
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "running", Message: fmt.Sprintf("正在进行第 %d 轮工具判断", i+1), Data: map[string]any{"iteration": i + 1, "max_iterations": maxAgentToolIterations, "tools": availableNames}})
}
resp, err := toolRouterState.complete(ctx, profile, model.CreateChatCompletionRequest{
Model: profile.Config.Model,
Messages: messages,
MaxTokens: intPtr(toolRouterState.cfg.MaxTokens),
Tools: definitions,
ToolChoice: model.ToolChoiceStringTypeAuto,
ParallelToolCalls: boolPtr(false),
}, time.Duration(toolRouterState.cfg.Timeout)*time.Second)
if err != nil {
return messages, err
}
if tracker := tokenUsageFromContext(ctx); tracker != nil {
tracker.addTool(resp.Usage.PromptTokens, resp.Usage.CompletionTokens)
}
if len(resp.Choices) == 0 {
return messages, nil
}
choice := resp.Choices[0]
decisionPreview := chatMessageContentString(choice.Message.Content)
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "decision", Status: "success", Message: "工具判断响应已返回", Data: map[string]any{"iteration": i + 1, "finish_reason": string(choice.FinishReason), "content_preview": truncateString(decisionPreview, 800)}})
}
calls := choice.Message.ToolCalls
if len(calls) == 0 && choice.Message.FunctionCall != nil {
calls = []*model.ToolCall{{ID: "legacy_function_call", Type: model.ToolTypeFunction, Function: *choice.Message.FunctionCall}}
}
if len(calls) == 0 {
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "success", Message: "模型未请求工具,进入回答生成"})
}
return messages, nil
}
callNames := make([]string, 0, len(calls))
for _, call := range calls {
if call != nil {
callNames = append(callNames, call.Function.Name)
}
}
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "tool_calls", Status: "running", Message: fmt.Sprintf("模型请求调用 %d 个工具", len(calls)), Data: map[string]any{"tools": callNames, "iteration": i + 1}})
}
messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleAssistant, ToolCalls: calls, Content: choice.Message.Content})
for _, call := range calls {
result := executeAgentToolCall(ctx, call, toolByName, emit)
messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleTool, ToolCallID: call.ID, Content: stringContent(result)})
}
}
messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent("工具调用轮数已达到上限。请基于已有工具结果回答,并说明可能未完成全部工具调用。")})
return messages, nil
}
func executeAgentToolCall(ctx context.Context, call *model.ToolCall, tools map[string]agentTool, emit func(chatSSEFrame)) string {
if call == nil || call.Type != model.ToolTypeFunction {
result := "工具调用无效:仅支持 function 类型工具。"
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "execute", Status: "error", Message: result})
}
return result
}
toolName := call.Function.Name
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: toolName, Stage: "arguments", Status: "running", Message: "准备执行工具", Data: map[string]any{"tool_call_id": call.ID, "arguments": call.Function.Arguments}})
}
tool, ok := tools[toolName]
if !ok {
result := fmt.Sprintf("工具调用失败:未知工具 %s。", toolName)
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: toolName, Stage: "execute", Status: "error", Message: result})
}
return result
}
started := time.Now()
result, err := tool.execute(ctx, call.Function.Arguments)
durationMs := time.Since(started).Milliseconds()
if err != nil {
message := fmt.Sprintf("工具 %s 执行失败:%v", tool.name, err)
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: tool.name, Stage: "execute", Status: "error", Message: "工具执行失败", Data: map[string]any{"tool_call_id": call.ID, "duration_ms": durationMs, "error": err.Error()}})
}
return message
}
if strings.TrimSpace(result) == "" {
result = fmt.Sprintf("工具 %s 执行完成,但没有返回内容。", tool.name)
}
if emit != nil {
emit(chatSSEFrame{Type: "trace", Tool: tool.name, Stage: "result", Status: "success", Message: "工具执行完成", Data: map[string]any{"tool_call_id": call.ID, "duration_ms": durationMs, "result_preview": truncateString(result, 1200)}})
}
return result
}
func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) {
@@ -1370,14 +1198,10 @@ func completeTextWithTimeout(ctx context.Context, profile *OpenAIProfile, chatMe
}
}
func extractJSONObject(text string) string {
text = strings.TrimSpace(text)
start := strings.Index(text, "{")
end := strings.LastIndex(text, "}")
if start >= 0 && end > start {
return text[start : end+1]
}
return text
func completeChatWithTimeout(ctx context.Context, profile *OpenAIProfile, request model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) {
completionCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return profile.Client.CreateChatCompletion(completionCtx, request.WithStream(false))
}
func newUUID() string {
@@ -1590,6 +1414,17 @@ func textPart(text string) *model.ChatCompletionMessageContentPart {
}
}
func stringContent(text string) *model.ChatCompletionMessageContent {
return &model.ChatCompletionMessageContent{StringValue: &text}
}
func chatMessageContentString(content *model.ChatCompletionMessageContent) string {
if content == nil || content.StringValue == nil {
return ""
}
return *content.StringValue
}
func normalizeImageURL(raw string) (string, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
@@ -1651,6 +1486,16 @@ func contains(items []string, target string) bool {
func intPtr(i int) *int { return &i }
func boolPtr(v bool) *bool { return &v }
func truncateString(text string, maxRunes int) string {
runes := []rune(strings.TrimSpace(text))
if maxRunes <= 0 || len(runes) <= maxRunes {
return string(runes)
}
return string(runes[:maxRunes]) + "..."
}
func writeSSEJSON(w io.Writer, frame chatSSEFrame) {
data, err := json.Marshal(frame)
if err != nil {