更新工具链
This commit is contained in:
@@ -37,16 +37,12 @@ const (
|
||||
defaultOpenAITimeout = 120
|
||||
defaultToolRouterTimeout = 30
|
||||
defaultToolRouterMaxTokens = 512
|
||||
defaultToolRouterSystemText = `你是工具路由器。根据用户最新问题和可用工具列表,判断本轮是否需要调用一个或多个工具。
|
||||
只能返回 JSON,不要使用 Markdown。
|
||||
JSON 格式:{"tools":[{"name":"工具名称","reason":"..."}],"reason":"..."}
|
||||
工具名称必须来自“可用工具”列表。
|
||||
可以选择多个工具,工具会按配置顺序依次执行;后面的工具可以使用前面工具写入的上下文。
|
||||
如果用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近等相对时间,且还需要调用 search 或 sql,必须同时选择 time,并让 time 排在这些工具之前。
|
||||
例如“历史上的今天都发生了什么”应选择 time 和 search:先获取今天的绝对日期,再搜索当天历史事件;如果联网无结果,主模型会回退到自身知识库回答并说明来源。
|
||||
例如“本月有什么日程安排”应选择 time 和 sql:先获取本月绝对日期范围,再查询日程表。
|
||||
如果无需工具,返回 {"tools":[],"reason":"..."}。
|
||||
只选择确实必要的工具。`
|
||||
defaultToolRouterSystemText = `你可以按需直接调用可用工具来回答用户问题。
|
||||
如果用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近等相对时间,且后续需要搜索或查询数据库,应先调用 time 获取绝对日期范围。
|
||||
需要实时网页资料、新闻、当前版本、近期事件、网页核验或用户明确要求联网时,调用 search。
|
||||
需要查询本地业务数据、日程、会议、待办、记录、统计或时间范围内数据时,调用 sql。
|
||||
工具结果优先于模型内置知识;工具失败时必须如实说明,不要编造结果。
|
||||
只调用确实必要的工具。`
|
||||
)
|
||||
|
||||
type OpenAIConfig struct {
|
||||
@@ -286,6 +282,11 @@ func normalizeOpenAIConfigs(cfg *Config) (bool, error) {
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
func isLegacyToolRouterPrompt(prompt string) bool {
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
return strings.Contains(prompt, "工具路由器") || strings.Contains(prompt, "route_tools") || strings.Contains(prompt, `"tools":[`)
|
||||
}
|
||||
|
||||
func normalizeToolRouterConfig(cfg *Config) (bool, error) {
|
||||
changed := false
|
||||
defaults := defaultToolRouterConfig()
|
||||
@@ -298,11 +299,12 @@ func normalizeToolRouterConfig(cfg *Config) (bool, error) {
|
||||
cfg.ToolRouter.MaxTokens = defaultToolRouterMaxTokens
|
||||
changed = true
|
||||
}
|
||||
if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) == "" {
|
||||
systemPrompt := strings.TrimSpace(cfg.ToolRouter.SystemPrompt)
|
||||
if systemPrompt == "" || isLegacyToolRouterPrompt(systemPrompt) {
|
||||
cfg.ToolRouter.SystemPrompt = defaultToolRouterSystemText
|
||||
changed = true
|
||||
} else if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) != cfg.ToolRouter.SystemPrompt {
|
||||
cfg.ToolRouter.SystemPrompt = strings.TrimSpace(cfg.ToolRouter.SystemPrompt)
|
||||
} else if systemPrompt != cfg.ToolRouter.SystemPrompt {
|
||||
cfg.ToolRouter.SystemPrompt = systemPrompt
|
||||
changed = true
|
||||
}
|
||||
if len(cfg.ToolRouter.Tools) == 0 {
|
||||
@@ -432,12 +434,12 @@ type openAIListResponse struct {
|
||||
Profiles []OpenAIConfig `json:"profiles"`
|
||||
}
|
||||
|
||||
type toolTextCompleter func(context.Context, *OpenAIProfile, []ChatMessage, int, time.Duration) (string, error)
|
||||
type chatCompleter func(context.Context, *OpenAIProfile, model.CreateChatCompletionRequest, time.Duration) (model.ChatCompletionResponse, error)
|
||||
|
||||
type ToolRouterState struct {
|
||||
cfg *ToolRouterConfig
|
||||
ai *OpenAIState
|
||||
complete toolTextCompleter
|
||||
complete chatCompleter
|
||||
}
|
||||
|
||||
func NewToolRouterState(config *ToolRouterConfig, ai *OpenAIState) (*ToolRouterState, error) {
|
||||
@@ -453,7 +455,7 @@ func NewToolRouterState(config *ToolRouterConfig, ai *OpenAIState) (*ToolRouterS
|
||||
return nil, fmt.Errorf("tool_router.openai_name 配置无效: %w", err)
|
||||
}
|
||||
}
|
||||
return &ToolRouterState{cfg: config, ai: ai, complete: completeTextWithTimeout}, nil
|
||||
return &ToolRouterState{cfg: config, ai: ai, complete: completeChatWithTimeout}, nil
|
||||
}
|
||||
|
||||
func NewOpenAIState(configs []OpenAIConfig) (*OpenAIState, error) {
|
||||
@@ -796,20 +798,17 @@ func chatHandler(c *gin.Context) {
|
||||
usage := newTokenUsageTracker()
|
||||
ctx = contextWithTokenUsage(ctx, usage)
|
||||
|
||||
chatMessages := req.Messages
|
||||
withTools, err := enrichMessagesWithRoutedTools(ctx, profile, chatMessages, emit)
|
||||
// 用 Function Calling 工具循环替代旧的路由+隐藏上下文机制
|
||||
messages, err := runAgentToolLoop(ctx, profile, req.Messages, emit)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "工具路由调用失败:", err)
|
||||
} else {
|
||||
chatMessages = withTools
|
||||
fmt.Fprintln(os.Stderr, "Agent 工具循环失败:", err)
|
||||
messages, err = buildArkMessages(req.Messages)
|
||||
if err != nil {
|
||||
emitError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
// 构建 ark 消息列表
|
||||
messages, err := buildArkMessages(chatMessages)
|
||||
if err != nil {
|
||||
emitError(err)
|
||||
return
|
||||
}
|
||||
promptTokens := estimateChatMessagesTokens(chatMessages)
|
||||
promptTokens := estimateChatMessagesTokens(req.Messages)
|
||||
|
||||
emitTrace("model", "request", "running", "正在调用模型生成回答", nil)
|
||||
stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
|
||||
@@ -885,21 +884,16 @@ func chatHandler(c *gin.Context) {
|
||||
stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
|
||||
emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats})
|
||||
}
|
||||
// 思考过程 reasoning_content 单独事件推送
|
||||
if resp.Choices[0].Delta.ReasoningContent != nil && *resp.Choices[0].Delta.ReasoningContent != "" {
|
||||
emit(chatSSEFrame{Type: "reasoning", Text: *resp.Choices[0].Delta.ReasoningContent})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── 辅助函数 ─────────────────────────────────────────────
|
||||
|
||||
func latestUserQuery(messages []ChatMessage) string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
return strings.TrimSpace(messages[i].Content)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func estimateChatMessagesTokens(messages []ChatMessage) int {
|
||||
total := 0
|
||||
for _, msg := range messages {
|
||||
@@ -951,380 +945,214 @@ func tokensPerSecond(tokens int, start time.Time) float64 {
|
||||
return float64(tokens) / elapsed
|
||||
}
|
||||
|
||||
type ToolSelection struct {
|
||||
Name string `json:"name"`
|
||||
Reason string `json:"reason"`
|
||||
type agentTool struct {
|
||||
name string
|
||||
definition *model.Tool
|
||||
execute func(context.Context, string) (string, error)
|
||||
}
|
||||
|
||||
type ToolRoutingDecision struct {
|
||||
Tools []ToolSelection `json:"tools"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
func (t agentTool) Name() string { return t.name }
|
||||
|
||||
type ChatTool interface {
|
||||
Name() string
|
||||
Description() string
|
||||
Enabled() bool
|
||||
Enrich(context.Context, *OpenAIProfile, []ChatMessage, string, func(chatSSEFrame)) ([]ChatMessage, error)
|
||||
}
|
||||
const maxAgentToolIterations = 6
|
||||
|
||||
type TimeChatTool struct{}
|
||||
|
||||
func (t TimeChatTool) Name() string { return "time" }
|
||||
|
||||
func (t TimeChatTool) Description() string {
|
||||
return timeagent.ActivationPrompt
|
||||
}
|
||||
|
||||
func (t TimeChatTool) Enabled() bool { return true }
|
||||
|
||||
func (t TimeChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
||||
return runTimeTool(ctx, messages, routeReason, emit)
|
||||
}
|
||||
|
||||
type SQLChatTool struct {
|
||||
state *sqlquery.State
|
||||
}
|
||||
|
||||
func (t SQLChatTool) Name() string { return "sql" }
|
||||
|
||||
func (t SQLChatTool) Description() string {
|
||||
if t.state == nil {
|
||||
return ""
|
||||
}
|
||||
return t.state.ActivationPrompt()
|
||||
}
|
||||
|
||||
func (t SQLChatTool) Enabled() bool { return t.state != nil && t.state.Enabled() }
|
||||
|
||||
func (t SQLChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
||||
return runSQLTool(ctx, t.state, profile, messages, routeReason, emit)
|
||||
}
|
||||
|
||||
type SearchChatTool struct {
|
||||
state *searchagent.State
|
||||
}
|
||||
|
||||
func (t SearchChatTool) Name() string { return "search" }
|
||||
|
||||
func (t SearchChatTool) Description() string {
|
||||
if t.state == nil {
|
||||
return ""
|
||||
}
|
||||
return t.state.ActivationPrompt()
|
||||
}
|
||||
|
||||
func (t SearchChatTool) Enabled() bool { return t.state != nil && t.state.Enabled() }
|
||||
|
||||
func (t SearchChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
||||
return runSearchTool(ctx, t.state, messages, routeReason, emit)
|
||||
}
|
||||
|
||||
type sqlGenerationResult struct {
|
||||
Database string `json:"database"`
|
||||
SQL string `json:"sql"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
func runSQLTool(ctx context.Context, state *sqlquery.State, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
||||
query := latestUserQuery(messages)
|
||||
if query == "" {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "running", Message: "正在读取数据库结构"})
|
||||
schemaContext, err := state.SchemaContext(ctx)
|
||||
if err != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "error", Message: "数据库结构读取失败", Data: map[string]any{"error": err.Error()}})
|
||||
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
||||
}
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "success", Message: "数据库结构读取完成"})
|
||||
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "running", Message: "正在生成只读 SQL"})
|
||||
generated, err := generateSQLForUserQuery(ctx, profile, query, schemaContext)
|
||||
if err != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "SQL 生成失败", Data: map[string]any{"error": err.Error()}})
|
||||
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
||||
}
|
||||
generated.Database = strings.TrimSpace(generated.Database)
|
||||
generated.SQL = strings.TrimSpace(generated.SQL)
|
||||
if generated.SQL == "" {
|
||||
err := fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason)
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "模型未生成可执行 SQL", Data: map[string]any{"reason": generated.Reason}})
|
||||
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
||||
}
|
||||
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "running", Message: "正在执行数据库查询", Data: map[string]any{"database": generated.Database}})
|
||||
result, err := state.ExecuteReadOnly(ctx, generated.Database, generated.SQL)
|
||||
if err != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "error", Message: "数据库查询失败", Data: map[string]any{"error": err.Error()}})
|
||||
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
||||
}
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "success", Message: "已生成只读 SQL", Data: map[string]any{"database": generated.Database, "sql": generated.SQL, "reason": generated.Reason}})
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "success", Message: fmt.Sprintf("数据库查询完成,返回 %d 行", len(result.Rows)), Data: map[string]any{"database": result.Database, "rows": len(result.Rows), "columns": len(result.Columns), "truncated": result.Truncated, "max_rows": result.MaxRows}})
|
||||
contextText := sqlquery.BuildResultContext(query, generated.SQL, result)
|
||||
if strings.TrimSpace(routeReason) != "" {
|
||||
contextText += "\n激活原因:" + routeReason
|
||||
}
|
||||
return prependHiddenContext(messages, contextText), nil
|
||||
}
|
||||
|
||||
func prependHiddenContext(messages []ChatMessage, content string) []ChatMessage {
|
||||
withContext := make([]ChatMessage, 0, len(messages)+1)
|
||||
withContext = append(withContext, ChatMessage{Role: "system", Content: content, Hidden: true})
|
||||
withContext = append(withContext, messages...)
|
||||
return withContext
|
||||
}
|
||||
|
||||
func runTimeTool(ctx context.Context, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
||||
_ = ctx
|
||||
resolved := timeagent.Resolve(time.Now())
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "time", Stage: "resolve", Status: "success", Message: "已获取当前时间上下文", Data: map[string]any{
|
||||
"today": timeagent.FormatDate(resolved.Now),
|
||||
"this_month": fmt.Sprintf("%s 至 %s", timeagent.FormatDate(resolved.ThisMonth.Start), timeagent.FormatDate(resolved.ThisMonth.End.AddDate(0, 0, -1))),
|
||||
}})
|
||||
return prependHiddenContext(messages, timeagent.BuildContext(resolved, routeReason)), nil
|
||||
}
|
||||
|
||||
func runSearchTool(ctx context.Context, state *searchagent.State, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
||||
query := latestUserQuery(messages)
|
||||
if query == "" {
|
||||
return messages, nil
|
||||
}
|
||||
if state == nil || !state.Enabled() {
|
||||
err := errors.New("联网搜索未启用")
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索未启用", Data: map[string]any{"error": err.Error()}})
|
||||
return prependHiddenContext(messages, searchagent.BuildErrorContext(query, err)), nil
|
||||
}
|
||||
active := state.ActiveProfile()
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "running", Message: "正在联网搜索", Data: map[string]any{"provider": active.Provider}})
|
||||
results, profile, err := state.Search(ctx, query)
|
||||
if err != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索失败", Data: map[string]any{"error": err.Error()}})
|
||||
return prependHiddenContext(messages, searchagent.BuildErrorContext(query, err)), nil
|
||||
}
|
||||
if len(results) == 0 {
|
||||
err := errors.New("未搜索到相关网页结果")
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "warning", Message: "未搜索到相关网页结果,将使用模型知识库回答"})
|
||||
return prependHiddenContext(messages, searchagent.BuildFallbackContext(profile, query, routeReason, err)), nil
|
||||
}
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "success", Message: fmt.Sprintf("联网搜索完成,找到 %d 条结果", len(results)), Data: map[string]any{"provider": profile.Provider, "count": len(results)}})
|
||||
return prependHiddenContext(messages, searchagent.BuildResultContext(profile, query, results, routeReason)), nil
|
||||
}
|
||||
|
||||
func enrichMessagesWithRoutedTools(ctx context.Context, chatProfile *OpenAIProfile, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
||||
func availableAgentTools(profile *OpenAIProfile, emit func(chatSSEFrame)) []agentTool {
|
||||
if toolRouterState == nil || toolRouterState.cfg == nil || !toolRouterState.cfg.Enabled {
|
||||
return messages, nil
|
||||
return nil
|
||||
}
|
||||
if latestUserQuery(messages) == "" {
|
||||
return messages, nil
|
||||
}
|
||||
tools := availableChatTools(toolRouterState.cfg)
|
||||
if len(tools) == 0 {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "running", Message: "正在进行工具路由"})
|
||||
decision, err := routeTools(ctx, toolRouterState, chatProfile, messages, tools)
|
||||
if err != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "error", Message: "工具路由失败,将继续普通回答", Data: map[string]any{"error": err.Error()}})
|
||||
return messages, err
|
||||
}
|
||||
selected := filterToolSelections(decision, tools, toolRouterState.cfg.Tools)
|
||||
selected = ensureTimeSelectionForRelativeQuery(selected, tools, toolRouterState.cfg.Tools, latestUserQuery(messages))
|
||||
if len(selected) == 0 {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "success", Message: "工具路由结果:无需调用工具", Data: map[string]any{"reason": decision.Reason}})
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(selected))
|
||||
for _, item := range selected {
|
||||
names = append(names, item.Name)
|
||||
}
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "success", Message: "工具路由结果:将调用 " + strings.Join(names, ", "), Data: map[string]any{"tools": names, "reason": decision.Reason}})
|
||||
|
||||
current := messages
|
||||
for _, item := range selected {
|
||||
tool := tools[item.Name]
|
||||
next, err := tool.Enrich(ctx, chatProfile, current, firstNonEmpty(item.Reason, decision.Reason), emit)
|
||||
if err != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: item.Name, Stage: "error", Status: "error", Message: "工具调用失败,将继续普通回答", Data: map[string]any{"error": err.Error()}})
|
||||
continue
|
||||
}
|
||||
current = next
|
||||
}
|
||||
return current, nil
|
||||
}
|
||||
|
||||
func availableChatTools(config *ToolRouterConfig) map[string]ChatTool {
|
||||
configured := map[string]ToolRouteConfig{}
|
||||
for _, item := range config.Tools {
|
||||
configured[item.Name] = item
|
||||
}
|
||||
registered := []ChatTool{
|
||||
TimeChatTool{},
|
||||
SearchChatTool{state: searchState},
|
||||
SQLChatTool{state: sqlState},
|
||||
}
|
||||
available := map[string]ChatTool{}
|
||||
for _, tool := range registered {
|
||||
name := tool.Name()
|
||||
item, ok := configured[name]
|
||||
if !ok || !item.Enabled || !tool.Enabled() {
|
||||
continue
|
||||
}
|
||||
available[name] = tool
|
||||
}
|
||||
return available
|
||||
}
|
||||
|
||||
func routeTools(ctx context.Context, state *ToolRouterState, chatProfile *OpenAIProfile, messages []ChatMessage, tools map[string]ChatTool) (ToolRoutingDecision, error) {
|
||||
routerProfile := chatProfile
|
||||
if strings.TrimSpace(state.cfg.OpenAIName) != "" {
|
||||
profile, err := state.ai.GetProfile(state.cfg.OpenAIName)
|
||||
if err != nil {
|
||||
return ToolRoutingDecision{}, err
|
||||
}
|
||||
routerProfile = profile
|
||||
}
|
||||
prompt := buildToolRouterPrompt(state.cfg, messages, tools)
|
||||
text, err := state.complete(ctx, routerProfile, []ChatMessage{{Role: "system", Content: prompt}}, state.cfg.MaxTokens, time.Duration(state.cfg.Timeout)*time.Second)
|
||||
if err != nil {
|
||||
return ToolRoutingDecision{}, err
|
||||
}
|
||||
return parseToolRoutingDecision(text)
|
||||
}
|
||||
|
||||
func buildToolRouterPrompt(config *ToolRouterConfig, messages []ChatMessage, tools map[string]ChatTool) string {
|
||||
query := latestUserQuery(messages)
|
||||
var b strings.Builder
|
||||
b.WriteString(strings.TrimSpace(config.SystemPrompt))
|
||||
b.WriteString("\n\n可用工具:\n")
|
||||
for _, item := range config.Tools {
|
||||
tool, ok := tools[item.Name]
|
||||
if !ok {
|
||||
tools := make([]agentTool, 0, len(toolRouterState.cfg.Tools))
|
||||
for _, item := range toolRouterState.cfg.Tools {
|
||||
if !item.Enabled {
|
||||
continue
|
||||
}
|
||||
description := strings.TrimSpace(item.Description)
|
||||
if description == "" {
|
||||
description = tool.Description()
|
||||
}
|
||||
fmt.Fprintf(&b, "- name: %s\n description: %s\n", item.Name, description)
|
||||
}
|
||||
fmt.Fprintf(&b, "\n最新用户问题:%s", query)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func parseToolRoutingDecision(text string) (ToolRoutingDecision, error) {
|
||||
var decision ToolRoutingDecision
|
||||
if err := json.Unmarshal([]byte(extractJSONObject(text)), &decision); err != nil {
|
||||
return decision, fmt.Errorf("解析工具路由结果失败: %w", err)
|
||||
}
|
||||
for i := range decision.Tools {
|
||||
decision.Tools[i].Name = strings.ToLower(strings.TrimSpace(decision.Tools[i].Name))
|
||||
decision.Tools[i].Reason = strings.TrimSpace(decision.Tools[i].Reason)
|
||||
}
|
||||
decision.Reason = strings.TrimSpace(decision.Reason)
|
||||
return decision, nil
|
||||
}
|
||||
|
||||
func filterToolSelections(decision ToolRoutingDecision, tools map[string]ChatTool, order []ToolRouteConfig) []ToolSelection {
|
||||
selected := map[string]ToolSelection{}
|
||||
for _, item := range decision.Tools {
|
||||
if item.Name == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := tools[item.Name]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := selected[item.Name]; !ok {
|
||||
selected[item.Name] = item
|
||||
switch item.Name {
|
||||
case timeagent.ToolName:
|
||||
tools = append(tools, agentTool{
|
||||
name: timeagent.ToolName,
|
||||
definition: timeagent.ToolDefinition(description),
|
||||
execute: func(ctx context.Context, args string) (string, error) {
|
||||
result, err := timeagent.ExecuteTool(args, time.Now())
|
||||
if err == nil && emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: timeagent.ToolName, Stage: "resolve", Status: "success", Message: "已获取当前时间上下文"})
|
||||
}
|
||||
return result, err
|
||||
},
|
||||
})
|
||||
case searchagent.ToolName:
|
||||
if searchState == nil || !searchState.Enabled() {
|
||||
continue
|
||||
}
|
||||
tools = append(tools, agentTool{
|
||||
name: searchagent.ToolName,
|
||||
definition: searchState.ToolDefinition(description),
|
||||
execute: func(ctx context.Context, args string) (string, error) {
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: searchagent.ToolName, Stage: "request", Status: "running", Message: "正在联网搜索"})
|
||||
}
|
||||
result, err := searchState.ExecuteTool(ctx, args)
|
||||
if emit != nil {
|
||||
status := "success"
|
||||
message := "联网搜索完成"
|
||||
if err != nil {
|
||||
status = "error"
|
||||
message = "联网搜索失败"
|
||||
}
|
||||
emit(chatSSEFrame{Type: "trace", Tool: searchagent.ToolName, Stage: "results", Status: status, Message: message})
|
||||
}
|
||||
return result, err
|
||||
},
|
||||
})
|
||||
case sqlquery.ToolName:
|
||||
if sqlState == nil || !sqlState.Enabled() {
|
||||
continue
|
||||
}
|
||||
tools = append(tools, agentTool{
|
||||
name: sqlquery.ToolName,
|
||||
definition: sqlState.ToolDefinition(description),
|
||||
execute: func(ctx context.Context, args string) (string, error) {
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: sqlquery.ToolName, Stage: "execute", Status: "running", Message: "正在查询数据库"})
|
||||
}
|
||||
generator := func(ctx context.Context, prompt string, maxTokens int) (string, error) {
|
||||
return completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, maxTokens)
|
||||
}
|
||||
result, err := sqlState.ExecuteTool(ctx, args, generator)
|
||||
if emit != nil {
|
||||
status := "success"
|
||||
message := "数据库查询完成"
|
||||
if err != nil {
|
||||
status = "error"
|
||||
message = "数据库查询失败"
|
||||
}
|
||||
emit(chatSSEFrame{Type: "trace", Tool: sqlquery.ToolName, Stage: "execute", Status: status, Message: message})
|
||||
}
|
||||
return result, err
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
return orderToolSelections(selected, order)
|
||||
return tools
|
||||
}
|
||||
|
||||
func ensureTimeSelectionForRelativeQuery(selected []ToolSelection, tools map[string]ChatTool, order []ToolRouteConfig, query string) []ToolSelection {
|
||||
if !containsRelativeTime(query) || hasToolSelection(selected, "time") || (!hasToolSelection(selected, "search") && !hasToolSelection(selected, "sql")) {
|
||||
return selected
|
||||
}
|
||||
if _, ok := tools["time"]; !ok {
|
||||
return selected
|
||||
}
|
||||
withTime := make(map[string]ToolSelection, len(selected)+1)
|
||||
for _, item := range selected {
|
||||
withTime[item.Name] = item
|
||||
}
|
||||
withTime["time"] = ToolSelection{Name: "time", Reason: "问题包含相对日期,需要先获取当前日期"}
|
||||
return orderToolSelections(withTime, order)
|
||||
}
|
||||
|
||||
func containsRelativeTime(query string) bool {
|
||||
query = strings.TrimSpace(query)
|
||||
if query == "" {
|
||||
return false
|
||||
}
|
||||
for _, keyword := range []string{"今天", "今日", "明天", "昨天", "本周", "这周", "本月", "这个月", "本年", "今年", "最近", "历史上的今天"} {
|
||||
if strings.Contains(query, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hasToolSelection(selected []ToolSelection, name string) bool {
|
||||
for _, item := range selected {
|
||||
if item.Name == name {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func orderToolSelections(selected map[string]ToolSelection, order []ToolRouteConfig) []ToolSelection {
|
||||
result := make([]ToolSelection, 0, len(selected))
|
||||
for _, item := range order {
|
||||
if selection, ok := selected[item.Name]; ok {
|
||||
result = append(result, selection)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func firstNonEmpty(items ...string) string {
|
||||
for _, item := range items {
|
||||
if strings.TrimSpace(item) != "" {
|
||||
return strings.TrimSpace(item)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func generateSQLForUserQuery(ctx context.Context, profile *OpenAIProfile, userQuery string, schemaContext string) (*sqlGenerationResult, error) {
|
||||
prompt := fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题、隐藏上下文和数据库 schema 生成一条只读 SQL。
|
||||
要求:
|
||||
- 只能返回 JSON,不要使用 Markdown。
|
||||
- JSON 格式:{"database":"数据库名称","sql":"SELECT ... LIMIT N","reason":"生成原因"}
|
||||
- 只能生成 SELECT 或 WITH 查询,禁止 INSERT/UPDATE/DELETE/DROP/ALTER/CREATE 等任何修改语句。
|
||||
- 必须只使用 schema 中出现的数据库、表和字段。
|
||||
- 如果隐藏上下文中包含“时间工具结果”,必须使用其中的绝对日期范围解释用户问题里的今天、明天、昨天、本周、本月、本年、最近等相对时间。
|
||||
- 用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,优先查询 tab_calendar_events 表;如果 schema 中没有该表,再返回无法根据已知表结构生成查询。
|
||||
- 查询日程表时,涉及日期范围必须使用半开区间:时间字段 >= start AND 时间字段 < end_exclusive;时间字段必须从 schema 中选择真实存在的字段。
|
||||
- 必须添加 LIMIT,且 LIMIT 不超过插件配置的 max_rows。
|
||||
- 如果无法根据 schema 回答,返回 {"database":"","sql":"","reason":"无法根据已知表结构生成查询"}。
|
||||
|
||||
%s
|
||||
|
||||
用户问题:%s`, schemaContext, userQuery)
|
||||
text, err := completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, 1024)
|
||||
func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, emit func(chatSSEFrame)) ([]*model.ChatCompletionMessage, error) {
|
||||
messages, err := buildArkMessages(chatMessages)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var generated sqlGenerationResult
|
||||
if err := json.Unmarshal([]byte(extractJSONObject(text)), &generated); err != nil {
|
||||
return nil, fmt.Errorf("解析 SQL 生成结果失败: %w", err)
|
||||
tools := availableAgentTools(profile, emit)
|
||||
if len(tools) == 0 {
|
||||
return messages, nil
|
||||
}
|
||||
return &generated, nil
|
||||
toolByName := make(map[string]agentTool, len(tools))
|
||||
definitions := make([]*model.Tool, 0, len(tools))
|
||||
availableNames := make([]string, 0, len(tools))
|
||||
toolDescriptions := make([]string, 0, len(tools))
|
||||
for _, tool := range tools {
|
||||
toolByName[tool.name] = tool
|
||||
definitions = append(definitions, tool.definition)
|
||||
availableNames = append(availableNames, tool.name)
|
||||
if tool.definition != nil && tool.definition.Function != nil {
|
||||
toolDescriptions = append(toolDescriptions, fmt.Sprintf("%s: %s", tool.name, tool.definition.Function.Description))
|
||||
}
|
||||
}
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "prepare", Status: "success", Message: "已准备可用工具", Data: map[string]any{"tools": availableNames, "tool_descriptions": toolDescriptions}})
|
||||
}
|
||||
if prompt := strings.TrimSpace(toolRouterState.cfg.SystemPrompt); prompt != "" {
|
||||
messages = append([]*model.ChatCompletionMessage{{Role: model.ChatMessageRoleSystem, Content: stringContent(prompt)}}, messages...)
|
||||
}
|
||||
for i := 0; i < maxAgentToolIterations; i++ {
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "running", Message: fmt.Sprintf("正在进行第 %d 轮工具判断", i+1), Data: map[string]any{"iteration": i + 1, "max_iterations": maxAgentToolIterations, "tools": availableNames}})
|
||||
}
|
||||
resp, err := toolRouterState.complete(ctx, profile, model.CreateChatCompletionRequest{
|
||||
Model: profile.Config.Model,
|
||||
Messages: messages,
|
||||
MaxTokens: intPtr(toolRouterState.cfg.MaxTokens),
|
||||
Tools: definitions,
|
||||
ToolChoice: model.ToolChoiceStringTypeAuto,
|
||||
ParallelToolCalls: boolPtr(false),
|
||||
}, time.Duration(toolRouterState.cfg.Timeout)*time.Second)
|
||||
if err != nil {
|
||||
return messages, err
|
||||
}
|
||||
if tracker := tokenUsageFromContext(ctx); tracker != nil {
|
||||
tracker.addTool(resp.Usage.PromptTokens, resp.Usage.CompletionTokens)
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return messages, nil
|
||||
}
|
||||
choice := resp.Choices[0]
|
||||
decisionPreview := chatMessageContentString(choice.Message.Content)
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "decision", Status: "success", Message: "工具判断响应已返回", Data: map[string]any{"iteration": i + 1, "finish_reason": string(choice.FinishReason), "content_preview": truncateString(decisionPreview, 800)}})
|
||||
}
|
||||
calls := choice.Message.ToolCalls
|
||||
if len(calls) == 0 && choice.Message.FunctionCall != nil {
|
||||
calls = []*model.ToolCall{{ID: "legacy_function_call", Type: model.ToolTypeFunction, Function: *choice.Message.FunctionCall}}
|
||||
}
|
||||
if len(calls) == 0 {
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "success", Message: "模型未请求工具,进入回答生成"})
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
callNames := make([]string, 0, len(calls))
|
||||
for _, call := range calls {
|
||||
if call != nil {
|
||||
callNames = append(callNames, call.Function.Name)
|
||||
}
|
||||
}
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "tool_calls", Status: "running", Message: fmt.Sprintf("模型请求调用 %d 个工具", len(calls)), Data: map[string]any{"tools": callNames, "iteration": i + 1}})
|
||||
}
|
||||
messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleAssistant, ToolCalls: calls, Content: choice.Message.Content})
|
||||
for _, call := range calls {
|
||||
result := executeAgentToolCall(ctx, call, toolByName, emit)
|
||||
messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleTool, ToolCallID: call.ID, Content: stringContent(result)})
|
||||
}
|
||||
}
|
||||
messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent("工具调用轮数已达到上限。请基于已有工具结果回答,并说明可能未完成全部工具调用。")})
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func executeAgentToolCall(ctx context.Context, call *model.ToolCall, tools map[string]agentTool, emit func(chatSSEFrame)) string {
|
||||
if call == nil || call.Type != model.ToolTypeFunction {
|
||||
result := "工具调用无效:仅支持 function 类型工具。"
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "execute", Status: "error", Message: result})
|
||||
}
|
||||
return result
|
||||
}
|
||||
toolName := call.Function.Name
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: toolName, Stage: "arguments", Status: "running", Message: "准备执行工具", Data: map[string]any{"tool_call_id": call.ID, "arguments": call.Function.Arguments}})
|
||||
}
|
||||
tool, ok := tools[toolName]
|
||||
if !ok {
|
||||
result := fmt.Sprintf("工具调用失败:未知工具 %s。", toolName)
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: toolName, Stage: "execute", Status: "error", Message: result})
|
||||
}
|
||||
return result
|
||||
}
|
||||
started := time.Now()
|
||||
result, err := tool.execute(ctx, call.Function.Arguments)
|
||||
durationMs := time.Since(started).Milliseconds()
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("工具 %s 执行失败:%v", tool.name, err)
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: tool.name, Stage: "execute", Status: "error", Message: "工具执行失败", Data: map[string]any{"tool_call_id": call.ID, "duration_ms": durationMs, "error": err.Error()}})
|
||||
}
|
||||
return message
|
||||
}
|
||||
if strings.TrimSpace(result) == "" {
|
||||
result = fmt.Sprintf("工具 %s 执行完成,但没有返回内容。", tool.name)
|
||||
}
|
||||
if emit != nil {
|
||||
emit(chatSSEFrame{Type: "trace", Tool: tool.name, Stage: "result", Status: "success", Message: "工具执行完成", Data: map[string]any{"tool_call_id": call.ID, "duration_ms": durationMs, "result_preview": truncateString(result, 1200)}})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) {
|
||||
@@ -1370,14 +1198,10 @@ func completeTextWithTimeout(ctx context.Context, profile *OpenAIProfile, chatMe
|
||||
}
|
||||
}
|
||||
|
||||
func extractJSONObject(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
start := strings.Index(text, "{")
|
||||
end := strings.LastIndex(text, "}")
|
||||
if start >= 0 && end > start {
|
||||
return text[start : end+1]
|
||||
}
|
||||
return text
|
||||
func completeChatWithTimeout(ctx context.Context, profile *OpenAIProfile, request model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) {
|
||||
completionCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
return profile.Client.CreateChatCompletion(completionCtx, request.WithStream(false))
|
||||
}
|
||||
|
||||
func newUUID() string {
|
||||
@@ -1590,6 +1414,17 @@ func textPart(text string) *model.ChatCompletionMessageContentPart {
|
||||
}
|
||||
}
|
||||
|
||||
func stringContent(text string) *model.ChatCompletionMessageContent {
|
||||
return &model.ChatCompletionMessageContent{StringValue: &text}
|
||||
}
|
||||
|
||||
func chatMessageContentString(content *model.ChatCompletionMessageContent) string {
|
||||
if content == nil || content.StringValue == nil {
|
||||
return ""
|
||||
}
|
||||
return *content.StringValue
|
||||
}
|
||||
|
||||
func normalizeImageURL(raw string) (string, error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
@@ -1651,6 +1486,16 @@ func contains(items []string, target string) bool {
|
||||
|
||||
func intPtr(i int) *int { return &i }
|
||||
|
||||
func boolPtr(v bool) *bool { return &v }
|
||||
|
||||
func truncateString(text string, maxRunes int) string {
|
||||
runes := []rune(strings.TrimSpace(text))
|
||||
if maxRunes <= 0 || len(runes) <= maxRunes {
|
||||
return string(runes)
|
||||
}
|
||||
return string(runes[:maxRunes]) + "..."
|
||||
}
|
||||
|
||||
func writeSSEJSON(w io.Writer, frame chatSSEFrame) {
|
||||
data, err := json.Marshal(frame)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user