修改工具调用机制
This commit is contained in:
@@ -75,11 +75,38 @@ type openaiChatRequest struct {
|
||||
Stream bool `json:"stream"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
Tools []openaiTool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
type openaiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
ToolCalls []openaiToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type openaiTool struct {
|
||||
Type string `json:"type"`
|
||||
Function openaiFunctionDefinition `json:"function"`
|
||||
}
|
||||
|
||||
type openaiFunctionDefinition struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters map[string]interface{} `json:"parameters"`
|
||||
}
|
||||
|
||||
type openaiToolCall struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function openaiFunctionCall `json:"function"`
|
||||
}
|
||||
|
||||
type openaiFunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
type openaiContentPart struct {
|
||||
@@ -94,8 +121,9 @@ type openaiImageURL struct {
|
||||
}
|
||||
|
||||
type openaiResponseMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []openaiToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
// openaiStreamChunk is one SSE data line from the upstream
|
||||
@@ -142,11 +170,12 @@ type openaiChoice struct {
|
||||
}
|
||||
|
||||
type openaiDelta struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Reasoning string `json:"reasoning,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Content string `json:"content,omitempty"`
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
Reasoning string `json:"reasoning,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
ToolCalls []openaiToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
|
||||
type openaiUsage struct {
|
||||
@@ -235,7 +264,6 @@ func handleChat(ctx *gin.Context) {
|
||||
sendSSEError(ctx, "AI 聊天未配置,请在后台配置 API Key 和模型")
|
||||
return
|
||||
}
|
||||
toolRouterProfile, hasToolRouterProfile := selectOpenAIProfile(cfg, cfg.ToolRouter.OpenAIName)
|
||||
chatMsgs := convertToChatMessages(req.Messages)
|
||||
|
||||
// Set up SSE headers before routing/tools so progress can stream immediately.
|
||||
@@ -272,32 +300,8 @@ func handleChat(ctx *gin.Context) {
|
||||
toolConfigs := []agents.ToolConfig{}
|
||||
if cfg.ToolRouter.Enabled {
|
||||
toolConfigs = buildToolConfigs(cfg.ToolRouter.Tools)
|
||||
if hasToolRouterProfile && toolRouterProfile.Model != "" && toolRouterProfile.ApiKey != "" {
|
||||
emitTrace("tool_router", "route", "running", "正在进行工具路由", nil)
|
||||
routeResult, routeErr := routeTools(ctx.Request.Context(), toolRouterProfile, cfg.ToolRouter, chatMsgs)
|
||||
if routeErr != nil {
|
||||
emitTrace("tool_router", "route", "error", "工具路由失败,将继续普通回答", map[string]interface{}{"error": routeErr.Error()})
|
||||
toolConfigs = []agents.ToolConfig{}
|
||||
} else if routeResult != nil {
|
||||
tracker.addToolUsage(routeResult.Usage, estimateOpenAIMessagesTokens(routeResult.Messages), estimateTokenCount(routeResult.Response))
|
||||
data := map[string]interface{}{
|
||||
"tools": routeResult.Selected,
|
||||
"selections": routeResult.Decision.Tools,
|
||||
"reason": routeResult.Decision.Reason,
|
||||
}
|
||||
message := "工具路由结果:无需调用工具"
|
||||
if len(routeResult.Selected) > 0 {
|
||||
message = "工具路由结果:将调用 " + strings.Join(routeResult.Selected, ", ")
|
||||
}
|
||||
emitTrace("tool_router", "route", "success", message, data)
|
||||
toolConfigs = filterToolConfigs(toolConfigs, routeResult.Selected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Enrich messages with tools (pre-process)
|
||||
chatMsgs = agents.EnrichMessages(ctx.Request.Context(), chatMsgs, toolConfigs, emitTrace)
|
||||
|
||||
// Build OpenAI-compatible request
|
||||
openaiMsgs, err := convertToOpenAIMessages(chatMsgs)
|
||||
if err != nil {
|
||||
@@ -305,6 +309,25 @@ func handleChat(ctx *gin.Context) {
|
||||
sendSSEDone(ctx, flusher)
|
||||
return
|
||||
}
|
||||
functionTools := buildFunctionTools(toolConfigs)
|
||||
if profile.SystemPrompt != "" {
|
||||
openaiMsgs = append([]openaiMessage{{Role: "system", Content: profile.SystemPrompt}}, openaiMsgs...)
|
||||
}
|
||||
if len(functionTools) > 0 {
|
||||
toolNames := make([]string, 0, len(functionTools))
|
||||
for _, tool := range functionTools {
|
||||
toolNames = append(toolNames, tool.Function.Name)
|
||||
}
|
||||
emitTrace("function_tools", "prepare", "success", "已启用 Function Calling 工具", map[string]interface{}{"tools": toolNames})
|
||||
openaiMsgs = append([]openaiMessage{{Role: "system", Content: "可用工具使用规则:当用户询问本月、今天、本周、下周等相对日期的日程时,先调用 time 获取明确 start_date/end_date,再调用 ops_ai_assistant_schedule_query 查询日程。不要臆造工具结果中不存在的日程。"}}, openaiMsgs...)
|
||||
var toolExecuted bool
|
||||
openaiMsgs, toolExecuted, err = runOpenAIToolLoop(ctx.Request.Context(), profile, openaiMsgs, functionTools, currentUser, tracker, emitTrace)
|
||||
if err != nil {
|
||||
emitTrace("model", "tool_call", "error", "工具调用失败,将继续普通回答", map[string]interface{}{"error": err.Error()})
|
||||
} else if toolExecuted {
|
||||
emitTrace("model", "tool_call", "success", "工具调用完成,准备生成最终回答", nil)
|
||||
}
|
||||
}
|
||||
apiReq := openaiChatRequest{
|
||||
Model: profile.Model,
|
||||
Messages: openaiMsgs,
|
||||
@@ -313,11 +336,6 @@ func handleChat(ctx *gin.Context) {
|
||||
Temperature: 0.7,
|
||||
}
|
||||
|
||||
// Add system prompt if configured
|
||||
if profile.SystemPrompt != "" {
|
||||
apiReq.Messages = append([]openaiMessage{{Role: "system", Content: profile.SystemPrompt}}, apiReq.Messages...)
|
||||
}
|
||||
|
||||
trimmedMessages, trimStats := trimOpenAIMessagesToContextWindow(apiReq.Messages, profile.ContextWindowTokens)
|
||||
apiReq.Messages = trimmedMessages
|
||||
if trimStats.RemovedMessages > 0 {
|
||||
@@ -413,6 +431,39 @@ func handleChat(ctx *gin.Context) {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
func callOpenAIChat(ctx context.Context, cfg models.ConfigsAIChatOpenAI_, req openaiChatRequest) (*openaiChatResponse, error) {
|
||||
bodyBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(cfg.BaseUrl, "/") + "/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+cfg.ApiKey)
|
||||
|
||||
client := &http.Client{Timeout: time.Duration(cfg.Timeout) * time.Second}
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接上游服务失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("上游返回 %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result openaiChatResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func streamOpenAI(ctx context.Context, cfg models.ConfigsAIChatOpenAI_, req openaiChatRequest, onData func(openaiStreamChunk)) error {
|
||||
bodyBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
@@ -836,6 +887,204 @@ func buildToolConfigs(configs []models.ConfigsAIChatTool_) []agents.ToolConfig {
|
||||
return result
|
||||
}
|
||||
|
||||
func buildFunctionTools(configs []agents.ToolConfig) []openaiTool {
|
||||
tools := make([]openaiTool, 0)
|
||||
for _, config := range configs {
|
||||
if !config.Enabled {
|
||||
continue
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(config.Name)) {
|
||||
case "time":
|
||||
tools = append(tools, openaiTool{
|
||||
Type: "function",
|
||||
Function: openaiFunctionDefinition{
|
||||
Name: "time",
|
||||
Description: "解析当前时间、相对日期和日期范围。遇到本月、今天、本周等相对日期时先调用本工具获得明确日期。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"range": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"today", "yesterday", "tomorrow", "this_week", "last_week", "next_week", "this_month", "last_month", "next_month", "this_year", "custom"},
|
||||
"description": "要解析的日期范围。",
|
||||
},
|
||||
"timezone": map[string]interface{}{"type": "string", "description": "可选时区,例如 Asia/Shanghai。"},
|
||||
"start_date": map[string]interface{}{"type": "string", "description": "custom 范围开始日期,格式 YYYY-MM-DD。"},
|
||||
"end_date": map[string]interface{}{"type": "string", "description": "custom 范围结束日期,格式 YYYY-MM-DD。"},
|
||||
},
|
||||
"required": []string{"range"},
|
||||
},
|
||||
},
|
||||
})
|
||||
case "ops_ai_assistant_schedule_query", "ops_ai_assistant":
|
||||
tools = append(tools, openaiTool{
|
||||
Type: "function",
|
||||
Function: openaiFunctionDefinition{
|
||||
Name: "ops_ai_assistant_schedule_query",
|
||||
Description: "按明确日期范围查询当前用户可见的 OPS 日历/日程。相对日期需先调用 time 获取 start_date/end_date。",
|
||||
Parameters: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"start_date": map[string]interface{}{"type": "string", "description": "开始日期,格式 YYYY-MM-DD。"},
|
||||
"end_date": map[string]interface{}{"type": "string", "description": "结束日期,格式 YYYY-MM-DD。"},
|
||||
"calendar_id": map[string]interface{}{"type": "integer", "description": "可选日历 ID;不传则查询全部可见日程。"},
|
||||
"limit": map[string]interface{}{"type": "integer", "description": "可选返回上限,默认 100,最大 200。"},
|
||||
},
|
||||
"required": []string{"start_date", "end_date"},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
func parseJSONTraceValue(raw string) interface{} {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return ""
|
||||
}
|
||||
var parsed interface{}
|
||||
if err := json.Unmarshal([]byte(value), &parsed); err != nil {
|
||||
return value
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
|
||||
func runOpenAIToolLoop(ctx context.Context, profile models.ConfigsAIChatOpenAI_, messages []openaiMessage, tools []openaiTool, currentUser *TabUser, tracker *tokenUsageTracker, trace agents.TraceFunc) ([]openaiMessage, bool, error) {
|
||||
toolExecuted := false
|
||||
for round := 0; round < 5; round++ {
|
||||
if trace != nil {
|
||||
trace("model", "tool_call", "running", "正在请求模型决定是否调用工具", map[string]interface{}{"round": round + 1})
|
||||
}
|
||||
req := openaiChatRequest{
|
||||
Model: profile.Model,
|
||||
Messages: messages,
|
||||
Stream: false,
|
||||
MaxTokens: profile.MaxTokens,
|
||||
Temperature: 0.1,
|
||||
Tools: tools,
|
||||
ToolChoice: "auto",
|
||||
}
|
||||
resp, err := callOpenAIChat(ctx, profile, req)
|
||||
if err != nil {
|
||||
return messages, toolExecuted, err
|
||||
}
|
||||
responseText := ""
|
||||
if len(resp.Choices) == 0 {
|
||||
return messages, toolExecuted, nil
|
||||
}
|
||||
message := resp.Choices[0].Message
|
||||
responseText = message.Content
|
||||
tracker.addToolUsage(resp.Usage, estimateOpenAIMessagesTokens(messages), estimateTokenCount(responseText))
|
||||
if len(message.ToolCalls) == 0 {
|
||||
return messages, toolExecuted, nil
|
||||
}
|
||||
|
||||
toolExecuted = true
|
||||
messages = append(messages, openaiMessage{Role: "assistant", Content: message.Content, ToolCalls: message.ToolCalls})
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
toolName := strings.TrimSpace(toolCall.Function.Name)
|
||||
parsedArgs := parseJSONTraceValue(toolCall.Function.Arguments)
|
||||
if trace != nil {
|
||||
trace(toolName, "call", "running", "模型调用工具:"+toolName, map[string]interface{}{
|
||||
"tool": toolName,
|
||||
"arguments": parsedArgs,
|
||||
})
|
||||
}
|
||||
resultJSON, err := executeAIFunctionTool(ctx, toolName, []byte(toolCall.Function.Arguments), currentUser)
|
||||
status := "success"
|
||||
if err != nil {
|
||||
status = "error"
|
||||
resultJSON, _ = json.Marshal(map[string]interface{}{"ok": false, "error": err.Error()})
|
||||
}
|
||||
if trace != nil {
|
||||
data := map[string]interface{}{
|
||||
"tool": toolName,
|
||||
"result": parseJSONTraceValue(string(resultJSON)),
|
||||
}
|
||||
if len(resultJSON) > 1200 {
|
||||
data["result"] = string(resultJSON[:1200]) + "..."
|
||||
data["truncated"] = true
|
||||
}
|
||||
trace(toolName, "execute", status, "工具执行完成:"+toolName, data)
|
||||
}
|
||||
messages = append(messages, openaiMessage{Role: "tool", ToolCallID: toolCall.ID, Name: toolName, Content: string(resultJSON)})
|
||||
}
|
||||
}
|
||||
return messages, toolExecuted, fmt.Errorf("工具调用超过最大轮数")
|
||||
}
|
||||
|
||||
type scheduleQueryArgs struct {
|
||||
StartDate string `json:"start_date"`
|
||||
EndDate string `json:"end_date"`
|
||||
CalendarID uint `json:"calendar_id,omitempty"`
|
||||
Limit int `json:"limit,omitempty"`
|
||||
}
|
||||
|
||||
func executeAIFunctionTool(ctx context.Context, name string, rawArgs []byte, currentUser *TabUser) ([]byte, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
switch name {
|
||||
case "time":
|
||||
var args agents.TimeRangeArgs
|
||||
if len(rawArgs) > 0 {
|
||||
if err := json.Unmarshal(rawArgs, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
result, err := agents.ResolveTimeRange(args, time.Now())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(result)
|
||||
case "ops_ai_assistant_schedule_query", "ops_ai_assistant":
|
||||
var args scheduleQueryArgs
|
||||
if err := json.Unmarshal(rawArgs, &args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
startDate, err := time.Parse("2006-01-02", args.StartDate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid start_date: %w", err)
|
||||
}
|
||||
endDate, err := time.Parse("2006-01-02", args.EndDate)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid end_date: %w", err)
|
||||
}
|
||||
limit := args.Limit
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
if limit > 200 {
|
||||
limit = 200
|
||||
}
|
||||
events, err := QueryCalendarSchedulesForAI(CalendarScheduleQuery{
|
||||
CalendarID: args.CalendarID,
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
User: currentUser,
|
||||
Limit: limit,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(map[string]interface{}{
|
||||
"ok": true,
|
||||
"start_date": args.StartDate,
|
||||
"end_date": args.EndDate,
|
||||
"count": len(events),
|
||||
"limit": limit,
|
||||
"events": events,
|
||||
})
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown tool: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
func selectOpenAIProfile(cfg models.ConfigsAIChat_, name string) (models.ConfigsAIChatOpenAI_, bool) {
|
||||
if name != "" {
|
||||
for _, p := range cfg.OpenAI {
|
||||
|
||||
Reference in New Issue
Block a user