修改工具调用机制

This commit is contained in:
2026-06-10 18:54:17 +08:00
parent a838a812a0
commit 04485b6b0e
6 changed files with 589 additions and 78 deletions
+288 -39
View File
@@ -75,11 +75,38 @@ type openaiChatRequest struct {
Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Tools []openaiTool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
}
type openaiMessage struct {
Role string `json:"role"`
Content any `json:"content"`
Role string `json:"role"`
Content any `json:"content,omitempty"`
Name string `json:"name,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
ToolCalls []openaiToolCall `json:"tool_calls,omitempty"`
}
type openaiTool struct {
Type string `json:"type"`
Function openaiFunctionDefinition `json:"function"`
}
type openaiFunctionDefinition struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]interface{} `json:"parameters"`
}
type openaiToolCall struct {
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function openaiFunctionCall `json:"function"`
}
type openaiFunctionCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
type openaiContentPart struct {
@@ -94,8 +121,9 @@ type openaiImageURL struct {
}
type openaiResponseMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []openaiToolCall `json:"tool_calls,omitempty"`
}
// openaiStreamChunk is one SSE data line from the upstream
@@ -142,11 +170,12 @@ type openaiChoice struct {
}
type openaiDelta struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
Reasoning string `json:"reasoning,omitempty"`
Thinking string `json:"thinking,omitempty"`
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
ReasoningContent string `json:"reasoning_content,omitempty"`
Reasoning string `json:"reasoning,omitempty"`
Thinking string `json:"thinking,omitempty"`
ToolCalls []openaiToolCall `json:"tool_calls,omitempty"`
}
type openaiUsage struct {
@@ -235,7 +264,6 @@ func handleChat(ctx *gin.Context) {
sendSSEError(ctx, "AI 聊天未配置,请在后台配置 API Key 和模型")
return
}
toolRouterProfile, hasToolRouterProfile := selectOpenAIProfile(cfg, cfg.ToolRouter.OpenAIName)
chatMsgs := convertToChatMessages(req.Messages)
// Set up SSE headers before routing/tools so progress can stream immediately.
@@ -272,32 +300,8 @@ func handleChat(ctx *gin.Context) {
toolConfigs := []agents.ToolConfig{}
if cfg.ToolRouter.Enabled {
toolConfigs = buildToolConfigs(cfg.ToolRouter.Tools)
if hasToolRouterProfile && toolRouterProfile.Model != "" && toolRouterProfile.ApiKey != "" {
emitTrace("tool_router", "route", "running", "正在进行工具路由", nil)
routeResult, routeErr := routeTools(ctx.Request.Context(), toolRouterProfile, cfg.ToolRouter, chatMsgs)
if routeErr != nil {
emitTrace("tool_router", "route", "error", "工具路由失败,将继续普通回答", map[string]interface{}{"error": routeErr.Error()})
toolConfigs = []agents.ToolConfig{}
} else if routeResult != nil {
tracker.addToolUsage(routeResult.Usage, estimateOpenAIMessagesTokens(routeResult.Messages), estimateTokenCount(routeResult.Response))
data := map[string]interface{}{
"tools": routeResult.Selected,
"selections": routeResult.Decision.Tools,
"reason": routeResult.Decision.Reason,
}
message := "工具路由结果:无需调用工具"
if len(routeResult.Selected) > 0 {
message = "工具路由结果:将调用 " + strings.Join(routeResult.Selected, ", ")
}
emitTrace("tool_router", "route", "success", message, data)
toolConfigs = filterToolConfigs(toolConfigs, routeResult.Selected)
}
}
}
// Enrich messages with tools (pre-process)
chatMsgs = agents.EnrichMessages(ctx.Request.Context(), chatMsgs, toolConfigs, emitTrace)
// Build OpenAI-compatible request
openaiMsgs, err := convertToOpenAIMessages(chatMsgs)
if err != nil {
@@ -305,6 +309,25 @@ func handleChat(ctx *gin.Context) {
sendSSEDone(ctx, flusher)
return
}
functionTools := buildFunctionTools(toolConfigs)
if profile.SystemPrompt != "" {
openaiMsgs = append([]openaiMessage{{Role: "system", Content: profile.SystemPrompt}}, openaiMsgs...)
}
if len(functionTools) > 0 {
toolNames := make([]string, 0, len(functionTools))
for _, tool := range functionTools {
toolNames = append(toolNames, tool.Function.Name)
}
emitTrace("function_tools", "prepare", "success", "已启用 Function Calling 工具", map[string]interface{}{"tools": toolNames})
openaiMsgs = append([]openaiMessage{{Role: "system", Content: "可用工具使用规则:当用户询问本月、今天、本周、下周等相对日期的日程时,先调用 time 获取明确 start_date/end_date,再调用 ops_ai_assistant_schedule_query 查询日程。不要臆造工具结果中不存在的日程。"}}, openaiMsgs...)
var toolExecuted bool
openaiMsgs, toolExecuted, err = runOpenAIToolLoop(ctx.Request.Context(), profile, openaiMsgs, functionTools, currentUser, tracker, emitTrace)
if err != nil {
emitTrace("model", "tool_call", "error", "工具调用失败,将继续普通回答", map[string]interface{}{"error": err.Error()})
} else if toolExecuted {
emitTrace("model", "tool_call", "success", "工具调用完成,准备生成最终回答", nil)
}
}
apiReq := openaiChatRequest{
Model: profile.Model,
Messages: openaiMsgs,
@@ -313,11 +336,6 @@ func handleChat(ctx *gin.Context) {
Temperature: 0.7,
}
// Add system prompt if configured
if profile.SystemPrompt != "" {
apiReq.Messages = append([]openaiMessage{{Role: "system", Content: profile.SystemPrompt}}, apiReq.Messages...)
}
trimmedMessages, trimStats := trimOpenAIMessagesToContextWindow(apiReq.Messages, profile.ContextWindowTokens)
apiReq.Messages = trimmedMessages
if trimStats.RemovedMessages > 0 {
@@ -413,6 +431,39 @@ func handleChat(ctx *gin.Context) {
flusher.Flush()
}
func callOpenAIChat(ctx context.Context, cfg models.ConfigsAIChatOpenAI_, req openaiChatRequest) (*openaiChatResponse, error) {
bodyBytes, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
url := strings.TrimRight(cfg.BaseUrl, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+cfg.ApiKey)
client := &http.Client{Timeout: time.Duration(cfg.Timeout) * time.Second}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("连接上游服务失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("上游返回 %d: %s", resp.StatusCode, string(body))
}
var result openaiChatResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
return &result, nil
}
func streamOpenAI(ctx context.Context, cfg models.ConfigsAIChatOpenAI_, req openaiChatRequest, onData func(openaiStreamChunk)) error {
bodyBytes, err := json.Marshal(req)
if err != nil {
@@ -836,6 +887,204 @@ func buildToolConfigs(configs []models.ConfigsAIChatTool_) []agents.ToolConfig {
return result
}
func buildFunctionTools(configs []agents.ToolConfig) []openaiTool {
tools := make([]openaiTool, 0)
for _, config := range configs {
if !config.Enabled {
continue
}
switch strings.ToLower(strings.TrimSpace(config.Name)) {
case "time":
tools = append(tools, openaiTool{
Type: "function",
Function: openaiFunctionDefinition{
Name: "time",
Description: "解析当前时间、相对日期和日期范围。遇到本月、今天、本周等相对日期时先调用本工具获得明确日期。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"range": map[string]interface{}{
"type": "string",
"enum": []string{"today", "yesterday", "tomorrow", "this_week", "last_week", "next_week", "this_month", "last_month", "next_month", "this_year", "custom"},
"description": "要解析的日期范围。",
},
"timezone": map[string]interface{}{"type": "string", "description": "可选时区,例如 Asia/Shanghai。"},
"start_date": map[string]interface{}{"type": "string", "description": "custom 范围开始日期,格式 YYYY-MM-DD。"},
"end_date": map[string]interface{}{"type": "string", "description": "custom 范围结束日期,格式 YYYY-MM-DD。"},
},
"required": []string{"range"},
},
},
})
case "ops_ai_assistant_schedule_query", "ops_ai_assistant":
tools = append(tools, openaiTool{
Type: "function",
Function: openaiFunctionDefinition{
Name: "ops_ai_assistant_schedule_query",
Description: "按明确日期范围查询当前用户可见的 OPS 日历/日程。相对日期需先调用 time 获取 start_date/end_date。",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"start_date": map[string]interface{}{"type": "string", "description": "开始日期,格式 YYYY-MM-DD。"},
"end_date": map[string]interface{}{"type": "string", "description": "结束日期,格式 YYYY-MM-DD。"},
"calendar_id": map[string]interface{}{"type": "integer", "description": "可选日历 ID;不传则查询全部可见日程。"},
"limit": map[string]interface{}{"type": "integer", "description": "可选返回上限,默认 100,最大 200。"},
},
"required": []string{"start_date", "end_date"},
},
},
})
}
}
return tools
}
func parseJSONTraceValue(raw string) interface{} {
value := strings.TrimSpace(raw)
if value == "" {
return ""
}
var parsed interface{}
if err := json.Unmarshal([]byte(value), &parsed); err != nil {
return value
}
return parsed
}
func runOpenAIToolLoop(ctx context.Context, profile models.ConfigsAIChatOpenAI_, messages []openaiMessage, tools []openaiTool, currentUser *TabUser, tracker *tokenUsageTracker, trace agents.TraceFunc) ([]openaiMessage, bool, error) {
toolExecuted := false
for round := 0; round < 5; round++ {
if trace != nil {
trace("model", "tool_call", "running", "正在请求模型决定是否调用工具", map[string]interface{}{"round": round + 1})
}
req := openaiChatRequest{
Model: profile.Model,
Messages: messages,
Stream: false,
MaxTokens: profile.MaxTokens,
Temperature: 0.1,
Tools: tools,
ToolChoice: "auto",
}
resp, err := callOpenAIChat(ctx, profile, req)
if err != nil {
return messages, toolExecuted, err
}
responseText := ""
if len(resp.Choices) == 0 {
return messages, toolExecuted, nil
}
message := resp.Choices[0].Message
responseText = message.Content
tracker.addToolUsage(resp.Usage, estimateOpenAIMessagesTokens(messages), estimateTokenCount(responseText))
if len(message.ToolCalls) == 0 {
return messages, toolExecuted, nil
}
toolExecuted = true
messages = append(messages, openaiMessage{Role: "assistant", Content: message.Content, ToolCalls: message.ToolCalls})
for _, toolCall := range message.ToolCalls {
toolName := strings.TrimSpace(toolCall.Function.Name)
parsedArgs := parseJSONTraceValue(toolCall.Function.Arguments)
if trace != nil {
trace(toolName, "call", "running", "模型调用工具:"+toolName, map[string]interface{}{
"tool": toolName,
"arguments": parsedArgs,
})
}
resultJSON, err := executeAIFunctionTool(ctx, toolName, []byte(toolCall.Function.Arguments), currentUser)
status := "success"
if err != nil {
status = "error"
resultJSON, _ = json.Marshal(map[string]interface{}{"ok": false, "error": err.Error()})
}
if trace != nil {
data := map[string]interface{}{
"tool": toolName,
"result": parseJSONTraceValue(string(resultJSON)),
}
if len(resultJSON) > 1200 {
data["result"] = string(resultJSON[:1200]) + "..."
data["truncated"] = true
}
trace(toolName, "execute", status, "工具执行完成:"+toolName, data)
}
messages = append(messages, openaiMessage{Role: "tool", ToolCallID: toolCall.ID, Name: toolName, Content: string(resultJSON)})
}
}
return messages, toolExecuted, fmt.Errorf("工具调用超过最大轮数")
}
type scheduleQueryArgs struct {
StartDate string `json:"start_date"`
EndDate string `json:"end_date"`
CalendarID uint `json:"calendar_id,omitempty"`
Limit int `json:"limit,omitempty"`
}
func executeAIFunctionTool(ctx context.Context, name string, rawArgs []byte, currentUser *TabUser) ([]byte, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
switch name {
case "time":
var args agents.TimeRangeArgs
if len(rawArgs) > 0 {
if err := json.Unmarshal(rawArgs, &args); err != nil {
return nil, err
}
}
result, err := agents.ResolveTimeRange(args, time.Now())
if err != nil {
return nil, err
}
return json.Marshal(result)
case "ops_ai_assistant_schedule_query", "ops_ai_assistant":
var args scheduleQueryArgs
if err := json.Unmarshal(rawArgs, &args); err != nil {
return nil, err
}
startDate, err := time.Parse("2006-01-02", args.StartDate)
if err != nil {
return nil, fmt.Errorf("invalid start_date: %w", err)
}
endDate, err := time.Parse("2006-01-02", args.EndDate)
if err != nil {
return nil, fmt.Errorf("invalid end_date: %w", err)
}
limit := args.Limit
if limit <= 0 {
limit = 100
}
if limit > 200 {
limit = 200
}
events, err := QueryCalendarSchedulesForAI(CalendarScheduleQuery{
CalendarID: args.CalendarID,
StartDate: startDate,
EndDate: endDate,
User: currentUser,
Limit: limit,
})
if err != nil {
return nil, err
}
return json.Marshal(map[string]interface{}{
"ok": true,
"start_date": args.StartDate,
"end_date": args.EndDate,
"count": len(events),
"limit": limit,
"events": events,
})
default:
return nil, fmt.Errorf("unknown tool: %s", name)
}
}
func selectOpenAIProfile(cfg models.ConfigsAIChat_, name string) (models.ConfigsAIChatOpenAI_, bool) {
if name != "" {
for _, p := range cfg.OpenAI {