更新工具链
This commit is contained in:
@@ -14,10 +14,12 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
ToolName = "search"
|
||||||
defaultActivationPrompt = `判断用户问题是否需要联网搜索。
|
defaultActivationPrompt = `判断用户问题是否需要联网搜索。
|
||||||
当问题涉及实时信息、新闻、价格、当前版本、近期事件、政策、网页资料核验,或用户明确要求“查一下/搜索/联网/最新”时调用 search。
|
当问题涉及实时信息、新闻、价格、当前版本、近期事件、政策、网页资料核验,或用户明确要求“查一下/搜索/联网/最新”时调用 search。
|
||||||
当用户询问“历史上的今天”、某日期历史事件、需要按当前日期动态确定查询词的常识资料时,也应调用 search;如果联网无结果,主模型会回退到自身知识库回答并说明来源。
|
当用户询问“历史上的今天”、某日期历史事件、需要按当前日期动态确定查询词的常识资料时,也应调用 search;如果联网无结果,主模型会回退到自身知识库回答并说明来源。
|
||||||
@@ -91,6 +93,11 @@ type ListResponse struct {
|
|||||||
Profiles []ProfileConfig `json:"profiles"`
|
Profiles []ProfileConfig `json:"profiles"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ToolArgs struct {
|
||||||
|
Query string `json:"query"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
|
||||||
type braveSearchResponse struct {
|
type braveSearchResponse struct {
|
||||||
Web struct {
|
Web struct {
|
||||||
Results []Result `json:"results"`
|
Results []Result `json:"results"`
|
||||||
@@ -208,6 +215,53 @@ func (s *State) ActivationPrompt() string {
|
|||||||
return strings.TrimSpace(s.cfg.ActivationPrompt)
|
return strings.TrimSpace(s.cfg.ActivationPrompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *State) ToolDefinition(description string) *model.Tool {
|
||||||
|
description = strings.TrimSpace(description)
|
||||||
|
if description == "" {
|
||||||
|
description = s.ActivationPrompt()
|
||||||
|
}
|
||||||
|
return &model.Tool{
|
||||||
|
Type: model.ToolTypeFunction,
|
||||||
|
Function: &model.FunctionDefinition{
|
||||||
|
Name: ToolName,
|
||||||
|
Description: description,
|
||||||
|
Parameters: map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"query": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "要联网搜索的关键词。若问题包含相对日期,应先调用 time 工具后使用绝对日期改写查询词。",
|
||||||
|
},
|
||||||
|
"reason": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "调用联网搜索的原因。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"query"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) ExecuteTool(ctx context.Context, args string) (string, error) {
|
||||||
|
var parsed ToolArgs
|
||||||
|
if err := json.Unmarshal([]byte(strings.TrimSpace(args)), &parsed); err != nil {
|
||||||
|
return "", fmt.Errorf("解析搜索工具参数失败: %w", err)
|
||||||
|
}
|
||||||
|
query := strings.TrimSpace(parsed.Query)
|
||||||
|
if query == "" {
|
||||||
|
return "", errors.New("搜索关键词不能为空")
|
||||||
|
}
|
||||||
|
results, profile, err := s.Search(ctx, query)
|
||||||
|
if err != nil {
|
||||||
|
return BuildErrorContext(query, err), nil
|
||||||
|
}
|
||||||
|
if len(results) == 0 {
|
||||||
|
return BuildFallbackContext(profile, query, parsed.Reason, errors.New("未搜索到相关网页结果")), nil
|
||||||
|
}
|
||||||
|
return BuildResultContext(profile, query, results, parsed.Reason), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *State) ActiveProfile() ProfileConfig {
|
func (s *State) ActiveProfile() ProfileConfig {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|||||||
@@ -121,3 +121,42 @@ func TestLoadConfigWritesLegacyProfiles(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToolDefinitionAndExecuteTool(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Query().Get("q") != "golang" {
|
||||||
|
t.Fatalf("query = %s", r.URL.RawQuery)
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{"Heading":"Go","Abstract":"Go language","AbstractURL":"https://go.dev"}`))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
state, err := NewState(&Config{Enabled: true, Profiles: ProfileConfigs{{Name: "ddg", Active: true, Enabled: true, Provider: "duckduckgo", BaseURL: server.URL, Count: 1, Timeout: 1}}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
definition := state.ToolDefinition("custom search")
|
||||||
|
if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom search" {
|
||||||
|
t.Fatalf("unexpected definition: %#v", definition)
|
||||||
|
}
|
||||||
|
text, err := state.ExecuteTool(context.Background(), `{"query":"golang","reason":"测试搜索"}`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, want := range []string{"联网搜索", "golang", "Go", "https://go.dev", "测试搜索"} {
|
||||||
|
if !strings.Contains(text, want) {
|
||||||
|
t.Fatalf("tool result missing %q:\n%s", want, text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteToolRejectsEmptyQuery(t *testing.T) {
|
||||||
|
state, err := NewState(&Config{Enabled: true, Profiles: ProfileConfigs{{Name: "ddg", Active: true, Enabled: true, Provider: "duckduckgo", BaseURL: defaultBaseURL, Count: 1, Timeout: 1}}})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = state.ExecuteTool(context.Background(), `{"query":" "}`)
|
||||||
|
if err == nil || !strings.Contains(err.Error(), "不能为空") {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package sqlquery
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@@ -15,11 +16,13 @@ import (
|
|||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
_ "github.com/go-sql-driver/mysql"
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
_ "modernc.org/sqlite"
|
_ "modernc.org/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
ToolName = "sql"
|
||||||
defaultActivationPrompt = `判断用户问题是否需要查询业务数据库。
|
defaultActivationPrompt = `判断用户问题是否需要查询业务数据库。
|
||||||
仅当用户询问数据库表、记录、字段、时间、状态、内容、统计、最近/最早/某时间范围内的数据时返回 activate=true。
|
仅当用户询问数据库表、记录、字段、时间、状态、内容、统计、最近/最早/某时间范围内的数据时返回 activate=true。
|
||||||
当用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,必须返回 activate=true,并说明应查询 tab_calendar_events 表。
|
当用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,必须返回 activate=true,并说明应查询 tab_calendar_events 表。
|
||||||
@@ -80,6 +83,19 @@ type QueryResult struct {
|
|||||||
MaxRows int `json:"max_rows"`
|
MaxRows int `json:"max_rows"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ToolArgs struct {
|
||||||
|
Question string `json:"question"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SQLGenerator func(ctx context.Context, prompt string, maxTokens int) (string, error)
|
||||||
|
|
||||||
|
type GenerationResult struct {
|
||||||
|
Database string `json:"database"`
|
||||||
|
SQL string `json:"sql"`
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
|
||||||
func LoadConfig(path string) (*Config, error) {
|
func LoadConfig(path string) (*Config, error) {
|
||||||
if _, err := os.Stat(path); err != nil {
|
if _, err := os.Stat(path); err != nil {
|
||||||
if !os.IsNotExist(err) {
|
if !os.IsNotExist(err) {
|
||||||
@@ -180,6 +196,73 @@ func (s *State) ActivationPrompt() string {
|
|||||||
return strings.TrimSpace(s.cfg.ActivationPrompt)
|
return strings.TrimSpace(s.cfg.ActivationPrompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *State) ToolDefinition(description string) *model.Tool {
|
||||||
|
description = strings.TrimSpace(description)
|
||||||
|
if description == "" {
|
||||||
|
description = s.ActivationPrompt()
|
||||||
|
}
|
||||||
|
return &model.Tool{
|
||||||
|
Type: model.ToolTypeFunction,
|
||||||
|
Function: &model.FunctionDefinition{
|
||||||
|
Name: ToolName,
|
||||||
|
Description: description,
|
||||||
|
Parameters: map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"question": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "需要查询数据库的问题。若已有 time 工具结果,应使用其中的绝对日期范围解释相对时间。",
|
||||||
|
},
|
||||||
|
"reason": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "调用数据库查询工具的原因。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"question"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *State) ExecuteTool(ctx context.Context, args string, generator SQLGenerator) (string, error) {
|
||||||
|
if generator == nil {
|
||||||
|
return "", errors.New("SQL 生成器未配置")
|
||||||
|
}
|
||||||
|
var parsed ToolArgs
|
||||||
|
if err := json.Unmarshal([]byte(strings.TrimSpace(args)), &parsed); err != nil {
|
||||||
|
return "", fmt.Errorf("解析 SQL 工具参数失败: %w", err)
|
||||||
|
}
|
||||||
|
question := strings.TrimSpace(parsed.Question)
|
||||||
|
if question == "" {
|
||||||
|
return "", errors.New("数据库查询问题不能为空")
|
||||||
|
}
|
||||||
|
schemaContext, err := s.SchemaContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return BuildErrorContext(question, err), nil
|
||||||
|
}
|
||||||
|
generated, err := GenerateSQL(ctx, generator, question, schemaContext)
|
||||||
|
if err != nil {
|
||||||
|
return BuildErrorContext(question, err), nil
|
||||||
|
}
|
||||||
|
generated.Database = strings.TrimSpace(generated.Database)
|
||||||
|
generated.SQL = strings.TrimSpace(generated.SQL)
|
||||||
|
if generated.SQL == "" {
|
||||||
|
return BuildErrorContext(question, fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason)), nil
|
||||||
|
}
|
||||||
|
result, err := s.ExecuteReadOnly(ctx, generated.Database, generated.SQL)
|
||||||
|
if err != nil {
|
||||||
|
return BuildErrorContext(question, err), nil
|
||||||
|
}
|
||||||
|
contextText := BuildResultContext(question, generated.SQL, result)
|
||||||
|
if strings.TrimSpace(parsed.Reason) != "" {
|
||||||
|
contextText += "\n调用原因:" + strings.TrimSpace(parsed.Reason)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(generated.Reason) != "" {
|
||||||
|
contextText += "\nSQL 生成原因:" + strings.TrimSpace(generated.Reason)
|
||||||
|
}
|
||||||
|
return contextText, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *State) DefaultDatabase() string {
|
func (s *State) DefaultDatabase() string {
|
||||||
if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.DefaultDatabase) == "" {
|
if s == nil || s.cfg == nil || strings.TrimSpace(s.cfg.DefaultDatabase) == "" {
|
||||||
return defaultDatabaseName
|
return defaultDatabaseName
|
||||||
@@ -220,6 +303,37 @@ func (s *State) SchemaContext(ctx context.Context) (string, error) {
|
|||||||
return text, nil
|
return text, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GenerateSQL(ctx context.Context, generator SQLGenerator, userQuery string, schemaContext string) (*GenerationResult, error) {
|
||||||
|
prompt := BuildSQLGenerationPrompt(userQuery, schemaContext)
|
||||||
|
text, err := generator(ctx, prompt, 1024)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var generated GenerationResult
|
||||||
|
if err := json.Unmarshal([]byte(extractJSONObject(text)), &generated); err != nil {
|
||||||
|
return nil, fmt.Errorf("解析 SQL 生成结果失败: %w", err)
|
||||||
|
}
|
||||||
|
return &generated, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildSQLGenerationPrompt(userQuery string, schemaContext string) string {
|
||||||
|
return fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题、工具结果上下文和数据库 schema 生成一条只读 SQL。
|
||||||
|
要求:
|
||||||
|
- 只能返回 JSON,不要使用 Markdown。
|
||||||
|
- JSON 格式:{"database":"数据库名称","sql":"SELECT ... LIMIT N","reason":"生成原因"}
|
||||||
|
- 只能生成 SELECT 或 WITH 查询,禁止 INSERT/UPDATE/DELETE/DROP/ALTER/CREATE 等任何修改语句。
|
||||||
|
- 必须只使用 schema 中出现的数据库、表和字段。
|
||||||
|
- 如果工具结果中包含“时间工具结果”,必须使用其中的绝对日期范围解释用户问题里的今天、明天、昨天、本周、本月、本年、最近等相对时间。
|
||||||
|
- 用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,优先查询 tab_calendar_events 表;如果 schema 中没有该表,再返回无法根据已知表结构生成查询。
|
||||||
|
- 查询日程表时,涉及日期范围必须使用半开区间:时间字段 >= start AND 时间字段 < end_exclusive;时间字段必须从 schema 中选择真实存在的字段。
|
||||||
|
- 必须添加 LIMIT,且 LIMIT 不超过插件配置的 max_rows。
|
||||||
|
- 如果无法根据 schema 回答,返回 {"database":"","sql":"","reason":"无法根据已知表结构生成查询"}。
|
||||||
|
|
||||||
|
%s
|
||||||
|
|
||||||
|
用户问题:%s`, schemaContext, userQuery)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *State) ExecuteReadOnly(ctx context.Context, databaseName string, query string) (*QueryResult, error) {
|
func (s *State) ExecuteReadOnly(ctx context.Context, databaseName string, query string) (*QueryResult, error) {
|
||||||
if !s.Enabled() {
|
if !s.Enabled() {
|
||||||
return nil, errors.New("SQL 查询插件未启用")
|
return nil, errors.New("SQL 查询插件未启用")
|
||||||
@@ -588,6 +702,16 @@ func (d *database) rejectExcludedTables(query string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractJSONObject(text string) string {
|
||||||
|
text = strings.TrimSpace(text)
|
||||||
|
start := strings.Index(text, "{")
|
||||||
|
end := strings.LastIndex(text, "}")
|
||||||
|
if start >= 0 && end > start {
|
||||||
|
return text[start : end+1]
|
||||||
|
}
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
func scanRows(rows *sql.Rows, cfg DatabaseConfig, query string) (*QueryResult, error) {
|
func scanRows(rows *sql.Rows, cfg DatabaseConfig, query string) (*QueryResult, error) {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
columns, err := rows.Columns()
|
columns, err := rows.Columns()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package sqlquery
|
package sqlquery
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -43,3 +44,42 @@ func TestValidateReadOnlySQLRejectsUnsafeStatements(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToolDefinition(t *testing.T) {
|
||||||
|
state := &State{}
|
||||||
|
definition := state.ToolDefinition("custom sql")
|
||||||
|
if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom sql" {
|
||||||
|
t.Fatalf("unexpected definition: %#v", definition)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildSQLGenerationPromptIncludesSafetyRules(t *testing.T) {
|
||||||
|
prompt := BuildSQLGenerationPrompt("本月有什么日程安排", "schema: tab_calendar_events(start_time)")
|
||||||
|
for _, want := range []string{"只读 SQL", "SELECT", "WITH", "半开区间", "tab_calendar_events", "max_rows", "本月有什么日程安排"} {
|
||||||
|
if !strings.Contains(prompt, want) {
|
||||||
|
t.Fatalf("prompt missing %q:\n%s", want, prompt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateSQLParsesJSONAndRejectsMalformed(t *testing.T) {
|
||||||
|
generated, err := GenerateSQL(context.Background(), func(ctx context.Context, prompt string, maxTokens int) (string, error) {
|
||||||
|
if !strings.Contains(prompt, "schema") || maxTokens != 1024 {
|
||||||
|
t.Fatalf("unexpected prompt/maxTokens: %s / %d", prompt, maxTokens)
|
||||||
|
}
|
||||||
|
return `{"database":"default","sql":"SELECT * FROM events LIMIT 1","reason":"ok"}`, nil
|
||||||
|
}, "查事件", "schema")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if generated.Database != "default" || !strings.Contains(generated.SQL, "SELECT") || generated.Reason != "ok" {
|
||||||
|
t.Fatalf("unexpected generated SQL: %#v", generated)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = GenerateSQL(context.Background(), func(context.Context, string, int) (string, error) {
|
||||||
|
return "not json", nil
|
||||||
|
}, "查事件", "schema")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected malformed JSON error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+45
-2
@@ -1,12 +1,22 @@
|
|||||||
package timeagent
|
package timeagent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ActivationPrompt = "提供当前日期、时间和常用时间范围。当用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近、历史上的今天、日程安排等相对时间表达时,应先调用此工具;如果后续还需要联网搜索或查数据库,可继续调用 search 或 sql。"
|
const (
|
||||||
|
ToolName = "time"
|
||||||
|
ActivationPrompt = "提供当前日期、时间和常用时间范围。当用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近、历史上的今天、日程安排等相对时间表达时,应先调用此工具;如果后续还需要联网搜索或查数据库,可继续调用 search 或 sql。"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ToolArgs struct {
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
}
|
||||||
|
|
||||||
type Range struct {
|
type Range struct {
|
||||||
Start time.Time
|
Start time.Time
|
||||||
@@ -23,6 +33,39 @@ type Context struct {
|
|||||||
ThisYear Range
|
ThisYear Range
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ToolDefinition(description string) *model.Tool {
|
||||||
|
description = strings.TrimSpace(description)
|
||||||
|
if description == "" {
|
||||||
|
description = ActivationPrompt
|
||||||
|
}
|
||||||
|
return &model.Tool{
|
||||||
|
Type: model.ToolTypeFunction,
|
||||||
|
Function: &model.FunctionDefinition{
|
||||||
|
Name: ToolName,
|
||||||
|
Description: description,
|
||||||
|
Parameters: map[string]any{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]any{
|
||||||
|
"reason": map[string]any{
|
||||||
|
"type": "string",
|
||||||
|
"description": "调用时间工具的原因。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExecuteTool(args string, now time.Time) (string, error) {
|
||||||
|
var parsed ToolArgs
|
||||||
|
if strings.TrimSpace(args) != "" {
|
||||||
|
if err := json.Unmarshal([]byte(args), &parsed); err != nil {
|
||||||
|
return "", fmt.Errorf("解析时间工具参数失败: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return BuildContext(Resolve(now), parsed.Reason), nil
|
||||||
|
}
|
||||||
|
|
||||||
func Resolve(now time.Time) Context {
|
func Resolve(now time.Time) Context {
|
||||||
loc := now.Location()
|
loc := now.Location()
|
||||||
today := Range{Start: startOfDay(now), End: startOfDay(now).AddDate(0, 0, 1)}
|
today := Range{Start: startOfDay(now), End: startOfDay(now).AddDate(0, 0, 1)}
|
||||||
@@ -47,7 +90,7 @@ func Resolve(now time.Time) Context {
|
|||||||
|
|
||||||
func BuildContext(ctx Context, routeReason string) string {
|
func BuildContext(ctx Context, routeReason string) string {
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
b.WriteString("时间工具结果。后续工具必须优先使用这里的绝对日期解释用户问题中的相对时间,不要自行猜测当前日期。\n")
|
b.WriteString("时间工具结果。请优先使用这里的绝对日期解释用户问题中的相对时间,不要自行猜测当前日期。\n")
|
||||||
fmt.Fprintf(&b, "当前本地日期时间:%s\n", ctx.Now.Format("2006-01-02 15:04:05 MST"))
|
fmt.Fprintf(&b, "当前本地日期时间:%s\n", ctx.Now.Format("2006-01-02 15:04:05 MST"))
|
||||||
fmt.Fprintf(&b, "今天:%s\n", FormatSQLRange(ctx.Today))
|
fmt.Fprintf(&b, "今天:%s\n", FormatSQLRange(ctx.Today))
|
||||||
fmt.Fprintf(&b, "明天:%s\n", FormatSQLRange(ctx.Tomorrow))
|
fmt.Fprintf(&b, "明天:%s\n", FormatSQLRange(ctx.Tomorrow))
|
||||||
|
|||||||
@@ -29,3 +29,19 @@ func TestBuildContextIncludesSQLHints(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToolDefinitionAndExecuteTool(t *testing.T) {
|
||||||
|
definition := ToolDefinition("custom description")
|
||||||
|
if definition.Function == nil || definition.Function.Name != ToolName || definition.Function.Description != "custom description" {
|
||||||
|
t.Fatalf("unexpected definition: %#v", definition)
|
||||||
|
}
|
||||||
|
text, err := ExecuteTool(`{"reason":"测试原因"}`, time.Date(2026, 6, 10, 13, 14, 15, 0, time.UTC))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, want := range []string{"时间工具结果", "2026-06-10", "本月", "start=", "end_exclusive=", "测试原因"} {
|
||||||
|
if !strings.Contains(text, want) {
|
||||||
|
t.Fatalf("tool result missing %q:\n%s", want, text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -37,16 +37,12 @@ const (
|
|||||||
defaultOpenAITimeout = 120
|
defaultOpenAITimeout = 120
|
||||||
defaultToolRouterTimeout = 30
|
defaultToolRouterTimeout = 30
|
||||||
defaultToolRouterMaxTokens = 512
|
defaultToolRouterMaxTokens = 512
|
||||||
defaultToolRouterSystemText = `你是工具路由器。根据用户最新问题和可用工具列表,判断本轮是否需要调用一个或多个工具。
|
defaultToolRouterSystemText = `你可以按需直接调用可用工具来回答用户问题。
|
||||||
只能返回 JSON,不要使用 Markdown。
|
如果用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近等相对时间,且后续需要搜索或查询数据库,应先调用 time 获取绝对日期范围。
|
||||||
JSON 格式:{"tools":[{"name":"工具名称","reason":"..."}],"reason":"..."}
|
需要实时网页资料、新闻、当前版本、近期事件、网页核验或用户明确要求联网时,调用 search。
|
||||||
工具名称必须来自“可用工具”列表。
|
需要查询本地业务数据、日程、会议、待办、记录、统计或时间范围内数据时,调用 sql。
|
||||||
可以选择多个工具,工具会按配置顺序依次执行;后面的工具可以使用前面工具写入的上下文。
|
工具结果优先于模型内置知识;工具失败时必须如实说明,不要编造结果。
|
||||||
如果用户问题包含今天、今日、明天、昨天、本周、本月、本年、最近等相对时间,且还需要调用 search 或 sql,必须同时选择 time,并让 time 排在这些工具之前。
|
只调用确实必要的工具。`
|
||||||
例如“历史上的今天都发生了什么”应选择 time 和 search:先获取今天的绝对日期,再搜索当天历史事件;如果联网无结果,主模型会回退到自身知识库回答并说明来源。
|
|
||||||
例如“本月有什么日程安排”应选择 time 和 sql:先获取本月绝对日期范围,再查询日程表。
|
|
||||||
如果无需工具,返回 {"tools":[],"reason":"..."}。
|
|
||||||
只选择确实必要的工具。`
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type OpenAIConfig struct {
|
type OpenAIConfig struct {
|
||||||
@@ -286,6 +282,11 @@ func normalizeOpenAIConfigs(cfg *Config) (bool, error) {
|
|||||||
return changed, nil
|
return changed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isLegacyToolRouterPrompt(prompt string) bool {
|
||||||
|
prompt = strings.TrimSpace(prompt)
|
||||||
|
return strings.Contains(prompt, "工具路由器") || strings.Contains(prompt, "route_tools") || strings.Contains(prompt, `"tools":[`)
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeToolRouterConfig(cfg *Config) (bool, error) {
|
func normalizeToolRouterConfig(cfg *Config) (bool, error) {
|
||||||
changed := false
|
changed := false
|
||||||
defaults := defaultToolRouterConfig()
|
defaults := defaultToolRouterConfig()
|
||||||
@@ -298,11 +299,12 @@ func normalizeToolRouterConfig(cfg *Config) (bool, error) {
|
|||||||
cfg.ToolRouter.MaxTokens = defaultToolRouterMaxTokens
|
cfg.ToolRouter.MaxTokens = defaultToolRouterMaxTokens
|
||||||
changed = true
|
changed = true
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) == "" {
|
systemPrompt := strings.TrimSpace(cfg.ToolRouter.SystemPrompt)
|
||||||
|
if systemPrompt == "" || isLegacyToolRouterPrompt(systemPrompt) {
|
||||||
cfg.ToolRouter.SystemPrompt = defaultToolRouterSystemText
|
cfg.ToolRouter.SystemPrompt = defaultToolRouterSystemText
|
||||||
changed = true
|
changed = true
|
||||||
} else if strings.TrimSpace(cfg.ToolRouter.SystemPrompt) != cfg.ToolRouter.SystemPrompt {
|
} else if systemPrompt != cfg.ToolRouter.SystemPrompt {
|
||||||
cfg.ToolRouter.SystemPrompt = strings.TrimSpace(cfg.ToolRouter.SystemPrompt)
|
cfg.ToolRouter.SystemPrompt = systemPrompt
|
||||||
changed = true
|
changed = true
|
||||||
}
|
}
|
||||||
if len(cfg.ToolRouter.Tools) == 0 {
|
if len(cfg.ToolRouter.Tools) == 0 {
|
||||||
@@ -432,12 +434,12 @@ type openAIListResponse struct {
|
|||||||
Profiles []OpenAIConfig `json:"profiles"`
|
Profiles []OpenAIConfig `json:"profiles"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type toolTextCompleter func(context.Context, *OpenAIProfile, []ChatMessage, int, time.Duration) (string, error)
|
type chatCompleter func(context.Context, *OpenAIProfile, model.CreateChatCompletionRequest, time.Duration) (model.ChatCompletionResponse, error)
|
||||||
|
|
||||||
type ToolRouterState struct {
|
type ToolRouterState struct {
|
||||||
cfg *ToolRouterConfig
|
cfg *ToolRouterConfig
|
||||||
ai *OpenAIState
|
ai *OpenAIState
|
||||||
complete toolTextCompleter
|
complete chatCompleter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewToolRouterState(config *ToolRouterConfig, ai *OpenAIState) (*ToolRouterState, error) {
|
func NewToolRouterState(config *ToolRouterConfig, ai *OpenAIState) (*ToolRouterState, error) {
|
||||||
@@ -453,7 +455,7 @@ func NewToolRouterState(config *ToolRouterConfig, ai *OpenAIState) (*ToolRouterS
|
|||||||
return nil, fmt.Errorf("tool_router.openai_name 配置无效: %w", err)
|
return nil, fmt.Errorf("tool_router.openai_name 配置无效: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &ToolRouterState{cfg: config, ai: ai, complete: completeTextWithTimeout}, nil
|
return &ToolRouterState{cfg: config, ai: ai, complete: completeChatWithTimeout}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOpenAIState(configs []OpenAIConfig) (*OpenAIState, error) {
|
func NewOpenAIState(configs []OpenAIConfig) (*OpenAIState, error) {
|
||||||
@@ -796,20 +798,17 @@ func chatHandler(c *gin.Context) {
|
|||||||
usage := newTokenUsageTracker()
|
usage := newTokenUsageTracker()
|
||||||
ctx = contextWithTokenUsage(ctx, usage)
|
ctx = contextWithTokenUsage(ctx, usage)
|
||||||
|
|
||||||
chatMessages := req.Messages
|
// 用 Function Calling 工具循环替代旧的路由+隐藏上下文机制
|
||||||
withTools, err := enrichMessagesWithRoutedTools(ctx, profile, chatMessages, emit)
|
messages, err := runAgentToolLoop(ctx, profile, req.Messages, emit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintln(os.Stderr, "工具路由调用失败:", err)
|
fmt.Fprintln(os.Stderr, "Agent 工具循环失败:", err)
|
||||||
} else {
|
messages, err = buildArkMessages(req.Messages)
|
||||||
chatMessages = withTools
|
if err != nil {
|
||||||
|
emitError(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// 构建 ark 消息列表
|
promptTokens := estimateChatMessagesTokens(req.Messages)
|
||||||
messages, err := buildArkMessages(chatMessages)
|
|
||||||
if err != nil {
|
|
||||||
emitError(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
promptTokens := estimateChatMessagesTokens(chatMessages)
|
|
||||||
|
|
||||||
emitTrace("model", "request", "running", "正在调用模型生成回答", nil)
|
emitTrace("model", "request", "running", "正在调用模型生成回答", nil)
|
||||||
stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
|
stream, err := profile.Client.CreateChatCompletionStream(ctx, model.CreateChatCompletionRequest{
|
||||||
@@ -885,21 +884,16 @@ func chatHandler(c *gin.Context) {
|
|||||||
stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
|
stats := usage.snapshot(tokensPerSecond(completionTokens, streamStarted), peakTokensPerSecond)
|
||||||
emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats})
|
emit(chatSSEFrame{Type: "delta", Text: delta, Stats: &stats})
|
||||||
}
|
}
|
||||||
|
// 思考过程 reasoning_content 单独事件推送
|
||||||
|
if resp.Choices[0].Delta.ReasoningContent != nil && *resp.Choices[0].Delta.ReasoningContent != "" {
|
||||||
|
emit(chatSSEFrame{Type: "reasoning", Text: *resp.Choices[0].Delta.ReasoningContent})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── 辅助函数 ─────────────────────────────────────────────
|
// ─── 辅助函数 ─────────────────────────────────────────────
|
||||||
|
|
||||||
func latestUserQuery(messages []ChatMessage) string {
|
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
|
||||||
if messages[i].Role == "user" {
|
|
||||||
return strings.TrimSpace(messages[i].Content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func estimateChatMessagesTokens(messages []ChatMessage) int {
|
func estimateChatMessagesTokens(messages []ChatMessage) int {
|
||||||
total := 0
|
total := 0
|
||||||
for _, msg := range messages {
|
for _, msg := range messages {
|
||||||
@@ -951,380 +945,214 @@ func tokensPerSecond(tokens int, start time.Time) float64 {
|
|||||||
return float64(tokens) / elapsed
|
return float64(tokens) / elapsed
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolSelection struct {
|
type agentTool struct {
|
||||||
Name string `json:"name"`
|
name string
|
||||||
Reason string `json:"reason"`
|
definition *model.Tool
|
||||||
|
execute func(context.Context, string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolRoutingDecision struct {
|
func (t agentTool) Name() string { return t.name }
|
||||||
Tools []ToolSelection `json:"tools"`
|
|
||||||
Reason string `json:"reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChatTool interface {
|
const maxAgentToolIterations = 6
|
||||||
Name() string
|
|
||||||
Description() string
|
|
||||||
Enabled() bool
|
|
||||||
Enrich(context.Context, *OpenAIProfile, []ChatMessage, string, func(chatSSEFrame)) ([]ChatMessage, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type TimeChatTool struct{}
|
func availableAgentTools(profile *OpenAIProfile, emit func(chatSSEFrame)) []agentTool {
|
||||||
|
|
||||||
func (t TimeChatTool) Name() string { return "time" }
|
|
||||||
|
|
||||||
func (t TimeChatTool) Description() string {
|
|
||||||
return timeagent.ActivationPrompt
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t TimeChatTool) Enabled() bool { return true }
|
|
||||||
|
|
||||||
func (t TimeChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
||||||
return runTimeTool(ctx, messages, routeReason, emit)
|
|
||||||
}
|
|
||||||
|
|
||||||
type SQLChatTool struct {
|
|
||||||
state *sqlquery.State
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t SQLChatTool) Name() string { return "sql" }
|
|
||||||
|
|
||||||
func (t SQLChatTool) Description() string {
|
|
||||||
if t.state == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return t.state.ActivationPrompt()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t SQLChatTool) Enabled() bool { return t.state != nil && t.state.Enabled() }
|
|
||||||
|
|
||||||
func (t SQLChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
||||||
return runSQLTool(ctx, t.state, profile, messages, routeReason, emit)
|
|
||||||
}
|
|
||||||
|
|
||||||
type SearchChatTool struct {
|
|
||||||
state *searchagent.State
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t SearchChatTool) Name() string { return "search" }
|
|
||||||
|
|
||||||
func (t SearchChatTool) Description() string {
|
|
||||||
if t.state == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return t.state.ActivationPrompt()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t SearchChatTool) Enabled() bool { return t.state != nil && t.state.Enabled() }
|
|
||||||
|
|
||||||
func (t SearchChatTool) Enrich(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
||||||
return runSearchTool(ctx, t.state, messages, routeReason, emit)
|
|
||||||
}
|
|
||||||
|
|
||||||
type sqlGenerationResult struct {
|
|
||||||
Database string `json:"database"`
|
|
||||||
SQL string `json:"sql"`
|
|
||||||
Reason string `json:"reason"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func runSQLTool(ctx context.Context, state *sqlquery.State, profile *OpenAIProfile, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
||||||
query := latestUserQuery(messages)
|
|
||||||
if query == "" {
|
|
||||||
return messages, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "running", Message: "正在读取数据库结构"})
|
|
||||||
schemaContext, err := state.SchemaContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "error", Message: "数据库结构读取失败", Data: map[string]any{"error": err.Error()}})
|
|
||||||
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
|
||||||
}
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "schema", Status: "success", Message: "数据库结构读取完成"})
|
|
||||||
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "running", Message: "正在生成只读 SQL"})
|
|
||||||
generated, err := generateSQLForUserQuery(ctx, profile, query, schemaContext)
|
|
||||||
if err != nil {
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "SQL 生成失败", Data: map[string]any{"error": err.Error()}})
|
|
||||||
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
|
||||||
}
|
|
||||||
generated.Database = strings.TrimSpace(generated.Database)
|
|
||||||
generated.SQL = strings.TrimSpace(generated.SQL)
|
|
||||||
if generated.SQL == "" {
|
|
||||||
err := fmt.Errorf("模型未生成可执行 SQL: %s", generated.Reason)
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "error", Message: "模型未生成可执行 SQL", Data: map[string]any{"reason": generated.Reason}})
|
|
||||||
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "running", Message: "正在执行数据库查询", Data: map[string]any{"database": generated.Database}})
|
|
||||||
result, err := state.ExecuteReadOnly(ctx, generated.Database, generated.SQL)
|
|
||||||
if err != nil {
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "error", Message: "数据库查询失败", Data: map[string]any{"error": err.Error()}})
|
|
||||||
return prependHiddenContext(messages, sqlquery.BuildErrorContext(query, err)), nil
|
|
||||||
}
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "generate", Status: "success", Message: "已生成只读 SQL", Data: map[string]any{"database": generated.Database, "sql": generated.SQL, "reason": generated.Reason}})
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "sql", Stage: "execute", Status: "success", Message: fmt.Sprintf("数据库查询完成,返回 %d 行", len(result.Rows)), Data: map[string]any{"database": result.Database, "rows": len(result.Rows), "columns": len(result.Columns), "truncated": result.Truncated, "max_rows": result.MaxRows}})
|
|
||||||
contextText := sqlquery.BuildResultContext(query, generated.SQL, result)
|
|
||||||
if strings.TrimSpace(routeReason) != "" {
|
|
||||||
contextText += "\n激活原因:" + routeReason
|
|
||||||
}
|
|
||||||
return prependHiddenContext(messages, contextText), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func prependHiddenContext(messages []ChatMessage, content string) []ChatMessage {
|
|
||||||
withContext := make([]ChatMessage, 0, len(messages)+1)
|
|
||||||
withContext = append(withContext, ChatMessage{Role: "system", Content: content, Hidden: true})
|
|
||||||
withContext = append(withContext, messages...)
|
|
||||||
return withContext
|
|
||||||
}
|
|
||||||
|
|
||||||
func runTimeTool(ctx context.Context, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
||||||
_ = ctx
|
|
||||||
resolved := timeagent.Resolve(time.Now())
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "time", Stage: "resolve", Status: "success", Message: "已获取当前时间上下文", Data: map[string]any{
|
|
||||||
"today": timeagent.FormatDate(resolved.Now),
|
|
||||||
"this_month": fmt.Sprintf("%s 至 %s", timeagent.FormatDate(resolved.ThisMonth.Start), timeagent.FormatDate(resolved.ThisMonth.End.AddDate(0, 0, -1))),
|
|
||||||
}})
|
|
||||||
return prependHiddenContext(messages, timeagent.BuildContext(resolved, routeReason)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func runSearchTool(ctx context.Context, state *searchagent.State, messages []ChatMessage, routeReason string, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
||||||
query := latestUserQuery(messages)
|
|
||||||
if query == "" {
|
|
||||||
return messages, nil
|
|
||||||
}
|
|
||||||
if state == nil || !state.Enabled() {
|
|
||||||
err := errors.New("联网搜索未启用")
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索未启用", Data: map[string]any{"error": err.Error()}})
|
|
||||||
return prependHiddenContext(messages, searchagent.BuildErrorContext(query, err)), nil
|
|
||||||
}
|
|
||||||
active := state.ActiveProfile()
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "running", Message: "正在联网搜索", Data: map[string]any{"provider": active.Provider}})
|
|
||||||
results, profile, err := state.Search(ctx, query)
|
|
||||||
if err != nil {
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "request", Status: "error", Message: "联网搜索失败", Data: map[string]any{"error": err.Error()}})
|
|
||||||
return prependHiddenContext(messages, searchagent.BuildErrorContext(query, err)), nil
|
|
||||||
}
|
|
||||||
if len(results) == 0 {
|
|
||||||
err := errors.New("未搜索到相关网页结果")
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "warning", Message: "未搜索到相关网页结果,将使用模型知识库回答"})
|
|
||||||
return prependHiddenContext(messages, searchagent.BuildFallbackContext(profile, query, routeReason, err)), nil
|
|
||||||
}
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "search", Stage: "results", Status: "success", Message: fmt.Sprintf("联网搜索完成,找到 %d 条结果", len(results)), Data: map[string]any{"provider": profile.Provider, "count": len(results)}})
|
|
||||||
return prependHiddenContext(messages, searchagent.BuildResultContext(profile, query, results, routeReason)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func enrichMessagesWithRoutedTools(ctx context.Context, chatProfile *OpenAIProfile, messages []ChatMessage, emit func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
||||||
if toolRouterState == nil || toolRouterState.cfg == nil || !toolRouterState.cfg.Enabled {
|
if toolRouterState == nil || toolRouterState.cfg == nil || !toolRouterState.cfg.Enabled {
|
||||||
return messages, nil
|
return nil
|
||||||
}
|
}
|
||||||
if latestUserQuery(messages) == "" {
|
tools := make([]agentTool, 0, len(toolRouterState.cfg.Tools))
|
||||||
return messages, nil
|
for _, item := range toolRouterState.cfg.Tools {
|
||||||
}
|
if !item.Enabled {
|
||||||
tools := availableChatTools(toolRouterState.cfg)
|
|
||||||
if len(tools) == 0 {
|
|
||||||
return messages, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "running", Message: "正在进行工具路由"})
|
|
||||||
decision, err := routeTools(ctx, toolRouterState, chatProfile, messages, tools)
|
|
||||||
if err != nil {
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "error", Message: "工具路由失败,将继续普通回答", Data: map[string]any{"error": err.Error()}})
|
|
||||||
return messages, err
|
|
||||||
}
|
|
||||||
selected := filterToolSelections(decision, tools, toolRouterState.cfg.Tools)
|
|
||||||
selected = ensureTimeSelectionForRelativeQuery(selected, tools, toolRouterState.cfg.Tools, latestUserQuery(messages))
|
|
||||||
if len(selected) == 0 {
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "success", Message: "工具路由结果:无需调用工具", Data: map[string]any{"reason": decision.Reason}})
|
|
||||||
return messages, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
names := make([]string, 0, len(selected))
|
|
||||||
for _, item := range selected {
|
|
||||||
names = append(names, item.Name)
|
|
||||||
}
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: "tool_router", Stage: "route", Status: "success", Message: "工具路由结果:将调用 " + strings.Join(names, ", "), Data: map[string]any{"tools": names, "reason": decision.Reason}})
|
|
||||||
|
|
||||||
current := messages
|
|
||||||
for _, item := range selected {
|
|
||||||
tool := tools[item.Name]
|
|
||||||
next, err := tool.Enrich(ctx, chatProfile, current, firstNonEmpty(item.Reason, decision.Reason), emit)
|
|
||||||
if err != nil {
|
|
||||||
emit(chatSSEFrame{Type: "trace", Tool: item.Name, Stage: "error", Status: "error", Message: "工具调用失败,将继续普通回答", Data: map[string]any{"error": err.Error()}})
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
current = next
|
|
||||||
}
|
|
||||||
return current, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func availableChatTools(config *ToolRouterConfig) map[string]ChatTool {
|
|
||||||
configured := map[string]ToolRouteConfig{}
|
|
||||||
for _, item := range config.Tools {
|
|
||||||
configured[item.Name] = item
|
|
||||||
}
|
|
||||||
registered := []ChatTool{
|
|
||||||
TimeChatTool{},
|
|
||||||
SearchChatTool{state: searchState},
|
|
||||||
SQLChatTool{state: sqlState},
|
|
||||||
}
|
|
||||||
available := map[string]ChatTool{}
|
|
||||||
for _, tool := range registered {
|
|
||||||
name := tool.Name()
|
|
||||||
item, ok := configured[name]
|
|
||||||
if !ok || !item.Enabled || !tool.Enabled() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
available[name] = tool
|
|
||||||
}
|
|
||||||
return available
|
|
||||||
}
|
|
||||||
|
|
||||||
func routeTools(ctx context.Context, state *ToolRouterState, chatProfile *OpenAIProfile, messages []ChatMessage, tools map[string]ChatTool) (ToolRoutingDecision, error) {
|
|
||||||
routerProfile := chatProfile
|
|
||||||
if strings.TrimSpace(state.cfg.OpenAIName) != "" {
|
|
||||||
profile, err := state.ai.GetProfile(state.cfg.OpenAIName)
|
|
||||||
if err != nil {
|
|
||||||
return ToolRoutingDecision{}, err
|
|
||||||
}
|
|
||||||
routerProfile = profile
|
|
||||||
}
|
|
||||||
prompt := buildToolRouterPrompt(state.cfg, messages, tools)
|
|
||||||
text, err := state.complete(ctx, routerProfile, []ChatMessage{{Role: "system", Content: prompt}}, state.cfg.MaxTokens, time.Duration(state.cfg.Timeout)*time.Second)
|
|
||||||
if err != nil {
|
|
||||||
return ToolRoutingDecision{}, err
|
|
||||||
}
|
|
||||||
return parseToolRoutingDecision(text)
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildToolRouterPrompt(config *ToolRouterConfig, messages []ChatMessage, tools map[string]ChatTool) string {
|
|
||||||
query := latestUserQuery(messages)
|
|
||||||
var b strings.Builder
|
|
||||||
b.WriteString(strings.TrimSpace(config.SystemPrompt))
|
|
||||||
b.WriteString("\n\n可用工具:\n")
|
|
||||||
for _, item := range config.Tools {
|
|
||||||
tool, ok := tools[item.Name]
|
|
||||||
if !ok {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
description := strings.TrimSpace(item.Description)
|
description := strings.TrimSpace(item.Description)
|
||||||
if description == "" {
|
switch item.Name {
|
||||||
description = tool.Description()
|
case timeagent.ToolName:
|
||||||
}
|
tools = append(tools, agentTool{
|
||||||
fmt.Fprintf(&b, "- name: %s\n description: %s\n", item.Name, description)
|
name: timeagent.ToolName,
|
||||||
}
|
definition: timeagent.ToolDefinition(description),
|
||||||
fmt.Fprintf(&b, "\n最新用户问题:%s", query)
|
execute: func(ctx context.Context, args string) (string, error) {
|
||||||
return b.String()
|
result, err := timeagent.ExecuteTool(args, time.Now())
|
||||||
}
|
if err == nil && emit != nil {
|
||||||
|
emit(chatSSEFrame{Type: "trace", Tool: timeagent.ToolName, Stage: "resolve", Status: "success", Message: "已获取当前时间上下文"})
|
||||||
func parseToolRoutingDecision(text string) (ToolRoutingDecision, error) {
|
}
|
||||||
var decision ToolRoutingDecision
|
return result, err
|
||||||
if err := json.Unmarshal([]byte(extractJSONObject(text)), &decision); err != nil {
|
},
|
||||||
return decision, fmt.Errorf("解析工具路由结果失败: %w", err)
|
})
|
||||||
}
|
case searchagent.ToolName:
|
||||||
for i := range decision.Tools {
|
if searchState == nil || !searchState.Enabled() {
|
||||||
decision.Tools[i].Name = strings.ToLower(strings.TrimSpace(decision.Tools[i].Name))
|
continue
|
||||||
decision.Tools[i].Reason = strings.TrimSpace(decision.Tools[i].Reason)
|
}
|
||||||
}
|
tools = append(tools, agentTool{
|
||||||
decision.Reason = strings.TrimSpace(decision.Reason)
|
name: searchagent.ToolName,
|
||||||
return decision, nil
|
definition: searchState.ToolDefinition(description),
|
||||||
}
|
execute: func(ctx context.Context, args string) (string, error) {
|
||||||
|
if emit != nil {
|
||||||
func filterToolSelections(decision ToolRoutingDecision, tools map[string]ChatTool, order []ToolRouteConfig) []ToolSelection {
|
emit(chatSSEFrame{Type: "trace", Tool: searchagent.ToolName, Stage: "request", Status: "running", Message: "正在联网搜索"})
|
||||||
selected := map[string]ToolSelection{}
|
}
|
||||||
for _, item := range decision.Tools {
|
result, err := searchState.ExecuteTool(ctx, args)
|
||||||
if item.Name == "" {
|
if emit != nil {
|
||||||
continue
|
status := "success"
|
||||||
}
|
message := "联网搜索完成"
|
||||||
if _, ok := tools[item.Name]; !ok {
|
if err != nil {
|
||||||
continue
|
status = "error"
|
||||||
}
|
message = "联网搜索失败"
|
||||||
if _, ok := selected[item.Name]; !ok {
|
}
|
||||||
selected[item.Name] = item
|
emit(chatSSEFrame{Type: "trace", Tool: searchagent.ToolName, Stage: "results", Status: status, Message: message})
|
||||||
|
}
|
||||||
|
return result, err
|
||||||
|
},
|
||||||
|
})
|
||||||
|
case sqlquery.ToolName:
|
||||||
|
if sqlState == nil || !sqlState.Enabled() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tools = append(tools, agentTool{
|
||||||
|
name: sqlquery.ToolName,
|
||||||
|
definition: sqlState.ToolDefinition(description),
|
||||||
|
execute: func(ctx context.Context, args string) (string, error) {
|
||||||
|
if emit != nil {
|
||||||
|
emit(chatSSEFrame{Type: "trace", Tool: sqlquery.ToolName, Stage: "execute", Status: "running", Message: "正在查询数据库"})
|
||||||
|
}
|
||||||
|
generator := func(ctx context.Context, prompt string, maxTokens int) (string, error) {
|
||||||
|
return completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, maxTokens)
|
||||||
|
}
|
||||||
|
result, err := sqlState.ExecuteTool(ctx, args, generator)
|
||||||
|
if emit != nil {
|
||||||
|
status := "success"
|
||||||
|
message := "数据库查询完成"
|
||||||
|
if err != nil {
|
||||||
|
status = "error"
|
||||||
|
message = "数据库查询失败"
|
||||||
|
}
|
||||||
|
emit(chatSSEFrame{Type: "trace", Tool: sqlquery.ToolName, Stage: "execute", Status: status, Message: message})
|
||||||
|
}
|
||||||
|
return result, err
|
||||||
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return orderToolSelections(selected, order)
|
return tools
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureTimeSelectionForRelativeQuery(selected []ToolSelection, tools map[string]ChatTool, order []ToolRouteConfig, query string) []ToolSelection {
|
func runAgentToolLoop(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, emit func(chatSSEFrame)) ([]*model.ChatCompletionMessage, error) {
|
||||||
if !containsRelativeTime(query) || hasToolSelection(selected, "time") || (!hasToolSelection(selected, "search") && !hasToolSelection(selected, "sql")) {
|
messages, err := buildArkMessages(chatMessages)
|
||||||
return selected
|
|
||||||
}
|
|
||||||
if _, ok := tools["time"]; !ok {
|
|
||||||
return selected
|
|
||||||
}
|
|
||||||
withTime := make(map[string]ToolSelection, len(selected)+1)
|
|
||||||
for _, item := range selected {
|
|
||||||
withTime[item.Name] = item
|
|
||||||
}
|
|
||||||
withTime["time"] = ToolSelection{Name: "time", Reason: "问题包含相对日期,需要先获取当前日期"}
|
|
||||||
return orderToolSelections(withTime, order)
|
|
||||||
}
|
|
||||||
|
|
||||||
func containsRelativeTime(query string) bool {
|
|
||||||
query = strings.TrimSpace(query)
|
|
||||||
if query == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for _, keyword := range []string{"今天", "今日", "明天", "昨天", "本周", "这周", "本月", "这个月", "本年", "今年", "最近", "历史上的今天"} {
|
|
||||||
if strings.Contains(query, keyword) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasToolSelection(selected []ToolSelection, name string) bool {
|
|
||||||
for _, item := range selected {
|
|
||||||
if item.Name == name {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func orderToolSelections(selected map[string]ToolSelection, order []ToolRouteConfig) []ToolSelection {
|
|
||||||
result := make([]ToolSelection, 0, len(selected))
|
|
||||||
for _, item := range order {
|
|
||||||
if selection, ok := selected[item.Name]; ok {
|
|
||||||
result = append(result, selection)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func firstNonEmpty(items ...string) string {
|
|
||||||
for _, item := range items {
|
|
||||||
if strings.TrimSpace(item) != "" {
|
|
||||||
return strings.TrimSpace(item)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateSQLForUserQuery(ctx context.Context, profile *OpenAIProfile, userQuery string, schemaContext string) (*sqlGenerationResult, error) {
|
|
||||||
prompt := fmt.Sprintf(`你是只读 SQL 生成器。请根据用户问题、隐藏上下文和数据库 schema 生成一条只读 SQL。
|
|
||||||
要求:
|
|
||||||
- 只能返回 JSON,不要使用 Markdown。
|
|
||||||
- JSON 格式:{"database":"数据库名称","sql":"SELECT ... LIMIT N","reason":"生成原因"}
|
|
||||||
- 只能生成 SELECT 或 WITH 查询,禁止 INSERT/UPDATE/DELETE/DROP/ALTER/CREATE 等任何修改语句。
|
|
||||||
- 必须只使用 schema 中出现的数据库、表和字段。
|
|
||||||
- 如果隐藏上下文中包含“时间工具结果”,必须使用其中的绝对日期范围解释用户问题里的今天、明天、昨天、本周、本月、本年、最近等相对时间。
|
|
||||||
- 用户询问日程、日程安排、行程、会议、待办、今天/明天/本周有什么安排时,优先查询 tab_calendar_events 表;如果 schema 中没有该表,再返回无法根据已知表结构生成查询。
|
|
||||||
- 查询日程表时,涉及日期范围必须使用半开区间:时间字段 >= start AND 时间字段 < end_exclusive;时间字段必须从 schema 中选择真实存在的字段。
|
|
||||||
- 必须添加 LIMIT,且 LIMIT 不超过插件配置的 max_rows。
|
|
||||||
- 如果无法根据 schema 回答,返回 {"database":"","sql":"","reason":"无法根据已知表结构生成查询"}。
|
|
||||||
|
|
||||||
%s
|
|
||||||
|
|
||||||
用户问题:%s`, schemaContext, userQuery)
|
|
||||||
text, err := completeText(ctx, profile, []ChatMessage{{Role: "system", Content: prompt}}, 1024)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var generated sqlGenerationResult
|
tools := availableAgentTools(profile, emit)
|
||||||
if err := json.Unmarshal([]byte(extractJSONObject(text)), &generated); err != nil {
|
if len(tools) == 0 {
|
||||||
return nil, fmt.Errorf("解析 SQL 生成结果失败: %w", err)
|
return messages, nil
|
||||||
}
|
}
|
||||||
return &generated, nil
|
toolByName := make(map[string]agentTool, len(tools))
|
||||||
|
definitions := make([]*model.Tool, 0, len(tools))
|
||||||
|
availableNames := make([]string, 0, len(tools))
|
||||||
|
toolDescriptions := make([]string, 0, len(tools))
|
||||||
|
for _, tool := range tools {
|
||||||
|
toolByName[tool.name] = tool
|
||||||
|
definitions = append(definitions, tool.definition)
|
||||||
|
availableNames = append(availableNames, tool.name)
|
||||||
|
if tool.definition != nil && tool.definition.Function != nil {
|
||||||
|
toolDescriptions = append(toolDescriptions, fmt.Sprintf("%s: %s", tool.name, tool.definition.Function.Description))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if emit != nil {
|
||||||
|
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "prepare", Status: "success", Message: "已准备可用工具", Data: map[string]any{"tools": availableNames, "tool_descriptions": toolDescriptions}})
|
||||||
|
}
|
||||||
|
if prompt := strings.TrimSpace(toolRouterState.cfg.SystemPrompt); prompt != "" {
|
||||||
|
messages = append([]*model.ChatCompletionMessage{{Role: model.ChatMessageRoleSystem, Content: stringContent(prompt)}}, messages...)
|
||||||
|
}
|
||||||
|
for i := 0; i < maxAgentToolIterations; i++ {
|
||||||
|
if emit != nil {
|
||||||
|
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "running", Message: fmt.Sprintf("正在进行第 %d 轮工具判断", i+1), Data: map[string]any{"iteration": i + 1, "max_iterations": maxAgentToolIterations, "tools": availableNames}})
|
||||||
|
}
|
||||||
|
resp, err := toolRouterState.complete(ctx, profile, model.CreateChatCompletionRequest{
|
||||||
|
Model: profile.Config.Model,
|
||||||
|
Messages: messages,
|
||||||
|
MaxTokens: intPtr(toolRouterState.cfg.MaxTokens),
|
||||||
|
Tools: definitions,
|
||||||
|
ToolChoice: model.ToolChoiceStringTypeAuto,
|
||||||
|
ParallelToolCalls: boolPtr(false),
|
||||||
|
}, time.Duration(toolRouterState.cfg.Timeout)*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
return messages, err
|
||||||
|
}
|
||||||
|
if tracker := tokenUsageFromContext(ctx); tracker != nil {
|
||||||
|
tracker.addTool(resp.Usage.PromptTokens, resp.Usage.CompletionTokens)
|
||||||
|
}
|
||||||
|
if len(resp.Choices) == 0 {
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
choice := resp.Choices[0]
|
||||||
|
decisionPreview := chatMessageContentString(choice.Message.Content)
|
||||||
|
if emit != nil {
|
||||||
|
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "decision", Status: "success", Message: "工具判断响应已返回", Data: map[string]any{"iteration": i + 1, "finish_reason": string(choice.FinishReason), "content_preview": truncateString(decisionPreview, 800)}})
|
||||||
|
}
|
||||||
|
calls := choice.Message.ToolCalls
|
||||||
|
if len(calls) == 0 && choice.Message.FunctionCall != nil {
|
||||||
|
calls = []*model.ToolCall{{ID: "legacy_function_call", Type: model.ToolTypeFunction, Function: *choice.Message.FunctionCall}}
|
||||||
|
}
|
||||||
|
if len(calls) == 0 {
|
||||||
|
if emit != nil {
|
||||||
|
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "request", Status: "success", Message: "模型未请求工具,进入回答生成"})
|
||||||
|
}
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
callNames := make([]string, 0, len(calls))
|
||||||
|
for _, call := range calls {
|
||||||
|
if call != nil {
|
||||||
|
callNames = append(callNames, call.Function.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if emit != nil {
|
||||||
|
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "tool_calls", Status: "running", Message: fmt.Sprintf("模型请求调用 %d 个工具", len(calls)), Data: map[string]any{"tools": callNames, "iteration": i + 1}})
|
||||||
|
}
|
||||||
|
messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleAssistant, ToolCalls: calls, Content: choice.Message.Content})
|
||||||
|
for _, call := range calls {
|
||||||
|
result := executeAgentToolCall(ctx, call, toolByName, emit)
|
||||||
|
messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleTool, ToolCallID: call.ID, Content: stringContent(result)})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages = append(messages, &model.ChatCompletionMessage{Role: model.ChatMessageRoleSystem, Content: stringContent("工具调用轮数已达到上限。请基于已有工具结果回答,并说明可能未完成全部工具调用。")})
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func executeAgentToolCall(ctx context.Context, call *model.ToolCall, tools map[string]agentTool, emit func(chatSSEFrame)) string {
|
||||||
|
if call == nil || call.Type != model.ToolTypeFunction {
|
||||||
|
result := "工具调用无效:仅支持 function 类型工具。"
|
||||||
|
if emit != nil {
|
||||||
|
emit(chatSSEFrame{Type: "trace", Tool: "agent_tools", Stage: "execute", Status: "error", Message: result})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
toolName := call.Function.Name
|
||||||
|
if emit != nil {
|
||||||
|
emit(chatSSEFrame{Type: "trace", Tool: toolName, Stage: "arguments", Status: "running", Message: "准备执行工具", Data: map[string]any{"tool_call_id": call.ID, "arguments": call.Function.Arguments}})
|
||||||
|
}
|
||||||
|
tool, ok := tools[toolName]
|
||||||
|
if !ok {
|
||||||
|
result := fmt.Sprintf("工具调用失败:未知工具 %s。", toolName)
|
||||||
|
if emit != nil {
|
||||||
|
emit(chatSSEFrame{Type: "trace", Tool: toolName, Stage: "execute", Status: "error", Message: result})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
started := time.Now()
|
||||||
|
result, err := tool.execute(ctx, call.Function.Arguments)
|
||||||
|
durationMs := time.Since(started).Milliseconds()
|
||||||
|
if err != nil {
|
||||||
|
message := fmt.Sprintf("工具 %s 执行失败:%v", tool.name, err)
|
||||||
|
if emit != nil {
|
||||||
|
emit(chatSSEFrame{Type: "trace", Tool: tool.name, Stage: "execute", Status: "error", Message: "工具执行失败", Data: map[string]any{"tool_call_id": call.ID, "duration_ms": durationMs, "error": err.Error()}})
|
||||||
|
}
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(result) == "" {
|
||||||
|
result = fmt.Sprintf("工具 %s 执行完成,但没有返回内容。", tool.name)
|
||||||
|
}
|
||||||
|
if emit != nil {
|
||||||
|
emit(chatSSEFrame{Type: "trace", Tool: tool.name, Stage: "result", Status: "success", Message: "工具执行完成", Data: map[string]any{"tool_call_id": call.ID, "duration_ms": durationMs, "result_preview": truncateString(result, 1200)}})
|
||||||
|
}
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) {
|
func completeText(ctx context.Context, profile *OpenAIProfile, chatMessages []ChatMessage, maxTokens int) (string, error) {
|
||||||
@@ -1370,14 +1198,10 @@ func completeTextWithTimeout(ctx context.Context, profile *OpenAIProfile, chatMe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractJSONObject(text string) string {
|
func completeChatWithTimeout(ctx context.Context, profile *OpenAIProfile, request model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) {
|
||||||
text = strings.TrimSpace(text)
|
completionCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
start := strings.Index(text, "{")
|
defer cancel()
|
||||||
end := strings.LastIndex(text, "}")
|
return profile.Client.CreateChatCompletion(completionCtx, request.WithStream(false))
|
||||||
if start >= 0 && end > start {
|
|
||||||
return text[start : end+1]
|
|
||||||
}
|
|
||||||
return text
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUUID() string {
|
func newUUID() string {
|
||||||
@@ -1590,6 +1414,17 @@ func textPart(text string) *model.ChatCompletionMessageContentPart {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func stringContent(text string) *model.ChatCompletionMessageContent {
|
||||||
|
return &model.ChatCompletionMessageContent{StringValue: &text}
|
||||||
|
}
|
||||||
|
|
||||||
|
func chatMessageContentString(content *model.ChatCompletionMessageContent) string {
|
||||||
|
if content == nil || content.StringValue == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *content.StringValue
|
||||||
|
}
|
||||||
|
|
||||||
func normalizeImageURL(raw string) (string, error) {
|
func normalizeImageURL(raw string) (string, error) {
|
||||||
raw = strings.TrimSpace(raw)
|
raw = strings.TrimSpace(raw)
|
||||||
if raw == "" {
|
if raw == "" {
|
||||||
@@ -1651,6 +1486,16 @@ func contains(items []string, target string) bool {
|
|||||||
|
|
||||||
func intPtr(i int) *int { return &i }
|
func intPtr(i int) *int { return &i }
|
||||||
|
|
||||||
|
func boolPtr(v bool) *bool { return &v }
|
||||||
|
|
||||||
|
func truncateString(text string, maxRunes int) string {
|
||||||
|
runes := []rune(strings.TrimSpace(text))
|
||||||
|
if maxRunes <= 0 || len(runes) <= maxRunes {
|
||||||
|
return string(runes)
|
||||||
|
}
|
||||||
|
return string(runes[:maxRunes]) + "..."
|
||||||
|
}
|
||||||
|
|
||||||
func writeSSEJSON(w io.Writer, frame chatSSEFrame) {
|
func writeSSEJSON(w io.Writer, frame chatSSEFrame) {
|
||||||
data, err := json.Marshal(frame)
|
data, err := json.Marshal(frame)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
+101
-162
@@ -6,21 +6,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/volcengine/volcengine-go-sdk/service/arkruntime/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type fakeChatTool struct {
|
|
||||||
name string
|
|
||||||
description string
|
|
||||||
enabled bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t fakeChatTool) Name() string { return t.name }
|
|
||||||
func (t fakeChatTool) Description() string { return t.description }
|
|
||||||
func (t fakeChatTool) Enabled() bool { return t.enabled }
|
|
||||||
func (t fakeChatTool) Enrich(context.Context, *OpenAIProfile, []ChatMessage, string, func(chatSSEFrame)) ([]ChatMessage, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNormalizeToolRouterConfigDefaults(t *testing.T) {
|
func TestNormalizeToolRouterConfigDefaults(t *testing.T) {
|
||||||
cfg := &Config{ToolRouter: ToolRouterConfig{Enabled: true}}
|
cfg := &Config{ToolRouter: ToolRouterConfig{Enabled: true}}
|
||||||
changed, err := normalizeToolRouterConfig(cfg)
|
changed, err := normalizeToolRouterConfig(cfg)
|
||||||
@@ -49,7 +38,7 @@ func TestNormalizeToolRouterConfigAddsTimeBeforeSQL(t *testing.T) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
Timeout: 1,
|
Timeout: 1,
|
||||||
MaxTokens: 1,
|
MaxTokens: 1,
|
||||||
SystemPrompt: "route",
|
SystemPrompt: "tools",
|
||||||
Tools: []ToolRouteConfig{
|
Tools: []ToolRouteConfig{
|
||||||
{Name: "search", Enabled: true},
|
{Name: "search", Enabled: true},
|
||||||
{Name: "sql", Enabled: true},
|
{Name: "sql", Enabled: true},
|
||||||
@@ -72,7 +61,7 @@ func TestNormalizeToolRouterConfigDuplicateTools(t *testing.T) {
|
|||||||
Enabled: true,
|
Enabled: true,
|
||||||
Timeout: 1,
|
Timeout: 1,
|
||||||
MaxTokens: 1,
|
MaxTokens: 1,
|
||||||
SystemPrompt: "route",
|
SystemPrompt: "tools",
|
||||||
Tools: []ToolRouteConfig{
|
Tools: []ToolRouteConfig{
|
||||||
{Name: "sql", Enabled: true},
|
{Name: "sql", Enabled: true},
|
||||||
{Name: " SQL ", Enabled: true},
|
{Name: " SQL ", Enabled: true},
|
||||||
@@ -84,169 +73,119 @@ func TestNormalizeToolRouterConfigDuplicateTools(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestParseToolRoutingDecision(t *testing.T) {
|
func TestAvailableAgentToolsUsesConfigOrderAndEnabled(t *testing.T) {
|
||||||
decision, err := parseToolRoutingDecision("```json\n{\"tools\":[{\"name\":\" SQL \",\"reason\":\" 需要查库 \"}],\"reason\":\" 总原因 \"}\n```")
|
oldRouter := toolRouterState
|
||||||
if err != nil {
|
oldSearch := searchState
|
||||||
t.Fatal(err)
|
oldSQL := sqlState
|
||||||
}
|
defer func() {
|
||||||
if len(decision.Tools) != 1 || decision.Tools[0].Name != "sql" || decision.Tools[0].Reason != "需要查库" {
|
toolRouterState = oldRouter
|
||||||
t.Fatalf("unexpected decision: %#v", decision)
|
searchState = oldSearch
|
||||||
}
|
sqlState = oldSQL
|
||||||
if decision.Reason != "总原因" {
|
}()
|
||||||
t.Fatalf("reason = %q", decision.Reason)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := parseToolRoutingDecision("not json"); err == nil {
|
toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
|
||||||
t.Fatal("expected malformed JSON error")
|
Enabled: true,
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFilterToolSelections(t *testing.T) {
|
|
||||||
tools := map[string]ChatTool{
|
|
||||||
"time": fakeChatTool{name: "time", enabled: true},
|
|
||||||
"sql": fakeChatTool{name: "sql", enabled: true},
|
|
||||||
"search": fakeChatTool{name: "search", enabled: true},
|
|
||||||
}
|
|
||||||
decision := ToolRoutingDecision{Tools: []ToolSelection{
|
|
||||||
{Name: "unknown", Reason: "ignore"},
|
|
||||||
{Name: "search", Reason: "second in config"},
|
|
||||||
{Name: "sql", Reason: "third in config"},
|
|
||||||
{Name: "time", Reason: "first in config"},
|
|
||||||
{Name: "sql", Reason: "duplicate"},
|
|
||||||
}}
|
|
||||||
selected := filterToolSelections(decision, tools, []ToolRouteConfig{{Name: "time"}, {Name: "search"}, {Name: "sql"}})
|
|
||||||
if len(selected) != 3 {
|
|
||||||
t.Fatalf("selected length = %d", len(selected))
|
|
||||||
}
|
|
||||||
if selected[0].Name != "time" || selected[0].Reason != "first in config" {
|
|
||||||
t.Fatalf("first selection = %#v", selected[0])
|
|
||||||
}
|
|
||||||
if selected[1].Name != "search" {
|
|
||||||
t.Fatalf("second selection = %#v", selected[1])
|
|
||||||
}
|
|
||||||
if selected[2].Name != "sql" || selected[2].Reason != "third in config" {
|
|
||||||
t.Fatalf("third selection = %#v", selected[2])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEnsureTimeSelectionForRelativeSearch(t *testing.T) {
|
|
||||||
tools := map[string]ChatTool{
|
|
||||||
"time": fakeChatTool{name: "time", enabled: true},
|
|
||||||
"search": fakeChatTool{name: "search", enabled: true},
|
|
||||||
}
|
|
||||||
selected := ensureTimeSelectionForRelativeQuery(
|
|
||||||
[]ToolSelection{{Name: "search", Reason: "查询历史事件"}},
|
|
||||||
tools,
|
|
||||||
[]ToolRouteConfig{{Name: "time"}, {Name: "search"}, {Name: "sql"}},
|
|
||||||
"历史上的今天都发生了什么?",
|
|
||||||
)
|
|
||||||
if len(selected) != 2 || selected[0].Name != "time" || selected[1].Name != "search" {
|
|
||||||
t.Fatalf("unexpected selected tools: %#v", selected)
|
|
||||||
}
|
|
||||||
if !strings.Contains(selected[0].Reason, "相对日期") {
|
|
||||||
t.Fatalf("unexpected time reason: %#v", selected[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEnsureTimeSelectionSkipsOrdinarySearch(t *testing.T) {
|
|
||||||
tools := map[string]ChatTool{
|
|
||||||
"time": fakeChatTool{name: "time", enabled: true},
|
|
||||||
"search": fakeChatTool{name: "search", enabled: true},
|
|
||||||
}
|
|
||||||
selected := ensureTimeSelectionForRelativeQuery(
|
|
||||||
[]ToolSelection{{Name: "search", Reason: "查询资料"}},
|
|
||||||
tools,
|
|
||||||
[]ToolRouteConfig{{Name: "time"}, {Name: "search"}},
|
|
||||||
"查一下 Go 语言官网",
|
|
||||||
)
|
|
||||||
if len(selected) != 1 || selected[0].Name != "search" {
|
|
||||||
t.Fatalf("unexpected selected tools: %#v", selected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunTimeToolAddsHiddenDateRanges(t *testing.T) {
|
|
||||||
messages := []ChatMessage{{Role: "user", Content: "本月有什么日程安排"}}
|
|
||||||
withTime, err := runTimeTool(context.Background(), messages, "需要日期范围", func(chatSSEFrame) {})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if len(withTime) != 2 || !withTime[0].Hidden || withTime[0].Role != "system" {
|
|
||||||
t.Fatalf("unexpected messages: %#v", withTime)
|
|
||||||
}
|
|
||||||
for _, want := range []string{"时间工具结果", "本月", "start=", "end_exclusive=", "半开区间"} {
|
|
||||||
if !strings.Contains(withTime[0].Content, want) {
|
|
||||||
t.Fatalf("time context missing %q:\n%s", want, withTime[0].Content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildToolRouterPrompt(t *testing.T) {
|
|
||||||
cfg := &ToolRouterConfig{
|
|
||||||
SystemPrompt: "router",
|
|
||||||
Tools: []ToolRouteConfig{
|
Tools: []ToolRouteConfig{
|
||||||
{Name: "time", Enabled: true},
|
|
||||||
{Name: "sql", Enabled: true, Description: "configured sql"},
|
|
||||||
{Name: "search", Enabled: true},
|
{Name: "search", Enabled: true},
|
||||||
|
{Name: "time", Enabled: true, Description: "custom time"},
|
||||||
|
{Name: "sql", Enabled: false},
|
||||||
},
|
},
|
||||||
|
}}
|
||||||
|
searchState = nil
|
||||||
|
sqlState = nil
|
||||||
|
|
||||||
|
tools := availableAgentTools(&OpenAIProfile{}, nil)
|
||||||
|
if len(tools) != 1 {
|
||||||
|
t.Fatalf("tools length = %d", len(tools))
|
||||||
}
|
}
|
||||||
tools := map[string]ChatTool{
|
if tools[0].name != "time" {
|
||||||
"time": fakeChatTool{name: "time", description: "fallback time", enabled: true},
|
t.Fatalf("tool name = %s", tools[0].name)
|
||||||
"sql": fakeChatTool{name: "sql", description: "fallback sql", enabled: true},
|
|
||||||
"search": fakeChatTool{name: "search", description: "fallback search", enabled: true},
|
|
||||||
}
|
}
|
||||||
prompt := buildToolRouterPrompt(cfg, []ChatMessage{{Role: "user", Content: "查一下订单"}}, tools)
|
if tools[0].definition.Function == nil || tools[0].definition.Function.Description != "custom time" {
|
||||||
for _, want := range []string{"router", "name: time", "fallback time", "name: sql", "configured sql", "name: search", "fallback search", "最新用户问题:查一下订单"} {
|
t.Fatalf("unexpected definition: %#v", tools[0].definition)
|
||||||
if !strings.Contains(prompt, want) {
|
|
||||||
t.Fatalf("prompt missing %q:\n%s", want, prompt)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRouteToolsUsesConfiguredRouterProfileAndTimeout(t *testing.T) {
|
func TestRunAgentToolLoopAppendsToolMessages(t *testing.T) {
|
||||||
ai := &OpenAIState{
|
oldRouter := toolRouterState
|
||||||
profiles: map[string]*OpenAIProfile{
|
defer func() { toolRouterState = oldRouter }()
|
||||||
"chat": {Config: OpenAIConfig{Name: "chat"}},
|
|
||||||
"router": {Config: OpenAIConfig{Name: "router"}},
|
calls := 0
|
||||||
},
|
toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
|
||||||
order: []string{"chat", "router"},
|
Enabled: true,
|
||||||
activeName: "chat",
|
Timeout: 1,
|
||||||
|
MaxTokens: 128,
|
||||||
|
SystemPrompt: "use tools",
|
||||||
|
Tools: []ToolRouteConfig{{Name: "time", Enabled: true}},
|
||||||
|
}}
|
||||||
|
toolRouterState.complete = func(ctx context.Context, profile *OpenAIProfile, req model.CreateChatCompletionRequest, timeout time.Duration) (model.ChatCompletionResponse, error) {
|
||||||
|
calls++
|
||||||
|
if req.ToolChoice != model.ToolChoiceStringTypeAuto {
|
||||||
|
t.Fatalf("tool choice = %#v", req.ToolChoice)
|
||||||
|
}
|
||||||
|
if len(req.Tools) != 1 || req.Tools[0].Function == nil || req.Tools[0].Function.Name != "time" {
|
||||||
|
t.Fatalf("unexpected tools: %#v", req.Tools)
|
||||||
|
}
|
||||||
|
if calls == 1 {
|
||||||
|
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{ToolCalls: []*model.ToolCall{{ID: "call_1", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "time", Arguments: `{"reason":"需要当前日期"}`}}}}}}}, nil
|
||||||
|
}
|
||||||
|
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{Content: stringContent("done")}}}}, nil
|
||||||
}
|
}
|
||||||
state := &ToolRouterState{cfg: &ToolRouterConfig{
|
|
||||||
OpenAIName: "router",
|
messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "test"}}, []ChatMessage{{Role: "user", Content: "今天几号"}}, nil)
|
||||||
Timeout: 7,
|
|
||||||
MaxTokens: 123,
|
|
||||||
SystemPrompt: "router prompt",
|
|
||||||
Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}},
|
|
||||||
}, ai: ai}
|
|
||||||
state.complete = func(ctx context.Context, profile *OpenAIProfile, messages []ChatMessage, maxTokens int, timeout time.Duration) (string, error) {
|
|
||||||
if profile.Config.Name != "router" {
|
|
||||||
t.Fatalf("profile = %s", profile.Config.Name)
|
|
||||||
}
|
|
||||||
if maxTokens != 123 {
|
|
||||||
t.Fatalf("maxTokens = %d", maxTokens)
|
|
||||||
}
|
|
||||||
if timeout != 7*time.Second {
|
|
||||||
t.Fatalf("timeout = %s", timeout)
|
|
||||||
}
|
|
||||||
return `{"tools":[],"reason":"无需工具"}`, nil
|
|
||||||
}
|
|
||||||
decision, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
if len(decision.Tools) != 0 || decision.Reason != "无需工具" {
|
if calls != 2 {
|
||||||
t.Fatalf("unexpected decision: %#v", decision)
|
t.Fatalf("calls = %d", calls)
|
||||||
|
}
|
||||||
|
if len(messages) < 4 {
|
||||||
|
t.Fatalf("expected system/user/assistant/tool messages, got %d", len(messages))
|
||||||
|
}
|
||||||
|
last := messages[len(messages)-1]
|
||||||
|
if last.Role != model.ChatMessageRoleTool || last.ToolCallID != "call_1" {
|
||||||
|
t.Fatalf("unexpected last message: %#v", last)
|
||||||
|
}
|
||||||
|
if last.Content == nil || last.Content.StringValue == nil || !strings.Contains(*last.Content.StringValue, "时间工具结果") {
|
||||||
|
t.Fatalf("unexpected tool content: %#v", last.Content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRouteToolsCompleterError(t *testing.T) {
|
func TestExecuteAgentToolCallUnknownAndError(t *testing.T) {
|
||||||
ai := &OpenAIState{profiles: map[string]*OpenAIProfile{"chat": {Config: OpenAIConfig{Name: "chat"}}}, activeName: "chat"}
|
unknown := executeAgentToolCall(context.Background(), &model.ToolCall{ID: "1", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "missing"}}, map[string]agentTool{}, nil)
|
||||||
state := &ToolRouterState{cfg: &ToolRouterConfig{Timeout: 1, MaxTokens: 1, SystemPrompt: "router", Tools: []ToolRouteConfig{{Name: "sql", Enabled: true}}}, ai: ai}
|
if !strings.Contains(unknown, "未知工具") {
|
||||||
state.complete = func(context.Context, *OpenAIProfile, []ChatMessage, int, time.Duration) (string, error) {
|
t.Fatalf("unknown result = %q", unknown)
|
||||||
return "", errors.New("boom")
|
|
||||||
}
|
}
|
||||||
_, err := routeTools(context.Background(), state, ai.profiles["chat"], []ChatMessage{{Role: "user", Content: "你好"}}, map[string]ChatTool{"sql": fakeChatTool{name: "sql", enabled: true}})
|
|
||||||
if err == nil {
|
failed := executeAgentToolCall(context.Background(), &model.ToolCall{ID: "2", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "boom"}}, map[string]agentTool{
|
||||||
t.Fatal("expected completer error")
|
"boom": {name: "boom", execute: func(context.Context, string) (string, error) { return "", errors.New("bad args") }},
|
||||||
|
}, nil)
|
||||||
|
if !strings.Contains(failed, "bad args") {
|
||||||
|
t.Fatalf("failed result = %q", failed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunAgentToolLoopMaxIterations(t *testing.T) {
|
||||||
|
oldRouter := toolRouterState
|
||||||
|
defer func() { toolRouterState = oldRouter }()
|
||||||
|
|
||||||
|
toolRouterState = &ToolRouterState{cfg: &ToolRouterConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Timeout: 1,
|
||||||
|
MaxTokens: 128,
|
||||||
|
SystemPrompt: "use tools",
|
||||||
|
Tools: []ToolRouteConfig{{Name: "time", Enabled: true}},
|
||||||
|
}}
|
||||||
|
toolRouterState.complete = func(context.Context, *OpenAIProfile, model.CreateChatCompletionRequest, time.Duration) (model.ChatCompletionResponse, error) {
|
||||||
|
return model.ChatCompletionResponse{Choices: []*model.ChatCompletionChoice{{Message: model.ChatCompletionMessage{ToolCalls: []*model.ToolCall{{ID: "loop", Type: model.ToolTypeFunction, Function: model.FunctionCall{Name: "time", Arguments: `{}`}}}}}}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := runAgentToolLoop(context.Background(), &OpenAIProfile{Config: OpenAIConfig{Model: "test"}}, []ChatMessage{{Role: "user", Content: "今天"}}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
last := messages[len(messages)-1]
|
||||||
|
if last.Role != model.ChatMessageRoleSystem || last.Content == nil || last.Content.StringValue == nil || !strings.Contains(*last.Content.StringValue, "工具调用轮数已达到上限") {
|
||||||
|
t.Fatalf("unexpected last message: %#v", last)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+52
-4
@@ -265,6 +265,30 @@
|
|||||||
white-space: normal;
|
white-space: normal;
|
||||||
}
|
}
|
||||||
.trace-panel:empty { display: none; }
|
.trace-panel:empty { display: none; }
|
||||||
|
.reasoning-panel {
|
||||||
|
display: none;
|
||||||
|
margin-bottom: 8px;
|
||||||
|
border: 1px solid var(--border);
|
||||||
|
border-radius: 8px;
|
||||||
|
background: var(--surface2);
|
||||||
|
color: var(--text-dim);
|
||||||
|
font-size: 0.78rem;
|
||||||
|
white-space: normal;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
.reasoning-panel.show { display: block; }
|
||||||
|
.reasoning-title {
|
||||||
|
padding: 6px 9px;
|
||||||
|
border-bottom: 1px solid var(--border);
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
.reasoning-content {
|
||||||
|
padding: 7px 9px;
|
||||||
|
white-space: pre-wrap;
|
||||||
|
max-height: 220px;
|
||||||
|
overflow-y: auto;
|
||||||
|
font-family: Consolas, 'Fira Code', monospace;
|
||||||
|
}
|
||||||
.trace-item {
|
.trace-item {
|
||||||
border-left: 2px solid var(--accent-border);
|
border-left: 2px solid var(--accent-border);
|
||||||
padding-left: 8px;
|
padding-left: 8px;
|
||||||
@@ -859,18 +883,23 @@ function addAIBubble() {
|
|||||||
const trace = document.createElement('div');
|
const trace = document.createElement('div');
|
||||||
trace.className = 'trace-panel';
|
trace.className = 'trace-panel';
|
||||||
|
|
||||||
|
const reasoning = document.createElement('div');
|
||||||
|
reasoning.className = 'reasoning-panel';
|
||||||
|
reasoning.innerHTML = '<div class="reasoning-title">思考过程(模型返回)</div><div class="reasoning-content"></div>';
|
||||||
|
|
||||||
const txt = document.createElement('span');
|
const txt = document.createElement('span');
|
||||||
txt.className = 'answer-text';
|
txt.className = 'answer-text';
|
||||||
const stats = document.createElement('div');
|
const stats = document.createElement('div');
|
||||||
stats.className = 'token-stats';
|
stats.className = 'token-stats';
|
||||||
bub.appendChild(trace);
|
bub.appendChild(trace);
|
||||||
|
bub.appendChild(reasoning);
|
||||||
bub.appendChild(txt);
|
bub.appendChild(txt);
|
||||||
bub.appendChild(stats);
|
bub.appendChild(stats);
|
||||||
row.appendChild(av);
|
row.appendChild(av);
|
||||||
row.appendChild(bub);
|
row.appendChild(bub);
|
||||||
msgBox.appendChild(row);
|
msgBox.appendChild(row);
|
||||||
scrollToBottom();
|
scrollToBottom();
|
||||||
return { bub, txt, trace, stats };
|
return { bub, txt, trace, reasoning, stats };
|
||||||
}
|
}
|
||||||
|
|
||||||
function formatTokenStats(stats) {
|
function formatTokenStats(stats) {
|
||||||
@@ -903,18 +932,24 @@ function appendTrace(aiBubble, frame) {
|
|||||||
if (!aiBubble.trace) return;
|
if (!aiBubble.trace) return;
|
||||||
const item = document.createElement('div');
|
const item = document.createElement('div');
|
||||||
item.className = `trace-item ${frame.status || ''}`;
|
item.className = `trace-item ${frame.status || ''}`;
|
||||||
|
const prefix = [frame.tool, frame.stage].filter(Boolean).join('/');
|
||||||
const label = frame.message || [frame.tool, frame.stage, frame.status].filter(Boolean).join(' ');
|
const label = frame.message || [frame.tool, frame.stage, frame.status].filter(Boolean).join(' ');
|
||||||
item.textContent = label;
|
item.textContent = prefix ? `${prefix}:${label}` : label;
|
||||||
|
|
||||||
const data = frame.data || {};
|
const data = frame.data || {};
|
||||||
const details = [];
|
const details = [];
|
||||||
if (data.sql) details.push(data.sql);
|
if (data.arguments) details.push(`参数:\n${data.arguments}`);
|
||||||
|
if (data.sql) details.push(`SQL:\n${data.sql}`);
|
||||||
|
if (data.result_preview) details.push(`结果预览:\n${data.result_preview}`);
|
||||||
const stats = [];
|
const stats = [];
|
||||||
|
if (typeof data.iteration === 'number') stats.push(`轮次: ${data.iteration}${data.max_iterations ? '/' + data.max_iterations : ''}`);
|
||||||
|
if (data.tool_call_id) stats.push(`调用 ID: ${data.tool_call_id}`);
|
||||||
if (data.database) stats.push(`数据库: ${data.database}`);
|
if (data.database) stats.push(`数据库: ${data.database}`);
|
||||||
if (typeof data.rows === 'number') stats.push(`行数: ${data.rows}`);
|
if (typeof data.rows === 'number') stats.push(`行数: ${data.rows}`);
|
||||||
if (typeof data.columns === 'number') stats.push(`列数: ${data.columns}`);
|
if (typeof data.columns === 'number') stats.push(`列数: ${data.columns}`);
|
||||||
if (typeof data.count === 'number') stats.push(`结果数: ${data.count}`);
|
if (typeof data.count === 'number') stats.push(`结果数: ${data.count}`);
|
||||||
if (Array.isArray(data.tools) && data.tools.length) stats.push(`工具: ${data.tools.join(', ')}`);
|
if (Array.isArray(data.tools) && data.tools.length) stats.push(`工具: ${data.tools.join(', ')}`);
|
||||||
|
if (typeof data.duration_ms === 'number') stats.push(`耗时: ${data.duration_ms}ms`);
|
||||||
if (data.truncated) stats.push(`已截断,最多 ${data.max_rows || ''} 行`);
|
if (data.truncated) stats.push(`已截断,最多 ${data.max_rows || ''} 行`);
|
||||||
if (data.reason) stats.push(`原因: ${data.reason}`);
|
if (data.reason) stats.push(`原因: ${data.reason}`);
|
||||||
if (data.error) stats.push(`错误: ${data.error}`);
|
if (data.error) stats.push(`错误: ${data.error}`);
|
||||||
@@ -923,7 +958,7 @@ function appendTrace(aiBubble, frame) {
|
|||||||
if (details.length) {
|
if (details.length) {
|
||||||
const detail = document.createElement('div');
|
const detail = document.createElement('div');
|
||||||
detail.className = 'trace-detail';
|
detail.className = 'trace-detail';
|
||||||
detail.textContent = details.join('\n');
|
detail.textContent = details.join('\n\n');
|
||||||
item.appendChild(detail);
|
item.appendChild(detail);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -931,6 +966,15 @@ function appendTrace(aiBubble, frame) {
|
|||||||
scrollToBottom();
|
scrollToBottom();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function appendReasoning(aiBubble, text) {
|
||||||
|
if (!aiBubble.reasoning || !text) return;
|
||||||
|
aiBubble.reasoning.classList.add('show');
|
||||||
|
const content = aiBubble.reasoning.querySelector('.reasoning-content');
|
||||||
|
content.textContent += text;
|
||||||
|
content.scrollTop = content.scrollHeight;
|
||||||
|
scrollToBottom();
|
||||||
|
}
|
||||||
|
|
||||||
async function streamChat(messages, aiBubble) {
|
async function streamChat(messages, aiBubble) {
|
||||||
const txtEl = aiBubble.txt;
|
const txtEl = aiBubble.txt;
|
||||||
let full = '';
|
let full = '';
|
||||||
@@ -996,6 +1040,10 @@ async function streamChat(messages, aiBubble) {
|
|||||||
scrollToBottom();
|
scrollToBottom();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (parsed.type === 'reasoning') {
|
||||||
|
appendReasoning(aiBubble, parsed.text || '');
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (parsed.type === 'trace') {
|
if (parsed.type === 'trace') {
|
||||||
appendTrace(aiBubble, parsed);
|
appendTrace(aiBubble, parsed);
|
||||||
continue;
|
continue;
|
||||||
|
|||||||
Reference in New Issue
Block a user