feat: implement streaming chat, skill routing, and SAFe PI planning tools
- Add /api/chat/stream endpoint with Server-Sent Events (SSE) for real-time message streaming * Implement StreamEvent types (thought, tool_call, tool_result, final, error) * Add StreamEventCallback mechanism for event propagation * Create StreamChatHandler in webui/bot with proper HTTP headers and flushing - Implement LLM-based skill router for intelligent capability selection * Add optional routerLLM client for semantic routing * Implement routeSkillsWithLLM() to match user intent to available skills * Add matchSkillsByName() for fuzzy skill matching * Update buildUnifiedSystemPrompt() to use routed skills - Add streaming support to ReAct pipeline * Implement runUnifiedReActStream() for streaming thought/action/observation * Emit StreamEvent at each ReAct step * Support callback error handling in streaming mode - Integrate three new DevOps tools * tools/filedoc: Extract document content from file_id via OpenAI * tools/giteaticket: Create Gitea issues from PI plan items with SAFe metadata * tools/piplan: Publish PI planning blueprints with dependency tracking - Add SAFe PI Planning skill * Implement PM/SA/RTE (iron triangle) workflow * Support for Feature, Enabler, and Dependency definition * Automatic task decomposition and Gitea integration - Create frontend integration documentation * Complete SSE protocol specification * TypeScript fetch + ReadableStream example * LLM-ready refactoring template for other projects - Simplify file handling * Remove legacy file context structures and dual-mode processing * Consolidate file operations into UploadAndCacheFiles() * Remove FilePromptMode configuration and related complexity - Update configuration * Add Router model support (LLM_ROUTER_MODEL) * Add Gitea configuration (BaseURL, Token, Owner, Repo) * WebSearch and additional tool infrastructure Tests: All 22 test packages passing, 8/8 webui tests including 3 new stream tests
This commit is contained in:
@@ -18,9 +18,32 @@ import (
|
||||
"laodingbot/internal/tools"
|
||||
)
|
||||
|
||||
// StreamEventType 定义流式输出事件类型
|
||||
type StreamEventType string
|
||||
|
||||
const (
|
||||
StreamEventTypeThought StreamEventType = "thought" // LLM 思考过程
|
||||
StreamEventTypeToolCall StreamEventType = "tool_call" // 工具调用请求
|
||||
StreamEventTypeToolResult StreamEventType = "tool_result" // 工具执行结果
|
||||
StreamEventTypeFinal StreamEventType = "final" // 最终答案
|
||||
StreamEventTypeError StreamEventType = "error" // 错误信息
|
||||
)
|
||||
|
||||
// StreamEvent 代表流式输出中的一个事件
|
||||
type StreamEvent struct {
|
||||
Type StreamEventType `json:"type"`
|
||||
Content string `json:"content"`
|
||||
Step int `json:"step,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
}
|
||||
|
||||
// StreamEventCallback 是流式事件回调函数类型,用于推送事件到客户端
|
||||
type StreamEventCallback func(event StreamEvent) error
|
||||
|
||||
// Orchestrator 负责协调和组合业务逻辑,包含 LLM 计算、上下文管理、技能匹配计算和工具调用。
|
||||
type Orchestrator struct {
|
||||
llm llm.Client
|
||||
routerLLM llm.Client // 可选:轻量路由模型,用于技能意图路由;为 nil 则仅用关键词匹配
|
||||
store *memory.SQLiteStore
|
||||
tools *tools.Registry
|
||||
soul string
|
||||
@@ -44,16 +67,10 @@ type pendingFileRef struct {
|
||||
MimeType string
|
||||
}
|
||||
|
||||
type filePromptContext struct {
|
||||
Summary string
|
||||
FatalReason string
|
||||
FileIDs []string
|
||||
Uploaded []pendingFileRef
|
||||
}
|
||||
|
||||
// NewOrchestrator 创建一个新的编排器对象,初始化关键路径和超时控制等。
|
||||
func NewOrchestrator(
|
||||
llmClient llm.Client,
|
||||
routerLLM llm.Client,
|
||||
store *memory.SQLiteStore,
|
||||
registry *tools.Registry,
|
||||
soul string,
|
||||
@@ -81,6 +98,7 @@ func NewOrchestrator(
|
||||
}
|
||||
return &Orchestrator{
|
||||
llm: llmClient,
|
||||
routerLLM: routerLLM,
|
||||
store: store,
|
||||
tools: registry,
|
||||
soul: soul,
|
||||
@@ -103,52 +121,88 @@ func NewOrchestrator(
|
||||
// - 是否需要调用工具(action + action_input)
|
||||
// 循环持续进行,直到 LLM 返回 is_final_answer=true。
|
||||
func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text string) (string, error) {
|
||||
return o.handleMessageInternal(ctx, chatID, userID, text, nil, false)
|
||||
return o.handleMessageInternal(ctx, chatID, userID, text)
|
||||
}
|
||||
|
||||
// HandleMessageWithFiles 接收用户消息和文件,上传文件获取 file_id 并缓存,然后进入普通消息处理流程。
|
||||
func (o *Orchestrator) HandleMessageWithFiles(ctx context.Context, chatID, userID, text string, files []llm.InputFile) (string, error) {
|
||||
return o.handleMessageInternal(ctx, chatID, userID, text, files, false)
|
||||
if len(files) > 0 {
|
||||
ids, err := o.UploadAndCacheFiles(ctx, chatID, userID, files)
|
||||
if err != nil && o.log != nil {
|
||||
o.log.Warnf("upload files failed chat_id=%s err=%v", chatID, err)
|
||||
}
|
||||
_ = ids
|
||||
}
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return "文件已接收。请继续发送你的问题。", nil
|
||||
}
|
||||
return o.handleMessageInternal(ctx, chatID, userID, text)
|
||||
}
|
||||
|
||||
// HandleMessageWithFileIDs 接收用户文本与外部 file_id 列表,复用统一 ReAct 链路。
|
||||
// 该方法会先把 file_id 注入当前会话上下文,然后调用常规 HandleMessage 流程。
|
||||
func (o *Orchestrator) HandleMessageWithFileIDs(ctx context.Context, chatID, userID, text string, fileIDs []string) (string, error) {
|
||||
ids := nonEmptyIDs(fileIDs)
|
||||
if len(ids) > 0 {
|
||||
refs := make([]pendingFileRef, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
refs = append(refs, pendingFileRef{ID: id})
|
||||
}
|
||||
o.appendPendingFiles(chatID, userID, refs)
|
||||
// HandleMessageStream 接收用户消息并通过流式方式返回回复。
|
||||
// 通过 callback 推送实时事件,包括思考过程、工具调用、工具结果和最终答案。
|
||||
func (o *Orchestrator) HandleMessageStream(ctx context.Context, chatID, userID, text string, callback StreamEventCallback) (string, error) {
|
||||
if callback == nil {
|
||||
return "", fmt.Errorf("stream callback is required")
|
||||
}
|
||||
return o.handleMessageInternal(ctx, chatID, userID, text, nil, true)
|
||||
return o.handleMessageStreamInternal(ctx, chatID, userID, text, callback)
|
||||
}
|
||||
|
||||
// HandleMessageStreamWithFiles 接收用户消息和文件,上传文件后进入流式处理流程。
|
||||
func (o *Orchestrator) HandleMessageStreamWithFiles(ctx context.Context, chatID, userID, text string, files []llm.InputFile, callback StreamEventCallback) (string, error) {
|
||||
if callback == nil {
|
||||
return "", fmt.Errorf("stream callback is required")
|
||||
}
|
||||
if len(files) > 0 {
|
||||
ids, err := o.UploadAndCacheFiles(ctx, chatID, userID, files)
|
||||
if err != nil && o.log != nil {
|
||||
o.log.Warnf("upload files failed chat_id=%s err=%v", chatID, err)
|
||||
}
|
||||
_ = ids
|
||||
}
|
||||
if strings.TrimSpace(text) == "" {
|
||||
return "文件已接收。请继续发送你的问题。", nil
|
||||
}
|
||||
return o.handleMessageStreamInternal(ctx, chatID, userID, text, callback)
|
||||
}
|
||||
|
||||
// UploadAndCacheFiles 上传文件到 LLM 并缓存 file_id,供后续同会话文本问答复用。
|
||||
// 该方法不会写入 messages 表,仅更新内存中的 pending file 上下文。
|
||||
func (o *Orchestrator) UploadAndCacheFiles(ctx context.Context, chatID, userID string, files []llm.InputFile) ([]string, error) {
|
||||
if len(files) == 0 {
|
||||
return nil, fmt.Errorf("no files provided")
|
||||
}
|
||||
uploadCtx := o.prepareFilePromptContext(ctx, files, nil)
|
||||
if strings.TrimSpace(uploadCtx.FatalReason) != "" {
|
||||
return nil, fmt.Errorf(uploadCtx.FatalReason)
|
||||
uploader, ok := o.llm.(llm.FileUploader)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("当前 LLM 客户端不支持文件上传接口")
|
||||
}
|
||||
ids := nonEmptyIDs(uploadCtx.FileIDs)
|
||||
if len(ids) == 0 {
|
||||
return nil, fmt.Errorf("file upload completed but no valid file_id returned")
|
||||
var ids []string
|
||||
var refs []pendingFileRef
|
||||
for i, f := range files {
|
||||
if strings.TrimSpace(f.FileName) == "" || len(f.Content) == 0 {
|
||||
return nil, fmt.Errorf("file[%d] 缺少文件名或内容", i+1)
|
||||
}
|
||||
fileID, err := uploader.UploadFile(ctx, f, "file-extract")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("file[%d] name=%s 上传失败: %w", i+1, f.FileName, err)
|
||||
}
|
||||
ids = append(ids, fileID)
|
||||
refs = append(refs, pendingFileRef{
|
||||
ID: fileID,
|
||||
Name: strings.TrimSpace(f.FileName),
|
||||
MimeType: defaultIfEmpty(strings.TrimSpace(f.MimeType), "application/octet-stream"),
|
||||
})
|
||||
}
|
||||
o.appendPendingFiles(chatID, userID, uploadCtx.toPendingRefs())
|
||||
o.appendPendingFiles(chatID, userID, refs)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (o *Orchestrator) handleMessageInternal(ctx context.Context, chatID, userID, text string, files []llm.InputFile, appendFileIDText bool) (string, error) {
|
||||
func (o *Orchestrator) handleMessageInternal(ctx context.Context, chatID, userID, text string) (string, error) {
|
||||
// 为链路追踪设置唯一的 TraceID
|
||||
traceID := logger.NewTraceID()
|
||||
ctx = logger.WithTraceID(ctx, traceID)
|
||||
traceLogPrefix := "trace_id=" + traceID
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s handle message chat_id=%s user_id=%s text_len=%d files=%d", traceLogPrefix, chatID, userID, len(text), len(files))
|
||||
o.log.Infof("%s handle message chat_id=%s user_id=%s text_len=%d", traceLogPrefix, chatID, userID, len(text))
|
||||
o.log.Debugf("%s handle message text=%q", traceLogPrefix, text)
|
||||
}
|
||||
|
||||
@@ -169,38 +223,6 @@ func (o *Orchestrator) handleMessageInternal(ctx context.Context, chatID, userID
|
||||
return report, nil
|
||||
}
|
||||
|
||||
trimmedText := strings.TrimSpace(text)
|
||||
isFileOnly := len(files) > 0 && trimmedText == ""
|
||||
|
||||
if isFileOnly {
|
||||
if err := o.store.SaveMessage(chatID, userID, "user", "[FILE_UPLOAD]"); err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("%s save file-only user marker failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
uploadCtx := o.prepareFilePromptContext(ctx, files, nil)
|
||||
if strings.TrimSpace(uploadCtx.FatalReason) != "" {
|
||||
finalText := "文件上传失败,无法建立文档上下文。" + "\n" + uploadCtx.FatalReason
|
||||
if err := o.store.SaveMessage(chatID, userID, "assistant", finalText); err != nil && o.log != nil {
|
||||
o.log.Warnf("%s save upload failure message failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return finalText, nil
|
||||
}
|
||||
o.appendPendingFiles(chatID, userID, uploadCtx.toPendingRefs())
|
||||
finalText := o.buildFileUploadAck(uploadCtx)
|
||||
if err := o.store.SaveMessage(chatID, userID, "assistant", finalText); err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("%s save file upload ack failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s file-only message handled chat_id=%s cached_files=%d", traceLogPrefix, chatID, len(uploadCtx.FileIDs))
|
||||
}
|
||||
return finalText, nil
|
||||
}
|
||||
|
||||
// 保存用户消息到 SQLite 中
|
||||
if err := o.store.SaveMessage(chatID, userID, "user", text); err != nil {
|
||||
if o.log != nil {
|
||||
@@ -223,28 +245,13 @@ func (o *Orchestrator) handleMessageInternal(ctx context.Context, chatID, userID
|
||||
}
|
||||
|
||||
// 进入统一 ReAct 循环
|
||||
pendingRefs := o.getPendingFiles(chatID, userID)
|
||||
fileCtx := o.prepareFilePromptContext(ctx, files, pendingRefs)
|
||||
if strings.TrimSpace(fileCtx.FatalReason) != "" {
|
||||
finalText := "文件上传失败,无法继续进行文档解析。" + "\n" + fileCtx.FatalReason
|
||||
if err := o.store.SaveMessage(chatID, userID, "assistant", finalText); err != nil && o.log != nil {
|
||||
o.log.Warnf("%s save assistant failure message failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
if o.log != nil {
|
||||
o.log.Warnf("%s stop before react due to file upload failure reason=%s", traceLogPrefix, fileCtx.FatalReason)
|
||||
}
|
||||
return finalText, nil
|
||||
}
|
||||
response, err := o.runUnifiedReAct(ctx, chatID, userID, compressed, text, fileCtx, appendFileIDText)
|
||||
response, err := o.runUnifiedReAct(ctx, chatID, userID, compressed, text)
|
||||
if err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("%s message generation failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
if len(pendingRefs) > 0 {
|
||||
o.clearPendingFiles(chatID, userID)
|
||||
}
|
||||
|
||||
// 最终将机器人的回复也加入记忆缓存
|
||||
if err := o.store.SaveMessage(chatID, userID, "assistant", response); err != nil {
|
||||
@@ -260,11 +267,89 @@ func (o *Orchestrator) handleMessageInternal(ctx context.Context, chatID, userID
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// handleMessageStreamInternal 处理流式消息的内部逻辑,类似于handleMessageInternal但支持流式回调
|
||||
func (o *Orchestrator) handleMessageStreamInternal(ctx context.Context, chatID, userID, text string, callback StreamEventCallback) (string, error) {
|
||||
// 为链路追踪设置唯一的 TraceID
|
||||
traceID := logger.NewTraceID()
|
||||
ctx = logger.WithTraceID(ctx, traceID)
|
||||
traceLogPrefix := "trace_id=" + traceID
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s handle message stream chat_id=%s user_id=%s text_len=%d", traceLogPrefix, chatID, userID, len(text))
|
||||
o.log.Debugf("%s handle message stream text=%q", traceLogPrefix, text)
|
||||
}
|
||||
|
||||
// 处理特殊的重载指令
|
||||
if strings.EqualFold(strings.TrimSpace(text), "/reload_skills") {
|
||||
if err := o.ReloadSkills(); err != nil {
|
||||
return "技能热加载失败: " + err.Error(), nil
|
||||
}
|
||||
return "技能已热加载完成。", nil
|
||||
}
|
||||
|
||||
// 如果用户请求能力缺口报告,则生成报告格式化输出
|
||||
if strings.EqualFold(strings.TrimSpace(text), "/capability_gaps") {
|
||||
report, err := o.BuildCapabilityGapReport(10)
|
||||
if err != nil {
|
||||
return "缺口报告生成失败: " + err.Error(), nil
|
||||
}
|
||||
return report, nil
|
||||
}
|
||||
|
||||
// 保存用户消息到 SQLite 中
|
||||
if err := o.store.SaveMessage(chatID, userID, "user", text); err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("%s save user message failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 读取最近的会话记忆并压缩成 Prompt 上下文
|
||||
recent, err := o.store.LoadRecent(chatID, 16)
|
||||
if err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("%s load recent failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
compressed := memory.CompressForPrompt(recent, 6000)
|
||||
if o.log != nil {
|
||||
o.log.Debugf("%s stream prompt context prepared chat_id=%s recent_count=%d compressed_len=%d", traceLogPrefix, chatID, len(recent), len(compressed))
|
||||
}
|
||||
|
||||
// 进入流式统一 ReAct 循环
|
||||
response, err := o.runUnifiedReActStream(ctx, chatID, userID, compressed, text, callback)
|
||||
if err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("%s stream message generation failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 最终将机器人的回复也加入记忆缓存
|
||||
if err := o.store.SaveMessage(chatID, userID, "assistant", response); err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("%s save assistant response failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s stream message handled chat_id=%s response_len=%d", traceLogPrefix, chatID, len(response))
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// buildUnifiedSystemPrompt 构建统一 ReAct 循环的 system prompt。
|
||||
// 工具定义通过 API 的 tools 字段传递;此处只需包含人格、技能、运行环境和思考指引。
|
||||
func (o *Orchestrator) buildUnifiedSystemPrompt(userInput string) string {
|
||||
// routedSkills 为 LLM 路由预选的技能列表;如果为 nil,则回退到关键词匹配。
|
||||
func (o *Orchestrator) buildUnifiedSystemPrompt(userInput string, routedSkills []knowledge.Skill) string {
|
||||
skillMetaDoc := o.formatSkillSummariesForPrompt()
|
||||
relevantSkillsDoc := o.formatSelectedSkillsForPrompt(userInput, nil)
|
||||
var relevantSkillsDoc string
|
||||
if routedSkills != nil {
|
||||
relevantSkillsDoc = o.formatSelectedSkillsForPrompt(userInput, routedSkills)
|
||||
} else {
|
||||
relevantSkillsDoc = o.formatSelectedSkillsForPrompt(userInput, nil)
|
||||
}
|
||||
runtimeDoc := formatRuntimeContextForPrompt()
|
||||
|
||||
return strings.Join([]string{
|
||||
@@ -292,20 +377,151 @@ func (o *Orchestrator) buildUnifiedSystemPrompt(userInput string) string {
|
||||
"",
|
||||
"===== 本轮相关技能(按用户问题筛选) =====",
|
||||
relevantSkillsDoc,
|
||||
"",
|
||||
"===== 关键约束 =====",
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
// routeSkillsWithLLM 使用轻量 LLM 模型对用户输入进行语义路由,判断是否需要加载技能以及选择哪些技能。
|
||||
// 返回匹配到的技能列表(可能为空切片表示不需要技能,nil 表示调用失败应回退)。
|
||||
func (o *Orchestrator) routeSkillsWithLLM(ctx context.Context, userInput string) ([]knowledge.Skill, error) {
|
||||
traceLogPrefix := "trace_id=" + logger.TraceIDFromContext(ctx)
|
||||
|
||||
summaries := o.getSkillSummariesSnapshot()
|
||||
if len(summaries) == 0 {
|
||||
if o.log != nil {
|
||||
o.log.Debugf("%s skill router: no skills available, skip", traceLogPrefix)
|
||||
}
|
||||
return []knowledge.Skill{}, nil
|
||||
}
|
||||
|
||||
// 构建技能池描述
|
||||
skillPool := strings.Builder{}
|
||||
for _, s := range summaries {
|
||||
name := strings.TrimSpace(s.Name)
|
||||
desc := strings.TrimSpace(s.Description)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
skillPool.WriteString("- ")
|
||||
skillPool.WriteString(name)
|
||||
if desc != "" {
|
||||
skillPool.WriteString(": ")
|
||||
skillPool.WriteString(desc)
|
||||
}
|
||||
skillPool.WriteString("\n")
|
||||
}
|
||||
|
||||
routerSystemPrompt := strings.Join([]string{
|
||||
"你是一个意图路由器。根据用户输入,从技能池中挑选最合适的技能。",
|
||||
"",
|
||||
"规则:",
|
||||
"1. 如果用户的问题可以直接回答(闲聊、简单问答)或只需简单工具调用,设置 need_skills=false,selected_skills 为空数组。",
|
||||
"2. 如果用户的问题涉及专业流程、复杂任务或与某个技能高度相关,设置 need_skills=true 并选择最相关的技能名称。",
|
||||
"3. 最多选择 3 个技能。",
|
||||
"4. 仅返回 JSON,不要附加任何其他文字。",
|
||||
"",
|
||||
"可用技能池:",
|
||||
strings.TrimSpace(skillPool.String()),
|
||||
"",
|
||||
"输出格式(严格 JSON):",
|
||||
`{"need_skills": true, "selected_skills": ["技能名称1"], "reason": "简要说明"}`,
|
||||
}, "\n")
|
||||
|
||||
routerUserPrompt := "用户输入:" + userInput
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Debugf("%s skill router request: skills_count=%d input_len=%d", traceLogPrefix, len(summaries), len(userInput))
|
||||
}
|
||||
|
||||
raw, err := o.routerLLM.Generate(ctx, routerSystemPrompt, routerUserPrompt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("router llm call failed: %w", err)
|
||||
}
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Debugf("%s skill router response: %s", traceLogPrefix, truncateForLog(raw, 500))
|
||||
}
|
||||
|
||||
decision, err := parseCapabilityRoute(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("router response parse failed: %w", err)
|
||||
}
|
||||
|
||||
if !decision.NeedSkills || len(decision.SelectedSkills) == 0 {
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s skill router: no skills needed, reason=%s", traceLogPrefix, decision.Reason)
|
||||
}
|
||||
return []knowledge.Skill{}, nil
|
||||
}
|
||||
|
||||
// 根据路由结果匹配完整技能内容
|
||||
allSkills := o.getSkillsSnapshot()
|
||||
selected := matchSkillsByName(allSkills, decision.SelectedSkills)
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s skill router: need_skills=true requested=%v matched=%d reason=%s",
|
||||
traceLogPrefix, decision.SelectedSkills, len(selected), decision.Reason)
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
// matchSkillsByName 根据名称列表从全量技能中模糊匹配。
|
||||
func matchSkillsByName(allSkills []knowledge.Skill, names []string) []knowledge.Skill {
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
matched := make([]knowledge.Skill, 0, len(names))
|
||||
for _, wantName := range names {
|
||||
want := strings.ToLower(strings.TrimSpace(wantName))
|
||||
if want == "" {
|
||||
continue
|
||||
}
|
||||
for _, sk := range allSkills {
|
||||
skName := strings.ToLower(strings.TrimSpace(sk.Name))
|
||||
if skName == want || strings.Contains(skName, want) || strings.Contains(want, skName) {
|
||||
matched = append(matched, sk)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return matched
|
||||
}
|
||||
|
||||
// runUnifiedReAct 执行统一的 ReAct 循环,使用原生 function calling API。
|
||||
// messages 数组随交互动态增长:system → history → user → assistant(tool_calls) → tool → ...
|
||||
// 循环持续到 LLM 返回无 tool_calls 的纯文本回复(即最终回答)或达到安全上限。
|
||||
func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, compressedContext, userInput string, fileCtx filePromptContext, appendFileIDText bool) (string, error) {
|
||||
func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, compressedContext, userInput string) (string, error) {
|
||||
traceID := logger.TraceIDFromContext(ctx)
|
||||
traceLogPrefix := "trace_id=" + traceID
|
||||
|
||||
systemPrompt := o.buildUnifiedSystemPrompt(userInput)
|
||||
// ===== LLM 意图路由:使用轻量模型判断是否需要加载技能 =====
|
||||
var routedSkills []knowledge.Skill
|
||||
if o.routerLLM != nil {
|
||||
routed, routeErr := o.routeSkillsWithLLM(ctx, userInput)
|
||||
if routeErr != nil {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("%s skill router failed, fallback to keyword matching err=%v", traceLogPrefix, routeErr)
|
||||
}
|
||||
// 路由失败时 routedSkills 保持 nil,buildUnifiedSystemPrompt 回退到关键词匹配
|
||||
} else {
|
||||
routedSkills = routed
|
||||
if o.log != nil {
|
||||
names := make([]string, 0, len(routedSkills))
|
||||
for _, sk := range routedSkills {
|
||||
names = append(names, sk.Name)
|
||||
}
|
||||
o.log.Infof("%s skill router selected %d skills: %v", traceLogPrefix, len(routedSkills), names)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
systemPrompt := o.buildUnifiedSystemPrompt(userInput, routedSkills)
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s unified react start", traceLogPrefix)
|
||||
o.log.Debugf("%s system_prompt_len=%d", traceLogPrefix, len(systemPrompt))
|
||||
}
|
||||
|
||||
// 检查 LLM 客户端是否支持原生 tool_calls
|
||||
@@ -314,7 +530,7 @@ func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, comp
|
||||
if o.log != nil {
|
||||
o.log.Warnf("%s llm client does not support ToolCallChatClient, falling back to legacy ReAct", traceLogPrefix)
|
||||
}
|
||||
return o.runLegacyReAct(ctx, chatID, userID, compressedContext, userInput, fileCtx, appendFileIDText)
|
||||
return o.runLegacyReAct(ctx, chatID, userID, compressedContext, userInput)
|
||||
}
|
||||
|
||||
// 构建初始 messages 数组
|
||||
@@ -322,13 +538,20 @@ func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, comp
|
||||
messages = append(messages, llm.PromptMessage{Role: "system", Content: systemPrompt})
|
||||
|
||||
// 加入历史会话上下文
|
||||
//messages = append(messages, parseCompressedHistoryMessages(compressedContext)...)
|
||||
messages = append(messages, parseCompressedHistoryMessages(compressedContext)...)
|
||||
|
||||
// 加入当前用户消息
|
||||
messages = append(messages, llm.PromptMessage{Role: "user", Content: userInput})
|
||||
|
||||
// 构建工具定义列表(通过 API tools 字段传递)
|
||||
toolDefs := o.buildToolDefinitions()
|
||||
if o.log != nil {
|
||||
toolNames := make([]string, 0, len(toolDefs))
|
||||
for _, td := range toolDefs {
|
||||
toolNames = append(toolNames, td.Function.Name)
|
||||
}
|
||||
o.log.Debugf("%s tool_defs_count=%d names=%v", traceLogPrefix, len(toolDefs), toolNames)
|
||||
}
|
||||
|
||||
const maxSteps = 20
|
||||
for step := 1; step <= maxSteps; step++ {
|
||||
@@ -337,7 +560,7 @@ func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, comp
|
||||
}
|
||||
|
||||
// 调用 LLM(传入完整 messages + tools 定义)
|
||||
completion, err := toolCallClient.GenerateWithTools(ctx, messages, toolDefs, fileCtx.FileIDs, appendFileIDText)
|
||||
completion, err := toolCallClient.GenerateWithTools(ctx, messages, toolDefs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -391,7 +614,8 @@ func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, comp
|
||||
}
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react step=%d tool_call tool=%s input=%q", traceLogPrefix, step, toolName, toolInput)
|
||||
o.log.Infof("%s react step=%d tool_call tool=%s input_len=%d", traceLogPrefix, step, toolName, len(toolInput))
|
||||
o.log.Debugf("%s react step=%d tool=%s input=%q", traceLogPrefix, step, toolName, toolInput)
|
||||
}
|
||||
|
||||
toolOut, toolErr := tool.Call(ctx, toolInput)
|
||||
@@ -410,6 +634,7 @@ func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, comp
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react step=%d tool=%s observation_len=%d", traceLogPrefix, step, toolName, len(obs))
|
||||
o.log.Debugf("%s react step=%d tool=%s observation=%q", traceLogPrefix, step, toolName, truncateForLog(obs, 500))
|
||||
}
|
||||
|
||||
messages = append(messages, llm.PromptMessage{
|
||||
@@ -426,8 +651,217 @@ func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, comp
|
||||
return "我尝试了多轮推理与工具调用,但仍未得到稳定结论。请给我更具体的约束或允许我继续尝试。", nil
|
||||
}
|
||||
|
||||
// runUnifiedReActStream 执行统一的 ReAct 循环并通过回调推送流式事件。
|
||||
func (o *Orchestrator) runUnifiedReActStream(ctx context.Context, chatID, userID, compressedContext, userInput string, callback StreamEventCallback) (string, error) {
|
||||
traceID := logger.TraceIDFromContext(ctx)
|
||||
traceLogPrefix := "trace_id=" + traceID
|
||||
|
||||
// ===== LLM 意图路由:使用轻量模型判断是否需要加载技能 =====
|
||||
var routedSkills []knowledge.Skill
|
||||
if o.routerLLM != nil {
|
||||
routed, routeErr := o.routeSkillsWithLLM(ctx, userInput)
|
||||
if routeErr != nil {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("%s skill router failed, fallback to keyword matching err=%v", traceLogPrefix, routeErr)
|
||||
}
|
||||
} else {
|
||||
routedSkills = routed
|
||||
if o.log != nil {
|
||||
names := make([]string, 0, len(routedSkills))
|
||||
for _, sk := range routedSkills {
|
||||
names = append(names, sk.Name)
|
||||
}
|
||||
o.log.Infof("%s skill router selected %d skills: %v", traceLogPrefix, len(routedSkills), names)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
systemPrompt := o.buildUnifiedSystemPrompt(userInput, routedSkills)
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s unified react stream start", traceLogPrefix)
|
||||
o.log.Debugf("%s system_prompt_len=%d", traceLogPrefix, len(systemPrompt))
|
||||
}
|
||||
|
||||
// 检查 LLM 客户端是否支持原生 tool_calls
|
||||
toolCallClient, supportsToolCalls := o.llm.(llm.ToolCallChatClient)
|
||||
if !supportsToolCalls {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("%s llm client does not support ToolCallChatClient, stream mode not available", traceLogPrefix)
|
||||
}
|
||||
return "", fmt.Errorf("stream mode requires ToolCallChatClient support")
|
||||
}
|
||||
|
||||
// 构建初始 messages 数组
|
||||
messages := make([]llm.PromptMessage, 0, 32)
|
||||
messages = append(messages, llm.PromptMessage{Role: "system", Content: systemPrompt})
|
||||
messages = append(messages, parseCompressedHistoryMessages(compressedContext)...)
|
||||
messages = append(messages, llm.PromptMessage{Role: "user", Content: userInput})
|
||||
|
||||
// 构建工具定义列表
|
||||
toolDefs := o.buildToolDefinitions()
|
||||
if o.log != nil {
|
||||
toolNames := make([]string, 0, len(toolDefs))
|
||||
for _, td := range toolDefs {
|
||||
toolNames = append(toolNames, td.Function.Name)
|
||||
}
|
||||
o.log.Debugf("%s tool_defs_count=%d names=%v", traceLogPrefix, len(toolDefs), toolNames)
|
||||
}
|
||||
|
||||
const maxSteps = 20
|
||||
for step := 1; step <= maxSteps; step++ {
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react stream step=%d start messages_count=%d", traceLogPrefix, step, len(messages))
|
||||
}
|
||||
|
||||
// 调用 LLM
|
||||
completion, err := toolCallClient.GenerateWithTools(ctx, messages, toolDefs)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react stream step=%d content_len=%d tool_calls=%d",
|
||||
traceLogPrefix, step, len(completion.Content), len(completion.ToolCalls))
|
||||
if completion.Content != "" {
|
||||
o.log.Debugf("%s react stream step=%d thought=%q", traceLogPrefix, step, completion.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// 推送思考过程事件
|
||||
if completion.Content != "" {
|
||||
if err := callback(StreamEvent{
|
||||
Type: StreamEventTypeThought,
|
||||
Content: completion.Content,
|
||||
Step: step,
|
||||
}); err != nil {
|
||||
return "", fmt.Errorf("callback error: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 无 tool_calls → 最终回答 ==========
|
||||
if len(completion.ToolCalls) == 0 {
|
||||
finalText := strings.TrimSpace(completion.Content)
|
||||
if finalText == "" {
|
||||
finalText = "已完成处理。"
|
||||
}
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react stream final at step=%d answer_len=%d", traceLogPrefix, step, len(finalText))
|
||||
}
|
||||
// 推送最终答案事件
|
||||
if err := callback(StreamEvent{
|
||||
Type: StreamEventTypeFinal,
|
||||
Content: finalText,
|
||||
Step: step,
|
||||
}); err != nil {
|
||||
return "", fmt.Errorf("callback error: %w", err)
|
||||
}
|
||||
return finalText, nil
|
||||
}
|
||||
|
||||
// ========== 有 tool_calls → 执行工具 ==========
|
||||
assistantMsg := llm.PromptMessage{
|
||||
Role: "assistant",
|
||||
Content: completion.Content,
|
||||
ToolCalls: completion.ToolCalls,
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
// 逐个执行工具调用
|
||||
for _, tc := range completion.ToolCalls {
|
||||
toolName := strings.ToLower(strings.TrimSpace(tc.Function.Name))
|
||||
toolInput := extractToolInput(tc.Function.Arguments)
|
||||
|
||||
// 推送工具调用事件
|
||||
if err := callback(StreamEvent{
|
||||
Type: StreamEventTypeToolCall,
|
||||
Content: toolInput,
|
||||
Step: step,
|
||||
ToolName: toolName,
|
||||
}); err != nil {
|
||||
return "", fmt.Errorf("callback error: %w", err)
|
||||
}
|
||||
|
||||
tool, ok := o.tools.Get(toolName)
|
||||
if !ok {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("%s react stream step=%d tool_not_found=%s", traceLogPrefix, step, toolName)
|
||||
}
|
||||
// 推送错误事件
|
||||
errMsg := "工具不存在:" + toolName
|
||||
if err := callback(StreamEvent{
|
||||
Type: StreamEventTypeError,
|
||||
Content: errMsg,
|
||||
Step: step,
|
||||
ToolName: toolName,
|
||||
}); err != nil {
|
||||
return "", fmt.Errorf("callback error: %w", err)
|
||||
}
|
||||
messages = append(messages, llm.PromptMessage{
|
||||
Role: "tool",
|
||||
ToolCallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Content: formatToolErrorObservation("TOOL_NOT_FOUND", toolName, "该工具不存在,请检查工具名称后重试"),
|
||||
})
|
||||
o.emitCapabilityGap(chatID, userID, userInput, "tool_not_found:"+toolName)
|
||||
continue
|
||||
}
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react stream step=%d tool_call tool=%s input_len=%d", traceLogPrefix, step, toolName, len(toolInput))
|
||||
o.log.Debugf("%s react stream step=%d tool=%s input=%q", traceLogPrefix, step, toolName, toolInput)
|
||||
}
|
||||
|
||||
toolOut, toolErr := tool.Call(ctx, toolInput)
|
||||
obs := strings.TrimSpace(toolOut)
|
||||
if obs == "" {
|
||||
obs = "(empty output)"
|
||||
}
|
||||
if toolErr != nil {
|
||||
obs = formatToolErrorObservation("TOOL_EXEC_ERROR", toolName, toolErr.Error()) + "\nOUTPUT:\n" + obs
|
||||
o.emitCapabilityGap(chatID, userID, userInput, "tool_call_failed:"+toolName)
|
||||
}
|
||||
// 限制观察值长度防止超出 LLM 上下文窗口
|
||||
if len(obs) > 4000 {
|
||||
obs = obs[:4000] + "\n...(truncated)"
|
||||
}
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react stream step=%d tool=%s observation_len=%d", traceLogPrefix, step, toolName, len(obs))
|
||||
o.log.Debugf("%s react stream step=%d tool=%s observation=%q", traceLogPrefix, step, toolName, truncateForLog(obs, 500))
|
||||
}
|
||||
|
||||
// 推送工具结果事件
|
||||
if err := callback(StreamEvent{
|
||||
Type: StreamEventTypeToolResult,
|
||||
Content: obs,
|
||||
Step: step,
|
||||
ToolName: toolName,
|
||||
}); err != nil {
|
||||
return "", fmt.Errorf("callback error: %w", err)
|
||||
}
|
||||
|
||||
messages = append(messages, llm.PromptMessage{
|
||||
Role: "tool",
|
||||
ToolCallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Content: obs,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 达到安全上限
|
||||
o.emitCapabilityGap(chatID, userID, userInput, "react_step_exhausted")
|
||||
errMsg := "我尝试了多轮推理与工具调用,但仍未得到稳定结论。请给我更具体的约束或允许我继续尝试。"
|
||||
_ = callback(StreamEvent{
|
||||
Type: StreamEventTypeError,
|
||||
Content: errMsg,
|
||||
})
|
||||
return errMsg, nil
|
||||
}
|
||||
|
||||
// runLegacyReAct 是旧版基于 JSON 决策解析的 ReAct 循环,作为不支持 tool_calls 的 LLM 的降级方案。
|
||||
func (o *Orchestrator) runLegacyReAct(ctx context.Context, chatID, userID, compressedContext, userInput string, fileCtx filePromptContext, appendFileIDText bool) (string, error) {
|
||||
func (o *Orchestrator) runLegacyReAct(ctx context.Context, chatID, userID, compressedContext, userInput string) (string, error) {
|
||||
traceID := logger.TraceIDFromContext(ctx)
|
||||
traceLogPrefix := "trace_id=" + traceID
|
||||
|
||||
@@ -441,8 +875,8 @@ func (o *Orchestrator) runLegacyReAct(ctx context.Context, chatID, userID, compr
|
||||
o.log.Infof("%s legacy react step=%d start", traceLogPrefix, step)
|
||||
}
|
||||
|
||||
messages := buildReActMessages(systemPrompt, compressedContext, userInput, fileCtx.Summary, scratchpad)
|
||||
raw, err := o.generateWithOptionalFilesMessages(ctx, messages, fileCtx.FileIDs, appendFileIDText)
|
||||
messages := buildReActMessages(systemPrompt, compressedContext, userInput, scratchpad)
|
||||
raw, err := o.generateMessages(ctx, messages)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -584,105 +1018,19 @@ func extractToolInput(arguments string) string {
|
||||
return arguments
|
||||
}
|
||||
|
||||
func (o *Orchestrator) prepareFilePromptContext(ctx context.Context, files []llm.InputFile, pending []pendingFileRef) filePromptContext {
|
||||
ctxOut := filePromptContext{}
|
||||
if len(pending) > 0 {
|
||||
for _, p := range pending {
|
||||
id := strings.TrimSpace(p.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
ctxOut.FileIDs = append(ctxOut.FileIDs, id)
|
||||
}
|
||||
}
|
||||
if len(files) == 0 {
|
||||
ctxOut.Summary = buildFileSummary(pending, nil)
|
||||
return ctxOut
|
||||
}
|
||||
uploader, ok := o.llm.(llm.FileUploader)
|
||||
if !ok {
|
||||
return filePromptContext{FatalReason: "检测到文件输入,但当前 LLM 客户端不支持文件上传接口。"}
|
||||
}
|
||||
|
||||
uploaded := make([]pendingFileRef, 0, len(files))
|
||||
for i, f := range files {
|
||||
if strings.TrimSpace(f.FileName) == "" || len(f.Content) == 0 {
|
||||
return filePromptContext{FatalReason: fmt.Sprintf("file[%d] 缺少文件名或内容,无法上传。", i+1)}
|
||||
}
|
||||
fileID, err := uploader.UploadFile(ctx, f, "file-extract")
|
||||
if err != nil {
|
||||
return filePromptContext{FatalReason: fmt.Sprintf("file[%d] name=%s 上传失败: %v", i+1, f.FileName, err)}
|
||||
}
|
||||
ctxOut.FileIDs = append(ctxOut.FileIDs, fileID)
|
||||
uploaded = append(uploaded, pendingFileRef{
|
||||
ID: fileID,
|
||||
Name: strings.TrimSpace(f.FileName),
|
||||
MimeType: defaultIfEmpty(strings.TrimSpace(f.MimeType), "application/octet-stream"),
|
||||
})
|
||||
}
|
||||
ctxOut.Uploaded = uploaded
|
||||
ctxOut.Summary = buildFileSummary(pending, uploaded)
|
||||
return ctxOut
|
||||
}
|
||||
|
||||
func buildFileSummary(pending, uploaded []pendingFileRef) string {
|
||||
if len(pending) == 0 && len(uploaded) == 0 {
|
||||
return ""
|
||||
}
|
||||
lines := make([]string, 0, len(pending)+len(uploaded)+2)
|
||||
lines = append(lines, "以下文件 file_id 可用于本轮问答:")
|
||||
idx := 1
|
||||
for _, p := range pending {
|
||||
id := strings.TrimSpace(p.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("- cached_file[%d] name=%s mime=%s file_id=%s", idx, defaultIfEmpty(strings.TrimSpace(p.Name), "(unknown)"), defaultIfEmpty(strings.TrimSpace(p.MimeType), "application/octet-stream"), id))
|
||||
idx++
|
||||
}
|
||||
for _, p := range uploaded {
|
||||
id := strings.TrimSpace(p.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("- uploaded_file[%d] name=%s mime=%s file_id=%s", idx, defaultIfEmpty(strings.TrimSpace(p.Name), "(unknown)"), defaultIfEmpty(strings.TrimSpace(p.MimeType), "application/octet-stream"), id))
|
||||
idx++
|
||||
}
|
||||
if len(lines) == 1 {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func (o *Orchestrator) generateWithOptionalFilesMessages(ctx context.Context, messages []llm.PromptMessage, fileIDs []string, appendFileIDText bool) (string, error) {
|
||||
ids := nonEmptyIDs(fileIDs)
|
||||
if len(ids) == 0 {
|
||||
if client, ok := o.llm.(llm.MessageChatClient); ok {
|
||||
return client.GenerateMessages(ctx, messages)
|
||||
}
|
||||
systemPrompt, userPrompt := fallbackPromptsFromMessages(messages)
|
||||
return o.llm.Generate(ctx, systemPrompt, userPrompt)
|
||||
}
|
||||
if client, ok := o.llm.(llm.FileMessageChatClient); ok {
|
||||
return client.GenerateMessagesWithFiles(ctx, messages, ids, appendFileIDText)
|
||||
}
|
||||
client, ok := o.llm.(llm.FileChatClient)
|
||||
if !ok {
|
||||
systemPrompt, userPrompt := fallbackPromptsFromMessages(messages)
|
||||
return o.llm.Generate(ctx, systemPrompt, userPrompt)
|
||||
func (o *Orchestrator) generateMessages(ctx context.Context, messages []llm.PromptMessage) (string, error) {
|
||||
if client, ok := o.llm.(llm.MessageChatClient); ok {
|
||||
return client.GenerateMessages(ctx, messages)
|
||||
}
|
||||
systemPrompt, userPrompt := fallbackPromptsFromMessages(messages)
|
||||
return client.GenerateWithFiles(ctx, systemPrompt, userPrompt, ids, appendFileIDText)
|
||||
return o.llm.Generate(ctx, systemPrompt, userPrompt)
|
||||
}
|
||||
|
||||
func buildReActMessages(systemPrompt, compressedContext, userInput, fileSummary, scratchpad string) []llm.PromptMessage {
|
||||
func buildReActMessages(systemPrompt, compressedContext, userInput, scratchpad string) []llm.PromptMessage {
|
||||
msgs := make([]llm.PromptMessage, 0, 16)
|
||||
msgs = append(msgs, llm.PromptMessage{Role: "system", Content: systemPrompt})
|
||||
msgs = append(msgs, parseCompressedHistoryMessages(compressedContext)...)
|
||||
|
||||
if strings.TrimSpace(fileSummary) != "" {
|
||||
msgs = append(msgs, llm.PromptMessage{Role: "assistant", Content: "文件上下文摘要:\n" + strings.TrimSpace(fileSummary)})
|
||||
}
|
||||
if strings.TrimSpace(scratchpad) != "" {
|
||||
msgs = append(msgs, llm.PromptMessage{Role: "assistant", Content: "推理记录:\n" + strings.TrimSpace(scratchpad)})
|
||||
}
|
||||
@@ -735,20 +1083,6 @@ func fallbackPromptsFromMessages(messages []llm.PromptMessage) (string, string)
|
||||
return strings.Join(sysParts, "\n\n"), strings.Join(userParts, "\n")
|
||||
}
|
||||
|
||||
func (o *Orchestrator) buildFileUploadAck(ctx filePromptContext) string {
|
||||
if len(ctx.FileIDs) == 0 {
|
||||
return "文件已接收,但未拿到有效 file_id。请重新上传一次。"
|
||||
}
|
||||
lines := []string{
|
||||
fmt.Sprintf("文件上传完成,已缓存 %d 个 file_id。", len(ctx.FileIDs)),
|
||||
"请继续发送你的问题,我会结合这些文件内容和历史对话一起回答。",
|
||||
}
|
||||
if strings.TrimSpace(ctx.Summary) != "" {
|
||||
lines = append(lines, "", ctx.Summary)
|
||||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func nonEmptyIDs(ids []string) []string {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
@@ -769,20 +1103,6 @@ func nonEmptyIDs(ids []string) []string {
|
||||
return out
|
||||
}
|
||||
|
||||
func (c filePromptContext) toPendingRefs() []pendingFileRef {
|
||||
if len(c.Uploaded) > 0 {
|
||||
copied := make([]pendingFileRef, len(c.Uploaded))
|
||||
copy(copied, c.Uploaded)
|
||||
return sanitizePendingRefs(copied)
|
||||
}
|
||||
ids := nonEmptyIDs(c.FileIDs)
|
||||
out := make([]pendingFileRef, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
out = append(out, pendingFileRef{ID: id})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (o *Orchestrator) appendPendingFiles(chatID, userID string, refs []pendingFileRef) {
|
||||
refs = sanitizePendingRefs(refs)
|
||||
if len(refs) == 0 {
|
||||
@@ -890,6 +1210,9 @@ func (o *Orchestrator) selectRelevantSkills(userInput string, maxCount int) []kn
|
||||
continue
|
||||
}
|
||||
ranked = append(ranked, item{skill: sk, score: score})
|
||||
if o.log != nil {
|
||||
o.log.Debugf("selectRelevantSkills skill=%q score=%d", sk.Name, score)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ranked) == 0 {
|
||||
@@ -910,6 +1233,13 @@ func (o *Orchestrator) selectRelevantSkills(userInput string, maxCount int) []kn
|
||||
for _, r := range ranked {
|
||||
out = append(out, r.skill)
|
||||
}
|
||||
if o.log != nil {
|
||||
selectedNames := make([]string, 0, len(out))
|
||||
for _, sk := range out {
|
||||
selectedNames = append(selectedNames, sk.Name)
|
||||
}
|
||||
o.log.Debugf("selectRelevantSkills query=%q matched=%d selected=%v", query, len(ranked), selectedNames)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -1181,3 +1511,10 @@ func (o *Orchestrator) formatToolDoc() string {
|
||||
}
|
||||
return strings.TrimSpace(b.String())
|
||||
}
|
||||
|
||||
func truncateForLog(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "...(truncated)"
|
||||
}
|
||||
|
||||
@@ -46,3 +46,50 @@ func TestFormatRuntimeContextForPromptIncludesGOOS(t *testing.T) {
|
||||
t.Fatalf("expected runtime context contains GOOS=%s, got: %s", runtime.GOOS, doc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchSkillsByNameExact(t *testing.T) {
|
||||
all := []knowledge.Skill{
|
||||
{Name: "SAFe PI Planning", Content: "PI规划技能"},
|
||||
{Name: "文件系统查询专家", Content: "文件查询"},
|
||||
{Name: "代码生成", Content: "代码生成技能"},
|
||||
}
|
||||
matched := matchSkillsByName(all, []string{"SAFe PI Planning"})
|
||||
if len(matched) != 1 {
|
||||
t.Fatalf("expected 1 match, got %d", len(matched))
|
||||
}
|
||||
if matched[0].Name != "SAFe PI Planning" {
|
||||
t.Fatalf("expected SAFe PI Planning, got %s", matched[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchSkillsByNameFuzzy(t *testing.T) {
|
||||
all := []knowledge.Skill{
|
||||
{Name: "SAFe PI Planning", Content: "PI规划技能"},
|
||||
{Name: "文件系统查询专家", Content: "文件查询"},
|
||||
}
|
||||
matched := matchSkillsByName(all, []string{"pi planning", "文件"})
|
||||
if len(matched) != 2 {
|
||||
t.Fatalf("expected 2 matches, got %d", len(matched))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchSkillsByNameNoMatch(t *testing.T) {
|
||||
all := []knowledge.Skill{
|
||||
{Name: "文件系统查询专家", Content: "文件查询"},
|
||||
}
|
||||
matched := matchSkillsByName(all, []string{"不存在的技能"})
|
||||
if len(matched) != 0 {
|
||||
t.Fatalf("expected 0 matches, got %d", len(matched))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchSkillsByNameEmpty(t *testing.T) {
|
||||
matched := matchSkillsByName(nil, []string{"any"})
|
||||
if len(matched) != 0 {
|
||||
t.Fatalf("expected 0 matches, got %d", len(matched))
|
||||
}
|
||||
matched = matchSkillsByName([]knowledge.Skill{{Name: "test"}}, nil)
|
||||
if len(matched) != 0 {
|
||||
t.Fatalf("expected 0 matches, got %d", len(matched))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,6 +29,7 @@ type Config struct {
|
||||
LLM LLMConfig
|
||||
Security SecurityConfig
|
||||
WebSearch WebSearchConfig
|
||||
Gitea GiteaConfig
|
||||
|
||||
SQLitePath string
|
||||
}
|
||||
@@ -52,11 +53,11 @@ type WebUIConfig struct {
|
||||
}
|
||||
|
||||
type LLMConfig struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
FileModel string
|
||||
FilePromptMode string
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
FileModel string
|
||||
RouterModel string // 轻量路由模型,用于技能意图路由;为空则仅用关键词匹配
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
@@ -70,6 +71,13 @@ type WebSearchConfig struct {
|
||||
APIKey string
|
||||
}
|
||||
|
||||
type GiteaConfig struct {
|
||||
BaseURL string // Gitea 实例地址
|
||||
Token string // Personal Access Token
|
||||
Owner string // 仓库所有者
|
||||
Repo string // 仓库名称
|
||||
}
|
||||
|
||||
func Load() (Config, error) {
|
||||
agentWorkspaceDir := resolveAgentWorkspaceDir()
|
||||
if err := preloadEnvFiles(); err != nil {
|
||||
@@ -106,11 +114,11 @@ func Load() (Config, error) {
|
||||
MaxUploadBytes: int64(intFromEnv("WEBUI_MAX_UPLOAD_MB", 20)) * 1024 * 1024,
|
||||
},
|
||||
LLM: LLMConfig{
|
||||
BaseURL: strings.TrimRight(defaultIfEmpty(os.Getenv("LLM_BASE_URL"), "https://api.openai.com/v1"), "/"),
|
||||
APIKey: strings.TrimSpace(os.Getenv("LLM_API_KEY")),
|
||||
Model: defaultIfEmpty(os.Getenv("LLM_MODEL"), "gpt-4o-mini"),
|
||||
FileModel: defaultIfEmpty(os.Getenv("LLM_FILE_MODEL"), defaultIfEmpty(os.Getenv("LLM_MODEL"), "gpt-4o-mini")),
|
||||
FilePromptMode: normalizeFilePromptMode(defaultIfEmpty(os.Getenv("LLM_FILE_PROMPT_MODE"), "user_content_file_parts")),
|
||||
BaseURL: strings.TrimRight(defaultIfEmpty(os.Getenv("LLM_BASE_URL"), "https://api.openai.com/v1"), "/"),
|
||||
APIKey: strings.TrimSpace(os.Getenv("LLM_API_KEY")),
|
||||
Model: defaultIfEmpty(os.Getenv("LLM_MODEL"), "gpt-4o-mini"),
|
||||
FileModel: defaultIfEmpty(os.Getenv("LLM_FILE_MODEL"), defaultIfEmpty(os.Getenv("LLM_MODEL"), "gpt-4o-mini")),
|
||||
RouterModel: strings.TrimSpace(os.Getenv("LLM_ROUTER_MODEL")),
|
||||
},
|
||||
SQLitePath: defaultIfEmpty(os.Getenv("SQLITE_PATH"), filepath.Join(defaultDataDir, "laodingbot.db")),
|
||||
WebSearch: WebSearchConfig{
|
||||
@@ -122,6 +130,12 @@ func Load() (Config, error) {
|
||||
AllowedCommands: splitCSV(defaultIfEmpty(os.Getenv("ALLOWED_COMMANDS"), "pwd,ls,cat,echo,grep,find,head,tail,go")),
|
||||
WorkDir: defaultIfEmpty(os.Getenv("WORK_DIR"), defaultWorkSubdir),
|
||||
},
|
||||
Gitea: GiteaConfig{
|
||||
BaseURL: strings.TrimRight(strings.TrimSpace(os.Getenv("GITEA_BASE_URL")), "/"),
|
||||
Token: strings.TrimSpace(os.Getenv("GITEA_TOKEN")),
|
||||
Owner: strings.TrimSpace(os.Getenv("GITEA_OWNER")),
|
||||
Repo: strings.TrimSpace(os.Getenv("GITEA_REPO")),
|
||||
},
|
||||
}
|
||||
|
||||
cfg.MessageChannel = strings.ToLower(strings.TrimSpace(cfg.MessageChannel))
|
||||
@@ -178,9 +192,6 @@ func Load() (Config, error) {
|
||||
if cfg.LLM.APIKey == "" {
|
||||
return Config{}, fmt.Errorf("LLM_API_KEY is required")
|
||||
}
|
||||
if cfg.LLM.FilePromptMode != "user_content_file_parts" && cfg.LLM.FilePromptMode != "system_fileid_uri" {
|
||||
return Config{}, fmt.Errorf("LLM_FILE_PROMPT_MODE must be one of: user_content_file_parts, system_fileid_uri")
|
||||
}
|
||||
|
||||
cfg.SoulPath = resolvePathInWorkspace(cfg.SoulPath, agentWorkspaceDir)
|
||||
cfg.SkillsDir = resolvePathInWorkspace(cfg.SkillsDir, agentWorkspaceDir)
|
||||
@@ -417,14 +428,3 @@ func splitCSV(raw string) []string {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeFilePromptMode(v string) string {
|
||||
v = strings.ToLower(strings.TrimSpace(v))
|
||||
if v == "" {
|
||||
return "user_content_file_parts"
|
||||
}
|
||||
if v == "system_fileid" || v == "system_fileid_url" || v == "system_fileid_uri" {
|
||||
return "system_fileid_uri"
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
@@ -33,21 +33,13 @@ type MessageChatClient interface {
|
||||
GenerateMessages(ctx context.Context, messages []PromptMessage) (string, error)
|
||||
}
|
||||
|
||||
type FileChatClient interface {
|
||||
GenerateWithFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string, appendFileIDText bool) (string, error)
|
||||
}
|
||||
|
||||
type FileMessageChatClient interface {
|
||||
GenerateMessagesWithFiles(ctx context.Context, messages []PromptMessage, fileIDs []string, appendFileIDText bool) (string, error)
|
||||
}
|
||||
|
||||
type FileUploader interface {
|
||||
UploadFile(ctx context.Context, file InputFile, purpose string) (string, error)
|
||||
}
|
||||
|
||||
// ToolCallChatClient 支持原生 function calling 的 LLM 客户端接口。
|
||||
type ToolCallChatClient interface {
|
||||
GenerateWithTools(ctx context.Context, messages []PromptMessage, tools []ToolDefinition, fileIDs []string, appendFileIDText bool) (*ChatCompletion, error)
|
||||
GenerateWithTools(ctx context.Context, messages []PromptMessage, tools []ToolDefinition) (*ChatCompletion, error)
|
||||
}
|
||||
|
||||
// ToolDefinition 描述一个可供 LLM 调用的工具函数定义。
|
||||
@@ -89,11 +81,9 @@ type InputFile struct {
|
||||
}
|
||||
|
||||
type OpenAICompatibleClient struct {
|
||||
client openai.Client
|
||||
model string
|
||||
fileModel string
|
||||
filePromptMode string
|
||||
log *logger.Logger
|
||||
client openai.Client
|
||||
model string
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func NewOpenAICompatibleClient(cfg config.LLMConfig, log *logger.Logger) *OpenAICompatibleClient {
|
||||
@@ -105,11 +95,9 @@ func NewOpenAICompatibleClient(cfg config.LLMConfig, log *logger.Logger) *OpenAI
|
||||
opts = append(opts, option.WithBaseURL(cfg.BaseURL))
|
||||
}
|
||||
return &OpenAICompatibleClient{
|
||||
client: openai.NewClient(opts...),
|
||||
model: cfg.Model,
|
||||
fileModel: cfg.FileModel,
|
||||
filePromptMode: cfg.FilePromptMode,
|
||||
log: log,
|
||||
client: openai.NewClient(opts...),
|
||||
model: cfg.Model,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,38 +106,22 @@ func (c *OpenAICompatibleClient) Generate(ctx context.Context, systemPrompt, use
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userPrompt},
|
||||
}
|
||||
return c.generateWithMessagesInternal(ctx, messages, nil, false)
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) GenerateWithFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string, appendFileIDText bool) (string, error) {
|
||||
messages := []PromptMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userPrompt},
|
||||
}
|
||||
return c.generateWithMessagesInternal(ctx, messages, fileIDs, appendFileIDText)
|
||||
return c.generateWithMessagesInternal(ctx, messages)
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) GenerateMessages(ctx context.Context, messages []PromptMessage) (string, error) {
|
||||
return c.generateWithMessagesInternal(ctx, messages, nil, false)
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) GenerateMessagesWithFiles(ctx context.Context, messages []PromptMessage, fileIDs []string, appendFileIDText bool) (string, error) {
|
||||
return c.generateWithMessagesInternal(ctx, messages, fileIDs, appendFileIDText)
|
||||
return c.generateWithMessagesInternal(ctx, messages)
|
||||
}
|
||||
|
||||
// GenerateWithTools 使用原生 function calling 发送请求,返回结构化的 ChatCompletion。
|
||||
func (c *OpenAICompatibleClient) GenerateWithTools(ctx context.Context, messages []PromptMessage, tools []ToolDefinition, fileIDs []string, appendFileIDText bool) (*ChatCompletion, error) {
|
||||
func (c *OpenAICompatibleClient) GenerateWithTools(ctx context.Context, messages []PromptMessage, tools []ToolDefinition) (*ChatCompletion, error) {
|
||||
model := c.model
|
||||
ids := nonEmptyIDs(fileIDs)
|
||||
if len(ids) > 0 && strings.TrimSpace(c.fileModel) != "" {
|
||||
model = c.fileModel
|
||||
}
|
||||
|
||||
sdkMessages := buildSDKMessages(messages, ids, c.normalizedFilePromptMode(), appendFileIDText)
|
||||
sdkMessages := buildSDKMessages(messages)
|
||||
sdkTools := toSDKTools(tools)
|
||||
|
||||
if c.log != nil {
|
||||
c.log.Debugf("llm tool-call request start model=%s messages=%d tools=%d files=%d", model, len(sdkMessages), len(sdkTools), len(ids))
|
||||
c.log.Debugf("llm tool-call request start model=%s messages=%d tools=%d", model, len(sdkMessages), len(sdkTools))
|
||||
}
|
||||
|
||||
params := openai.ChatCompletionNewParams{
|
||||
@@ -188,12 +160,8 @@ func (c *OpenAICompatibleClient) GenerateWithTools(ctx context.Context, messages
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) generateWithMessagesInternal(ctx context.Context, messages []PromptMessage, fileIDs []string, appendFileIDText bool) (string, error) {
|
||||
func (c *OpenAICompatibleClient) generateWithMessagesInternal(ctx context.Context, messages []PromptMessage) (string, error) {
|
||||
model := c.model
|
||||
ids := nonEmptyIDs(fileIDs)
|
||||
if len(ids) > 0 && strings.TrimSpace(c.fileModel) != "" {
|
||||
model = c.fileModel
|
||||
}
|
||||
|
||||
baseMessages := normalizePromptMessages(messages)
|
||||
if len(baseMessages) == 0 {
|
||||
@@ -202,10 +170,10 @@ func (c *OpenAICompatibleClient) generateWithMessagesInternal(ctx context.Contex
|
||||
|
||||
systemLen, userLen := promptMessageLengths(baseMessages)
|
||||
if c.log != nil {
|
||||
c.log.Debugf("llm request start model=%s system_len=%d user_len=%d file_count=%d file_prompt_mode=%s", model, systemLen, userLen, len(ids), c.normalizedFilePromptMode())
|
||||
c.log.Debugf("llm request start model=%s system_len=%d user_len=%d", model, systemLen, userLen)
|
||||
}
|
||||
|
||||
sdkMessages := buildSDKMessages(baseMessages, ids, c.normalizedFilePromptMode(), appendFileIDText)
|
||||
sdkMessages := buildSDKMessages(baseMessages)
|
||||
|
||||
params := openai.ChatCompletionNewParams{
|
||||
Model: shared.ChatModel(model),
|
||||
@@ -234,10 +202,9 @@ func (c *OpenAICompatibleClient) generateWithMessagesInternal(ctx context.Contex
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// buildSDKMessages 将 PromptMessage 列表转换为 openai SDK 的消息格式,并注入 file_id(如需要)。
|
||||
func buildSDKMessages(base []PromptMessage, fileIDs []string, mode string, appendFileIDText bool) []openai.ChatCompletionMessageParamUnion {
|
||||
mode = strings.ToLower(strings.TrimSpace(mode))
|
||||
out := make([]openai.ChatCompletionMessageParamUnion, 0, len(base)+2)
|
||||
// buildSDKMessages 将 PromptMessage 列表转换为 openai SDK 的消息格式。
|
||||
func buildSDKMessages(base []PromptMessage) []openai.ChatCompletionMessageParamUnion {
|
||||
out := make([]openai.ChatCompletionMessageParamUnion, 0, len(base))
|
||||
|
||||
for _, m := range base {
|
||||
role := normalizeRole(m.Role)
|
||||
@@ -247,34 +214,6 @@ func buildSDKMessages(base []PromptMessage, fileIDs []string, mode string, appen
|
||||
out = append(out, toSDKMessage(m, role))
|
||||
}
|
||||
|
||||
if len(fileIDs) == 0 {
|
||||
return out
|
||||
}
|
||||
|
||||
if appendFileIDText {
|
||||
// WebUI 场景:将首个 fileID 作为 text part 追加到最后一个 user 消息。
|
||||
firstFileID := strings.TrimSpace(fileIDs[0])
|
||||
if firstFileID == "" {
|
||||
return out
|
||||
}
|
||||
for i := len(out) - 1; i >= 0; i-- {
|
||||
if r := out[i].GetRole(); r != nil && *r == "user" {
|
||||
out[i] = buildUserMessageWithFileIDText(out[i], firstFileID)
|
||||
return out
|
||||
}
|
||||
}
|
||||
out = append(out, buildUserMessageWithFileIDText(openai.UserMessage(""), firstFileID))
|
||||
return out
|
||||
}
|
||||
|
||||
// 非 WebUI 场景:保持原有 file content part 方式。
|
||||
for i := len(out) - 1; i >= 0; i-- {
|
||||
if r := out[i].GetRole(); r != nil && *r == "user" {
|
||||
out[i] = buildUserMessageWithFiles(out[i], fileIDs)
|
||||
return out
|
||||
}
|
||||
}
|
||||
out = append(out, buildUserMessageWithFiles(openai.UserMessage(""), fileIDs))
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -309,53 +248,6 @@ func toSDKMessage(m PromptMessage, role string) openai.ChatCompletionMessagePara
|
||||
}
|
||||
}
|
||||
|
||||
// buildUserMessageWithFileIDText 为 user 消息追加一个 text part,内容为 fileID。
|
||||
func buildUserMessageWithFileIDText(msg openai.ChatCompletionMessageParamUnion, fileID string) openai.ChatCompletionMessageParamUnion {
|
||||
// 提取已有的文本内容
|
||||
text := ""
|
||||
if s, ok := msg.GetContent().AsAny().(*string); ok && s != nil {
|
||||
text = *s
|
||||
}
|
||||
fileID = strings.TrimSpace(fileID)
|
||||
if fileID == "" {
|
||||
return msg
|
||||
}
|
||||
|
||||
parts := make([]openai.ChatCompletionContentPartUnionParam, 0, 2)
|
||||
if strings.TrimSpace(text) != "" {
|
||||
parts = append(parts, openai.TextContentPart(text))
|
||||
}
|
||||
parts = append(parts, openai.TextContentPart(fileID))
|
||||
if len(parts) == 0 {
|
||||
return msg
|
||||
}
|
||||
return openai.UserMessage(parts)
|
||||
}
|
||||
|
||||
// buildUserMessageWithFiles 为 user 消息追加 file content parts。
|
||||
func buildUserMessageWithFiles(msg openai.ChatCompletionMessageParamUnion, fileIDs []string) openai.ChatCompletionMessageParamUnion {
|
||||
text := ""
|
||||
if s, ok := msg.GetContent().AsAny().(*string); ok && s != nil {
|
||||
text = *s
|
||||
}
|
||||
|
||||
parts := make([]openai.ChatCompletionContentPartUnionParam, 0, len(fileIDs)+1)
|
||||
if strings.TrimSpace(text) != "" {
|
||||
parts = append(parts, openai.TextContentPart(text))
|
||||
}
|
||||
for _, id := range fileIDs {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, openai.FileContentPart(openai.ChatCompletionContentPartFileFileParam{FileID: param.NewOpt(id)}))
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return msg
|
||||
}
|
||||
return openai.UserMessage(parts)
|
||||
}
|
||||
|
||||
// toSDKTools 将内部 ToolDefinition 列表转换为 openai SDK 的 ChatCompletionToolParam 列表。
|
||||
func toSDKTools(tools []ToolDefinition) []openai.ChatCompletionToolParam {
|
||||
if len(tools) == 0 {
|
||||
@@ -397,6 +289,46 @@ func fromSDKToolCalls(sdkCalls []openai.ChatCompletionMessageToolCall) []ToolCal
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizePromptMessages(messages []PromptMessage) []PromptMessage {
|
||||
out := make([]PromptMessage, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
role := normalizeRole(m.Role)
|
||||
if role == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, PromptMessage{
|
||||
Role: role,
|
||||
Content: m.Content,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
Name: m.Name,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeRole(role string) string {
|
||||
r := strings.ToLower(strings.TrimSpace(role))
|
||||
if r != "system" && r != "user" && r != "assistant" && r != "tool" {
|
||||
return ""
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func promptMessageLengths(messages []PromptMessage) (int, int) {
|
||||
systemLen := 0
|
||||
userLen := 0
|
||||
for _, m := range messages {
|
||||
switch normalizeRole(m.Role) {
|
||||
case "system":
|
||||
systemLen += len(m.Content)
|
||||
case "user":
|
||||
userLen += len(m.Content)
|
||||
}
|
||||
}
|
||||
return systemLen, userLen
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) UploadFile(ctx context.Context, file InputFile, purpose string) (string, error) {
|
||||
if strings.TrimSpace(file.FileName) == "" {
|
||||
return "", fmt.Errorf("empty file name")
|
||||
@@ -460,71 +392,3 @@ func appendIfMissing(items []string, value string) []string {
|
||||
}
|
||||
return append(items, value)
|
||||
}
|
||||
|
||||
func nonEmptyIDs(ids []string) []string {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(ids))
|
||||
seen := map[string]struct{}{}
|
||||
for _, id := range ids {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizePromptMessages(messages []PromptMessage) []PromptMessage {
|
||||
out := make([]PromptMessage, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
role := normalizeRole(m.Role)
|
||||
if role == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, PromptMessage{
|
||||
Role: role,
|
||||
Content: m.Content,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
Name: m.Name,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeRole(role string) string {
|
||||
r := strings.ToLower(strings.TrimSpace(role))
|
||||
if r != "system" && r != "user" && r != "assistant" && r != "tool" {
|
||||
return ""
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func promptMessageLengths(messages []PromptMessage) (int, int) {
|
||||
systemLen := 0
|
||||
userLen := 0
|
||||
for _, m := range messages {
|
||||
switch normalizeRole(m.Role) {
|
||||
case "system":
|
||||
systemLen += len(m.Content)
|
||||
case "user":
|
||||
userLen += len(m.Content)
|
||||
}
|
||||
}
|
||||
return systemLen, userLen
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) normalizedFilePromptMode() string {
|
||||
mode := strings.ToLower(strings.TrimSpace(c.filePromptMode))
|
||||
if mode == "system_fileid" || mode == "system_fileid_url" || mode == "system_fileid_uri" {
|
||||
return "system_fileid_uri"
|
||||
}
|
||||
return "user_content_file_parts"
|
||||
}
|
||||
|
||||
@@ -8,8 +8,11 @@ import (
|
||||
"laodingbot/internal/config"
|
||||
"laodingbot/internal/logger"
|
||||
"laodingbot/internal/tools"
|
||||
"laodingbot/tools/filedoc"
|
||||
"laodingbot/tools/fileoperation"
|
||||
"laodingbot/tools/git"
|
||||
"laodingbot/tools/giteaticket"
|
||||
"laodingbot/tools/piplan"
|
||||
"laodingbot/tools/shell"
|
||||
"laodingbot/tools/websearch"
|
||||
)
|
||||
@@ -20,6 +23,9 @@ func RunChild(ctx context.Context, cfg config.Config, log *logger.Logger) error
|
||||
var gitLog *logger.Logger
|
||||
var shellLog *logger.Logger
|
||||
var searchLog *logger.Logger
|
||||
var fileDocLog *logger.Logger
|
||||
var piPlanLog *logger.Logger
|
||||
var giteaTicketLog *logger.Logger
|
||||
var serverLog *logger.Logger
|
||||
if log != nil {
|
||||
log.Infof("toolhost child starting")
|
||||
@@ -28,6 +34,9 @@ func RunChild(ctx context.Context, cfg config.Config, log *logger.Logger) error
|
||||
gitLog = log.WithComponent("toolhost.git")
|
||||
shellLog = log.WithComponent("toolhost.shell")
|
||||
searchLog = log.WithComponent("toolhost.websearch")
|
||||
fileDocLog = log.WithComponent("toolhost.filedoc")
|
||||
piPlanLog = log.WithComponent("toolhost.piplan")
|
||||
giteaTicketLog = log.WithComponent("toolhost.giteaticket")
|
||||
serverLog = log.WithComponent("toolhost.server")
|
||||
}
|
||||
registry := tools.NewRegistry(registryLog)
|
||||
@@ -53,6 +62,27 @@ func RunChild(ctx context.Context, cfg config.Config, log *logger.Logger) error
|
||||
cfg.ToolOutputMaxChars,
|
||||
searchLog,
|
||||
))
|
||||
registry.Register(filedoc.New(
|
||||
filedoc.Config{
|
||||
APIKey: cfg.LLM.APIKey,
|
||||
BaseURL: cfg.LLM.BaseURL,
|
||||
Model: cfg.LLM.FileModel,
|
||||
Timeout: time.Duration(cfg.ToolCallTimeoutSec) * time.Second,
|
||||
},
|
||||
cfg.ToolOutputMaxChars,
|
||||
fileDocLog,
|
||||
))
|
||||
registry.Register(piplan.New(cfg.ToolOutputMaxChars, piPlanLog))
|
||||
registry.Register(giteaticket.New(
|
||||
giteaticket.Config{
|
||||
BaseURL: cfg.Gitea.BaseURL,
|
||||
Token: cfg.Gitea.Token,
|
||||
Owner: cfg.Gitea.Owner,
|
||||
Repo: cfg.Gitea.Repo,
|
||||
Timeout: time.Duration(cfg.ToolCallTimeoutSec) * time.Second,
|
||||
},
|
||||
giteaTicketLog,
|
||||
))
|
||||
|
||||
server := NewServer(registry, serverLog)
|
||||
if err := server.Serve(ctx, stdin(), stdout()); err != nil && ctx.Err() == nil {
|
||||
|
||||
@@ -18,13 +18,33 @@ import (
|
||||
)
|
||||
|
||||
type IncomingMessage struct {
|
||||
ChatID string
|
||||
UserID string
|
||||
Text string
|
||||
FileIDs []string
|
||||
ChatID string
|
||||
UserID string
|
||||
Text string
|
||||
}
|
||||
|
||||
// StreamEventType 定义流式输出的事件类型
|
||||
type StreamEventType string
|
||||
|
||||
const (
|
||||
StreamEventTypeThought StreamEventType = "thought" // LLM 思考过程
|
||||
StreamEventTypeToolCall StreamEventType = "tool_call" // 工具调用请求
|
||||
StreamEventTypeToolResult StreamEventType = "tool_result" // 工具执行结果
|
||||
StreamEventTypeFinal StreamEventType = "final" // 最终答案
|
||||
StreamEventTypeError StreamEventType = "error" // 错误信息
|
||||
)
|
||||
|
||||
// StreamEvent 代表流式输出中的一个事件
|
||||
type StreamEvent struct {
|
||||
Type StreamEventType `json:"type"`
|
||||
Content string `json:"content"`
|
||||
Step int `json:"step,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
}
|
||||
|
||||
type ChatHandler func(context.Context, IncomingMessage) (string, error)
|
||||
type StreamChatHandler func(context.Context, IncomingMessage, StreamEventCallback) (string, error)
|
||||
type StreamEventCallback func(event StreamEvent) error
|
||||
type UploadHandler func(context.Context, string, string, []llm.InputFile) ([]string, error)
|
||||
|
||||
type Bot struct {
|
||||
@@ -32,29 +52,25 @@ type Bot struct {
|
||||
maxUploadBytes int64
|
||||
log *logger.Logger
|
||||
|
||||
chatHandler ChatHandler
|
||||
uploadHandler UploadHandler
|
||||
counter uint64
|
||||
chatHandler ChatHandler
|
||||
streamChatHandler StreamChatHandler
|
||||
uploadHandler UploadHandler
|
||||
counter uint64
|
||||
}
|
||||
|
||||
type chatRequest struct {
|
||||
Text string `json:"text"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
FileIDs []string `json:"file_ids"`
|
||||
Text string `json:"text"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
func (r *chatRequest) UnmarshalJSON(data []byte) error {
|
||||
type rawChatRequest struct {
|
||||
Text string `json:"text"`
|
||||
SessionID string `json:"session_id"`
|
||||
SessionIDCamel string `json:"sessionId"`
|
||||
UserID string `json:"user_id"`
|
||||
UserIDCamel string `json:"userId"`
|
||||
FileIDs json.RawMessage `json:"file_ids"`
|
||||
FileIDsCamel json.RawMessage `json:"fileIds"`
|
||||
FileIDsFlat json.RawMessage `json:"fileids"`
|
||||
FileID json.RawMessage `json:"file_id"`
|
||||
Text string `json:"text"`
|
||||
SessionID string `json:"session_id"`
|
||||
SessionIDCamel string `json:"sessionId"`
|
||||
UserID string `json:"user_id"`
|
||||
UserIDCamel string `json:"userId"`
|
||||
}
|
||||
|
||||
var raw rawChatRequest
|
||||
@@ -65,13 +81,6 @@ func (r *chatRequest) UnmarshalJSON(data []byte) error {
|
||||
r.Text = raw.Text
|
||||
r.SessionID = firstNonEmpty(raw.SessionID, raw.SessionIDCamel)
|
||||
r.UserID = firstNonEmpty(raw.UserID, raw.UserIDCamel)
|
||||
|
||||
rawIDs := firstNonEmptyRaw(raw.FileIDs, raw.FileIDsCamel, raw.FileIDsFlat, raw.FileID)
|
||||
ids, err := decodeStringList(rawIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.FileIDs = ids
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -109,7 +118,7 @@ func NewBot(cfg config.WebUIConfig, log *logger.Logger) (*Bot, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *Bot) Run(ctx context.Context, chatHandler ChatHandler, uploadHandler UploadHandler) error {
|
||||
func (b *Bot) Run(ctx context.Context, chatHandler ChatHandler, streamChatHandler StreamChatHandler, uploadHandler UploadHandler) error {
|
||||
if chatHandler == nil {
|
||||
return fmt.Errorf("nil webui chat handler")
|
||||
}
|
||||
@@ -117,10 +126,12 @@ func (b *Bot) Run(ctx context.Context, chatHandler ChatHandler, uploadHandler Up
|
||||
return fmt.Errorf("nil webui upload handler")
|
||||
}
|
||||
b.chatHandler = chatHandler
|
||||
b.streamChatHandler = streamChatHandler
|
||||
b.uploadHandler = uploadHandler
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/api/chat", b.handleChat)
|
||||
mux.HandleFunc("/api/chat/stream", b.handleChatStream)
|
||||
mux.HandleFunc("/api/upload", b.handleUpload)
|
||||
|
||||
srv := &http.Server{
|
||||
@@ -191,10 +202,9 @@ func (b *Bot) handleChat(w http.ResponseWriter, r *http.Request) {
|
||||
userID := b.resolveID(req.UserID, "user")
|
||||
|
||||
reply, err := b.chatHandler(r.Context(), IncomingMessage{
|
||||
ChatID: sessionID,
|
||||
UserID: userID,
|
||||
Text: req.Text,
|
||||
FileIDs: req.FileIDs,
|
||||
ChatID: sessionID,
|
||||
UserID: userID,
|
||||
Text: req.Text,
|
||||
})
|
||||
if err != nil {
|
||||
if b.log != nil {
|
||||
@@ -210,37 +220,8 @@ func (b *Bot) handleChat(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
func decodeStringList(raw json.RawMessage) ([]string, error) {
|
||||
if len(raw) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var list []string
|
||||
if err := json.Unmarshal(raw, &list); err == nil {
|
||||
return nonEmptyIDs(list), nil
|
||||
}
|
||||
|
||||
var single string
|
||||
if err := json.Unmarshal(raw, &single); err == nil {
|
||||
if strings.TrimSpace(single) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return nonEmptyIDs(strings.Split(single, ",")), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid file ids format")
|
||||
}
|
||||
|
||||
func firstNonEmptyRaw(vals ...json.RawMessage) json.RawMessage {
|
||||
for _, v := range vals {
|
||||
if len(v) > 0 {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func firstNonEmpty(vals ...string) string {
|
||||
|
||||
for _, v := range vals {
|
||||
if strings.TrimSpace(v) != "" {
|
||||
return v
|
||||
@@ -249,24 +230,82 @@ func firstNonEmpty(vals ...string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func nonEmptyIDs(ids []string) []string {
|
||||
if len(ids) == 0 {
|
||||
func (b *Bot) handleChatStream(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"})
|
||||
return
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))), "application/json") {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "content-type must be application/json"})
|
||||
return
|
||||
}
|
||||
if b.streamChatHandler == nil {
|
||||
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "stream chat handler not ready"})
|
||||
return
|
||||
}
|
||||
|
||||
var req chatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid json body"})
|
||||
return
|
||||
}
|
||||
req.Text = strings.TrimSpace(req.Text)
|
||||
if req.Text == "" {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "text is required"})
|
||||
return
|
||||
}
|
||||
sessionID := b.resolveID(req.SessionID, "sess")
|
||||
userID := b.resolveID(req.UserID, "user")
|
||||
|
||||
// 设置 SSE 响应头
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
http.Error(w, "streaming not supported", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建回调函数来推送 SSE 事件
|
||||
callback := func(event StreamEvent) error {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintf(w, "data: %s\n\n", string(data))
|
||||
flusher.Flush()
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(ids))
|
||||
seen := map[string]struct{}{}
|
||||
for _, id := range ids {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
|
||||
// 调用流式处理器
|
||||
reply, err := b.streamChatHandler(r.Context(), IncomingMessage{
|
||||
ChatID: sessionID,
|
||||
UserID: userID,
|
||||
Text: req.Text,
|
||||
}, callback)
|
||||
if err != nil {
|
||||
if b.log != nil {
|
||||
b.log.Errorf("webui stream chat handler failed session_id=%s user_id=%s err=%v", sessionID, userID, err)
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
// 推送错误事件
|
||||
errEvent := StreamEvent{
|
||||
Type: StreamEventTypeError,
|
||||
Content: "stream error: " + err.Error(),
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
data, _ := json.Marshal(errEvent)
|
||||
fmt.Fprintf(w, "data: %s\n\n", string(data))
|
||||
flusher.Flush()
|
||||
return
|
||||
}
|
||||
|
||||
if b.log != nil {
|
||||
b.log.Infof("webui stream chat completed session_id=%s user_id=%s reply_len=%d", sessionID, userID, len(reply))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (b *Bot) handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -51,66 +52,6 @@ func TestHandleChatSuccess(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatWithFileIDs(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.chatHandler = func(_ context.Context, msg IncomingMessage) (string, error) {
|
||||
if msg.ChatID != "s1" || msg.UserID != "u1" || msg.Text != "hello" {
|
||||
t.Fatalf("unexpected message: %+v", msg)
|
||||
}
|
||||
if len(msg.FileIDs) != 2 || msg.FileIDs[0] != "file_a" || msg.FileIDs[1] != "file_b" {
|
||||
t.Fatalf("unexpected file ids: %+v", msg.FileIDs)
|
||||
}
|
||||
return "ok", nil
|
||||
}
|
||||
|
||||
body := strings.NewReader(`{"text":"hello","session_id":"s1","user_id":"u1","file_ids":["file_a","file_b"]}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChat(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatWithFileIDsAliases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{name: "camel array", body: `{"text":"hello","sessionId":"s1","userId":"u1","fileIds":["file_a","file_b"]}`},
|
||||
{name: "flat array", body: `{"text":"hello","session_id":"s1","user_id":"u1","fileids":["file_a","file_b"]}`},
|
||||
{name: "single key", body: `{"text":"hello","session_id":"s1","user_id":"u1","file_id":"file_a"}`},
|
||||
{name: "csv string", body: `{"text":"hello","session_id":"s1","user_id":"u1","file_ids":"file_a, file_b"}`},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.chatHandler = func(_ context.Context, msg IncomingMessage) (string, error) {
|
||||
if msg.ChatID != "s1" || msg.UserID != "u1" || msg.Text != "hello" {
|
||||
t.Fatalf("unexpected message: %+v", msg)
|
||||
}
|
||||
if len(msg.FileIDs) == 0 {
|
||||
t.Fatalf("expected file ids from alias payload, got empty")
|
||||
}
|
||||
return "ok", nil
|
||||
}
|
||||
|
||||
body := strings.NewReader(tt.body)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChat(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatMissingText(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.chatHandler = func(_ context.Context, _ IncomingMessage) (string, error) { return "", nil }
|
||||
@@ -215,3 +156,149 @@ func TestHandleUploadMissingFile(t *testing.T) {
|
||||
t.Fatalf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatStreamSuccess(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.streamChatHandler = func(_ context.Context, msg IncomingMessage, cb StreamEventCallback) (string, error) {
|
||||
if msg.ChatID != "s1" || msg.UserID != "u1" || msg.Text != "hello" {
|
||||
t.Fatalf("unexpected message: %+v", msg)
|
||||
}
|
||||
if err := cb(StreamEvent{Type: StreamEventTypeThought, Content: "thinking", Step: 1}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := cb(StreamEvent{Type: StreamEventTypeToolCall, Content: "{\"input\":\"pwd\"}", Step: 1, ToolName: "shell"}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := cb(StreamEvent{Type: StreamEventTypeToolResult, Content: "C:/Project", Step: 1, ToolName: "shell"}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := cb(StreamEvent{Type: StreamEventTypeFinal, Content: "done", Step: 2}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "done", nil
|
||||
}
|
||||
|
||||
body := strings.NewReader(`{"text":"hello","session_id":"s1","user_id":"u1"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat/stream", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChatStream(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
if got := w.Header().Get("Content-Type"); got != "text/event-stream" {
|
||||
t.Fatalf("expected text/event-stream, got %q", got)
|
||||
}
|
||||
|
||||
var events []StreamEvent
|
||||
chunks := strings.Split(strings.TrimSpace(w.Body.String()), "\n\n")
|
||||
for _, chunk := range chunks {
|
||||
line := strings.TrimSpace(chunk)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
t.Fatalf("invalid sse line: %q", line)
|
||||
}
|
||||
payload := strings.TrimPrefix(line, "data: ")
|
||||
var ev StreamEvent
|
||||
if err := json.Unmarshal([]byte(payload), &ev); err != nil {
|
||||
t.Fatalf("unmarshal stream event failed: %v payload=%s", err, payload)
|
||||
}
|
||||
events = append(events, ev)
|
||||
}
|
||||
|
||||
if len(events) != 4 {
|
||||
t.Fatalf("expected 4 events, got %d", len(events))
|
||||
}
|
||||
if events[0].Type != StreamEventTypeThought {
|
||||
t.Fatalf("event[0] type mismatch: %s", events[0].Type)
|
||||
}
|
||||
if events[1].Type != StreamEventTypeToolCall || events[1].ToolName != "shell" {
|
||||
t.Fatalf("event[1] mismatch: %+v", events[1])
|
||||
}
|
||||
if events[2].Type != StreamEventTypeToolResult || events[2].ToolName != "shell" {
|
||||
t.Fatalf("event[2] mismatch: %+v", events[2])
|
||||
}
|
||||
if events[3].Type != StreamEventTypeFinal || events[3].Content != "done" {
|
||||
t.Fatalf("event[3] mismatch: %+v", events[3])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatStreamHandlerError(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.streamChatHandler = func(_ context.Context, _ IncomingMessage, _ StreamEventCallback) (string, error) {
|
||||
return "", errors.New("boom")
|
||||
}
|
||||
|
||||
body := strings.NewReader(`{"text":"hello"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat/stream", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChatStream(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", w.Code)
|
||||
}
|
||||
respBody := w.Body.String()
|
||||
if !strings.Contains(respBody, `"type":"error"`) {
|
||||
t.Fatalf("expected error event in stream, body=%q", respBody)
|
||||
}
|
||||
if !strings.Contains(respBody, "stream error: boom") {
|
||||
t.Fatalf("expected error detail in stream, body=%q", respBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatStreamValidation(t *testing.T) {
|
||||
t.Run("method not allowed", func(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/chat/stream", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChatStream(w, req)
|
||||
if w.Code != http.StatusMethodNotAllowed {
|
||||
t.Fatalf("expected 405, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("content type must be json", func(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat/stream", strings.NewReader(`{"text":"hello"}`))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChatStream(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handler not ready", func(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat/stream", strings.NewReader(`{"text":"hello"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChatStream(w, req)
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("expected 500, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("text required", func(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.streamChatHandler = func(_ context.Context, _ IncomingMessage, _ StreamEventCallback) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat/stream", strings.NewReader(`{"text":" "}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChatStream(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user