Refactored orchestrator for staged file handling, added structured prompt support, adjusted Feishu file handling
This commit is contained in:
@@ -3,6 +3,7 @@ package agent
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -32,6 +33,29 @@ type Orchestrator struct {
|
||||
enableCapabilityGap bool
|
||||
log *logger.Logger
|
||||
skillsMu sync.RWMutex
|
||||
pendingFilesMu sync.Mutex
|
||||
pendingFiles map[string][]pendingFileRef
|
||||
}
|
||||
|
||||
type pendingFileRef struct {
|
||||
ID string
|
||||
Name string
|
||||
MimeType string
|
||||
}
|
||||
|
||||
type capabilityRoutingResult struct {
|
||||
NeedSkills bool
|
||||
SelectedToolNames []string
|
||||
SelectedSkills []knowledge.Skill
|
||||
Reason string
|
||||
UsedFallback bool
|
||||
}
|
||||
|
||||
type filePromptContext struct {
|
||||
Summary string
|
||||
FatalReason string
|
||||
FileIDs []string
|
||||
Uploaded []pendingFileRef
|
||||
}
|
||||
|
||||
// NewOrchestrator 创建一个新的编排器对象,初始化关键路径和超时控制等。
|
||||
@@ -76,6 +100,7 @@ func NewOrchestrator(
|
||||
reactMaxStep: reactMaxStep,
|
||||
enableCapabilityGap: enableCapabilityGap,
|
||||
log: log,
|
||||
pendingFiles: make(map[string][]pendingFileRef),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,12 +110,20 @@ func NewOrchestrator(
|
||||
// - 是否需要调用工具(action + action_input)
|
||||
// 循环持续进行,直到 LLM 返回 is_final_answer=true。
|
||||
func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text string) (string, error) {
|
||||
return o.handleMessageInternal(ctx, chatID, userID, text, nil)
|
||||
}
|
||||
|
||||
func (o *Orchestrator) HandleMessageWithFiles(ctx context.Context, chatID, userID, text string, files []llm.InputFile) (string, error) {
|
||||
return o.handleMessageInternal(ctx, chatID, userID, text, files)
|
||||
}
|
||||
|
||||
func (o *Orchestrator) handleMessageInternal(ctx context.Context, chatID, userID, text string, files []llm.InputFile) (string, error) {
|
||||
// 为链路追踪设置唯一的 TraceID
|
||||
traceID := logger.NewTraceID()
|
||||
ctx = logger.WithTraceID(ctx, traceID)
|
||||
traceLogPrefix := "trace_id=" + traceID
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s handle message chat_id=%s user_id=%s text_len=%d", traceLogPrefix, chatID, userID, len(text))
|
||||
o.log.Infof("%s handle message chat_id=%s user_id=%s text_len=%d files=%d", traceLogPrefix, chatID, userID, len(text), len(files))
|
||||
o.log.Debugf("%s handle message text=%q", traceLogPrefix, text)
|
||||
}
|
||||
|
||||
@@ -111,6 +144,38 @@ func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text s
|
||||
return report, nil
|
||||
}
|
||||
|
||||
trimmedText := strings.TrimSpace(text)
|
||||
isFileOnly := len(files) > 0 && trimmedText == ""
|
||||
|
||||
if isFileOnly {
|
||||
if err := o.store.SaveMessage(chatID, userID, "user", "[FILE_UPLOAD]"); err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("%s save file-only user marker failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
uploadCtx := o.prepareFilePromptContext(ctx, files, nil)
|
||||
if strings.TrimSpace(uploadCtx.FatalReason) != "" {
|
||||
finalText := "文件上传失败,无法建立文档上下文。" + "\n" + uploadCtx.FatalReason
|
||||
if err := o.store.SaveMessage(chatID, userID, "assistant", finalText); err != nil && o.log != nil {
|
||||
o.log.Warnf("%s save upload failure message failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return finalText, nil
|
||||
}
|
||||
o.appendPendingFiles(chatID, userID, uploadCtx.toPendingRefs())
|
||||
finalText := o.buildFileUploadAck(uploadCtx)
|
||||
if err := o.store.SaveMessage(chatID, userID, "assistant", finalText); err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("%s save file upload ack failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s file-only message handled chat_id=%s cached_files=%d", traceLogPrefix, chatID, len(uploadCtx.FileIDs))
|
||||
}
|
||||
return finalText, nil
|
||||
}
|
||||
|
||||
// 保存用户消息到 SQLite 中
|
||||
if err := o.store.SaveMessage(chatID, userID, "user", text); err != nil {
|
||||
if o.log != nil {
|
||||
@@ -133,13 +198,30 @@ func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text s
|
||||
}
|
||||
|
||||
// 进入统一 ReAct 循环
|
||||
response, err := o.runUnifiedReAct(ctx, chatID, userID, compressed, text)
|
||||
pendingRefs := o.getPendingFiles(chatID, userID)
|
||||
fileCtx := o.prepareFilePromptContext(ctx, files, pendingRefs)
|
||||
if strings.TrimSpace(fileCtx.FatalReason) != "" {
|
||||
finalText := "文件上传失败,无法继续进行文档解析。" + "\n" + fileCtx.FatalReason
|
||||
if err := o.store.SaveMessage(chatID, userID, "assistant", finalText); err != nil && o.log != nil {
|
||||
o.log.Warnf("%s save assistant failure message failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
if o.log != nil {
|
||||
o.log.Warnf("%s stop before react due to file upload failure reason=%s", traceLogPrefix, fileCtx.FatalReason)
|
||||
}
|
||||
return finalText, nil
|
||||
}
|
||||
routeInput := composeRouteInput(text, fileCtx.Summary)
|
||||
route := o.routeCapabilities(ctx, routeInput)
|
||||
response, err := o.runUnifiedReAct(ctx, chatID, userID, compressed, text, fileCtx, routeInput, route)
|
||||
if err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("%s message generation failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if len(pendingRefs) > 0 {
|
||||
o.clearPendingFiles(chatID, userID)
|
||||
}
|
||||
|
||||
// 最终将机器人的回复也加入记忆缓存
|
||||
if err := o.store.SaveMessage(chatID, userID, "assistant", response); err != nil {
|
||||
@@ -156,21 +238,29 @@ func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text s
|
||||
}
|
||||
|
||||
// buildUnifiedSystemPrompt 构建统一 ReAct 循环的 system prompt。
|
||||
// 包含人格设定、所有可用技能(含完整内容)、所有可用工具、以及 JSON 输出格式约束。
|
||||
func (o *Orchestrator) buildUnifiedSystemPrompt() string {
|
||||
// 工具始终可用;技能仅按当前问题挑选相关项作为增强上下文。
|
||||
func (o *Orchestrator) buildUnifiedSystemPrompt(userInput string, route capabilityRoutingResult) string {
|
||||
skillMetaDoc := o.formatSkillSummariesForPrompt()
|
||||
allSkillsDoc := o.formatAllSkillsContent()
|
||||
relevantSkillsDoc := o.formatSelectedSkillsForPrompt(userInput, route.SelectedSkills)
|
||||
toolDoc := o.formatToolDoc()
|
||||
runtimeDoc := formatRuntimeContextForPrompt()
|
||||
routeDoc := formatRouteForPrompt(route)
|
||||
|
||||
return strings.Join([]string{
|
||||
"你是一个个人自动化助手,必须遵循如下人格设定并保持一致:",
|
||||
o.soul,
|
||||
"",
|
||||
"===== 运行环境 =====",
|
||||
runtimeDoc,
|
||||
"",
|
||||
"===== 可用技能概览 =====",
|
||||
skillMetaDoc,
|
||||
"",
|
||||
"===== 技能详细说明 =====",
|
||||
allSkillsDoc,
|
||||
"===== 能力路由结果 =====",
|
||||
routeDoc,
|
||||
"",
|
||||
"===== 本轮相关技能(按用户问题筛选) =====",
|
||||
relevantSkillsDoc,
|
||||
"",
|
||||
"===== 可用工具 =====",
|
||||
toolDoc,
|
||||
@@ -190,25 +280,32 @@ func (o *Orchestrator) buildUnifiedSystemPrompt() string {
|
||||
"决策规则:",
|
||||
"1) 如果你可以直接回答用户问题(不需要任何工具):",
|
||||
" 设 is_final_answer=true,action=\"none\",final_answer 填写完整回复。",
|
||||
"2) 如果你需要调用工具获取信息后才能回答:",
|
||||
"2) 优先判断是否可通过原子工具能力完成任务;若可完成,直接进行工具调用链路。",
|
||||
"3) 当纯工具调用无法满足时,再结合已加载的技能详细说明进行决策。",
|
||||
"4) 如果你需要调用工具获取信息后才能回答:",
|
||||
" 设 is_final_answer=false,action 填工具名,action_input 填工具所需输入,final_answer=null。",
|
||||
"3) 不要在 JSON 之外输出任何内容。",
|
||||
"4) 根据技能说明中的指引决定何时以及如何使用工具。",
|
||||
"5) 每轮工具调用结果会以 Observation 的形式追加到推理记录中,供你下一轮决策参考。",
|
||||
"5) 不要在 JSON 之外输出任何内容。",
|
||||
"6) 根据技能说明中的指引决定何时以及如何使用工具。",
|
||||
"7) 工具能力是全局可用的,不依赖技能命中;当技能不匹配时,仍可直接选择合适工具。",
|
||||
"8) 若技能中存在与当前运行环境不匹配的章节(如 Windows 专章),应降低优先级,除非用户明确要求该环境。",
|
||||
"9) 每轮工具调用结果会以 Observation 的形式追加到推理记录中,供你下一轮决策参考。",
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
// runUnifiedReAct 执行统一的 ReAct 循环。
|
||||
// LLM 每次都看到完整的技能集+工具集,自行决定是否调用工具或直接回答。
|
||||
// 循环持续到 is_final_answer=true 或达到安全上限。
|
||||
func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, compressedContext, userInput string) (string, error) {
|
||||
func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, compressedContext, userInput string, fileCtx filePromptContext, routeInput string, route capabilityRoutingResult) (string, error) {
|
||||
traceID := logger.TraceIDFromContext(ctx)
|
||||
traceLogPrefix := "trace_id=" + traceID
|
||||
|
||||
systemPrompt := o.buildUnifiedSystemPrompt()
|
||||
if strings.TrimSpace(routeInput) == "" {
|
||||
routeInput = composeRouteInput(userInput, fileCtx.Summary)
|
||||
}
|
||||
systemPrompt := o.buildUnifiedSystemPrompt(routeInput, route)
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s unified react start", traceLogPrefix)
|
||||
o.log.Infof("%s unified react start route_need_skills=%v route_tools=%v route_skills=%d fallback=%v", traceLogPrefix, route.NeedSkills, route.SelectedToolNames, len(route.SelectedSkills), route.UsedFallback)
|
||||
}
|
||||
|
||||
// 安全上限:防止无限循环(当前暂不使用 reactMaxStep 配置约束,使用固定硬上限)
|
||||
@@ -229,13 +326,16 @@ func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, comp
|
||||
"用户问题:",
|
||||
userInput,
|
||||
"",
|
||||
"文件上下文:",
|
||||
defaultIfEmpty(fileCtx.Summary, "(none)"),
|
||||
"",
|
||||
"当前推理记录(按时间顺序):",
|
||||
scratchpad,
|
||||
"",
|
||||
"请输出你的 JSON 决策。",
|
||||
}, "\n")
|
||||
|
||||
raw, err := o.llm.Generate(ctx, systemPrompt, prompt)
|
||||
raw, err := o.generateWithOptionalFiles(ctx, systemPrompt, prompt, fileCtx.FileIDs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -335,15 +435,514 @@ func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, comp
|
||||
return "我尝试了多轮推理与工具调用,但仍未得到稳定结论。请给我更具体的约束或允许我继续尝试。", nil
|
||||
}
|
||||
|
||||
// formatAllSkillsContent 返回所有技能的完整内容,用于注入到 system prompt 中。
|
||||
func (o *Orchestrator) formatAllSkillsContent() string {
|
||||
skills := o.getSkillsSnapshot()
|
||||
func composeRouteInput(userInput, fileSummary string) string {
|
||||
userInput = strings.TrimSpace(userInput)
|
||||
fileSummary = strings.TrimSpace(fileSummary)
|
||||
if userInput == "" {
|
||||
return fileSummary
|
||||
}
|
||||
if fileSummary == "" {
|
||||
return userInput
|
||||
}
|
||||
return userInput + "\n\n" + fileSummary
|
||||
}
|
||||
|
||||
func (o *Orchestrator) prepareFilePromptContext(ctx context.Context, files []llm.InputFile, pending []pendingFileRef) filePromptContext {
|
||||
ctxOut := filePromptContext{}
|
||||
if len(pending) > 0 {
|
||||
for _, p := range pending {
|
||||
id := strings.TrimSpace(p.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
ctxOut.FileIDs = append(ctxOut.FileIDs, id)
|
||||
}
|
||||
}
|
||||
if len(files) == 0 {
|
||||
ctxOut.Summary = buildFileSummary(pending, nil)
|
||||
return ctxOut
|
||||
}
|
||||
uploader, ok := o.llm.(llm.FileUploader)
|
||||
if !ok {
|
||||
return filePromptContext{FatalReason: "检测到文件输入,但当前 LLM 客户端不支持文件上传接口。"}
|
||||
}
|
||||
|
||||
uploaded := make([]pendingFileRef, 0, len(files))
|
||||
for i, f := range files {
|
||||
if strings.TrimSpace(f.FileName) == "" || len(f.Content) == 0 {
|
||||
return filePromptContext{FatalReason: fmt.Sprintf("file[%d] 缺少文件名或内容,无法上传。", i+1)}
|
||||
}
|
||||
fileID, err := uploader.UploadFile(ctx, f, "file-extract")
|
||||
if err != nil {
|
||||
return filePromptContext{FatalReason: fmt.Sprintf("file[%d] name=%s 上传失败: %v", i+1, f.FileName, err)}
|
||||
}
|
||||
ctxOut.FileIDs = append(ctxOut.FileIDs, fileID)
|
||||
uploaded = append(uploaded, pendingFileRef{
|
||||
ID: fileID,
|
||||
Name: strings.TrimSpace(f.FileName),
|
||||
MimeType: defaultIfEmpty(strings.TrimSpace(f.MimeType), "application/octet-stream"),
|
||||
})
|
||||
}
|
||||
ctxOut.Uploaded = uploaded
|
||||
ctxOut.Summary = buildFileSummary(pending, uploaded)
|
||||
return ctxOut
|
||||
}
|
||||
|
||||
func buildFileSummary(pending, uploaded []pendingFileRef) string {
|
||||
if len(pending) == 0 && len(uploaded) == 0 {
|
||||
return ""
|
||||
}
|
||||
lines := make([]string, 0, len(pending)+len(uploaded)+2)
|
||||
lines = append(lines, "以下文件 file_id 可用于本轮问答:")
|
||||
idx := 1
|
||||
for _, p := range pending {
|
||||
id := strings.TrimSpace(p.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("- cached_file[%d] name=%s mime=%s file_id=%s", idx, defaultIfEmpty(strings.TrimSpace(p.Name), "(unknown)"), defaultIfEmpty(strings.TrimSpace(p.MimeType), "application/octet-stream"), id))
|
||||
idx++
|
||||
}
|
||||
for _, p := range uploaded {
|
||||
id := strings.TrimSpace(p.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("- uploaded_file[%d] name=%s mime=%s file_id=%s", idx, defaultIfEmpty(strings.TrimSpace(p.Name), "(unknown)"), defaultIfEmpty(strings.TrimSpace(p.MimeType), "application/octet-stream"), id))
|
||||
idx++
|
||||
}
|
||||
if len(lines) == 1 {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func (o *Orchestrator) generateWithOptionalFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string) (string, error) {
|
||||
ids := nonEmptyIDs(fileIDs)
|
||||
if len(ids) == 0 {
|
||||
return o.llm.Generate(ctx, systemPrompt, userPrompt)
|
||||
}
|
||||
client, ok := o.llm.(llm.FileChatClient)
|
||||
if !ok {
|
||||
return o.llm.Generate(ctx, systemPrompt, userPrompt)
|
||||
}
|
||||
return client.GenerateWithFiles(ctx, systemPrompt, userPrompt, ids)
|
||||
}
|
||||
|
||||
func (o *Orchestrator) buildFileUploadAck(ctx filePromptContext) string {
|
||||
if len(ctx.FileIDs) == 0 {
|
||||
return "文件已接收,但未拿到有效 file_id。请重新上传一次。"
|
||||
}
|
||||
lines := []string{
|
||||
fmt.Sprintf("文件上传完成,已缓存 %d 个 file_id。", len(ctx.FileIDs)),
|
||||
"请继续发送你的问题,我会结合这些文件内容和历史对话一起回答。",
|
||||
}
|
||||
if strings.TrimSpace(ctx.Summary) != "" {
|
||||
lines = append(lines, "", ctx.Summary)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func nonEmptyIDs(ids []string) []string {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(ids))
|
||||
seen := map[string]struct{}{}
|
||||
for _, id := range ids {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c filePromptContext) toPendingRefs() []pendingFileRef {
|
||||
if len(c.Uploaded) > 0 {
|
||||
copied := make([]pendingFileRef, len(c.Uploaded))
|
||||
copy(copied, c.Uploaded)
|
||||
return sanitizePendingRefs(copied)
|
||||
}
|
||||
ids := nonEmptyIDs(c.FileIDs)
|
||||
out := make([]pendingFileRef, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
out = append(out, pendingFileRef{ID: id})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (o *Orchestrator) appendPendingFiles(chatID, userID string, refs []pendingFileRef) {
|
||||
refs = sanitizePendingRefs(refs)
|
||||
if len(refs) == 0 {
|
||||
return
|
||||
}
|
||||
key := pendingFileKey(chatID, userID)
|
||||
o.pendingFilesMu.Lock()
|
||||
defer o.pendingFilesMu.Unlock()
|
||||
merged := append(o.pendingFiles[key], refs...)
|
||||
o.pendingFiles[key] = sanitizePendingRefs(merged)
|
||||
}
|
||||
|
||||
func (o *Orchestrator) getPendingFiles(chatID, userID string) []pendingFileRef {
|
||||
key := pendingFileKey(chatID, userID)
|
||||
o.pendingFilesMu.Lock()
|
||||
defer o.pendingFilesMu.Unlock()
|
||||
snapshot := o.pendingFiles[key]
|
||||
out := make([]pendingFileRef, len(snapshot))
|
||||
copy(out, snapshot)
|
||||
return out
|
||||
}
|
||||
|
||||
func (o *Orchestrator) clearPendingFiles(chatID, userID string) {
|
||||
key := pendingFileKey(chatID, userID)
|
||||
o.pendingFilesMu.Lock()
|
||||
defer o.pendingFilesMu.Unlock()
|
||||
delete(o.pendingFiles, key)
|
||||
}
|
||||
|
||||
func pendingFileKey(chatID, userID string) string {
|
||||
return strings.TrimSpace(chatID) + "::" + strings.TrimSpace(userID)
|
||||
}
|
||||
|
||||
func sanitizePendingRefs(refs []pendingFileRef) []pendingFileRef {
|
||||
if len(refs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]pendingFileRef, 0, len(refs))
|
||||
seen := map[string]struct{}{}
|
||||
for _, r := range refs {
|
||||
id := strings.TrimSpace(r.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
r.ID = id
|
||||
r.Name = strings.TrimSpace(r.Name)
|
||||
r.MimeType = strings.TrimSpace(r.MimeType)
|
||||
out = append(out, r)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func defaultIfEmpty(v, fallback string) string {
|
||||
if strings.TrimSpace(v) == "" {
|
||||
return fallback
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// formatRelevantSkillsForPrompt 返回与当前用户问题最相关的技能内容。
|
||||
func (o *Orchestrator) formatSelectedSkillsForPrompt(userInput string, selected []knowledge.Skill) string {
|
||||
skills := selected
|
||||
if len(skills) == 0 {
|
||||
return "(none)"
|
||||
skills = o.selectRelevantSkills(userInput, 4)
|
||||
}
|
||||
if len(skills) == 0 {
|
||||
return "(none matched, tools are still globally available)"
|
||||
}
|
||||
return formatSkills(skills)
|
||||
}
|
||||
|
||||
func (o *Orchestrator) routeCapabilities(ctx context.Context, userInput string) capabilityRoutingResult {
|
||||
fallback := capabilityRoutingResult{
|
||||
NeedSkills: true,
|
||||
SelectedSkills: o.selectRelevantSkills(userInput, 4),
|
||||
Reason: "router fallback: keyword matching",
|
||||
UsedFallback: true,
|
||||
}
|
||||
|
||||
raw, err := o.llm.Generate(ctx, o.buildRouteSystemPrompt(), o.buildRouteUserPrompt(userInput))
|
||||
if err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("capability router llm call failed err=%v", err)
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
decision, err := parseCapabilityRoute(raw)
|
||||
if err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("capability router parse failed err=%v raw=%q", err, raw)
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
resolvedTools := o.normalizeToolSelection(decision.SelectedTools)
|
||||
resolved := capabilityRoutingResult{
|
||||
NeedSkills: decision.NeedSkills,
|
||||
SelectedToolNames: resolvedTools,
|
||||
Reason: strings.TrimSpace(decision.Reason),
|
||||
}
|
||||
|
||||
if resolved.NeedSkills {
|
||||
skills := o.resolveSkillsByNames(decision.SelectedSkills, 4)
|
||||
if len(skills) == 0 {
|
||||
skills = o.selectRelevantSkills(userInput, 4)
|
||||
resolved.UsedFallback = true
|
||||
}
|
||||
resolved.SelectedSkills = skills
|
||||
}
|
||||
|
||||
return resolved
|
||||
}
|
||||
|
||||
func (o *Orchestrator) buildRouteSystemPrompt() string {
|
||||
return strings.Join([]string{
|
||||
"你是能力路由器(Router Agent)。",
|
||||
"你的任务是:在不加载技能全文的前提下,仅根据工具摘要和技能摘要,判断本请求是否可以仅靠原子工具能力完成,还是需要加载技能详细说明。",
|
||||
"输出必须且仅能是 JSON:",
|
||||
"{",
|
||||
" \"need_skills\": true 或 false,",
|
||||
" \"selected_tools\": [\"tool_name\", ...],",
|
||||
" \"selected_skills\": [\"skill_name\", ...],",
|
||||
" \"reason\": \"简短路由理由\"",
|
||||
"}",
|
||||
"规则:",
|
||||
"1) 优先原子工具能力。若可通过工具链路完成,need_skills=false。",
|
||||
"2) 只有当工具能力不足以覆盖业务约束时,need_skills=true 并选择少量最相关技能。",
|
||||
"3) selected_skills 仅填写技能名称(来自技能摘要)。",
|
||||
"4) selected_tools 仅填写可用工具名。",
|
||||
"5) 不要输出 JSON 之外内容。",
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
func (o *Orchestrator) buildRouteUserPrompt(userInput string) string {
|
||||
return strings.Join([]string{
|
||||
"当前运行环境:",
|
||||
formatRuntimeContextForPrompt(),
|
||||
"",
|
||||
"用户问题:",
|
||||
userInput,
|
||||
"",
|
||||
"可用工具摘要:",
|
||||
o.formatToolDoc(),
|
||||
"",
|
||||
"可用技能摘要:",
|
||||
o.formatSkillSummariesForPrompt(),
|
||||
"",
|
||||
"请给出路由 JSON。",
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
func (o *Orchestrator) normalizeToolSelection(in []string) []string {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
allowed := map[string]struct{}{}
|
||||
for _, t := range o.tools.List() {
|
||||
allowed[strings.ToLower(strings.TrimSpace(t.Name()))] = struct{}{}
|
||||
}
|
||||
out := make([]string, 0, len(in))
|
||||
set := map[string]struct{}{}
|
||||
for _, name := range in {
|
||||
n := strings.ToLower(strings.TrimSpace(name))
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := allowed[n]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := set[n]; exists {
|
||||
continue
|
||||
}
|
||||
set[n] = struct{}{}
|
||||
out = append(out, n)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func (o *Orchestrator) resolveSkillsByNames(names []string, maxCount int) []knowledge.Skill {
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
if maxCount <= 0 {
|
||||
maxCount = 4
|
||||
}
|
||||
all := o.getSkillsSnapshot()
|
||||
idx := make(map[string]knowledge.Skill, len(all))
|
||||
for _, sk := range all {
|
||||
key := strings.ToLower(strings.TrimSpace(sk.Name))
|
||||
if key != "" {
|
||||
idx[key] = sk
|
||||
}
|
||||
}
|
||||
out := make([]knowledge.Skill, 0, maxCount)
|
||||
used := map[string]struct{}{}
|
||||
for _, name := range names {
|
||||
key := strings.ToLower(strings.TrimSpace(name))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
sk, ok := idx[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := used[key]; exists {
|
||||
continue
|
||||
}
|
||||
used[key] = struct{}{}
|
||||
out = append(out, sk)
|
||||
if len(out) >= maxCount {
|
||||
break
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func formatRouteForPrompt(route capabilityRoutingResult) string {
|
||||
b := strings.Builder{}
|
||||
if route.UsedFallback {
|
||||
b.WriteString("router_status: fallback\n")
|
||||
} else {
|
||||
b.WriteString("router_status: ok\n")
|
||||
}
|
||||
b.WriteString("need_skills: ")
|
||||
b.WriteString(strconv.FormatBool(route.NeedSkills))
|
||||
b.WriteString("\n")
|
||||
b.WriteString("selected_tools: ")
|
||||
if len(route.SelectedToolNames) == 0 {
|
||||
b.WriteString("(none)")
|
||||
} else {
|
||||
b.WriteString(strings.Join(route.SelectedToolNames, ", "))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
b.WriteString("selected_skill_count: ")
|
||||
b.WriteString(strconv.Itoa(len(route.SelectedSkills)))
|
||||
b.WriteString("\n")
|
||||
if strings.TrimSpace(route.Reason) != "" {
|
||||
b.WriteString("reason: ")
|
||||
b.WriteString(strings.TrimSpace(route.Reason))
|
||||
}
|
||||
return strings.TrimSpace(b.String())
|
||||
}
|
||||
|
||||
func (o *Orchestrator) selectRelevantSkills(userInput string, maxCount int) []knowledge.Skill {
|
||||
if maxCount <= 0 {
|
||||
maxCount = 4
|
||||
}
|
||||
query := strings.TrimSpace(strings.ToLower(userInput))
|
||||
all := o.getSkillsSnapshot()
|
||||
if query == "" || len(all) <= maxCount {
|
||||
return all
|
||||
}
|
||||
|
||||
queryTokens := buildQueryTokens(query)
|
||||
type item struct {
|
||||
skill knowledge.Skill
|
||||
score int
|
||||
}
|
||||
ranked := make([]item, 0, len(all))
|
||||
|
||||
for _, sk := range all {
|
||||
hay := strings.ToLower(sk.Name + "\n" + clipForScoring(sk.Content, 1800))
|
||||
score := 0
|
||||
if strings.Contains(hay, query) {
|
||||
score += 8
|
||||
}
|
||||
for _, tk := range queryTokens {
|
||||
if strings.Contains(hay, tk) {
|
||||
score++
|
||||
}
|
||||
}
|
||||
if score == 0 {
|
||||
continue
|
||||
}
|
||||
ranked = append(ranked, item{skill: sk, score: score})
|
||||
}
|
||||
|
||||
if len(ranked) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sort.Slice(ranked, func(i, j int) bool {
|
||||
if ranked[i].score == ranked[j].score {
|
||||
return strings.ToLower(strings.TrimSpace(ranked[i].skill.Name)) < strings.ToLower(strings.TrimSpace(ranked[j].skill.Name))
|
||||
}
|
||||
return ranked[i].score > ranked[j].score
|
||||
})
|
||||
|
||||
if len(ranked) > maxCount {
|
||||
ranked = ranked[:maxCount]
|
||||
}
|
||||
out := make([]knowledge.Skill, 0, len(ranked))
|
||||
for _, r := range ranked {
|
||||
out = append(out, r.skill)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildQueryTokens(query string) []string {
|
||||
set := map[string]struct{}{}
|
||||
collectToken := func(t string) {
|
||||
t = strings.TrimSpace(t)
|
||||
if len([]rune(t)) < 2 {
|
||||
return
|
||||
}
|
||||
set[t] = struct{}{}
|
||||
}
|
||||
|
||||
for _, part := range strings.FieldsFunc(query, func(r rune) bool {
|
||||
if r >= 'a' && r <= 'z' {
|
||||
return false
|
||||
}
|
||||
if r >= '0' && r <= '9' {
|
||||
return false
|
||||
}
|
||||
if r >= 0x4e00 && r <= 0x9fff {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}) {
|
||||
collectToken(part)
|
||||
}
|
||||
|
||||
// 针对中文无空格输入,补充 2-gram 提升匹配命中率。
|
||||
runes := []rune(query)
|
||||
for i := 0; i+1 < len(runes); i++ {
|
||||
r1 := runes[i]
|
||||
r2 := runes[i+1]
|
||||
if (r1 >= 0x4e00 && r1 <= 0x9fff) && (r2 >= 0x4e00 && r2 <= 0x9fff) {
|
||||
collectToken(string([]rune{r1, r2}))
|
||||
}
|
||||
}
|
||||
|
||||
out := make([]string, 0, len(set))
|
||||
for tk := range set {
|
||||
out = append(out, tk)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func clipForScoring(s string, maxRunes int) string {
|
||||
if maxRunes <= 0 {
|
||||
maxRunes = 1800
|
||||
}
|
||||
r := []rune(s)
|
||||
if len(r) <= maxRunes {
|
||||
return s
|
||||
}
|
||||
return string(r[:maxRunes])
|
||||
}
|
||||
|
||||
func formatRuntimeContextForPrompt() string {
|
||||
goos := strings.TrimSpace(strings.ToLower(runtime.GOOS))
|
||||
if goos == "" {
|
||||
goos = "unknown"
|
||||
}
|
||||
return "当前运行系统 GOOS=" + goos + "。请优先使用与该系统一致的策略。仅当用户明确要求时,才采用其他系统(如 Windows)的专用流程。"
|
||||
}
|
||||
|
||||
// emitCapabilityGap 处理能力缺口信息埋点或者通过 AI 自动创建生成相应缺失技能的逻辑
|
||||
func (o *Orchestrator) emitCapabilityGap(chatID, userID, intent, reason string) {
|
||||
if !o.enableCapabilityGap {
|
||||
|
||||
48
internal/agent/orchestrator_skill_selection_test.go
Normal file
48
internal/agent/orchestrator_skill_selection_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"laodingbot/internal/knowledge"
|
||||
)
|
||||
|
||||
func TestBuildQueryTokensIncludesChineseBigrams(t *testing.T) {
|
||||
tokens := buildQueryTokens("请执行命令并查看文件")
|
||||
joined := strings.Join(tokens, ",")
|
||||
if !strings.Contains(joined, "命令") {
|
||||
t.Fatalf("expected token contains 命令, got: %v", tokens)
|
||||
}
|
||||
if !strings.Contains(joined, "文件") {
|
||||
t.Fatalf("expected token contains 文件, got: %v", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectRelevantSkillsPrefersMatchingSkill(t *testing.T) {
|
||||
o := &Orchestrator{
|
||||
skills: []knowledge.Skill{
|
||||
{Name: "文件系统查询专家", Content: "适用于目录、文件、路径、命令执行等场景"},
|
||||
{Name: "天气查询", Content: "用于天气和空气质量查询"},
|
||||
{Name: "日程助手", Content: "用于日程管理"},
|
||||
},
|
||||
}
|
||||
|
||||
selected := o.selectRelevantSkills("帮我执行命令查看某个文件", 2)
|
||||
if len(selected) == 0 {
|
||||
t.Fatal("expected non-empty selected skills")
|
||||
}
|
||||
if selected[0].Name != "文件系统查询专家" {
|
||||
t.Fatalf("expected top skill 文件系统查询专家, got: %s", selected[0].Name)
|
||||
}
|
||||
if len(selected) > 2 {
|
||||
t.Fatalf("expected at most 2 skills, got: %d", len(selected))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatRuntimeContextForPromptIncludesGOOS(t *testing.T) {
|
||||
doc := formatRuntimeContextForPrompt()
|
||||
if !strings.Contains(strings.ToLower(doc), strings.ToLower(runtime.GOOS)) {
|
||||
t.Fatalf("expected runtime context contains GOOS=%s, got: %s", runtime.GOOS, doc)
|
||||
}
|
||||
}
|
||||
31
internal/agent/router_parser.go
Normal file
31
internal/agent/router_parser.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type capabilityRouteDecision struct {
|
||||
NeedSkills bool `json:"need_skills"`
|
||||
SelectedTools []string `json:"selected_tools"`
|
||||
SelectedSkills []string `json:"selected_skills"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
func parseCapabilityRoute(raw string) (capabilityRouteDecision, error) {
|
||||
raw = normalizeJSON(raw)
|
||||
start := strings.Index(raw, "{")
|
||||
end := strings.LastIndex(raw, "}")
|
||||
if start < 0 || end < start {
|
||||
return capabilityRouteDecision{}, fmt.Errorf("no json object found")
|
||||
}
|
||||
raw = raw[start : end+1]
|
||||
|
||||
var out capabilityRouteDecision
|
||||
if err := json.Unmarshal([]byte(raw), &out); err != nil {
|
||||
return capabilityRouteDecision{}, err
|
||||
}
|
||||
out.Reason = strings.TrimSpace(out.Reason)
|
||||
return out, nil
|
||||
}
|
||||
34
internal/agent/router_parser_test.go
Normal file
34
internal/agent/router_parser_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package agent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseCapabilityRoute(t *testing.T) {
|
||||
raw := `{"need_skills":true,"selected_tools":["shell"],"selected_skills":["文件系统查询专家"],"reason":"需要技能约束"}`
|
||||
out, err := parseCapabilityRoute(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("parseCapabilityRoute error: %v", err)
|
||||
}
|
||||
if !out.NeedSkills {
|
||||
t.Fatal("expected need_skills=true")
|
||||
}
|
||||
if len(out.SelectedTools) != 1 || out.SelectedTools[0] != "shell" {
|
||||
t.Fatalf("unexpected selected_tools: %#v", out.SelectedTools)
|
||||
}
|
||||
if len(out.SelectedSkills) != 1 || out.SelectedSkills[0] != "文件系统查询专家" {
|
||||
t.Fatalf("unexpected selected_skills: %#v", out.SelectedSkills)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCapabilityRouteCodeFence(t *testing.T) {
|
||||
raw := "```json\n{\"need_skills\":false,\"selected_tools\":[\"file\",\"shell\"],\"selected_skills\":[],\"reason\":\"工具足够\"}\n```"
|
||||
out, err := parseCapabilityRoute(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("parseCapabilityRoute error: %v", err)
|
||||
}
|
||||
if out.NeedSkills {
|
||||
t.Fatal("expected need_skills=false")
|
||||
}
|
||||
if len(out.SelectedTools) != 2 {
|
||||
t.Fatalf("unexpected selected_tools len: %d", len(out.SelectedTools))
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -18,6 +19,20 @@ type Client interface {
|
||||
Generate(ctx context.Context, systemPrompt, userPrompt string) (string, error)
|
||||
}
|
||||
|
||||
type FileChatClient interface {
|
||||
GenerateWithFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string) (string, error)
|
||||
}
|
||||
|
||||
type FileUploader interface {
|
||||
UploadFile(ctx context.Context, file InputFile, purpose string) (string, error)
|
||||
}
|
||||
|
||||
type InputFile struct {
|
||||
FileName string
|
||||
MimeType string
|
||||
Content []byte
|
||||
}
|
||||
|
||||
type OpenAICompatibleClient struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
@@ -43,27 +58,64 @@ type chatRequest struct {
|
||||
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type chatContentPart struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
FileID string `json:"file_id,omitempty"`
|
||||
}
|
||||
|
||||
type chatResponse struct {
|
||||
Choices []struct {
|
||||
Message chatMessage `json:"message"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type fileUploadResponse struct {
|
||||
ID string `json:"id"`
|
||||
Bytes int64 `json:"bytes,omitempty"`
|
||||
CreatedAt int64 `json:"created_at,omitempty"`
|
||||
Filename string `json:"filename,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Purpose string `json:"purpose,omitempty"`
|
||||
Code int `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Status any `json:"status,omitempty"`
|
||||
StatusDetails any `json:"status_details,omitempty"`
|
||||
Data *struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data,omitempty"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) Generate(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
|
||||
return c.generateInternal(ctx, systemPrompt, userPrompt, nil)
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) GenerateWithFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string) (string, error) {
|
||||
return c.generateInternal(ctx, systemPrompt, userPrompt, fileIDs)
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) generateInternal(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string) (string, error) {
|
||||
if c.log != nil {
|
||||
c.log.Debugf("llm request start model=%s system_len=%d user_len=%d", c.model, len(systemPrompt), len(userPrompt))
|
||||
c.log.Debugf("llm request start model=%s system_len=%d user_len=%d file_count=%d", c.model, len(systemPrompt), len(userPrompt), len(fileIDs))
|
||||
}
|
||||
userContent := buildUserContent(userPrompt, fileIDs)
|
||||
body := chatRequest{
|
||||
Model: c.model,
|
||||
Messages: []chatMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userPrompt},
|
||||
{Role: "user", Content: userContent},
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
@@ -132,3 +184,144 @@ func (c *OpenAICompatibleClient) Generate(ctx context.Context, systemPrompt, use
|
||||
|
||||
return out.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func buildUserContent(userPrompt string, fileIDs []string) any {
|
||||
trimmedPrompt := strings.TrimSpace(userPrompt)
|
||||
if len(fileIDs) == 0 {
|
||||
return userPrompt
|
||||
}
|
||||
|
||||
parts := make([]chatContentPart, 0, len(fileIDs)+1)
|
||||
if trimmedPrompt != "" {
|
||||
parts = append(parts, chatContentPart{Type: "text", Text: userPrompt})
|
||||
}
|
||||
for _, id := range fileIDs {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, chatContentPart{Type: "file", FileID: id})
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return userPrompt
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) UploadFile(ctx context.Context, file InputFile, purpose string) (string, error) {
|
||||
if strings.TrimSpace(file.FileName) == "" {
|
||||
return "", fmt.Errorf("empty file name")
|
||||
}
|
||||
if len(file.Content) == 0 {
|
||||
return "", fmt.Errorf("empty file content")
|
||||
}
|
||||
purpose = strings.TrimSpace(purpose)
|
||||
purposes := []string{}
|
||||
if purpose != "" {
|
||||
purposes = append(purposes, purpose)
|
||||
}
|
||||
// Provider compatibility fallback order.
|
||||
purposes = appendIfMissing(purposes, "file-extract")
|
||||
purposes = appendIfMissing(purposes, "batch")
|
||||
|
||||
var lastErr error
|
||||
for _, p := range purposes {
|
||||
fileID, err := c.uploadFileOnce(ctx, file, p)
|
||||
if err == nil {
|
||||
return fileID, nil
|
||||
}
|
||||
lastErr = err
|
||||
if c.log != nil {
|
||||
c.log.Warnf("llm file upload failed purpose=%s err=%v", p, err)
|
||||
}
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = fmt.Errorf("llm file upload failed: no purpose tried")
|
||||
}
|
||||
return "", lastErr
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) uploadFileOnce(ctx context.Context, file InputFile, purpose string) (string, error) {
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
if err := writer.WriteField("purpose", purpose); err != nil {
|
||||
return "", err
|
||||
}
|
||||
part, err := writer.CreateFormFile("file", file.FileName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := part.Write(file.Content); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url := strings.TrimRight(c.baseURL, "/") + "/files"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
raw, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var out fileUploadResponse
|
||||
if err := json.Unmarshal(raw, &out); err != nil {
|
||||
return "", fmt.Errorf("llm file upload response decode failed: %w body=%s", err, clipForError(raw))
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
if strings.TrimSpace(out.Message) != "" {
|
||||
return "", fmt.Errorf("llm file upload error: %s", out.Message)
|
||||
}
|
||||
if out.Error != nil && out.Error.Message != "" {
|
||||
return "", fmt.Errorf("llm file upload error: %s", out.Error.Message)
|
||||
}
|
||||
return "", fmt.Errorf("llm file upload status: %d body=%s", resp.StatusCode, clipForError(raw))
|
||||
}
|
||||
fileID := strings.TrimSpace(out.ID)
|
||||
if fileID == "" && out.Data != nil {
|
||||
fileID = strings.TrimSpace(out.Data.ID)
|
||||
}
|
||||
if fileID == "" {
|
||||
return "", fmt.Errorf("llm file upload returned empty file id body=%s", clipForError(raw))
|
||||
}
|
||||
if c.log != nil {
|
||||
c.log.Infof("llm file uploaded name=%s size=%d file_id=%s purpose=%s status=%v", file.FileName, len(file.Content), fileID, purpose, out.Status)
|
||||
}
|
||||
return fileID, nil
|
||||
}
|
||||
|
||||
func clipForError(raw []byte) string {
|
||||
s := strings.TrimSpace(string(raw))
|
||||
const max = 400
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max] + "...(truncated)"
|
||||
}
|
||||
|
||||
func appendIfMissing(items []string, value string) []string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return items
|
||||
}
|
||||
for _, it := range items {
|
||||
if strings.EqualFold(strings.TrimSpace(it), value) {
|
||||
return items
|
||||
}
|
||||
}
|
||||
return append(items, value)
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"laodingbot/internal/logger"
|
||||
"laodingbot/internal/tools"
|
||||
"laodingbot/tools/fileoperation"
|
||||
"laodingbot/tools/git"
|
||||
"laodingbot/tools/shell"
|
||||
"laodingbot/tools/websearch"
|
||||
)
|
||||
@@ -16,6 +17,7 @@ import (
|
||||
func RunChild(ctx context.Context, cfg config.Config, log *logger.Logger) error {
|
||||
var registryLog *logger.Logger
|
||||
var fileLog *logger.Logger
|
||||
var gitLog *logger.Logger
|
||||
var shellLog *logger.Logger
|
||||
var searchLog *logger.Logger
|
||||
var serverLog *logger.Logger
|
||||
@@ -23,12 +25,19 @@ func RunChild(ctx context.Context, cfg config.Config, log *logger.Logger) error
|
||||
log.Infof("toolhost child starting")
|
||||
registryLog = log.WithComponent("toolhost.registry")
|
||||
fileLog = log.WithComponent("toolhost.file")
|
||||
gitLog = log.WithComponent("toolhost.git")
|
||||
shellLog = log.WithComponent("toolhost.shell")
|
||||
searchLog = log.WithComponent("toolhost.websearch")
|
||||
serverLog = log.WithComponent("toolhost.server")
|
||||
}
|
||||
registry := tools.NewRegistry(registryLog)
|
||||
registry.Register(fileoperation.New(cfg.Security.AllowedDirs, cfg.ToolOutputMaxChars, fileLog))
|
||||
registry.Register(git.New(
|
||||
cfg.Security.WorkDir,
|
||||
time.Duration(cfg.ToolCallTimeoutSec)*time.Second,
|
||||
cfg.ToolOutputMaxChars,
|
||||
gitLog,
|
||||
))
|
||||
registry.Register(shell.New(
|
||||
cfg.Security.AllowedCommands,
|
||||
cfg.Security.WorkDir,
|
||||
|
||||
@@ -4,6 +4,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -34,9 +38,17 @@ type IncomingMessage struct {
|
||||
MessageID string
|
||||
ChatID string
|
||||
UserID string
|
||||
MsgType string
|
||||
Text string
|
||||
FileName string
|
||||
FileKey string
|
||||
FileMime string
|
||||
FileBytes []byte
|
||||
FilePath string
|
||||
}
|
||||
|
||||
const maxFeishuFileBytes = 20 * 1024 * 1024
|
||||
|
||||
func NewBot(appID, appSecret, verifyToken, _ string, _ string, log *logger.Logger) (*Bot, error) {
|
||||
if appID == "" || appSecret == "" {
|
||||
return nil, fmt.Errorf("empty feishu app credentials")
|
||||
@@ -66,7 +78,7 @@ func (b *Bot) Run(ctx context.Context, handler func(context.Context, IncomingMes
|
||||
incoming, ok := parseIncoming(event)
|
||||
if !ok {
|
||||
if b.log != nil {
|
||||
b.log.Debugf("skip non-text or invalid feishu event")
|
||||
b.log.Debugf("skip unsupported or invalid feishu event")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -76,8 +88,11 @@ func (b *Bot) Run(ctx context.Context, handler func(context.Context, IncomingMes
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if incoming.MsgType == "file" {
|
||||
b.enrichFileIncoming(evtCtx, &incoming)
|
||||
}
|
||||
if b.log != nil {
|
||||
b.log.Infof("feishu message received message_id=%s chat_id=%s user_id=%s text=%s", incoming.MessageID, incoming.ChatID, incoming.UserID, incoming.Text)
|
||||
b.log.Infof("feishu message received message_id=%s chat_id=%s user_id=%s msg_type=%s text=%s", incoming.MessageID, incoming.ChatID, incoming.UserID, incoming.MsgType, incoming.Text)
|
||||
}
|
||||
reply, err := handler(evtCtx, incoming)
|
||||
if err != nil {
|
||||
@@ -159,6 +174,175 @@ func extractText(content string) (string, error) {
|
||||
return parsed.Text, nil
|
||||
}
|
||||
|
||||
func extractFileMeta(content string) (fileName string, fileKey string, err error) {
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
readString := func(keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if v, ok := parsed[key]; ok {
|
||||
s, ok := v.(string)
|
||||
if ok {
|
||||
trimmed := strings.TrimSpace(s)
|
||||
if trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
fileName = readString("file_name", "fileName", "name", "filename")
|
||||
fileKey = readString("file_key", "fileKey", "key")
|
||||
return fileName, fileKey, nil
|
||||
}
|
||||
|
||||
func buildFileRecognitionText(fileName, fileKey string) string {
|
||||
if strings.TrimSpace(fileName) == "" {
|
||||
fileName = "(unknown)"
|
||||
}
|
||||
if strings.TrimSpace(fileKey) == "" {
|
||||
fileKey = "(unknown)"
|
||||
}
|
||||
|
||||
return strings.Join([]string{
|
||||
"用户发送了一条飞书文件消息。",
|
||||
"文件名: " + fileName,
|
||||
"文件Key: " + fileKey,
|
||||
"系统将先上传该文件到 LLM Provider,再由模型完成文档解析。若上传失败,本次请求将直接中止。",
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
func (b *Bot) enrichFileIncoming(ctx context.Context, incoming *IncomingMessage) {
|
||||
if incoming == nil {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(incoming.MessageID) == "" || strings.TrimSpace(incoming.FileKey) == "" {
|
||||
incoming.Text = buildFileRecognitionText(incoming.FileName, incoming.FileKey)
|
||||
incoming.Text += "\n\n未找到完整 file_key 或 message_id,暂时无法下载文件内容。"
|
||||
return
|
||||
}
|
||||
|
||||
content, fileName, err := b.downloadFileContent(ctx, incoming.MessageID, incoming.FileKey)
|
||||
if err != nil {
|
||||
if b.log != nil {
|
||||
b.log.Warnf("feishu download file content failed message_id=%s file_key=%s err=%v", incoming.MessageID, incoming.FileKey, err)
|
||||
}
|
||||
incoming.Text = buildFileRecognitionText(incoming.FileName, incoming.FileKey)
|
||||
incoming.Text += "\n\n文件下载失败: " + err.Error()
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(fileName) != "" {
|
||||
incoming.FileName = fileName
|
||||
}
|
||||
incoming.FileBytes = content
|
||||
incoming.FileMime = detectMimeByName(incoming.FileName)
|
||||
localPath, saveErr := saveIncomingFile("files", incoming.FileName, incoming.FileBytes)
|
||||
if saveErr != nil {
|
||||
if b.log != nil {
|
||||
b.log.Warnf("save incoming feishu file failed name=%s err=%v", incoming.FileName, saveErr)
|
||||
}
|
||||
incoming.Text = buildFileRecognitionText(incoming.FileName, incoming.FileKey)
|
||||
incoming.Text += "\n\n文件已下载但本地保存失败: " + saveErr.Error()
|
||||
return
|
||||
}
|
||||
incoming.FilePath = localPath
|
||||
incoming.Text = buildFileRecognitionText(incoming.FileName, incoming.FileKey)
|
||||
incoming.Text += fmt.Sprintf("\n\n文件已下载并保存到本地,路径=%s,大小=%d bytes,mime=%s。", incoming.FilePath, len(content), incoming.FileMime)
|
||||
}
|
||||
|
||||
func (b *Bot) downloadFileContent(ctx context.Context, messageID, fileKey string) ([]byte, string, error) {
|
||||
req := larkim.NewGetMessageResourceReqBuilder().
|
||||
MessageId(messageID).
|
||||
FileKey(fileKey).
|
||||
Type("file").
|
||||
Build()
|
||||
|
||||
resp, err := b.apiClient.Im.MessageResource.Get(ctx, req)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if resp == nil || resp.File == nil {
|
||||
if resp != nil {
|
||||
return nil, "", fmt.Errorf("empty file stream code=%d msg=%s", resp.Code, resp.Msg)
|
||||
}
|
||||
return nil, "", fmt.Errorf("empty file stream")
|
||||
}
|
||||
|
||||
bts, err := io.ReadAll(resp.File)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
if len(bts) > maxFeishuFileBytes {
|
||||
return nil, "", fmt.Errorf("file too large: %d bytes, max=%d", len(bts), maxFeishuFileBytes)
|
||||
}
|
||||
return bts, strings.TrimSpace(resp.FileName), nil
|
||||
}
|
||||
|
||||
func detectMimeByName(fileName string) string {
|
||||
ext := strings.ToLower(strings.TrimSpace(filepath.Ext(fileName)))
|
||||
if ext == "" {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
m := strings.TrimSpace(mime.TypeByExtension(ext))
|
||||
if m == "" {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func saveIncomingFile(baseDir, fileName string, content []byte) (string, error) {
|
||||
if len(content) == 0 {
|
||||
return "", fmt.Errorf("empty file content")
|
||||
}
|
||||
if strings.TrimSpace(baseDir) == "" {
|
||||
baseDir = "files"
|
||||
}
|
||||
if err := os.MkdirAll(baseDir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
safeName := sanitizeFileName(fileName)
|
||||
if safeName == "" {
|
||||
safeName = "upload.bin"
|
||||
}
|
||||
finalName := fmt.Sprintf("%d_%s", time.Now().UnixNano(), safeName)
|
||||
target := filepath.Join(baseDir, finalName)
|
||||
if err := os.WriteFile(target, content, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
abs, err := filepath.Abs(target)
|
||||
if err != nil {
|
||||
return target, nil
|
||||
}
|
||||
return abs, nil
|
||||
}
|
||||
|
||||
func sanitizeFileName(fileName string) string {
|
||||
name := strings.TrimSpace(filepath.Base(fileName))
|
||||
if name == "" || name == "." || name == ".." {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, r := range name {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' {
|
||||
b.WriteRune(r)
|
||||
continue
|
||||
}
|
||||
b.WriteByte('_')
|
||||
}
|
||||
out := strings.TrimSpace(b.String())
|
||||
if out == "" || out == "." || out == ".." {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(out, ".") {
|
||||
out = "file" + out
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseIncoming(event *larkim.P2MessageReceiveV1) (IncomingMessage, bool) {
|
||||
if event == nil || event.Event == nil || event.Event.Message == nil || event.Event.Sender == nil || event.Event.Sender.SenderId == nil {
|
||||
return IncomingMessage{}, false
|
||||
@@ -168,12 +352,11 @@ func parseIncoming(event *larkim.P2MessageReceiveV1) (IncomingMessage, bool) {
|
||||
}
|
||||
|
||||
msg := event.Event.Message
|
||||
if msg.MessageType == nil || *msg.MessageType != "text" || msg.ChatId == nil || msg.Content == nil || msg.MessageId == nil {
|
||||
if msg.MessageType == nil || msg.ChatId == nil || msg.Content == nil || msg.MessageId == nil {
|
||||
return IncomingMessage{}, false
|
||||
}
|
||||
|
||||
text, err := extractText(*msg.Content)
|
||||
if err != nil {
|
||||
msgType := strings.TrimSpace(*msg.MessageType)
|
||||
if msgType == "" {
|
||||
return IncomingMessage{}, false
|
||||
}
|
||||
|
||||
@@ -186,12 +369,33 @@ func parseIncoming(event *larkim.P2MessageReceiveV1) (IncomingMessage, bool) {
|
||||
userID = *event.Event.Sender.SenderId.UnionId
|
||||
}
|
||||
|
||||
return IncomingMessage{
|
||||
incoming := IncomingMessage{
|
||||
MessageID: *msg.MessageId,
|
||||
ChatID: *msg.ChatId,
|
||||
UserID: userID,
|
||||
Text: text,
|
||||
}, true
|
||||
MsgType: msgType,
|
||||
}
|
||||
|
||||
switch msgType {
|
||||
case "text":
|
||||
text, err := extractText(*msg.Content)
|
||||
if err != nil {
|
||||
return IncomingMessage{}, false
|
||||
}
|
||||
incoming.Text = text
|
||||
return incoming, true
|
||||
case "file":
|
||||
fileName, fileKey, err := extractFileMeta(*msg.Content)
|
||||
if err != nil {
|
||||
return IncomingMessage{}, false
|
||||
}
|
||||
incoming.FileName = fileName
|
||||
incoming.FileKey = fileKey
|
||||
incoming.Text = buildFileRecognitionText(fileName, fileKey)
|
||||
return incoming, true
|
||||
default:
|
||||
return IncomingMessage{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bot) shouldProcessMessage(messageID string) bool {
|
||||
|
||||
138
internal/transport/feishu/bot_test.go
Normal file
138
internal/transport/feishu/bot_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package feishu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1"
|
||||
)
|
||||
|
||||
func mustEventFromJSON(t *testing.T, raw string) *larkim.P2MessageReceiveV1 {
|
||||
t.Helper()
|
||||
var evt larkim.P2MessageReceiveV1
|
||||
if err := json.Unmarshal([]byte(raw), &evt); err != nil {
|
||||
t.Fatalf("unmarshal event json failed: %v", err)
|
||||
}
|
||||
return &evt
|
||||
}
|
||||
|
||||
func TestParseIncomingText(t *testing.T) {
|
||||
evt := mustEventFromJSON(t, `{
|
||||
"event": {
|
||||
"message": {
|
||||
"message_id": "msg_text_1",
|
||||
"chat_id": "chat_1",
|
||||
"message_type": "text",
|
||||
"content": "{\"text\":\"你好\"}"
|
||||
},
|
||||
"sender": {
|
||||
"sender_type": "user",
|
||||
"sender_id": {"open_id": "u_open_1"}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
in, ok := parseIncoming(evt)
|
||||
if !ok {
|
||||
t.Fatal("expected text message parse success")
|
||||
}
|
||||
if in.MsgType != "text" {
|
||||
t.Fatalf("expected msg type text, got %s", in.MsgType)
|
||||
}
|
||||
if in.Text != "你好" {
|
||||
t.Fatalf("unexpected text: %q", in.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseIncomingFile(t *testing.T) {
|
||||
evt := mustEventFromJSON(t, `{
|
||||
"event": {
|
||||
"message": {
|
||||
"message_id": "msg_file_1",
|
||||
"chat_id": "chat_1",
|
||||
"message_type": "file",
|
||||
"content": "{\"file_key\":\"file_key_123\",\"file_name\":\"report.pdf\"}"
|
||||
},
|
||||
"sender": {
|
||||
"sender_type": "user",
|
||||
"sender_id": {"user_id": "u_id_1"}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
in, ok := parseIncoming(evt)
|
||||
if !ok {
|
||||
t.Fatal("expected file message parse success")
|
||||
}
|
||||
if in.MsgType != "file" {
|
||||
t.Fatalf("expected msg type file, got %s", in.MsgType)
|
||||
}
|
||||
if in.FileName != "report.pdf" || in.FileKey != "file_key_123" {
|
||||
t.Fatalf("unexpected file meta: name=%q key=%q", in.FileName, in.FileKey)
|
||||
}
|
||||
if !strings.Contains(in.Text, "飞书文件消息") {
|
||||
t.Fatalf("expected synthesized text mentions file message, got: %q", in.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseIncomingUnsupportedType(t *testing.T) {
|
||||
evt := mustEventFromJSON(t, `{
|
||||
"event": {
|
||||
"message": {
|
||||
"message_id": "msg_image_1",
|
||||
"chat_id": "chat_1",
|
||||
"message_type": "image",
|
||||
"content": "{\"image_key\":\"img_1\"}"
|
||||
},
|
||||
"sender": {
|
||||
"sender_type": "user",
|
||||
"sender_id": {"open_id": "u_open_1"}
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
_, ok := parseIncoming(evt)
|
||||
if ok {
|
||||
t.Fatal("expected unsupported message type rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectMimeByName(t *testing.T) {
|
||||
if got := detectMimeByName("report.pdf"); !strings.Contains(got, "pdf") {
|
||||
t.Fatalf("expected pdf mime, got: %s", got)
|
||||
}
|
||||
if got := detectMimeByName("unknown.custom"); got != "application/octet-stream" {
|
||||
t.Fatalf("expected octet-stream fallback, got: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveIncomingFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path, err := saveIncomingFile(dir, "report.pdf", []byte("hello"))
|
||||
if err != nil {
|
||||
t.Fatalf("saveIncomingFile error: %v", err)
|
||||
}
|
||||
b, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read saved file failed: %v", err)
|
||||
}
|
||||
if string(b) != "hello" {
|
||||
t.Fatalf("unexpected saved content: %q", string(b))
|
||||
}
|
||||
if filepath.Ext(path) != ".pdf" {
|
||||
t.Fatalf("expected .pdf extension, got: %s", path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeFileName(t *testing.T) {
|
||||
got := sanitizeFileName("../bad path/测 试?.pdf")
|
||||
if strings.Contains(got, "/") || strings.Contains(got, "\\") {
|
||||
t.Fatalf("expected sanitized basename only, got: %q", got)
|
||||
}
|
||||
if !strings.HasSuffix(got, ".pdf") {
|
||||
t.Fatalf("expected .pdf suffix, got: %q", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user