feat: add workspace-isolated toolhost runtime and capability-gap skill loop
This commit is contained in:
@@ -5,7 +5,10 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"laodingbot/internal/knowledge"
|
||||
"laodingbot/internal/llm"
|
||||
@@ -20,9 +23,14 @@ type Orchestrator struct {
|
||||
tools *tools.Registry
|
||||
soul string
|
||||
skills []knowledge.Skill
|
||||
skillsDoc string
|
||||
skillsDir string
|
||||
autoSkillDir string
|
||||
gapDraftTriggerCount int
|
||||
gapLookbackDuration time.Duration
|
||||
reactMaxStep int
|
||||
enableCapabilityGap bool
|
||||
log *logger.Logger
|
||||
skillsMu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewOrchestrator(
|
||||
@@ -31,33 +39,66 @@ func NewOrchestrator(
|
||||
registry *tools.Registry,
|
||||
soul string,
|
||||
skills []knowledge.Skill,
|
||||
skillsDoc string,
|
||||
skillsDir string,
|
||||
reactMaxStep int,
|
||||
enableCapabilityGap bool,
|
||||
autoSkillDir string,
|
||||
gapDraftTriggerCount int,
|
||||
gapLookbackDuration time.Duration,
|
||||
log *logger.Logger,
|
||||
) *Orchestrator {
|
||||
if reactMaxStep <= 0 {
|
||||
reactMaxStep = 4
|
||||
}
|
||||
if gapDraftTriggerCount <= 0 {
|
||||
gapDraftTriggerCount = 3
|
||||
}
|
||||
if gapLookbackDuration <= 0 {
|
||||
gapLookbackDuration = 7 * 24 * time.Hour
|
||||
}
|
||||
if strings.TrimSpace(autoSkillDir) == "" {
|
||||
autoSkillDir = skillsDir
|
||||
}
|
||||
return &Orchestrator{
|
||||
llm: llmClient,
|
||||
store: store,
|
||||
tools: registry,
|
||||
soul: soul,
|
||||
skills: skills,
|
||||
skillsDoc: skillsDoc,
|
||||
skillsDir: skillsDir,
|
||||
autoSkillDir: autoSkillDir,
|
||||
gapDraftTriggerCount: gapDraftTriggerCount,
|
||||
gapLookbackDuration: gapLookbackDuration,
|
||||
reactMaxStep: reactMaxStep,
|
||||
enableCapabilityGap: enableCapabilityGap,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text string) (string, error) {
|
||||
traceID := logger.NewTraceID()
|
||||
ctx = logger.WithTraceID(ctx, traceID)
|
||||
traceLogPrefix := "trace_id=" + traceID
|
||||
if o.log != nil {
|
||||
o.log.Infof("handle message chat_id=%s user_id=%s text_len=%d", chatID, userID, len(text))
|
||||
o.log.Debugf("handle message text=%q", text)
|
||||
o.log.Infof("%s handle message chat_id=%s user_id=%s text_len=%d", traceLogPrefix, chatID, userID, len(text))
|
||||
o.log.Debugf("%s handle message text=%q", traceLogPrefix, text)
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(text), "/reload_skills") {
|
||||
if err := o.ReloadSkills(); err != nil {
|
||||
return "技能热加载失败: " + err.Error(), nil
|
||||
}
|
||||
return "技能已热加载完成。", nil
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(text), "/capability_gaps") {
|
||||
report, err := o.BuildCapabilityGapReport(10)
|
||||
if err != nil {
|
||||
return "缺口报告生成失败: " + err.Error(), nil
|
||||
}
|
||||
return report, nil
|
||||
}
|
||||
if err := o.store.SaveMessage(chatID, userID, "user", text); err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("save user message failed chat_id=%s err=%v", chatID, err)
|
||||
o.log.Errorf("%s save user message failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
@@ -65,50 +106,59 @@ func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text s
|
||||
recent, err := o.store.LoadRecent(chatID, 16)
|
||||
if err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("load recent failed chat_id=%s err=%v", chatID, err)
|
||||
o.log.Errorf("%s load recent failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
compressed := memory.CompressForPrompt(recent, 6000)
|
||||
if o.log != nil {
|
||||
o.log.Debugf("prompt context prepared chat_id=%s recent_count=%d compressed_len=%d", chatID, len(recent), len(compressed))
|
||||
o.log.Debugf("%s prompt context prepared chat_id=%s recent_count=%d compressed_len=%d", traceLogPrefix, chatID, len(recent), len(compressed))
|
||||
}
|
||||
|
||||
matchedSkills := o.matchSkills(ctx, compressed, text)
|
||||
if len(matchedSkills) == 0 {
|
||||
if bootstrap, ok := o.findSkillByKeyword("创建skill", "skill builder", "skill 创建", "构建技能"); ok {
|
||||
matchedSkills = []knowledge.Skill{bootstrap}
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s fallback bootstrap skill selected name=%s", traceLogPrefix, bootstrap.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var response string
|
||||
if len(matchedSkills) == 0 {
|
||||
if o.log != nil {
|
||||
o.log.Infof("no skill matched; use direct llm chat_id=%s", chatID)
|
||||
o.log.Infof("%s no skill matched; use direct llm chat_id=%s", traceLogPrefix, chatID)
|
||||
}
|
||||
o.emitCapabilityGap(chatID, userID, text, "no_skill_matched")
|
||||
response, err = o.runDirectLLM(ctx, compressed, text)
|
||||
} else {
|
||||
if o.log != nil {
|
||||
names := make([]string, 0, len(matchedSkills))
|
||||
for _, s := range matchedSkills {
|
||||
names = append(names, s.Name)
|
||||
o.log.Infof("skill selected name=%s source=%s", s.Name, s.Source)
|
||||
o.log.Debugf("skill selected content name=%s content=%q", s.Name, s.Content)
|
||||
o.log.Infof("%s skill selected name=%s source=%s", traceLogPrefix, s.Name, s.Source)
|
||||
o.log.Debugf("%s skill selected content name=%s content=%q", traceLogPrefix, s.Name, s.Content)
|
||||
}
|
||||
o.log.Infof("skills matched chat_id=%s skills=%s", chatID, strings.Join(names, ","))
|
||||
o.log.Infof("%s skills matched chat_id=%s skills=%s", traceLogPrefix, chatID, strings.Join(names, ","))
|
||||
}
|
||||
response, err = o.runReAct(ctx, compressed, text, matchedSkills)
|
||||
response, err = o.runReAct(ctx, chatID, userID, compressed, text, matchedSkills)
|
||||
}
|
||||
if err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("message generation failed chat_id=%s err=%v", chatID, err)
|
||||
o.log.Errorf("%s message generation failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := o.store.SaveMessage(chatID, userID, "assistant", response); err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("save assistant response failed chat_id=%s err=%v", chatID, err)
|
||||
o.log.Errorf("%s save assistant response failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if o.log != nil {
|
||||
o.log.Infof("message handled chat_id=%s response_len=%d", chatID, len(response))
|
||||
o.log.Infof("%s message handled chat_id=%s response_len=%d", traceLogPrefix, chatID, len(response))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
@@ -140,7 +190,9 @@ type reactDecision struct {
|
||||
Final string `json:"final"`
|
||||
}
|
||||
|
||||
func (o *Orchestrator) runReAct(ctx context.Context, compressedContext, userInput string, selectedSkills []knowledge.Skill) (string, error) {
|
||||
func (o *Orchestrator) runReAct(ctx context.Context, chatID, userID, compressedContext, userInput string, selectedSkills []knowledge.Skill) (string, error) {
|
||||
traceID := logger.TraceIDFromContext(ctx)
|
||||
traceLogPrefix := "trace_id=" + traceID
|
||||
selectedSkillsDoc := formatSkills(selectedSkills)
|
||||
toolDoc := o.formatToolDoc()
|
||||
if o.log != nil {
|
||||
@@ -148,9 +200,9 @@ func (o *Orchestrator) runReAct(ctx context.Context, compressedContext, userInpu
|
||||
for _, s := range selectedSkills {
|
||||
names = append(names, s.Name)
|
||||
}
|
||||
o.log.Infof("react start steps=%d skills=%s", o.reactMaxStep, strings.Join(names, ","))
|
||||
o.log.Debugf("react selected_skills_doc=%q", selectedSkillsDoc)
|
||||
o.log.Debugf("react tools_doc=%q", toolDoc)
|
||||
o.log.Infof("%s react start steps=%d skills=%s", traceLogPrefix, o.reactMaxStep, strings.Join(names, ","))
|
||||
o.log.Debugf("%s react selected_skills_doc=%q", traceLogPrefix, selectedSkillsDoc)
|
||||
o.log.Debugf("%s react tools_doc=%q", traceLogPrefix, toolDoc)
|
||||
}
|
||||
|
||||
systemPrompt := strings.Join([]string{
|
||||
@@ -176,8 +228,8 @@ func (o *Orchestrator) runReAct(ctx context.Context, compressedContext, userInpu
|
||||
scratchpad := ""
|
||||
for step := 1; step <= o.reactMaxStep; step++ {
|
||||
if o.log != nil {
|
||||
o.log.Infof("react step start step=%d/%d", step, o.reactMaxStep)
|
||||
o.log.Debugf("react scratchpad_before step=%d content=%q", step, scratchpad)
|
||||
o.log.Infof("%s react step start step=%d/%d", traceLogPrefix, step, o.reactMaxStep)
|
||||
o.log.Debugf("%s react scratchpad_before step=%d content=%q", traceLogPrefix, step, scratchpad)
|
||||
}
|
||||
prompt := strings.Join([]string{
|
||||
"历史上下文:",
|
||||
@@ -197,17 +249,18 @@ func (o *Orchestrator) runReAct(ctx context.Context, compressedContext, userInpu
|
||||
return "", err
|
||||
}
|
||||
if o.log != nil {
|
||||
o.log.Infof("react step llm output step=%d raw=%q", step, raw)
|
||||
o.log.Infof("%s react step llm output step=%d raw=%q", traceLogPrefix, step, raw)
|
||||
}
|
||||
decision, err := parseDecision(raw)
|
||||
if err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("react parse failed, use raw as final err=%v", err)
|
||||
o.log.Warnf("%s react parse failed, fallback to direct llm err=%v", traceLogPrefix, err)
|
||||
}
|
||||
return strings.TrimSpace(raw), nil
|
||||
o.emitCapabilityGap(chatID, userID, userInput, "react_parse_failed")
|
||||
return o.runDirectLLM(ctx, compressedContext, userInput)
|
||||
}
|
||||
if o.log != nil {
|
||||
o.log.Infof("react step decision step=%d thought=%q action=%q action_input=%q final=%q", step, decision.Thought, decision.Action, decision.ActionInput, decision.Final)
|
||||
o.log.Infof("%s react step decision step=%d thought=%q action=%q action_input=%q final=%q", traceLogPrefix, step, decision.Thought, decision.Action, decision.ActionInput, decision.Final)
|
||||
}
|
||||
|
||||
action := strings.ToLower(strings.TrimSpace(decision.Action))
|
||||
@@ -221,7 +274,7 @@ func (o *Orchestrator) runReAct(ctx context.Context, compressedContext, userInpu
|
||||
finalText = "我已完成思考,但当前没有足够信息给出稳定结论。"
|
||||
}
|
||||
if o.log != nil {
|
||||
o.log.Infof("react final step=%d final=%q", step, finalText)
|
||||
o.log.Infof("%s react final step=%d final=%q", traceLogPrefix, step, finalText)
|
||||
}
|
||||
return finalText, nil
|
||||
}
|
||||
@@ -229,37 +282,45 @@ func (o *Orchestrator) runReAct(ctx context.Context, compressedContext, userInpu
|
||||
tool, ok := o.tools.Get(action)
|
||||
if !ok {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("react step tool missing step=%d tool=%s", step, action)
|
||||
o.log.Warnf("%s react step tool missing step=%d tool=%s", traceLogPrefix, step, action)
|
||||
}
|
||||
scratchpad += fmt.Sprintf("Step %d Thought: %s\nStep %d Observation: tool %s 不存在\n", step, decision.Thought, step, action)
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Thought: " + decision.Thought + "\n"
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Observation: " + formatToolErrorObservation("TOOL_NOT_FOUND", action, "tool not found") + "\n"
|
||||
o.emitCapabilityGap(chatID, userID, userInput, "tool_not_found:"+action)
|
||||
continue
|
||||
}
|
||||
|
||||
toolOut, toolErr := tool.Call(ctx, decision.ActionInput)
|
||||
if o.log != nil {
|
||||
o.log.Infof("react step tool call step=%d tool=%s input=%q", step, action, decision.ActionInput)
|
||||
o.log.Infof("%s react step tool call step=%d tool=%s input=%q", traceLogPrefix, step, action, decision.ActionInput)
|
||||
}
|
||||
obs := strings.TrimSpace(toolOut)
|
||||
if obs == "" {
|
||||
obs = "(empty output)"
|
||||
}
|
||||
if toolErr != nil {
|
||||
obs = obs + "\nERROR: " + toolErr.Error()
|
||||
obs = formatToolErrorObservation("TOOL_EXEC_ERROR", action, toolErr.Error()) + "\nOUTPUT:\n" + obs
|
||||
o.emitCapabilityGap(chatID, userID, userInput, "tool_call_failed:"+action)
|
||||
}
|
||||
if o.log != nil {
|
||||
o.log.Infof("react step observation step=%d tool=%s observation=%q", step, action, obs)
|
||||
o.log.Infof("%s react step observation step=%d tool=%s observation=%q", traceLogPrefix, step, action, obs)
|
||||
}
|
||||
if len(obs) > 2000 {
|
||||
obs = obs[:2000]
|
||||
}
|
||||
scratchpad += fmt.Sprintf("Step %d Thought: %s\nStep %d Action: %s\nStep %d ActionInput: %s\nStep %d Observation: %s\n", step, decision.Thought, step, action, step, decision.ActionInput, step, obs)
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Thought: " + decision.Thought + "\n"
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Action: " + action + "\n"
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " ActionInput: " + decision.ActionInput + "\n"
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Observation: " + obs + "\n"
|
||||
}
|
||||
|
||||
o.emitCapabilityGap(chatID, userID, userInput, "react_step_exhausted")
|
||||
return "我尝试了多轮思考与工具调用,但仍未得到稳定结论。请给我更具体的约束或允许我继续尝试。", nil
|
||||
}
|
||||
|
||||
func (o *Orchestrator) matchSkills(ctx context.Context, compressedContext, userInput string) []knowledge.Skill {
|
||||
if len(o.skills) == 0 {
|
||||
skills := o.getSkillsSnapshot()
|
||||
if len(skills) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -277,7 +338,7 @@ func (o *Orchestrator) matchSkills(ctx context.Context, compressedContext, userI
|
||||
|
||||
userPrompt := strings.Join([]string{
|
||||
"候选技能:",
|
||||
formatSkillCatalog(o.skills),
|
||||
formatSkillCatalog(skills),
|
||||
"",
|
||||
"历史上下文:",
|
||||
compressedContext,
|
||||
@@ -316,7 +377,7 @@ func (o *Orchestrator) matchSkills(ctx context.Context, compressedContext, userI
|
||||
if _, ok := seen[name]; ok {
|
||||
continue
|
||||
}
|
||||
for _, skill := range o.skills {
|
||||
for _, skill := range skills {
|
||||
if strings.ToLower(strings.TrimSpace(skill.Name)) == name {
|
||||
picked = append(picked, skill)
|
||||
seen[name] = struct{}{}
|
||||
@@ -338,28 +399,132 @@ func (o *Orchestrator) matchSkills(ctx context.Context, compressedContext, userI
|
||||
return picked
|
||||
}
|
||||
|
||||
func parseDecision(raw string) (reactDecision, error) {
|
||||
raw = normalizeJSON(raw)
|
||||
start := strings.Index(raw, "{")
|
||||
end := strings.LastIndex(raw, "}")
|
||||
if start < 0 || end < start {
|
||||
return reactDecision{}, fmt.Errorf("no json object found")
|
||||
func (o *Orchestrator) emitCapabilityGap(chatID, userID, intent, reason string) {
|
||||
if !o.enableCapabilityGap {
|
||||
return
|
||||
}
|
||||
intent = strings.TrimSpace(intent)
|
||||
reason = strings.TrimSpace(reason)
|
||||
if intent == "" || reason == "" {
|
||||
return
|
||||
}
|
||||
if len(intent) > 1000 {
|
||||
intent = intent[:1000]
|
||||
}
|
||||
if len(reason) > 240 {
|
||||
reason = reason[:240]
|
||||
}
|
||||
if err := o.store.SaveCapabilityGap(chatID, userID, intent, reason); err != nil && o.log != nil {
|
||||
o.log.Warnf("save capability gap failed chat_id=%s user_id=%s err=%v", chatID, userID, err)
|
||||
return
|
||||
}
|
||||
raw = raw[start : end+1]
|
||||
|
||||
var out reactDecision
|
||||
if err := json.Unmarshal([]byte(raw), &out); err != nil {
|
||||
return reactDecision{}, err
|
||||
clusters, err := o.store.TopCapabilityGapClusters(20, time.Now().UTC().Add(-o.gapLookbackDuration))
|
||||
if err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("query capability gap clusters failed err=%v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
for _, c := range clusters {
|
||||
if c.Count < o.gapDraftTriggerCount {
|
||||
continue
|
||||
}
|
||||
path, created, draftErr := knowledge.GenerateSkillDraft(c, o.autoSkillDir)
|
||||
if draftErr != nil {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("generate skill draft failed intent_key=%s reason=%s err=%v", c.IntentKey, c.Reason, draftErr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if created && o.log != nil {
|
||||
o.log.Infof("capability gap draft generated path=%s intent_key=%s reason=%s count=%d", path, c.IntentKey, c.Reason, c.Count)
|
||||
}
|
||||
if created {
|
||||
if reloadErr := o.ReloadSkills(); reloadErr != nil && o.log != nil {
|
||||
o.log.Warnf("auto reload skills failed after generation path=%s err=%v", path, reloadErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func normalizeJSON(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
raw = strings.TrimPrefix(raw, "```json")
|
||||
raw = strings.TrimPrefix(raw, "```")
|
||||
raw = strings.TrimSuffix(raw, "```")
|
||||
return strings.TrimSpace(raw)
|
||||
func (o *Orchestrator) ReloadSkills() error {
|
||||
skills, err := knowledge.LoadSkillSet(o.skillsDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
o.skillsMu.Lock()
|
||||
o.skills = skills
|
||||
o.skillsMu.Unlock()
|
||||
if o.log != nil {
|
||||
o.log.Infof("skills hot reloaded count=%d dir=%s", len(skills), o.skillsDir)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Orchestrator) getSkillsSnapshot() []knowledge.Skill {
|
||||
o.skillsMu.RLock()
|
||||
defer o.skillsMu.RUnlock()
|
||||
out := make([]knowledge.Skill, len(o.skills))
|
||||
copy(out, o.skills)
|
||||
return out
|
||||
}
|
||||
|
||||
func (o *Orchestrator) BuildCapabilityGapReport(limit int) (string, error) {
|
||||
clusters, err := o.store.TopCapabilityGapClusters(limit, time.Now().UTC().Add(-o.gapLookbackDuration))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(clusters) == 0 {
|
||||
return "最近没有采集到能力缺口记录。", nil
|
||||
}
|
||||
b := strings.Builder{}
|
||||
b.WriteString("高频能力缺口清单:\n")
|
||||
for i, c := range clusters {
|
||||
line := fmt.Sprintf("%d) intent=%s | reason=%s | count=%d | last_seen=%s\n", i+1, c.IntentKey, c.Reason, c.Count, c.LastSeenAt.Format("2006-01-02 15:04:05"))
|
||||
b.WriteString(line)
|
||||
}
|
||||
b.WriteString("\n草稿目录:")
|
||||
b.WriteString(o.autoSkillDir)
|
||||
b.WriteString("\n系统会在达到阈值后自动生成并热加载技能;你也可以手动发送 /reload_skills。")
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func (o *Orchestrator) findSkillByKeyword(keywords ...string) (knowledge.Skill, bool) {
|
||||
if len(keywords) == 0 {
|
||||
return knowledge.Skill{}, false
|
||||
}
|
||||
skills := o.getSkillsSnapshot()
|
||||
for _, s := range skills {
|
||||
name := strings.ToLower(strings.TrimSpace(s.Name))
|
||||
content := strings.ToLower(strings.TrimSpace(s.Content))
|
||||
for _, kw := range keywords {
|
||||
kw = strings.ToLower(strings.TrimSpace(kw))
|
||||
if kw == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(name, kw) || strings.Contains(content, kw) {
|
||||
return s, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return knowledge.Skill{}, false
|
||||
}
|
||||
|
||||
func formatToolErrorObservation(code, action, reason string) string {
|
||||
code = strings.TrimSpace(code)
|
||||
action = strings.TrimSpace(action)
|
||||
reason = strings.TrimSpace(reason)
|
||||
if code == "" {
|
||||
code = "TOOL_EXEC_ERROR"
|
||||
}
|
||||
if action == "" {
|
||||
action = "unknown"
|
||||
}
|
||||
if reason == "" {
|
||||
reason = "unknown error"
|
||||
}
|
||||
return "ERROR_CODE=" + code + "; TOOL=" + action + "; REASON=" + reason
|
||||
}
|
||||
|
||||
func formatSkills(skills []knowledge.Skill) string {
|
||||
|
||||
31
internal/agent/react_parser.go
Normal file
31
internal/agent/react_parser.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func parseDecision(raw string) (reactDecision, error) {
|
||||
raw = normalizeJSON(raw)
|
||||
start := strings.Index(raw, "{")
|
||||
end := strings.LastIndex(raw, "}")
|
||||
if start < 0 || end < start {
|
||||
return reactDecision{}, fmt.Errorf("no json object found")
|
||||
}
|
||||
raw = raw[start : end+1]
|
||||
|
||||
var out reactDecision
|
||||
if err := json.Unmarshal([]byte(raw), &out); err != nil {
|
||||
return reactDecision{}, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func normalizeJSON(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
raw = strings.TrimPrefix(raw, "```json")
|
||||
raw = strings.TrimPrefix(raw, "```")
|
||||
raw = strings.TrimSuffix(raw, "```")
|
||||
return strings.TrimSpace(raw)
|
||||
}
|
||||
32
internal/agent/react_parser_test.go
Normal file
32
internal/agent/react_parser_test.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package agent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseDecisionPlainJSON(t *testing.T) {
|
||||
raw := `{"thought":"t","action":"none","action_input":"","final":"ok"}`
|
||||
got, err := parseDecision(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("parseDecision error: %v", err)
|
||||
}
|
||||
if got.Action != "none" || got.Final != "ok" {
|
||||
t.Fatalf("unexpected decision: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDecisionCodeFence(t *testing.T) {
|
||||
raw := "```json\n{\"thought\":\"t\",\"action\":\"shell\",\"action_input\":\"ls\",\"final\":\"\"}\n```"
|
||||
got, err := parseDecision(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("parseDecision error: %v", err)
|
||||
}
|
||||
if got.Action != "shell" || got.ActionInput != "ls" {
|
||||
t.Fatalf("unexpected decision: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDecisionInvalid(t *testing.T) {
|
||||
_, err := parseDecision("not json")
|
||||
if err == nil {
|
||||
t.Fatal("expected parse error")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user