245 lines
6.9 KiB
Go
245 lines
6.9 KiB
Go
|
|
package agent
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"encoding/json"
|
|||
|
|
"fmt"
|
|||
|
|
"strings"
|
|||
|
|
|
|||
|
|
"laodingbot/internal/llm"
|
|||
|
|
"laodingbot/internal/logger"
|
|||
|
|
"laodingbot/internal/memory"
|
|||
|
|
"laodingbot/internal/tools"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
type Orchestrator struct {
|
|||
|
|
llm llm.Client
|
|||
|
|
store *memory.SQLiteStore
|
|||
|
|
tools *tools.Registry
|
|||
|
|
soul string
|
|||
|
|
skillsDoc string
|
|||
|
|
reactMaxStep int
|
|||
|
|
log *logger.Logger
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func NewOrchestrator(
|
|||
|
|
llmClient llm.Client,
|
|||
|
|
store *memory.SQLiteStore,
|
|||
|
|
registry *tools.Registry,
|
|||
|
|
soul string,
|
|||
|
|
skillsDoc string,
|
|||
|
|
reactMaxStep int,
|
|||
|
|
log *logger.Logger,
|
|||
|
|
) *Orchestrator {
|
|||
|
|
if reactMaxStep <= 0 {
|
|||
|
|
reactMaxStep = 4
|
|||
|
|
}
|
|||
|
|
return &Orchestrator{
|
|||
|
|
llm: llmClient,
|
|||
|
|
store: store,
|
|||
|
|
tools: registry,
|
|||
|
|
soul: soul,
|
|||
|
|
skillsDoc: skillsDoc,
|
|||
|
|
reactMaxStep: reactMaxStep,
|
|||
|
|
log: log,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text string) (string, error) {
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Infof("handle message chat_id=%s user_id=%s text_len=%d", chatID, userID, len(text))
|
|||
|
|
}
|
|||
|
|
if err := o.store.SaveMessage(chatID, userID, "user", text); err != nil {
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Errorf("save user message failed chat_id=%s err=%v", chatID, err)
|
|||
|
|
}
|
|||
|
|
return "", err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if strings.HasPrefix(strings.TrimSpace(text), "/tool ") {
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Debugf("detected tool command chat_id=%s", chatID)
|
|||
|
|
}
|
|||
|
|
response, err := o.handleToolCommand(ctx, strings.TrimSpace(strings.TrimPrefix(text, "/tool ")))
|
|||
|
|
if err != nil {
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Errorf("tool command failed chat_id=%s err=%v", chatID, err)
|
|||
|
|
}
|
|||
|
|
return "", err
|
|||
|
|
}
|
|||
|
|
if err := o.store.SaveMessage(chatID, userID, "assistant", response); err != nil {
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Errorf("save assistant tool response failed chat_id=%s err=%v", chatID, err)
|
|||
|
|
}
|
|||
|
|
return "", err
|
|||
|
|
}
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Infof("tool command success chat_id=%s response_len=%d", chatID, len(response))
|
|||
|
|
}
|
|||
|
|
return response, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
recent, err := o.store.LoadRecent(chatID, 16)
|
|||
|
|
if err != nil {
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Errorf("load recent failed chat_id=%s err=%v", chatID, err)
|
|||
|
|
}
|
|||
|
|
return "", err
|
|||
|
|
}
|
|||
|
|
compressed := memory.CompressForPrompt(recent, 6000)
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Debugf("prompt context prepared chat_id=%s recent_count=%d compressed_len=%d", chatID, len(recent), len(compressed))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
response, err := o.runReAct(ctx, compressed, text)
|
|||
|
|
if err != nil {
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Errorf("llm generate failed chat_id=%s err=%v", chatID, err)
|
|||
|
|
}
|
|||
|
|
return "", err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := o.store.SaveMessage(chatID, userID, "assistant", response); err != nil {
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Errorf("save assistant response failed chat_id=%s err=%v", chatID, err)
|
|||
|
|
}
|
|||
|
|
return "", err
|
|||
|
|
}
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Infof("message handled chat_id=%s response_len=%d", chatID, len(response))
|
|||
|
|
}
|
|||
|
|
return response, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type reactDecision struct {
|
|||
|
|
Thought string `json:"thought"`
|
|||
|
|
Action string `json:"action"`
|
|||
|
|
ActionInput string `json:"action_input"`
|
|||
|
|
Final string `json:"final"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (o *Orchestrator) runReAct(ctx context.Context, compressedContext, userInput string) (string, error) {
|
|||
|
|
systemPrompt := strings.Join([]string{
|
|||
|
|
"你是一个个人自动化助手,必须遵循如下人格设定并保持一致:",
|
|||
|
|
o.soul,
|
|||
|
|
"",
|
|||
|
|
"当前可用 skills 文档:",
|
|||
|
|
o.skillsDoc,
|
|||
|
|
"",
|
|||
|
|
"你必须使用 ReAct 模式做决策。",
|
|||
|
|
"如果问题需要外部信息(如文件系统、目录内容、命令执行),优先通过工具获取证据再回答。",
|
|||
|
|
"当用户询问目录中文件时,应优先使用 shell 工具(例如 ls/find)。",
|
|||
|
|
"你的输出必须是 JSON,对象字段为 thought, action, action_input, final。",
|
|||
|
|
"规则:",
|
|||
|
|
"1) 当需要调工具时:final 置空,action 为 shell 或 file,action_input 为工具输入。",
|
|||
|
|
"2) 当可以最终回答时:action 置 none,action_input 置空,final 填最终回复。",
|
|||
|
|
"3) 不要输出 JSON 之外内容。",
|
|||
|
|
}, "\n")
|
|||
|
|
|
|||
|
|
scratchpad := ""
|
|||
|
|
for step := 1; step <= o.reactMaxStep; step++ {
|
|||
|
|
prompt := strings.Join([]string{
|
|||
|
|
"历史上下文:",
|
|||
|
|
compressedContext,
|
|||
|
|
"",
|
|||
|
|
"用户问题:",
|
|||
|
|
userInput,
|
|||
|
|
"",
|
|||
|
|
"当前推理记录(按时间顺序):",
|
|||
|
|
scratchpad,
|
|||
|
|
"",
|
|||
|
|
fmt.Sprintf("请输出下一步 JSON 决策。当前步骤: %d/%d", step, o.reactMaxStep),
|
|||
|
|
}, "\n")
|
|||
|
|
|
|||
|
|
raw, err := o.llm.Generate(ctx, systemPrompt, prompt)
|
|||
|
|
if err != nil {
|
|||
|
|
return "", err
|
|||
|
|
}
|
|||
|
|
decision, err := parseDecision(raw)
|
|||
|
|
if err != nil {
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Warnf("react parse failed, use raw as final err=%v", err)
|
|||
|
|
}
|
|||
|
|
return strings.TrimSpace(raw), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
action := strings.ToLower(strings.TrimSpace(decision.Action))
|
|||
|
|
if action == "" {
|
|||
|
|
action = "none"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if action == "none" {
|
|||
|
|
finalText := strings.TrimSpace(decision.Final)
|
|||
|
|
if finalText == "" {
|
|||
|
|
finalText = "我已完成思考,但当前没有足够信息给出稳定结论。"
|
|||
|
|
}
|
|||
|
|
return finalText, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
tool, ok := o.tools.Get(action)
|
|||
|
|
if !ok {
|
|||
|
|
scratchpad += fmt.Sprintf("Step %d Thought: %s\nStep %d Observation: tool %s 不存在\n", step, decision.Thought, step, action)
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
toolOut, toolErr := tool.Call(ctx, decision.ActionInput)
|
|||
|
|
obs := strings.TrimSpace(toolOut)
|
|||
|
|
if obs == "" {
|
|||
|
|
obs = "(empty output)"
|
|||
|
|
}
|
|||
|
|
if toolErr != nil {
|
|||
|
|
obs = obs + "\nERROR: " + toolErr.Error()
|
|||
|
|
}
|
|||
|
|
if len(obs) > 2000 {
|
|||
|
|
obs = obs[:2000]
|
|||
|
|
}
|
|||
|
|
scratchpad += fmt.Sprintf("Step %d Thought: %s\nStep %d Action: %s\nStep %d ActionInput: %s\nStep %d Observation: %s\n", step, decision.Thought, step, action, step, decision.ActionInput, step, obs)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return "我尝试了多轮思考与工具调用,但仍未得到稳定结论。请给我更具体的约束或允许我继续尝试。", nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func parseDecision(raw string) (reactDecision, error) {
|
|||
|
|
raw = strings.TrimSpace(raw)
|
|||
|
|
raw = strings.TrimPrefix(raw, "```json")
|
|||
|
|
raw = strings.TrimPrefix(raw, "```")
|
|||
|
|
raw = strings.TrimSuffix(raw, "```")
|
|||
|
|
raw = strings.TrimSpace(raw)
|
|||
|
|
|
|||
|
|
start := strings.Index(raw, "{")
|
|||
|
|
end := strings.LastIndex(raw, "}")
|
|||
|
|
if start < 0 || end < start {
|
|||
|
|
return reactDecision{}, fmt.Errorf("no json object found")
|
|||
|
|
}
|
|||
|
|
raw = raw[start : end+1]
|
|||
|
|
|
|||
|
|
var out reactDecision
|
|||
|
|
if err := json.Unmarshal([]byte(raw), &out); err != nil {
|
|||
|
|
return reactDecision{}, err
|
|||
|
|
}
|
|||
|
|
return out, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (o *Orchestrator) handleToolCommand(ctx context.Context, payload string) (string, error) {
|
|||
|
|
parts := strings.SplitN(payload, " ", 2)
|
|||
|
|
if len(parts) < 2 {
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Warnf("invalid tool command payload=%q", payload)
|
|||
|
|
}
|
|||
|
|
return "", fmt.Errorf("tool command format: /tool <name> <input>")
|
|||
|
|
}
|
|||
|
|
name := strings.TrimSpace(parts[0])
|
|||
|
|
input := parts[1]
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Debugf("dispatch tool name=%s input_len=%d", name, len(input))
|
|||
|
|
}
|
|||
|
|
t, ok := o.tools.Get(name)
|
|||
|
|
if !ok {
|
|||
|
|
if o.log != nil {
|
|||
|
|
o.log.Warnf("unknown tool requested name=%s", name)
|
|||
|
|
}
|
|||
|
|
return "", fmt.Errorf("unknown tool: %s", name)
|
|||
|
|
}
|
|||
|
|
return t.Call(ctx, input)
|
|||
|
|
}
|