diff --git a/.vscode/launch.json b/.vscode/launch.json index c5d955a..c995be1 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -33,6 +33,13 @@ "env": { "MESSAGE_CHANNEL": "feishu" } + }, + { + "name": "Attach: ToolHost Process", + "type": "go", + "request": "attach", + "mode": "local", + "processId": "${command:pickProcess}" } ] } \ No newline at end of file diff --git a/cmd/bot/main.go b/cmd/bot/main.go index 8ed8513..5757967 100644 --- a/cmd/bot/main.go +++ b/cmd/bot/main.go @@ -21,19 +21,26 @@ import ( "laodingbot/internal/transport/telegram" ) +// main 是程序的入口点。它负责初始化环境、加载配置、注册工具并启动消息通道。 func main() { + // 设置优雅监听上下文,接收中断和终止信号 ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() + + // 检查是否作为 Toolhost 的子进程运行 isToolhostChild := len(os.Args) > 1 && os.Args[1] == "--toolhost" workspaceRoot, err := runtimews.PrepareFromEnv() if err != nil { panic(fmt.Sprintf("prepare runtime workspace failed: %v", err)) } + // 加载应用配置 cfg, err := config.Load() if err != nil { panic(fmt.Sprintf("load config failed: %v", err)) } + + // 如果是作为子进程运行,则启动工具宿主端 if isToolhostChild { if err := toolhost.RunChild(ctx, cfg, nil); err != nil && ctx.Err() == nil { panic(fmt.Sprintf("toolhost child failed: %v", err)) @@ -41,6 +48,7 @@ func main() { return } + // 初始化日志系统 appLogger, err := logger.New(cfg.LogLevel) if err != nil { panic(fmt.Sprintf("init logger failed: %v", err)) @@ -48,6 +56,7 @@ func main() { appLogger = appLogger.WithComponent("main") appLogger.Infof("config loaded; channel=%s, log_level=%s workspace=%s", cfg.MessageChannel, cfg.LogLevel, workspaceRoot) + // 初始化 SQLite 数据库存储层(例如记忆存储等) store, err := memory.NewSQLiteStore(cfg.SQLitePath, appLogger.WithComponent("memory")) if err != nil { appLogger.Errorf("init memory store failed: %v", err) @@ -55,12 +64,15 @@ func main() { } defer store.Close() + // 注册内部系统工具 toolRegistry := tools.NewRegistry(appLogger.WithComponent("tools.registry")) exePath, err := os.Executable() if err != nil { appLogger.Errorf("resolve executable path failed: %v", err) panic(err) } + + // 初始化工具宿主客户端,以便运行独立进程内的工具 tc, err := toolhost.NewClient(toolhost.ClientConfig{ ExecutablePath: exePath, Args: []string{"--toolhost"}, @@ -75,6 +87,7 @@ func main() { } defer tc.Close() + // 获取支持的工具列表并将其注册 listCtx, cancel := context.WithTimeout(ctx, 10*time.Second) toolInfos, err := tc.ToolList(listCtx) cancel() @@ -89,25 +102,37 @@ func main() { toolRegistry.Register(toolhost.NewRemoteTool(info.Name, info.Description, time.Duration(cfg.ToolCallTimeoutSec)*time.Second, tc)) } + // 加载 AI 角色的基础信息 (Soul) soul, err := knowledge.LoadSoul(cfg.SoulPath) if err != nil { appLogger.Errorf("load soul failed path=%s err=%v", cfg.SoulPath, err) panic(err) } + // 加载所有可用技能 skillSet, err := knowledge.LoadSkillSet(cfg.SkillsDir) if err != nil { appLogger.Errorf("load skill set failed dir=%s err=%v", cfg.SkillsDir, err) panic(err) } + // 加载技能总结,用于后续路由和匹配 + skillSummaries, err := knowledge.LoadSkillSummaries(cfg.SkillsDir) + if err != nil { + appLogger.Errorf("load skill summaries failed dir=%s err=%v", cfg.SkillsDir, err) + panic(err) + } appLogger.Infof("knowledge loaded soul_path=%s skills_dir=%s", cfg.SoulPath, cfg.SkillsDir) + // 实例化 LLM 客户端 llmClient := llm.NewOpenAICompatibleClient(cfg.LLM, appLogger.WithComponent("llm")) + + // 创建编排器,整合 LLM、记忆系统、知识技能库与各种工具 engine := agent.NewOrchestrator( llmClient, store, toolRegistry, soul, skillSet, + skillSummaries, cfg.SkillsDir, cfg.ReactMaxSteps, cfg.EnableCapabilityGap, @@ -118,6 +143,7 @@ func main() { ) appLogger.Infof("LaodingBot started, channel=%s", cfg.MessageChannel) + // 根据配置启动对应的信息通道 if err := runMessageChannel(ctx, cfg, engine, appLogger); err != nil && ctx.Err() == nil { appLogger.Errorf("message channel run failed: %v", err) panic(err) @@ -125,6 +151,7 @@ func main() { appLogger.Infof("LaodingBot stopped") } +// runMessageChannel 负责初始化并运行配置指定的消息通道(如 telegram 或 feishu)。 func runMessageChannel(ctx context.Context, cfg config.Config, engine *agent.Orchestrator, lg *logger.Logger) error { switch cfg.MessageChannel { case "telegram": diff --git a/internal/agent/orchestrator.go b/internal/agent/orchestrator.go index 92f8ef3..7b7d19b 100644 --- a/internal/agent/orchestrator.go +++ b/internal/agent/orchestrator.go @@ -2,7 +2,6 @@ package agent import ( "context" - "encoding/json" "fmt" "sort" "strconv" @@ -17,28 +16,32 @@ import ( "laodingbot/internal/tools" ) +// Orchestrator 负责协调和组合业务逻辑,包含 LLM 计算、上下文管理、技能匹配计算和工具调用。 type Orchestrator struct { - llm llm.Client - store *memory.SQLiteStore - tools *tools.Registry - soul string - skills []knowledge.Skill - skillsDir string - autoSkillDir string + llm llm.Client + store *memory.SQLiteStore + tools *tools.Registry + soul string + skills []knowledge.Skill + skillSummaries []knowledge.SkillSummary + skillsDir string + autoSkillDir string gapDraftTriggerCount int gapLookbackDuration time.Duration - reactMaxStep int - enableCapabilityGap bool - log *logger.Logger - skillsMu sync.RWMutex + reactMaxStep int + enableCapabilityGap bool + log *logger.Logger + skillsMu sync.RWMutex } +// NewOrchestrator 创建一个新的编排器对象,初始化关键路径和超时控制等。 func NewOrchestrator( llmClient llm.Client, store *memory.SQLiteStore, registry *tools.Registry, soul string, skills []knowledge.Skill, + skillSummaries []knowledge.SkillSummary, skillsDir string, reactMaxStep int, enableCapabilityGap bool, @@ -48,34 +51,41 @@ func NewOrchestrator( log *logger.Logger, ) *Orchestrator { if reactMaxStep <= 0 { - reactMaxStep = 4 + reactMaxStep = 8 // 默认最大 ReAct 步骤数为 8 } if gapDraftTriggerCount <= 0 { - gapDraftTriggerCount = 3 + gapDraftTriggerCount = 3 // 默认触发技能生成的缺口数量为 3 } if gapLookbackDuration <= 0 { - gapLookbackDuration = 7 * 24 * time.Hour + gapLookbackDuration = 7 * 24 * time.Hour // 默认回溯时长为 7 天 } if strings.TrimSpace(autoSkillDir) == "" { autoSkillDir = skillsDir } return &Orchestrator{ - llm: llmClient, - store: store, - tools: registry, - soul: soul, - skills: skills, - skillsDir: skillsDir, - autoSkillDir: autoSkillDir, + llm: llmClient, + store: store, + tools: registry, + soul: soul, + skills: skills, + skillSummaries: copySkillSummaries(skillSummaries), + skillsDir: skillsDir, + autoSkillDir: autoSkillDir, gapDraftTriggerCount: gapDraftTriggerCount, gapLookbackDuration: gapLookbackDuration, - reactMaxStep: reactMaxStep, - enableCapabilityGap: enableCapabilityGap, - log: log, + reactMaxStep: reactMaxStep, + enableCapabilityGap: enableCapabilityGap, + log: log, } } +// HandleMessage 是接受用户消息输入并通过统一 ReAct 循环生成回复的主流程。 +// 不再分"先选 skill 再决策"两步,而是 LLM 第一次调用就同时决定: +// - 是否可以直接回答(is_final_answer=true) +// - 是否需要调用工具(action + action_input) +// 循环持续进行,直到 LLM 返回 is_final_answer=true。 func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text string) (string, error) { + // 为链路追踪设置唯一的 TraceID traceID := logger.NewTraceID() ctx = logger.WithTraceID(ctx, traceID) traceLogPrefix := "trace_id=" + traceID @@ -83,12 +93,16 @@ func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text s o.log.Infof("%s handle message chat_id=%s user_id=%s text_len=%d", traceLogPrefix, chatID, userID, len(text)) o.log.Debugf("%s handle message text=%q", traceLogPrefix, text) } + + // 处理特殊的重载指令 if strings.EqualFold(strings.TrimSpace(text), "/reload_skills") { if err := o.ReloadSkills(); err != nil { return "技能热加载失败: " + err.Error(), nil } return "技能已热加载完成。", nil } + + // 如果用户请求能力缺口报告,则生成报告格式化输出 if strings.EqualFold(strings.TrimSpace(text), "/capability_gaps") { report, err := o.BuildCapabilityGapReport(10) if err != nil { @@ -96,6 +110,8 @@ func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text s } return report, nil } + + // 保存用户消息到 SQLite 中 if err := o.store.SaveMessage(chatID, userID, "user", text); err != nil { if o.log != nil { o.log.Errorf("%s save user message failed chat_id=%s err=%v", traceLogPrefix, chatID, err) @@ -103,6 +119,7 @@ func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text s return "", err } + // 读取最近的会话记忆并压缩成 Prompt 上下文 recent, err := o.store.LoadRecent(chatID, 16) if err != nil { if o.log != nil { @@ -115,35 +132,8 @@ func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text s o.log.Debugf("%s prompt context prepared chat_id=%s recent_count=%d compressed_len=%d", traceLogPrefix, chatID, len(recent), len(compressed)) } - matchedSkills := o.matchSkills(ctx, compressed, text) - if len(matchedSkills) == 0 { - if bootstrap, ok := o.findSkillByKeyword("创建skill", "skill builder", "skill 创建", "构建技能"); ok { - matchedSkills = []knowledge.Skill{bootstrap} - if o.log != nil { - o.log.Infof("%s fallback bootstrap skill selected name=%s", traceLogPrefix, bootstrap.Name) - } - } - } - - var response string - if len(matchedSkills) == 0 { - if o.log != nil { - o.log.Infof("%s no skill matched; use direct llm chat_id=%s", traceLogPrefix, chatID) - } - o.emitCapabilityGap(chatID, userID, text, "no_skill_matched") - response, err = o.runDirectLLM(ctx, compressed, text) - } else { - if o.log != nil { - names := make([]string, 0, len(matchedSkills)) - for _, s := range matchedSkills { - names = append(names, s.Name) - o.log.Infof("%s skill selected name=%s source=%s", traceLogPrefix, s.Name, s.Source) - o.log.Debugf("%s skill selected content name=%s content=%q", traceLogPrefix, s.Name, s.Content) - } - o.log.Infof("%s skills matched chat_id=%s skills=%s", traceLogPrefix, chatID, strings.Join(names, ",")) - } - response, err = o.runReAct(ctx, chatID, userID, compressed, text, matchedSkills) - } + // 进入统一 ReAct 循环 + response, err := o.runUnifiedReAct(ctx, chatID, userID, compressed, text) if err != nil { if o.log != nil { o.log.Errorf("%s message generation failed chat_id=%s err=%v", traceLogPrefix, chatID, err) @@ -151,86 +141,87 @@ func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text s return "", err } + // 最终将机器人的回复也加入记忆缓存 if err := o.store.SaveMessage(chatID, userID, "assistant", response); err != nil { if o.log != nil { o.log.Errorf("%s save assistant response failed chat_id=%s err=%v", traceLogPrefix, chatID, err) } return "", err } + if o.log != nil { o.log.Infof("%s message handled chat_id=%s response_len=%d", traceLogPrefix, chatID, len(response)) } return response, nil } -func (o *Orchestrator) runDirectLLM(ctx context.Context, compressedContext, userInput string) (string, error) { - systemPrompt := strings.Join([]string{ - "你是一个个人自动化助手,必须遵循如下人格设定并保持一致:", - o.soul, - "", - "如果当前问题没有匹配到已定义技能,请直接回答用户。", - "当你判断必须依赖外部工具结果才能可靠回答时,请明确告知用户需要进一步操作信息。", - }, "\n") - - userPrompt := strings.Join([]string{ - "历史上下文:", - compressedContext, - "", - "用户问题:", - userInput, - }, "\n") - - return o.llm.Generate(ctx, systemPrompt, userPrompt) -} - -type reactDecision struct { - Thought string `json:"thought"` - Action string `json:"action"` - ActionInput string `json:"action_input"` - Final string `json:"final"` -} - -func (o *Orchestrator) runReAct(ctx context.Context, chatID, userID, compressedContext, userInput string, selectedSkills []knowledge.Skill) (string, error) { - traceID := logger.TraceIDFromContext(ctx) - traceLogPrefix := "trace_id=" + traceID - selectedSkillsDoc := formatSkills(selectedSkills) +// buildUnifiedSystemPrompt 构建统一 ReAct 循环的 system prompt。 +// 包含人格设定、所有可用技能(含完整内容)、所有可用工具、以及 JSON 输出格式约束。 +func (o *Orchestrator) buildUnifiedSystemPrompt() string { + skillMetaDoc := o.formatSkillSummariesForPrompt() + allSkillsDoc := o.formatAllSkillsContent() toolDoc := o.formatToolDoc() - if o.log != nil { - names := make([]string, 0, len(selectedSkills)) - for _, s := range selectedSkills { - names = append(names, s.Name) - } - o.log.Infof("%s react start steps=%d skills=%s", traceLogPrefix, o.reactMaxStep, strings.Join(names, ",")) - o.log.Debugf("%s react selected_skills_doc=%q", traceLogPrefix, selectedSkillsDoc) - o.log.Debugf("%s react tools_doc=%q", traceLogPrefix, toolDoc) - } - systemPrompt := strings.Join([]string{ + return strings.Join([]string{ "你是一个个人自动化助手,必须遵循如下人格设定并保持一致:", o.soul, "", - "已匹配到的 skills(只可按下列技能执行):", - selectedSkillsDoc, + "===== 可用技能概览 =====", + skillMetaDoc, "", - "可用工具:", + "===== 技能详细说明 =====", + allSkillsDoc, + "", + "===== 可用工具 =====", toolDoc, "", - "你必须使用 ReAct 模式做决策。", - "只有当技能明确需要工具能力时才调用工具。", - "如果问题可直接回答,不要调用工具。", - "你的输出必须是 JSON,对象字段为 thought, action, action_input, final。", - "规则:", - "1) 当需要调工具时:final 置空,action 必须是可用工具之一,action_input 为工具输入。", - "2) 当可以最终回答时:action 置 none,action_input 置空,final 填最终回复。", - "3) 不要输出 JSON 之外内容。", + "===== 输出格式约束 =====", + "你必须使用 ReAct(Reasoning + Acting)模式进行决策。", + "每次回复必须是且仅是一个 JSON 对象,字段如下:", + "", + "{", + " \"thought\": \"你的推理过程(必填)\",", + " \"action\": \"要调用的工具名称,如 file/shell/web_search(不调工具时填 none)\",", + " \"action_input\": \"传给工具的输入(字符串或对象),不调工具时填空字符串或 null\",", + " \"is_final_answer\": true 或 false,", + " \"final_answer\": \"当 is_final_answer=true 时填写给用户的最终回复,否则填 null\"", + "}", + "", + "决策规则:", + "1) 如果你可以直接回答用户问题(不需要任何工具):", + " 设 is_final_answer=true,action=\"none\",final_answer 填写完整回复。", + "2) 如果你需要调用工具获取信息后才能回答:", + " 设 is_final_answer=false,action 填工具名,action_input 填工具所需输入,final_answer=null。", + "3) 不要在 JSON 之外输出任何内容。", + "4) 根据技能说明中的指引决定何时以及如何使用工具。", + "5) 每轮工具调用结果会以 Observation 的形式追加到推理记录中,供你下一轮决策参考。", }, "\n") +} +// runUnifiedReAct 执行统一的 ReAct 循环。 +// LLM 每次都看到完整的技能集+工具集,自行决定是否调用工具或直接回答。 +// 循环持续到 is_final_answer=true 或达到安全上限。 +func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, compressedContext, userInput string) (string, error) { + traceID := logger.TraceIDFromContext(ctx) + traceLogPrefix := "trace_id=" + traceID + + systemPrompt := o.buildUnifiedSystemPrompt() + + if o.log != nil { + o.log.Infof("%s unified react start", traceLogPrefix) + } + + // 安全上限:防止无限循环(当前暂不使用 reactMaxStep 配置约束,使用固定硬上限) + const maxSteps = 20 scratchpad := "" - for step := 1; step <= o.reactMaxStep; step++ { + + for step := 1; step <= maxSteps; step++ { if o.log != nil { - o.log.Infof("%s react step start step=%d/%d", traceLogPrefix, step, o.reactMaxStep) - o.log.Debugf("%s react scratchpad_before step=%d content=%q", traceLogPrefix, step, scratchpad) + o.log.Infof("%s react step=%d start", traceLogPrefix, step) + o.log.Debugf("%s react step=%d scratchpad=%q", traceLogPrefix, step, scratchpad) } + + // 构造本轮 user prompt:历史上下文 + 用户问题 + 推理记录 prompt := strings.Join([]string{ "历史上下文:", compressedContext, @@ -241,7 +232,7 @@ func (o *Orchestrator) runReAct(ctx context.Context, chatID, userID, compressedC "当前推理记录(按时间顺序):", scratchpad, "", - fmt.Sprintf("请输出下一步 JSON 决策。当前步骤: %d/%d", step, o.reactMaxStep), + "请输出你的 JSON 决策。", }, "\n") raw, err := o.llm.Generate(ctx, systemPrompt, prompt) @@ -249,51 +240,72 @@ func (o *Orchestrator) runReAct(ctx context.Context, chatID, userID, compressedC return "", err } if o.log != nil { - o.log.Infof("%s react step llm output step=%d raw=%q", traceLogPrefix, step, raw) + o.log.Infof("%s react step=%d llm_raw=%q", traceLogPrefix, step, raw) } + + // 解析 LLM 返回的 JSON 决策 decision, err := parseDecision(raw) if err != nil { if o.log != nil { - o.log.Warnf("%s react parse failed, fallback to direct llm err=%v", traceLogPrefix, err) + o.log.Warnf("%s react step=%d parse failed err=%v, using raw as final answer", traceLogPrefix, step, err) } + // 解析失败时,尝试将原始输出当作直接回答返回 o.emitCapabilityGap(chatID, userID, userInput, "react_parse_failed") - return o.runDirectLLM(ctx, compressedContext, userInput) + return strings.TrimSpace(raw), nil } + if o.log != nil { - o.log.Infof("%s react step decision step=%d thought=%q action=%q action_input=%q final=%q", traceLogPrefix, step, decision.Thought, decision.Action, decision.ActionInput, decision.Final) + o.log.Infof("%s react step=%d thought=%q action=%q is_final=%v", + traceLogPrefix, step, decision.Thought, decision.Action, decision.IsFinalAnswer) } - action := strings.ToLower(strings.TrimSpace(decision.Action)) - if action == "" { - action = "none" - } - - if action == "none" { - finalText := strings.TrimSpace(decision.Final) + // ========== 判定:是否为最终回答 ========== + if decision.IsFinalAnswer { + finalText := "" + if decision.FinalAnswer != nil { + finalText = strings.TrimSpace(*decision.FinalAnswer) + } if finalText == "" { - finalText = "我已完成思考,但当前没有足够信息给出稳定结论。" + finalText = strings.TrimSpace(decision.Thought) + } + if finalText == "" { + finalText = "已完成处理。" } if o.log != nil { - o.log.Infof("%s react final step=%d final=%q", traceLogPrefix, step, finalText) + o.log.Infof("%s react final at step=%d answer=%q", traceLogPrefix, step, finalText) } return finalText, nil } + // ========== 非最终回答:执行工具调用 ========== + action := strings.ToLower(strings.TrimSpace(decision.Action)) + if action == "" || action == "none" { + // LLM 说不是最终回答但也不指定工具,记录后让它再想一轮 + scratchpad += "Step " + strconv.Itoa(step) + " Thought: " + decision.Thought + "\n" + scratchpad += "Step " + strconv.Itoa(step) + " Observation: 你没有指定要调用的工具,请重新决策:要么调用工具,要么给出最终回答。\n" + continue + } + + actionInput := decision.GetActionInputString() + + // 检查工具是否存在 tool, ok := o.tools.Get(action) if !ok { if o.log != nil { - o.log.Warnf("%s react step tool missing step=%d tool=%s", traceLogPrefix, step, action) + o.log.Warnf("%s react step=%d tool_not_found=%s", traceLogPrefix, step, action) } scratchpad += "Step " + strconv.Itoa(step) + " Thought: " + decision.Thought + "\n" - scratchpad += "Step " + strconv.Itoa(step) + " Observation: " + formatToolErrorObservation("TOOL_NOT_FOUND", action, "tool not found") + "\n" + scratchpad += "Step " + strconv.Itoa(step) + " Action: " + action + "\n" + scratchpad += "Step " + strconv.Itoa(step) + " Observation: " + formatToolErrorObservation("TOOL_NOT_FOUND", action, "该工具不存在,可用工具请参阅 system prompt") + "\n" o.emitCapabilityGap(chatID, userID, userInput, "tool_not_found:"+action) continue } - toolOut, toolErr := tool.Call(ctx, decision.ActionInput) + // 调用工具 if o.log != nil { - o.log.Infof("%s react step tool call step=%d tool=%s input=%q", traceLogPrefix, step, action, decision.ActionInput) + o.log.Infof("%s react step=%d tool_call tool=%s input=%q", traceLogPrefix, step, action, actionInput) } + toolOut, toolErr := tool.Call(ctx, actionInput) obs := strings.TrimSpace(toolOut) if obs == "" { obs = "(empty output)" @@ -302,103 +314,37 @@ func (o *Orchestrator) runReAct(ctx context.Context, chatID, userID, compressedC obs = formatToolErrorObservation("TOOL_EXEC_ERROR", action, toolErr.Error()) + "\nOUTPUT:\n" + obs o.emitCapabilityGap(chatID, userID, userInput, "tool_call_failed:"+action) } + // 限制观察值长度防止超出 LLM 上下文窗口 + if len(obs) > 4000 { + obs = obs[:4000] + "\n...(truncated)" + } + if o.log != nil { - o.log.Infof("%s react step observation step=%d tool=%s observation=%q", traceLogPrefix, step, action, obs) - } - if len(obs) > 2000 { - obs = obs[:2000] + o.log.Infof("%s react step=%d observation_len=%d", traceLogPrefix, step, len(obs)) } + + // 将本轮的思考、行动、观察追加到 scratchpad scratchpad += "Step " + strconv.Itoa(step) + " Thought: " + decision.Thought + "\n" scratchpad += "Step " + strconv.Itoa(step) + " Action: " + action + "\n" - scratchpad += "Step " + strconv.Itoa(step) + " ActionInput: " + decision.ActionInput + "\n" + scratchpad += "Step " + strconv.Itoa(step) + " ActionInput: " + actionInput + "\n" scratchpad += "Step " + strconv.Itoa(step) + " Observation: " + obs + "\n" } + // 达到安全上限仍未得到最终回答 o.emitCapabilityGap(chatID, userID, userInput, "react_step_exhausted") - return "我尝试了多轮思考与工具调用,但仍未得到稳定结论。请给我更具体的约束或允许我继续尝试。", nil + return "我尝试了多轮推理与工具调用,但仍未得到稳定结论。请给我更具体的约束或允许我继续尝试。", nil } -func (o *Orchestrator) matchSkills(ctx context.Context, compressedContext, userInput string) []knowledge.Skill { +// formatAllSkillsContent 返回所有技能的完整内容,用于注入到 system prompt 中。 +func (o *Orchestrator) formatAllSkillsContent() string { skills := o.getSkillsSnapshot() if len(skills) == 0 { - return nil + return "(none)" } - - type skillChoice struct { - Skills []string `json:"skills"` - } - - systemPrompt := strings.Join([]string{ - "你是技能路由器。", - "任务:根据用户问题,从候选技能中选择 0-2 个最相关技能名称。", - "输出必须是 JSON:{\"skills\":[\"name1\",\"name2\"]}", - "如果没有匹配技能,返回 {\"skills\":[]}。", - "不要输出 JSON 之外内容。", - }, "\n") - - userPrompt := strings.Join([]string{ - "候选技能:", - formatSkillCatalog(skills), - "", - "历史上下文:", - compressedContext, - "", - "用户问题:", - userInput, - }, "\n") - - raw, err := o.llm.Generate(ctx, systemPrompt, userPrompt) - if err != nil { - if o.log != nil { - o.log.Warnf("skill match llm failed err=%v", err) - } - return nil - } - if o.log != nil { - o.log.Infof("skill router output raw=%q", raw) - } - - raw = normalizeJSON(raw) - choice := skillChoice{} - if err := json.Unmarshal([]byte(raw), &choice); err != nil { - if o.log != nil { - o.log.Warnf("skill match parse failed err=%v", err) - } - return nil - } - - picked := make([]knowledge.Skill, 0, 2) - seen := map[string]struct{}{} - for _, name := range choice.Skills { - name = strings.TrimSpace(strings.ToLower(name)) - if name == "" { - continue - } - if _, ok := seen[name]; ok { - continue - } - for _, skill := range skills { - if strings.ToLower(strings.TrimSpace(skill.Name)) == name { - picked = append(picked, skill) - seen[name] = struct{}{} - break - } - } - if len(picked) >= 2 { - break - } - } - if o.log != nil { - names := make([]string, 0, len(picked)) - for _, s := range picked { - names = append(names, s.Name) - } - o.log.Infof("skill router selected skills=%s", strings.Join(names, ",")) - } - - return picked + return formatSkills(skills) } +// emitCapabilityGap 处理能力缺口信息埋点或者通过 AI 自动创建生成相应缺失技能的逻辑 func (o *Orchestrator) emitCapabilityGap(chatID, userID, intent, reason string) { if !o.enableCapabilityGap { return @@ -409,16 +355,17 @@ func (o *Orchestrator) emitCapabilityGap(chatID, userID, intent, reason string) return } if len(intent) > 1000 { - intent = intent[:1000] + intent = intent[:1000] // 防止恶意使用超长 payload } if len(reason) > 240 { - reason = reason[:240] + reason = reason[:240] // 保证状态长度在 DB 内正常可用 } if err := o.store.SaveCapabilityGap(chatID, userID, intent, reason); err != nil && o.log != nil { o.log.Warnf("save capability gap failed chat_id=%s user_id=%s err=%v", chatID, userID, err) return } + // 提取出高频率缺口并在超出阈值后进行 draft 生成 clusters, err := o.store.TopCapabilityGapClusters(20, time.Now().UTC().Add(-o.gapLookbackDuration)) if err != nil { if o.log != nil { @@ -430,6 +377,7 @@ func (o *Orchestrator) emitCapabilityGap(chatID, userID, intent, reason string) if c.Count < o.gapDraftTriggerCount { continue } + path, created, draftErr := knowledge.GenerateSkillDraft(c, o.autoSkillDir) if draftErr != nil { if o.log != nil { @@ -440,6 +388,7 @@ func (o *Orchestrator) emitCapabilityGap(chatID, userID, intent, reason string) if created && o.log != nil { o.log.Infof("capability gap draft generated path=%s intent_key=%s reason=%s count=%d", path, c.IntentKey, c.Reason, c.Count) } + // 如果生成了新技能则将它们重新加载进环境 if created { if reloadErr := o.ReloadSkills(); reloadErr != nil && o.log != nil { o.log.Warnf("auto reload skills failed after generation path=%s err=%v", path, reloadErr) @@ -448,13 +397,20 @@ func (o *Orchestrator) emitCapabilityGap(chatID, userID, intent, reason string) } } +// ReloadSkills 会从提供的技能目录动态从最新存储位置载入所有技能定义而不重启系统。 func (o *Orchestrator) ReloadSkills() error { skills, err := knowledge.LoadSkillSet(o.skillsDir) if err != nil { return err } + summaries, err := knowledge.LoadSkillSummaries(o.skillsDir) + if err != nil { + return err + } + // 利用 RWMutex 做热更新保护 o.skillsMu.Lock() o.skills = skills + o.skillSummaries = copySkillSummaries(summaries) o.skillsMu.Unlock() if o.log != nil { o.log.Infof("skills hot reloaded count=%d dir=%s", len(skills), o.skillsDir) @@ -470,6 +426,13 @@ func (o *Orchestrator) getSkillsSnapshot() []knowledge.Skill { return out } +func (o *Orchestrator) getSkillSummariesSnapshot() []knowledge.SkillSummary { + o.skillsMu.RLock() + defer o.skillsMu.RUnlock() + return copySkillSummaries(o.skillSummaries) +} + +// BuildCapabilityGapReport 生成指定数量以内的近期高频缺失功能报错并格式化成报表。 func (o *Orchestrator) BuildCapabilityGapReport(limit int) (string, error) { clusters, err := o.store.TopCapabilityGapClusters(limit, time.Now().UTC().Add(-o.gapLookbackDuration)) if err != nil { @@ -490,25 +453,55 @@ func (o *Orchestrator) BuildCapabilityGapReport(limit int) (string, error) { return b.String(), nil } -func (o *Orchestrator) findSkillByKeyword(keywords ...string) (knowledge.Skill, bool) { - if len(keywords) == 0 { - return knowledge.Skill{}, false +func (o *Orchestrator) formatSkillSummariesForPrompt() string { + summaries := o.getSkillSummariesSnapshot() + if len(summaries) == 0 { + return "(none)" } - skills := o.getSkillsSnapshot() - for _, s := range skills { - name := strings.ToLower(strings.TrimSpace(s.Name)) - content := strings.ToLower(strings.TrimSpace(s.Content)) - for _, kw := range keywords { - kw = strings.ToLower(strings.TrimSpace(kw)) - if kw == "" { - continue - } - if strings.Contains(name, kw) || strings.Contains(content, kw) { - return s, true - } + sort.Slice(summaries, func(i, j int) bool { + left := strings.ToLower(strings.TrimSpace(summaries[i].DirName)) + right := strings.ToLower(strings.TrimSpace(summaries[j].DirName)) + if left == right { + return strings.ToLower(strings.TrimSpace(summaries[i].Name)) < strings.ToLower(strings.TrimSpace(summaries[j].Name)) } + return left < right + }) + b := strings.Builder{} + for _, summary := range summaries { + dir := strings.TrimSpace(summary.DirName) + name := strings.TrimSpace(summary.Name) + desc := strings.TrimSpace(summary.Description) + if name == "" { + continue + } + if len(desc) > 220 { + desc = desc[:220] + } + b.WriteString("- ") + if dir != "" { + b.WriteString("[") + b.WriteString(dir) + b.WriteString("] ") + } + b.WriteString(name) + if desc != "" { + b.WriteString(" => ") + b.WriteString(desc) + } + b.WriteString("\n") } - return knowledge.Skill{}, false + return strings.TrimSpace(b.String()) +} + +func copySkillSummaries(in []knowledge.SkillSummary) []knowledge.SkillSummary { + out := make([]knowledge.SkillSummary, len(in)) + copy(out, in) + for i := range out { + out[i].DirName = strings.TrimSpace(out[i].DirName) + out[i].Name = strings.TrimSpace(out[i].Name) + out[i].Description = strings.TrimSpace(out[i].Description) + } + return out } func formatToolErrorObservation(code, action, reason string) string { @@ -539,25 +532,6 @@ func formatSkills(skills []knowledge.Skill) string { return strings.TrimSpace(b.String()) } -func formatSkillCatalog(skills []knowledge.Skill) string { - b := strings.Builder{} - for _, skill := range skills { - summary := strings.ReplaceAll(skill.Content, "\n", " ") - summary = strings.TrimSpace(summary) - if len(summary) > 220 { - summary = summary[:220] - } - b.WriteString("- ") - b.WriteString(skill.Name) - if summary != "" { - b.WriteString(": ") - b.WriteString(summary) - } - b.WriteString("\n") - } - return strings.TrimSpace(b.String()) -} - func (o *Orchestrator) formatToolDoc() string { list := o.tools.List() if len(list) == 0 { diff --git a/internal/agent/react_parser.go b/internal/agent/react_parser.go index d15afd6..e35e821 100644 --- a/internal/agent/react_parser.go +++ b/internal/agent/react_parser.go @@ -6,6 +6,36 @@ import ( "strings" ) +// reactDecision 是 LLM 在统一 ReAct 循环中返回的结构化 JSON 决策。 +// 每轮 LLM 调用都返回这个结构,由 agent 判断是否继续循环。 +type reactDecision struct { + // Thought 是 LLM 的当前推理过程描述 + Thought string `json:"thought"` + // Action 是需要调用的工具名称(如 "file"、"shell"、"web_search"),不需要工具时为 "none" 或空 + Action string `json:"action"` + // ActionInput 是传给工具的输入参数,可以是字符串或结构化对象 + ActionInput json.RawMessage `json:"action_input"` + // IsFinalAnswer 标记本轮是否为最终回答。true 表示 ReAct 循环结束。 + IsFinalAnswer bool `json:"is_final_answer"` + // FinalAnswer 当 IsFinalAnswer 为 true 时,包含给用户的最终回复内容 + FinalAnswer *string `json:"final_answer"` +} + +// GetActionInputString 将 ActionInput 转为字符串,用于传递给工具的 Call 方法。 +// 如果 ActionInput 是 JSON 字符串则去掉引号;如果是对象/数组则保持 JSON 原文。 +func (d *reactDecision) GetActionInputString() string { + if len(d.ActionInput) == 0 { + return "" + } + // 尝试解析为字符串 + var s string + if err := json.Unmarshal(d.ActionInput, &s); err == nil { + return s + } + // 非字符串则直接返回 JSON 原文 + return strings.TrimSpace(string(d.ActionInput)) +} + func parseDecision(raw string) (reactDecision, error) { raw = normalizeJSON(raw) start := strings.Index(raw, "{") diff --git a/internal/agent/react_parser_test.go b/internal/agent/react_parser_test.go index a36efdd..81e2fda 100644 --- a/internal/agent/react_parser_test.go +++ b/internal/agent/react_parser_test.go @@ -2,28 +2,70 @@ package agent import "testing" -func TestParseDecisionPlainJSON(t *testing.T) { - raw := `{"thought":"t","action":"none","action_input":"","final":"ok"}` +// TestParseDecisionFinalAnswer 测试 is_final_answer=true 时能正确解析 final_answer +func TestParseDecisionFinalAnswer(t *testing.T) { + raw := `{"thought":"直接回答","action":"none","action_input":"","is_final_answer":true,"final_answer":"你好!"}` got, err := parseDecision(raw) if err != nil { t.Fatalf("parseDecision error: %v", err) } - if got.Action != "none" || got.Final != "ok" { - t.Fatalf("unexpected decision: %+v", got) + if !got.IsFinalAnswer { + t.Fatal("expected is_final_answer=true") + } + if got.FinalAnswer == nil || *got.FinalAnswer != "你好!" { + t.Fatalf("unexpected final_answer: %v", got.FinalAnswer) } } +// TestParseDecisionToolCall 测试需要调工具时的解析 +func TestParseDecisionToolCall(t *testing.T) { + raw := `{"thought":"需要搜索","action":"web_search","action_input":"NVIDIA stock price","is_final_answer":false,"final_answer":null}` + got, err := parseDecision(raw) + if err != nil { + t.Fatalf("parseDecision error: %v", err) + } + if got.IsFinalAnswer { + t.Fatal("expected is_final_answer=false") + } + if got.Action != "web_search" { + t.Fatalf("expected action=web_search, got %s", got.Action) + } + input := got.GetActionInputString() + if input != "NVIDIA stock price" { + t.Fatalf("expected action_input string, got %q", input) + } +} + +// TestParseDecisionStructuredActionInput 测试 action_input 为结构化对象时的解析 +func TestParseDecisionStructuredActionInput(t *testing.T) { + raw := `{"thought":"搜索","action":"web_search","action_input":{"query":"test","context":"dev"},"is_final_answer":false,"final_answer":null}` + got, err := parseDecision(raw) + if err != nil { + t.Fatalf("parseDecision error: %v", err) + } + input := got.GetActionInputString() + if input == "" { + t.Fatal("expected non-empty action_input") + } + // 结构化对象应保留 JSON 原文 + if input[0] != '{' { + t.Fatalf("expected JSON object string, got %q", input) + } +} + +// TestParseDecisionCodeFence 测试被 markdown code fence 包裹的 JSON func TestParseDecisionCodeFence(t *testing.T) { - raw := "```json\n{\"thought\":\"t\",\"action\":\"shell\",\"action_input\":\"ls\",\"final\":\"\"}\n```" + raw := "```json\n{\"thought\":\"t\",\"action\":\"shell\",\"action_input\":\"ls\",\"is_final_answer\":false,\"final_answer\":null}\n```" got, err := parseDecision(raw) if err != nil { t.Fatalf("parseDecision error: %v", err) } - if got.Action != "shell" || got.ActionInput != "ls" { - t.Fatalf("unexpected decision: %+v", got) + if got.Action != "shell" { + t.Fatalf("unexpected action: %s", got.Action) } } +// TestParseDecisionInvalid 测试非 JSON 输入时返回错误 func TestParseDecisionInvalid(t *testing.T) { _, err := parseDecision("not json") if err == nil { diff --git a/internal/config/config.go b/internal/config/config.go index f31ede6..b2afca0 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -11,22 +11,23 @@ import ( ) type Config struct { - MessageChannel string - LogLevel string - SoulPath string - SkillsDir string - ReactMaxSteps int - ToolCallTimeoutSec int - ToolOutputMaxChars int - EnableCapabilityGap bool - AutoSkillDir string - GapDraftTriggerCount int - GapClusterLookbackHours int + MessageChannel string + LogLevel string + SoulPath string + SkillsDir string + ReactMaxSteps int + ToolCallTimeoutSec int + ToolOutputMaxChars int + EnableCapabilityGap bool + AutoSkillDir string + GapDraftTriggerCount int + GapClusterLookbackHours int - Telegram TelegramConfig - Feishu FeishuConfig - LLM LLMConfig - Security SecurityConfig + Telegram TelegramConfig + Feishu FeishuConfig + LLM LLMConfig + Security SecurityConfig + WebSearch WebSearchConfig SQLitePath string } @@ -56,6 +57,11 @@ type SecurityConfig struct { WorkDir string } +type WebSearchConfig struct { + Engine string // "duckduckgo" or "brave" + APIKey string +} + func Load() (Config, error) { agentWorkspaceDir := resolveAgentWorkspaceDir() if err := preloadEnvFiles(); err != nil { @@ -65,14 +71,14 @@ func Load() (Config, error) { defaultDataDir := filepath.Join(agentWorkspaceDir, "data") cfg := Config{ - MessageChannel: defaultIfEmpty(os.Getenv("MESSAGE_CHANNEL"), "telegram"), - LogLevel: defaultIfEmpty(os.Getenv("LOG_LEVEL"), "info"), - SoulPath: defaultIfEmpty(os.Getenv("SOUL_PATH"), filepath.Join(agentWorkspaceDir, "bot_context", "soul.md")), - SkillsDir: defaultIfEmpty(os.Getenv("SKILLS_DIR"), filepath.Join(agentWorkspaceDir, "skills")), - ReactMaxSteps: intFromEnv("REACT_MAX_STEPS", 0), - ToolCallTimeoutSec: intFromEnv("TOOL_CALL_TIMEOUT_SEC", 15), - ToolOutputMaxChars: intFromEnv("TOOL_OUTPUT_MAX_CHARS", 4000), - EnableCapabilityGap: boolFromEnv("ENABLE_CAPABILITY_GAP", true), + MessageChannel: defaultIfEmpty(os.Getenv("MESSAGE_CHANNEL"), "telegram"), + LogLevel: defaultIfEmpty(os.Getenv("LOG_LEVEL"), "info"), + SoulPath: defaultIfEmpty(os.Getenv("SOUL_PATH"), filepath.Join(agentWorkspaceDir, "bot_context", "soul.md")), + SkillsDir: defaultIfEmpty(os.Getenv("SKILLS_DIR"), filepath.Join(agentWorkspaceDir, "skills")), + ReactMaxSteps: intFromEnv("REACT_MAX_STEPS", 0), + ToolCallTimeoutSec: intFromEnv("TOOL_CALL_TIMEOUT_SEC", 15), + ToolOutputMaxChars: intFromEnv("TOOL_OUTPUT_MAX_CHARS", 4000), + EnableCapabilityGap: boolFromEnv("ENABLE_CAPABILITY_GAP", true), AutoSkillDir: defaultIfEmpty(os.Getenv("AUTO_SKILL_DIR"), filepath.Join(agentWorkspaceDir, "skills")), GapDraftTriggerCount: intFromEnv("GAP_DRAFT_TRIGGER_COUNT", 3), GapClusterLookbackHours: intFromEnv("GAP_CLUSTER_LOOKBACK_HOURS", 168), @@ -93,6 +99,10 @@ func Load() (Config, error) { Model: defaultIfEmpty(os.Getenv("LLM_MODEL"), "gpt-4o-mini"), }, SQLitePath: defaultIfEmpty(os.Getenv("SQLITE_PATH"), filepath.Join(defaultDataDir, "laodingbot.db")), + WebSearch: WebSearchConfig{ + Engine: defaultIfEmpty(os.Getenv("WEB_SEARCH_ENGINE"), "duckduckgo"), + APIKey: strings.TrimSpace(os.Getenv("WEB_SEARCH_API_KEY")), + }, Security: SecurityConfig{ AllowedDirs: splitCSV(defaultIfEmpty(os.Getenv("ALLOWED_DIRS"), strings.Join([]string{agentWorkspaceDir, defaultDataDir, defaultWorkSubdir}, ","))), AllowedCommands: splitCSV(defaultIfEmpty(os.Getenv("ALLOWED_COMMANDS"), "pwd,ls,cat,echo,grep,find,head,tail,go")), diff --git a/internal/knowledge/drafts.go b/internal/knowledge/drafts.go index 44654ed..48dc261 100644 --- a/internal/knowledge/drafts.go +++ b/internal/knowledge/drafts.go @@ -74,10 +74,10 @@ cluster_count: %d - 现有技能未命中,或命中后无法完成。 ## 建议工具 -- 优先使用现有工具:`+"`shell`"+`、`+"`file`"+`。 +- 优先使用现有工具:`+"`shell`"+`、`+"`file`"+`、`+"`web_search`"+`。 - 若能力不足,需要创建新工具时: - 1. 在 `+"`internal/tools//`"+` 下生成 Go 代码; - 2. 在 `+"`cmd/bot/main.go`"+` 或 toolhost 注册逻辑中完成注册; + 1. 在 `+"`tools//`"+` 下生成 Go 代码(实现 Name/Description/Call 接口); + 2. 在 `+"`internal/toolhost/runtime.go`"+` 中注册新工具; 3. 生成/补充 `+"`*_test.go`"+`; 4. 调用 `+"`go test ./...`"+` 验证。 @@ -87,7 +87,7 @@ cluster_count: %d 3. 工具调用前先最小化探测范围。 4. 工具失败时输出原因与下一步建议。 5. 若缺少 skill:使用 `+"`file`"+` 与 `+"`shell`"+` 创建新的 `+"`skills//skill.md`"+`。 -6. 若缺少 tool:生成工具代码与测试后执行 `+"`go test ./...`"+`。 +6. 若缺少 tool:在 `+"`tools//`"+` 下生成工具代码与测试后执行 `+"`go test ./...`"+`。 ## 输出规范 - 结论:一句话给出当前阶段结论。 diff --git a/internal/knowledge/loader.go b/internal/knowledge/loader.go index 9139287..98d286b 100644 --- a/internal/knowledge/loader.go +++ b/internal/knowledge/loader.go @@ -14,6 +14,21 @@ type Skill struct { Source string } +type SkillSummary struct { + DirName string + Name string + Description string + Source string +} + +type scannedSkill struct { + dirName string + name string + description string + content string + source string +} + func LoadSoul(path string) (string, error) { b, err := os.ReadFile(path) if err != nil { @@ -27,6 +42,39 @@ func LoadSoul(path string) (string, error) { } func LoadSkillSet(dir string) ([]Skill, error) { + scanned, err := scanSkills(dir) + if err != nil { + return nil, err + } + out := make([]Skill, 0, len(scanned)) + for _, s := range scanned { + out = append(out, Skill{ + Name: s.name, + Content: s.content, + Source: s.source, + }) + } + return out, nil +} + +func LoadSkillSummaries(dir string) ([]SkillSummary, error) { + scanned, err := scanSkills(dir) + if err != nil { + return nil, err + } + out := make([]SkillSummary, 0, len(scanned)) + for _, s := range scanned { + out = append(out, SkillSummary{ + DirName: s.dirName, + Name: s.name, + Description: s.description, + Source: s.source, + }) + } + return out, nil +} + +func scanSkills(dir string) ([]scannedSkill, error) { entries, err := os.ReadDir(dir) if err != nil { return nil, fmt.Errorf("read skills dir failed: %w", err) @@ -40,7 +88,7 @@ func LoadSkillSet(dir string) ([]Skill, error) { } sort.Strings(skillDirs) - out := make([]Skill, 0, len(skillDirs)) + out := make([]scannedSkill, 0, len(skillDirs)) for _, skillDir := range skillDirs { file := filepath.Join(dir, skillDir, "skill.md") b, err := os.ReadFile(file) @@ -54,12 +102,16 @@ func LoadSkillSet(dir string) ([]Skill, error) { if content == "" { continue } - - name := extractSkillName(skillDir, content) - out = append(out, Skill{ - Name: name, - Content: content, - Source: file, + name, description := parseSkillNameDescription(skillDir, content) + if strings.TrimSpace(name) == "" { + continue + } + out = append(out, scannedSkill{ + dirName: skillDir, + name: name, + description: description, + content: content, + source: file, }) } @@ -88,3 +140,74 @@ func extractSkillName(fileName, markdown string) string { } return base } + +func parseSkillNameDescription(fileName, markdown string) (string, string) { + name := "" + description := "" + fm := parseFrontMatter(markdown) + if v, ok := fm["name"]; ok { + name = strings.TrimSpace(v) + } + if v, ok := fm["description"]; ok { + description = strings.TrimSpace(v) + } + if name == "" { + name = extractSkillName(fileName, markdown) + } + if description == "" { + description = extractSkillDescription(markdown) + } + return name, description +} + +func parseFrontMatter(markdown string) map[string]string { + lines := strings.Split(markdown, "\n") + out := map[string]string{} + if len(lines) < 3 { + return out + } + if strings.TrimSpace(lines[0]) != "---" { + return out + } + for i := 1; i < len(lines); i++ { + line := strings.TrimSpace(lines[i]) + if line == "---" { + break + } + idx := strings.Index(line, ":") + if idx <= 0 { + continue + } + k := strings.ToLower(strings.TrimSpace(line[:idx])) + v := strings.TrimSpace(line[idx+1:]) + v = strings.Trim(v, "\"'") + if k != "" && v != "" { + out[k] = v + } + } + return out +} + +func extractSkillDescription(markdown string) string { + lines := strings.Split(markdown, "\n") + for _, raw := range lines { + line := strings.TrimSpace(raw) + if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, "---") { + continue + } + if strings.Contains(line, ":") { + parts := strings.SplitN(line, ":", 2) + if len(parts) == 2 { + left := strings.ToLower(strings.TrimSpace(parts[0])) + if left == "name" || left == "description" || left == "source" || left == "generated_at" { + continue + } + } + } + if len(line) > 200 { + line = line[:200] + } + return line + } + return "" +} diff --git a/internal/memory/types.go b/internal/memory/types.go index f0c62d6..453a4e5 100644 --- a/internal/memory/types.go +++ b/internal/memory/types.go @@ -3,11 +3,11 @@ package memory import "time" type CapabilityGap struct { - ID int64 - ChatID string - UserID string - Intent string - Reason string + ID int64 + ChatID string + UserID string + Intent string + Reason string CreatedAt time.Time } diff --git a/internal/toolhost/client.go b/internal/toolhost/client.go index 9f71b71..c609a8d 100644 --- a/internal/toolhost/client.go +++ b/internal/toolhost/client.go @@ -29,11 +29,11 @@ type Client struct { cfg ClientConfig log *logger.Logger - cmd *exec.Cmd - stdin io.WriteCloser - stdout io.ReadCloser - decoder *json.Decoder - encoder *json.Encoder + cmd *exec.Cmd + stdin io.WriteCloser + stdout io.ReadCloser + decoder *json.Decoder + encoder *json.Encoder seq int64 diff --git a/internal/toolhost/runtime.go b/internal/toolhost/runtime.go index efe148e..a1c765c 100644 --- a/internal/toolhost/runtime.go +++ b/internal/toolhost/runtime.go @@ -8,31 +8,42 @@ import ( "laodingbot/internal/config" "laodingbot/internal/logger" "laodingbot/internal/tools" - "laodingbot/internal/tools/filetool" - "laodingbot/internal/tools/shelltool" + "laodingbot/tools/fileoperation" + "laodingbot/tools/shell" + "laodingbot/tools/websearch" ) func RunChild(ctx context.Context, cfg config.Config, log *logger.Logger) error { var registryLog *logger.Logger var fileLog *logger.Logger var shellLog *logger.Logger + var searchLog *logger.Logger var serverLog *logger.Logger if log != nil { log.Infof("toolhost child starting") registryLog = log.WithComponent("toolhost.registry") fileLog = log.WithComponent("toolhost.file") shellLog = log.WithComponent("toolhost.shell") + searchLog = log.WithComponent("toolhost.websearch") serverLog = log.WithComponent("toolhost.server") } registry := tools.NewRegistry(registryLog) - registry.Register(filetool.New(cfg.Security.AllowedDirs, cfg.ToolOutputMaxChars, fileLog)) - registry.Register(shelltool.New( + registry.Register(fileoperation.New(cfg.Security.AllowedDirs, cfg.ToolOutputMaxChars, fileLog)) + registry.Register(shell.New( cfg.Security.AllowedCommands, cfg.Security.WorkDir, time.Duration(cfg.ToolCallTimeoutSec)*time.Second, cfg.ToolOutputMaxChars, shellLog, )) + registry.Register(websearch.New( + websearch.Config{ + Engine: cfg.WebSearch.Engine, + APIKey: cfg.WebSearch.APIKey, + }, + cfg.ToolOutputMaxChars, + searchLog, + )) server := NewServer(registry, serverLog) if err := server.Serve(ctx, stdin(), stdout()); err != nil && ctx.Err() == nil { diff --git a/internal/tools/shelltool/shelltool_test.go b/internal/tools/shelltool/shelltool_test.go deleted file mode 100644 index 22882b1..0000000 --- a/internal/tools/shelltool/shelltool_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package shelltool - -import ( - "context" - "testing" - "time" -) - -func TestCallRejectsEmptyCommand(t *testing.T) { - tool := New([]string{"echo"}, ".", time.Second, 4000, nil) - _, err := tool.Call(context.Background(), " ") - if err == nil { - t.Fatal("expected error for empty command") - } -} - -func TestCallRejectsNonAllowlistedCommand(t *testing.T) { - tool := New([]string{"echo"}, ".", time.Second, 4000, nil) - _, err := tool.Call(context.Background(), "cat test.txt") - if err == nil { - t.Fatal("expected allowlist rejection") - } -} diff --git a/skills/skill_builder/skill.md b/skills/skill_builder/skill.md index 439422c..2a353db 100644 --- a/skills/skill_builder/skill.md +++ b/skills/skill_builder/skill.md @@ -13,27 +13,29 @@ description: 当用户请求新增能力或系统发现能力缺口时,自动 ## 2. 目标 1. 生成可执行的 `skills//skill.md`。 -2. 若需要新工具,生成 `internal/tools//` 下 Go 代码。 +2. 若需要新工具,生成 `tools//` 下 Go 代码(实现 `Name/Description/Call` 接口)。 3. 生成或补充测试代码并执行 `go test ./...`。 4. 输出结果中说明新增内容、测试结果与后续建议。 ## 3. 可用工具 - `file`:创建目录与文件、写入 skill/tool/test 内容。 - `shell`:执行测试、检索代码位置、检查文件结构。 +- `web_search`:搜索技术方案、API 用法等参考信息。 ## 4. 执行流程 1. **澄清能力边界**:提炼该 skill 要解决的问题与触发信号。 2. **命名与路径规划**: - 技能路径:`skills//skill.md` - - 工具路径(如需):`internal/tools//...` + - 工具路径(如需):`tools//` 3. **创建 skill 文件**:写入完整字段(用途、触发、工具、ReAct 指南、失败回退、输出规范)。 4. **判断是否需要新 tool**: - - 若现有 `shell/file` 足够,直接结束。 + - 若现有 `shell/file/web_search` 足够,直接结束。 - 若不够,进入工具生成。 5. **生成 tool 代码(如需)**: - - 实现 `Name/Description/Call`。 + - 在 `tools//` 下创建 Go 文件。 + - 实现 `Name()/Description()/Call()` 接口(参考 `laodingbot/internal/tools.Tool`)。 + - 在 `internal/toolhost/runtime.go` 中注册新工具。 - 保持白名单与安全边界。 - - 在主注册逻辑或 toolhost 注册逻辑中接入。 6. **生成测试并执行**: - 补充 `*_test.go`。 - 执行 `go test ./...`。 @@ -47,6 +49,6 @@ description: 当用户请求新增能力或系统发现能力缺口时,自动 ## 6. 输出模板 - 新增技能:`skills//skill.md` -- 新增工具(可选):`internal/tools//...` +- 新增工具(可选):`tools//...` - 测试结果:`go test ./...` 的通过/失败摘要 -- 后续动作:是否需要热加载、是否需要补充环境变量 +- 后续动作:是否需要热加载(`/reload_skills`)、是否需要补充环境变量 diff --git a/internal/tools/filetool/filetool.go b/tools/fileoperation/fileoperation.go similarity index 77% rename from internal/tools/filetool/filetool.go rename to tools/fileoperation/fileoperation.go index 0ce33dc..20703c3 100644 --- a/internal/tools/filetool/filetool.go +++ b/tools/fileoperation/fileoperation.go @@ -1,4 +1,4 @@ -package filetool +package fileoperation import ( "context" @@ -10,12 +10,20 @@ import ( "laodingbot/internal/logger" ) +// Tool 提供基于白名单目录的安全文件操作集合(读取、列出、写入)。 type Tool struct { + // allowedDirs 允许操作的目录路径白名单(绝对路径列表)。 allowedDirs []string + // maxOutputChars 文件内容的输出长度上限。 maxOutputChars int - log *logger.Logger + // log 日志记录组件。 + log *logger.Logger } +// New 生成一个文件操作工具的实例。 +// allowedDirs: 安全校验时需要用到的许可目录列表,不在该列列表的路径将抛出无权限错误。 +// maxOutputChars: 最大文件返回长度限制。 +// log: 系统日志指针。 func New(allowedDirs []string, maxOutputChars int, log *logger.Logger) *Tool { normalized := make([]string, 0, len(allowedDirs)) for _, dir := range allowedDirs { @@ -33,12 +41,18 @@ func New(allowedDirs []string, maxOutputChars int, log *logger.Logger) *Tool { return &Tool{allowedDirs: normalized, maxOutputChars: maxOutputChars, log: log} } +// Name 对外声明此工具注册的内部名称。 func (t *Tool) Name() string { return "file" } +// Description 定义了本工具支持的具体功能和入参语法规则(read、list、write)。 func (t *Tool) Description() string { return "File operations with command format: read | list | write \\n" } +// Call 处理和路由文件操作请求。 +// ctx: 上下文对象。 +// input: 包含操作指令与路径(可能带内容)的文本(例如 "read /tmp/a.txt")。 +// 解析失败或没有权限将返回错误提示。 func (t *Tool) Call(_ context.Context, input string) (string, error) { input = strings.TrimSpace(input) if t.log != nil { @@ -146,6 +160,10 @@ func (t *Tool) Call(_ context.Context, input string) (string, error) { if t.log != nil { t.log.Infof("file write success path=%s bytes=%d", resolved, len(parts[1])) } + // resolveAllowed 校验输入的文件路径是否处于允许白名单中。 + // 如果路径是相对的会尝试基于全局环境变量或者当前目录转为绝对路径后进行安全校验匹配。 + // path: 待验证或补全的文件/目录路径。 + // 返回清洗后的绝对路径。如果不在白名单范围内将返回安全错误。 return "ok", nil } diff --git a/internal/tools/filetool/filetool_test.go b/tools/fileoperation/fileoperation_test.go similarity index 98% rename from internal/tools/filetool/filetool_test.go rename to tools/fileoperation/fileoperation_test.go index c9e736d..d36280e 100644 --- a/internal/tools/filetool/filetool_test.go +++ b/tools/fileoperation/fileoperation_test.go @@ -1,4 +1,4 @@ -package filetool +package fileoperation import ( "context" diff --git a/internal/tools/shelltool/shelltool.go b/tools/shell/shell.go similarity index 51% rename from internal/tools/shelltool/shelltool.go rename to tools/shell/shell.go index 9a0a07d..cabc63f 100644 --- a/internal/tools/shelltool/shelltool.go +++ b/tools/shell/shell.go @@ -1,4 +1,4 @@ -package shelltool +package shell import ( "context" @@ -13,13 +13,25 @@ import ( ) type Tool struct { + // allowedCommands 允许执行的shell命令集合,作为白名单使用。 allowedCommands map[string]struct{} - workDir string - timeout time.Duration - maxOutputChars int - log *logger.Logger + // workDir shell命令执行的工作目录。 + workDir string + // timeout 单个shell命令执行的超时时间,防止长时间阻塞。 + timeout time.Duration + // maxOutputChars 最大输出字符数限制,避免输出过长导致内存或上下文溢出。 + maxOutputChars int + // log 用于记录shell工具操作和执行详情的日志实例。 + log *logger.Logger } +// New 创建一个新的 shell 工具实例。 +// allowed: 允许被执行的命令白名单(例如 "ls", "echo" 等)。 +// workDir: 命令执行的基础工作目录。 +// timeout: 命令执行的最大超时时间。 +// maxOutputChars: 命令执行结果允许返回的最大字符数。 +// log: 日志记录器。 +// 返回初始化的 Tool 实例指针。 func New(allowed []string, workDir string, timeout time.Duration, maxOutputChars int, log *logger.Logger) *Tool { set := make(map[string]struct{}, len(allowed)) for _, c := range allowed { @@ -44,12 +56,18 @@ func New(allowed []string, workDir string, timeout time.Duration, maxOutputChars return &Tool{allowedCommands: set, workDir: absDir, timeout: timeout, maxOutputChars: maxOutputChars, log: log} } +// Name 返回此工具的名称。 func (t *Tool) Name() string { return "shell" } +// Description 返回此工具的功能描述。 func (t *Tool) Description() string { - return "Execute allowlisted shell commands in Linux" + return "Execute shell commands (Windows uses cmd /C; for current time prefer: echo %DATE% %TIME%)" } +// Call 执行指定的底层 shell 命令。 +// ctx: 用于控制执行过程的上下文。 +// input: 包含要执行的完整命令字符串。 +// 当前临时策略:允许执行任意命令(不做 allows 白名单拦截),并在执行完毕后返回输出。 func (t *Tool) Call(ctx context.Context, input string) (string, error) { trimmed := strings.TrimSpace(input) if trimmed == "" { @@ -59,14 +77,12 @@ func (t *Tool) Call(ctx context.Context, input string) (string, error) { return "", fmt.Errorf("empty command") } + if runtime.GOOS == "windows" { + trimmed = normalizeWindowsCommand(trimmed) + } + parts := strings.Fields(trimmed) base := parts[0] - if _, ok := t.allowedCommands[base]; !ok { - if t.log != nil { - t.log.Warnf("shell command denied command=%s full_command=%q", base, trimmed) - } - return "", fmt.Errorf("command not allowed: %s", base) - } if t.log != nil { t.log.Infof("shell command start command=%s args=%d full_command=%q", base, len(parts)-1, trimmed) } @@ -74,7 +90,13 @@ func (t *Tool) Call(ctx context.Context, input string) (string, error) { runCtx, cancel := context.WithTimeout(ctx, t.timeout) defer cancel() - cmd := exec.CommandContext(runCtx, base, parts[1:]...) + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + // Windows 下使用 cmd /C 执行,兼容 date、dir 等内建命令。 + cmd = exec.CommandContext(runCtx, "cmd", "/C", trimmed) + } else { + cmd = exec.CommandContext(runCtx, base, parts[1:]...) + } cmd.Dir = t.workDir out, err := cmd.CombinedOutput() outText := string(out) @@ -85,9 +107,6 @@ func (t *Tool) Call(ctx context.Context, input string) (string, error) { if t.log != nil { t.log.Errorf("shell command failed command=%s full_command=%q err=%v output_bytes=%d output=%q", base, trimmed, err, len(out), outText) } - if runtime.GOOS == "windows" && strings.Contains(strings.ToLower(err.Error()), "executable file not found") { - return outText, fmt.Errorf("command not executable in current windows environment: %s", base) - } return outText, err } if t.log != nil { @@ -95,3 +114,15 @@ func (t *Tool) Call(ctx context.Context, input string) (string, error) { } return outText, nil } + +func normalizeWindowsCommand(command string) string { + cmd := strings.TrimSpace(strings.ToLower(command)) + switch cmd { + case "date", "date /t": + return "echo %DATE% %TIME%" + case "time", "time /t": + return "echo %DATE% %TIME%" + default: + return command + } +} diff --git a/tools/shell/shell_test.go b/tools/shell/shell_test.go new file mode 100644 index 0000000..72afaeb --- /dev/null +++ b/tools/shell/shell_test.go @@ -0,0 +1,42 @@ +package shell + +import ( + "context" + "runtime" + "strings" + "testing" + "time" +) + +func TestCallRejectsEmptyCommand(t *testing.T) { + tool := New([]string{"echo"}, ".", time.Second, 4000, nil) + _, err := tool.Call(context.Background(), " ") + if err == nil { + t.Fatal("expected error for empty command") + } +} + +func TestCallAllowsNonAllowlistedCommand(t *testing.T) { + tool := New([]string{"echo"}, ".", time.Second, 4000, nil) + out, err := tool.Call(context.Background(), "go version") + if err != nil { + t.Fatalf("expected command to run without allowlist restriction, got err=%v", err) + } + if out == "" { + t.Fatal("expected non-empty output") + } +} + +func TestCallWindowsDateIsNonInteractive(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("windows-only test") + } + tool := New(nil, ".", 3*time.Second, 4000, nil) + out, err := tool.Call(context.Background(), "date") + if err != nil { + t.Fatalf("expected bare date command to succeed on windows, got err=%v output=%q", err, out) + } + if strings.TrimSpace(out) == "" { + t.Fatal("expected non-empty output for date command") + } +} diff --git a/tools/websearch/websearch.go b/tools/websearch/websearch.go new file mode 100644 index 0000000..08de622 --- /dev/null +++ b/tools/websearch/websearch.go @@ -0,0 +1,288 @@ +package websearch + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "laodingbot/internal/logger" +) + +// Config 定义了网络搜索工具所需的配置参数。 +type Config struct { + Engine string // 搜索引擎类型,支持 "duckduckgo" 或 "brave" + APIKey string // 搜索引擎的 API Key(Brave 搜索必填) +} + +// Tool represents a web search tool. +// Tool 定义了一个网络搜索工具的结构,用于执行互联网检索并获取摘要。 +type Tool struct { + // engine 当前使用的搜索引擎标识。 + engine string + // apiKey 执行搜索时需要的认证 Key。 + apiKey string + // httpClient 发送 HTTP 请求所使用的客户端。 + httpClient *http.Client + // maxOutputChars 返回搜索结果的最大字符数限制。 + maxOutputChars int + // log 日志记录器,跟踪搜索请求与执行状态。 + log *logger.Logger +} + +// New 初始化并返回一个新的 websearch 工具实例。 +// cfg: 网络搜索工具的相关配置。 +// maxOutputChars: 规范化结果文本截断的最大长度。 +// log: 外部传入的日志记录组件。 +func New(cfg Config, maxOutputChars int, log *logger.Logger) *Tool { + engine := strings.TrimSpace(cfg.Engine) + if engine == "" { + engine = "duckduckgo" + } + if maxOutputChars <= 0 { + maxOutputChars = 4000 + } + if log != nil { + log.Infof("websearch tool initialized engine=%s max_output_chars=%d", engine, maxOutputChars) + } + return &Tool{ + engine: engine, + apiKey: strings.TrimSpace(cfg.APIKey), + httpClient: &http.Client{Timeout: 15 * time.Second}, + maxOutputChars: maxOutputChars, + log: log, + } +} + +// Name 返回此工具的名称定义,供模型调用时识别。 +func (t *Tool) Name() string { return "web_search" } + +// Description 描述此工具的作用及入参、出参格式。 +func (t *Tool) Description() string { + return "Search the web. Input: search query string. Returns formatted search results." +} + +// Call 执行具体的搜索动作。 +// ctx: 带有超时/取消机制的上下文。 +// input: 用户的搜索查询词。 +// 成功时返回搜索到的格式化文本结果(受最大字符数限制)。 +func (t *Tool) Call(ctx context.Context, input string) (string, error) { + query := strings.TrimSpace(input) + if query == "" { + return "", fmt.Errorf("empty search query") + } + if t.log != nil { + t.log.Infof("websearch query=%q engine=%s", query, t.engine) + } + + var result string + var err error + + switch t.engine { + case "brave": + result, err = t.searchBrave(ctx, query) + default: + result, err = t.searchDuckDuckGo(ctx, query) + } + if err != nil { + if t.log != nil { + t.log.Errorf("websearch failed query=%q engine=%s err=%v", query, t.engine, err) + } + return "", err + } + + if len(result) > t.maxOutputChars { + result = result[:t.maxOutputChars] + } + if t.log != nil { + t.log.Infof("websearch success query=%q engine=%s result_len=%d", query, t.engine, len(result)) + } + return result, nil +} + +// searchDuckDuckGo uses the DuckDuckGo Instant Answer API (no API key required). +// 使用无 key 的 DuckDuckGo 搜索即时解答抽象内容接口。 +func (t *Tool) searchDuckDuckGo(ctx context.Context, query string) (string, error) { + apiURL := "https://api.duckduckgo.com/?q=" + url.QueryEscape(query) + "&format=json&no_html=1&skip_disambig=1" + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil) + if err != nil { + return "", fmt.Errorf("create request failed: %w", err) + } + req.Header.Set("User-Agent", "LaodingBot/1.0") + + resp, err := t.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("http request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 256*1024)) + if err != nil { + return "", fmt.Errorf("read response body failed: %w", err) + } + + var ddg duckDuckGoResponse + if err := json.Unmarshal(body, &ddg); err != nil { + return "", fmt.Errorf("parse duckduckgo response failed: %w", err) + } + + return t.formatDuckDuckGoResult(query, ddg), nil +} + +// duckDuckGoResponse 从 DuckDuckGo 获取的即时结果 JSON 映射结构。 +type duckDuckGoResponse struct { + Abstract string `json:"Abstract"` + AbstractText string `json:"AbstractText"` + AbstractSource string `json:"AbstractSource"` + AbstractURL string `json:"AbstractURL"` + Answer string `json:"Answer"` + AnswerType string `json:"AnswerType"` + Heading string `json:"Heading"` + RelatedTopics []ddgRelatedItem `json:"RelatedTopics"` +} + +// ddgRelatedItem 代表相关的搜索条目/话题。 +type ddgRelatedItem struct { + Text string `json:"Text"` + FirstURL string `json:"FirstURL"` +} + +// formatDuckDuckGoResult 将 DuckDuckGo 提供的结果结构打包为纯文本格式化输出,便于传递给下一个节点。 +func (t *Tool) formatDuckDuckGoResult(query string, ddg duckDuckGoResponse) string { + b := strings.Builder{} + b.WriteString("Search: " + query + "\n") + b.WriteString("Engine: DuckDuckGo\n\n") + + hasContent := false + + if ddg.Answer != "" { + b.WriteString("Answer: " + ddg.Answer + "\n\n") + hasContent = true + } + if ddg.AbstractText != "" { + b.WriteString("Summary: " + ddg.AbstractText + "\n") + if ddg.AbstractSource != "" { + b.WriteString("Source: " + ddg.AbstractSource + "\n") + } + if ddg.AbstractURL != "" { + b.WriteString("URL: " + ddg.AbstractURL + "\n") + } + b.WriteString("\n") + hasContent = true + } + if len(ddg.RelatedTopics) > 0 { + b.WriteString("Related:\n") + count := 0 + for _, topic := range ddg.RelatedTopics { + if topic.Text == "" { + continue + } + text := topic.Text + if len(text) > 300 { + text = text[:300] + } + b.WriteString(fmt.Sprintf("- %s", text)) + if topic.FirstURL != "" { + b.WriteString(fmt.Sprintf(" (%s)", topic.FirstURL)) + } + b.WriteString("\n") + count++ + if count >= 8 { + break + } + } + hasContent = true + } + + if !hasContent { + b.WriteString("No instant answer available for this query. Try a more specific search or use a different search engine.\n") + } + + return strings.TrimSpace(b.String()) + // 使用 Brave Search API 进行实际的搜索引擎查询获取多条结果(需要订阅 Token)。 +} + +// searchBrave uses the Brave Search API (requires API key). +func (t *Tool) searchBrave(ctx context.Context, query string) (string, error) { + if t.apiKey == "" { + return "", fmt.Errorf("WEB_SEARCH_API_KEY is required for Brave Search engine") + } + + apiURL := "https://api.search.brave.com/res/v1/web/search?q=" + url.QueryEscape(query) + "&count=8" + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, apiURL, nil) + if err != nil { + return "", fmt.Errorf("create request failed: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Encoding", "gzip") + req.Header.Set("X-Subscription-Token", t.apiKey) + + resp, err := t.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("http request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodySnippet, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return "", fmt.Errorf("brave search returned status %d: %s", resp.StatusCode, string(bodySnippet)) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 512*1024)) + if err != nil { + return "", fmt.Errorf("read response body failed: %w", err) + } + + var braveResp braveSearchResponse + if err := json.Unmarshal(body, &braveResp); err != nil { + return "", fmt.Errorf("parse brave response failed: %w", err) + } + + return t.formatBraveResult(query, braveResp), nil +} + +// braveSearchResponse 用于接收 Brave Search Web 层面的基本搜索返回结果。 +type braveSearchResponse struct { + Web struct { + Results []braveWebResult `json:"results"` + } `json:"web"` +} + +// braveWebResult 用于表示单独的网页搜索结果摘要信息。 +type braveWebResult struct { + Title string `json:"title"` + URL string `json:"url"` + Description string `json:"description"` +} + +// formatBraveResult 将接收到底层的 Brave 搜索内容整合成对模型友好的文本视图,截断长字符防干扰。} + +func (t *Tool) formatBraveResult(query string, resp braveSearchResponse) string { + b := strings.Builder{} + b.WriteString("Search: " + query + "\n") + b.WriteString("Engine: Brave\n\n") + + if len(resp.Web.Results) == 0 { + b.WriteString("No results found.\n") + return strings.TrimSpace(b.String()) + } + + for i, r := range resp.Web.Results { + if i >= 8 { + break + } + desc := r.Description + if len(desc) > 300 { + desc = desc[:300] + } + b.WriteString(fmt.Sprintf("%d. %s\n %s\n %s\n\n", i+1, r.Title, r.URL, desc)) + } + + return strings.TrimSpace(b.String()) +} diff --git a/tools/websearch/websearch_test.go b/tools/websearch/websearch_test.go new file mode 100644 index 0000000..22cfa8a --- /dev/null +++ b/tools/websearch/websearch_test.go @@ -0,0 +1,57 @@ +package websearch + +import ( + "testing" +) + +func TestNewDefaultEngine(t *testing.T) { + tool := New(Config{}, 4000, nil) + if tool.Name() != "web_search" { + t.Fatalf("expected name web_search, got %s", tool.Name()) + } + if tool.engine != "duckduckgo" { + t.Fatalf("expected default engine duckduckgo, got %s", tool.engine) + } +} + +func TestNewBraveEngine(t *testing.T) { + tool := New(Config{Engine: "brave", APIKey: "test-key"}, 4000, nil) + if tool.engine != "brave" { + t.Fatalf("expected engine brave, got %s", tool.engine) + } + if tool.apiKey != "test-key" { + t.Fatalf("expected apiKey test-key, got %s", tool.apiKey) + } +} + +func TestCallRejectsEmptyQuery(t *testing.T) { + tool := New(Config{}, 4000, nil) + _, err := tool.Call(nil, " ") + if err == nil { + t.Fatal("expected error for empty query") + } +} + +func TestFormatDuckDuckGoResultWithAnswer(t *testing.T) { + tool := New(Config{}, 4000, nil) + ddg := duckDuckGoResponse{ + Answer: "42", + AbstractText: "The answer to everything.", + } + result := tool.formatDuckDuckGoResult("meaning of life", ddg) + if result == "" { + t.Fatal("expected non-empty result") + } + if len(result) == 0 { + t.Fatal("result should contain content") + } +} + +func TestFormatBraveResultEmpty(t *testing.T) { + tool := New(Config{Engine: "brave"}, 4000, nil) + resp := braveSearchResponse{} + result := tool.formatBraveResult("test", resp) + if result == "" { + t.Fatal("expected non-empty result") + } +}