Migrate LLM client to OpenAI SDK and implement WebUI-specific fileID handling
This commit is contained in:
@@ -2,6 +2,7 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sort"
|
||||
@@ -43,14 +44,6 @@ type pendingFileRef struct {
|
||||
MimeType string
|
||||
}
|
||||
|
||||
type capabilityRoutingResult struct {
|
||||
NeedSkills bool
|
||||
SelectedToolNames []string
|
||||
SelectedSkills []knowledge.Skill
|
||||
Reason string
|
||||
UsedFallback bool
|
||||
}
|
||||
|
||||
type filePromptContext struct {
|
||||
Summary string
|
||||
FatalReason string
|
||||
@@ -110,11 +103,25 @@ func NewOrchestrator(
|
||||
// - 是否需要调用工具(action + action_input)
|
||||
// 循环持续进行,直到 LLM 返回 is_final_answer=true。
|
||||
func (o *Orchestrator) HandleMessage(ctx context.Context, chatID, userID, text string) (string, error) {
|
||||
return o.handleMessageInternal(ctx, chatID, userID, text, nil)
|
||||
return o.handleMessageInternal(ctx, chatID, userID, text, nil, false)
|
||||
}
|
||||
|
||||
func (o *Orchestrator) HandleMessageWithFiles(ctx context.Context, chatID, userID, text string, files []llm.InputFile) (string, error) {
|
||||
return o.handleMessageInternal(ctx, chatID, userID, text, files)
|
||||
return o.handleMessageInternal(ctx, chatID, userID, text, files, false)
|
||||
}
|
||||
|
||||
// HandleMessageWithFileIDs 接收用户文本与外部 file_id 列表,复用统一 ReAct 链路。
|
||||
// 该方法会先把 file_id 注入当前会话上下文,然后调用常规 HandleMessage 流程。
|
||||
func (o *Orchestrator) HandleMessageWithFileIDs(ctx context.Context, chatID, userID, text string, fileIDs []string) (string, error) {
|
||||
ids := nonEmptyIDs(fileIDs)
|
||||
if len(ids) > 0 {
|
||||
refs := make([]pendingFileRef, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
refs = append(refs, pendingFileRef{ID: id})
|
||||
}
|
||||
o.appendPendingFiles(chatID, userID, refs)
|
||||
}
|
||||
return o.handleMessageInternal(ctx, chatID, userID, text, nil, true)
|
||||
}
|
||||
|
||||
// UploadAndCacheFiles 上传文件到 LLM 并缓存 file_id,供后续同会话文本问答复用。
|
||||
@@ -135,7 +142,7 @@ func (o *Orchestrator) UploadAndCacheFiles(ctx context.Context, chatID, userID s
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (o *Orchestrator) handleMessageInternal(ctx context.Context, chatID, userID, text string, files []llm.InputFile) (string, error) {
|
||||
func (o *Orchestrator) handleMessageInternal(ctx context.Context, chatID, userID, text string, files []llm.InputFile, appendFileIDText bool) (string, error) {
|
||||
// 为链路追踪设置唯一的 TraceID
|
||||
traceID := logger.NewTraceID()
|
||||
ctx = logger.WithTraceID(ctx, traceID)
|
||||
@@ -228,9 +235,7 @@ func (o *Orchestrator) handleMessageInternal(ctx context.Context, chatID, userID
|
||||
}
|
||||
return finalText, nil
|
||||
}
|
||||
routeInput := composeRouteInput(text, fileCtx.Summary)
|
||||
route := o.routeCapabilities(ctx, routeInput)
|
||||
response, err := o.runUnifiedReAct(ctx, chatID, userID, compressed, text, fileCtx, routeInput, route)
|
||||
response, err := o.runUnifiedReAct(ctx, chatID, userID, compressed, text, fileCtx, appendFileIDText)
|
||||
if err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Errorf("%s message generation failed chat_id=%s err=%v", traceLogPrefix, chatID, err)
|
||||
@@ -256,128 +261,198 @@ func (o *Orchestrator) handleMessageInternal(ctx context.Context, chatID, userID
|
||||
}
|
||||
|
||||
// buildUnifiedSystemPrompt 构建统一 ReAct 循环的 system prompt。
|
||||
// 工具始终可用;技能仅按当前问题挑选相关项作为增强上下文。
|
||||
func (o *Orchestrator) buildUnifiedSystemPrompt(userInput string, route capabilityRoutingResult) string {
|
||||
// 工具定义通过 API 的 tools 字段传递;此处只需包含人格、技能、运行环境和思考指引。
|
||||
func (o *Orchestrator) buildUnifiedSystemPrompt(userInput string) string {
|
||||
skillMetaDoc := o.formatSkillSummariesForPrompt()
|
||||
relevantSkillsDoc := o.formatSelectedSkillsForPrompt(userInput, route.SelectedSkills)
|
||||
toolDoc := o.formatToolDoc()
|
||||
relevantSkillsDoc := o.formatSelectedSkillsForPrompt(userInput, nil)
|
||||
runtimeDoc := formatRuntimeContextForPrompt()
|
||||
routeDoc := formatRouteForPrompt(route)
|
||||
|
||||
return strings.Join([]string{
|
||||
"你是一个个人自动化助手,必须遵循如下人格设定并保持一致:",
|
||||
o.soul,
|
||||
"",
|
||||
"===== ReAct 思考指引 =====",
|
||||
"你采用 ReAct(Reasoning + Acting)模式进行任务处理。",
|
||||
"1. 思考优先:在做出任何行动之前,先在回复中阐述你的推理过程(Thought)。",
|
||||
"2. 工具调用:如果需要获取信息或执行操作,使用提供的工具函数(function calling)进行调用。",
|
||||
"3. 观察反馈:检查工具返回的结果,据此决定下一步行动。",
|
||||
"4. 最终回答:当你有足够信息时,直接给出面向用户的最终文本回复,不要调用工具。",
|
||||
"",
|
||||
"注意事项:",
|
||||
"- 每次要么调用工具,要么给出最终回答,不要两者都做。",
|
||||
"- 如果工具调用失败,根据错误信息(Traceback)调整策略后重试或给出替代方案。",
|
||||
"- 涉及文件、目录、命令时,优先调用工具获取真实结果,不要猜测。",
|
||||
"- 你的思考过程(Thought)应写在回复内容中,帮助追踪推理逻辑。",
|
||||
"",
|
||||
"===== 运行环境 =====",
|
||||
runtimeDoc,
|
||||
"",
|
||||
"===== 可用技能概览 =====",
|
||||
skillMetaDoc,
|
||||
"",
|
||||
"===== 能力路由结果 =====",
|
||||
routeDoc,
|
||||
"",
|
||||
"===== 本轮相关技能(按用户问题筛选) =====",
|
||||
relevantSkillsDoc,
|
||||
"",
|
||||
"===== 可用工具 =====",
|
||||
toolDoc,
|
||||
"",
|
||||
"===== 输出格式约束 =====",
|
||||
"你必须使用 ReAct(Reasoning + Acting)模式进行决策。",
|
||||
"每次回复必须是且仅是一个 JSON 对象,字段如下:",
|
||||
"",
|
||||
"{",
|
||||
" \"thought\": \"你的推理过程(必填)\",",
|
||||
" \"action\": \"要调用的工具名称,如 file/shell/web_search(不调工具时填 none)\",",
|
||||
" \"action_input\": \"传给工具的输入(字符串或对象),不调工具时填空字符串或 null\",",
|
||||
" \"is_final_answer\": true 或 false,",
|
||||
" \"final_answer\": \"当 is_final_answer=true 时填写给用户的最终回复,否则填 null\"",
|
||||
"}",
|
||||
"",
|
||||
"决策规则:",
|
||||
"1) 如果你可以直接回答用户问题(不需要任何工具):",
|
||||
" 设 is_final_answer=true,action=\"none\",final_answer 填写完整回复。",
|
||||
"2) 优先判断是否可通过原子工具能力完成任务;若可完成,直接进行工具调用链路。",
|
||||
"3) 当纯工具调用无法满足时,再结合已加载的技能详细说明进行决策。",
|
||||
"4) 如果你需要调用工具获取信息后才能回答:",
|
||||
" 设 is_final_answer=false,action 填工具名,action_input 填工具所需输入,final_answer=null。",
|
||||
"5) 不要在 JSON 之外输出任何内容。",
|
||||
"6) 根据技能说明中的指引决定何时以及如何使用工具。",
|
||||
"7) 工具能力是全局可用的,不依赖技能命中;当技能不匹配时,仍可直接选择合适工具。",
|
||||
"8) 若技能中存在与当前运行环境不匹配的章节(如 Windows 专章),应降低优先级,除非用户明确要求该环境。",
|
||||
"9) 每轮工具调用结果会以 Observation 的形式追加到推理记录中,供你下一轮决策参考。",
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
// runUnifiedReAct 执行统一的 ReAct 循环。
|
||||
// LLM 每次都看到完整的技能集+工具集,自行决定是否调用工具或直接回答。
|
||||
// 循环持续到 is_final_answer=true 或达到安全上限。
|
||||
func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, compressedContext, userInput string, fileCtx filePromptContext, routeInput string, route capabilityRoutingResult) (string, error) {
|
||||
// runUnifiedReAct 执行统一的 ReAct 循环,使用原生 function calling API。
|
||||
// messages 数组随交互动态增长:system → history → user → assistant(tool_calls) → tool → ...
|
||||
// 循环持续到 LLM 返回无 tool_calls 的纯文本回复(即最终回答)或达到安全上限。
|
||||
func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, compressedContext, userInput string, fileCtx filePromptContext, appendFileIDText bool) (string, error) {
|
||||
traceID := logger.TraceIDFromContext(ctx)
|
||||
traceLogPrefix := "trace_id=" + traceID
|
||||
|
||||
if strings.TrimSpace(routeInput) == "" {
|
||||
routeInput = composeRouteInput(userInput, fileCtx.Summary)
|
||||
}
|
||||
systemPrompt := o.buildUnifiedSystemPrompt(routeInput, route)
|
||||
systemPrompt := o.buildUnifiedSystemPrompt(userInput)
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s unified react start route_need_skills=%v route_tools=%v route_skills=%d fallback=%v", traceLogPrefix, route.NeedSkills, route.SelectedToolNames, len(route.SelectedSkills), route.UsedFallback)
|
||||
o.log.Infof("%s unified react start", traceLogPrefix)
|
||||
}
|
||||
|
||||
// 安全上限:防止无限循环(当前暂不使用 reactMaxStep 配置约束,使用固定硬上限)
|
||||
// 检查 LLM 客户端是否支持原生 tool_calls
|
||||
toolCallClient, supportsToolCalls := o.llm.(llm.ToolCallChatClient)
|
||||
if !supportsToolCalls {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("%s llm client does not support ToolCallChatClient, falling back to legacy ReAct", traceLogPrefix)
|
||||
}
|
||||
return o.runLegacyReAct(ctx, chatID, userID, compressedContext, userInput, fileCtx, appendFileIDText)
|
||||
}
|
||||
|
||||
// 构建初始 messages 数组
|
||||
messages := make([]llm.PromptMessage, 0, 32)
|
||||
messages = append(messages, llm.PromptMessage{Role: "system", Content: systemPrompt})
|
||||
|
||||
// 加入历史会话上下文
|
||||
//messages = append(messages, parseCompressedHistoryMessages(compressedContext)...)
|
||||
|
||||
// 加入当前用户消息
|
||||
messages = append(messages, llm.PromptMessage{Role: "user", Content: userInput})
|
||||
|
||||
// 构建工具定义列表(通过 API tools 字段传递)
|
||||
toolDefs := o.buildToolDefinitions()
|
||||
|
||||
const maxSteps = 20
|
||||
for step := 1; step <= maxSteps; step++ {
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react step=%d start messages_count=%d", traceLogPrefix, step, len(messages))
|
||||
}
|
||||
|
||||
// 调用 LLM(传入完整 messages + tools 定义)
|
||||
completion, err := toolCallClient.GenerateWithTools(ctx, messages, toolDefs, fileCtx.FileIDs, appendFileIDText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react step=%d content_len=%d tool_calls=%d",
|
||||
traceLogPrefix, step, len(completion.Content), len(completion.ToolCalls))
|
||||
if completion.Content != "" {
|
||||
o.log.Debugf("%s react step=%d thought=%q", traceLogPrefix, step, completion.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 无 tool_calls → 最终回答 ==========
|
||||
if len(completion.ToolCalls) == 0 {
|
||||
finalText := strings.TrimSpace(completion.Content)
|
||||
if finalText == "" {
|
||||
finalText = "已完成处理。"
|
||||
}
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react final at step=%d answer_len=%d", traceLogPrefix, step, len(finalText))
|
||||
}
|
||||
return finalText, nil
|
||||
}
|
||||
|
||||
// ========== 有 tool_calls → 将 assistant 消息加入历史,然后执行工具 ==========
|
||||
assistantMsg := llm.PromptMessage{
|
||||
Role: "assistant",
|
||||
Content: completion.Content,
|
||||
ToolCalls: completion.ToolCalls,
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
// 逐个执行工具调用,并将结果作为 tool 角色消息加入
|
||||
for _, tc := range completion.ToolCalls {
|
||||
toolName := strings.ToLower(strings.TrimSpace(tc.Function.Name))
|
||||
toolInput := extractToolInput(tc.Function.Arguments)
|
||||
|
||||
tool, ok := o.tools.Get(toolName)
|
||||
if !ok {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("%s react step=%d tool_not_found=%s", traceLogPrefix, step, toolName)
|
||||
}
|
||||
messages = append(messages, llm.PromptMessage{
|
||||
Role: "tool",
|
||||
ToolCallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Content: formatToolErrorObservation("TOOL_NOT_FOUND", toolName, "该工具不存在,请检查工具名称后重试"),
|
||||
})
|
||||
o.emitCapabilityGap(chatID, userID, userInput, "tool_not_found:"+toolName)
|
||||
continue
|
||||
}
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react step=%d tool_call tool=%s input=%q", traceLogPrefix, step, toolName, toolInput)
|
||||
}
|
||||
|
||||
toolOut, toolErr := tool.Call(ctx, toolInput)
|
||||
obs := strings.TrimSpace(toolOut)
|
||||
if obs == "" {
|
||||
obs = "(empty output)"
|
||||
}
|
||||
if toolErr != nil {
|
||||
obs = formatToolErrorObservation("TOOL_EXEC_ERROR", toolName, toolErr.Error()) + "\nOUTPUT:\n" + obs
|
||||
o.emitCapabilityGap(chatID, userID, userInput, "tool_call_failed:"+toolName)
|
||||
}
|
||||
// 限制观察值长度防止超出 LLM 上下文窗口
|
||||
if len(obs) > 4000 {
|
||||
obs = obs[:4000] + "\n...(truncated)"
|
||||
}
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react step=%d tool=%s observation_len=%d", traceLogPrefix, step, toolName, len(obs))
|
||||
}
|
||||
|
||||
messages = append(messages, llm.PromptMessage{
|
||||
Role: "tool",
|
||||
ToolCallID: tc.ID,
|
||||
Name: tc.Function.Name,
|
||||
Content: obs,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 达到安全上限仍未得到最终回答
|
||||
o.emitCapabilityGap(chatID, userID, userInput, "react_step_exhausted")
|
||||
return "我尝试了多轮推理与工具调用,但仍未得到稳定结论。请给我更具体的约束或允许我继续尝试。", nil
|
||||
}
|
||||
|
||||
// runLegacyReAct 是旧版基于 JSON 决策解析的 ReAct 循环,作为不支持 tool_calls 的 LLM 的降级方案。
|
||||
func (o *Orchestrator) runLegacyReAct(ctx context.Context, chatID, userID, compressedContext, userInput string, fileCtx filePromptContext, appendFileIDText bool) (string, error) {
|
||||
traceID := logger.TraceIDFromContext(ctx)
|
||||
traceLogPrefix := "trace_id=" + traceID
|
||||
|
||||
systemPrompt := o.buildLegacySystemPrompt(userInput)
|
||||
|
||||
const maxSteps = 20
|
||||
scratchpad := ""
|
||||
|
||||
for step := 1; step <= maxSteps; step++ {
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react step=%d start", traceLogPrefix, step)
|
||||
o.log.Debugf("%s react step=%d scratchpad=%q", traceLogPrefix, step, scratchpad)
|
||||
o.log.Infof("%s legacy react step=%d start", traceLogPrefix, step)
|
||||
}
|
||||
|
||||
// 构造本轮 user prompt:历史上下文 + 用户问题 + 推理记录
|
||||
prompt := strings.Join([]string{
|
||||
"历史上下文:",
|
||||
compressedContext,
|
||||
"",
|
||||
"用户问题:",
|
||||
userInput,
|
||||
"",
|
||||
"文件上下文:",
|
||||
defaultIfEmpty(fileCtx.Summary, "(none)"),
|
||||
"",
|
||||
"当前推理记录(按时间顺序):",
|
||||
scratchpad,
|
||||
"",
|
||||
"请输出你的 JSON 决策。",
|
||||
}, "\n")
|
||||
|
||||
raw, err := o.generateWithOptionalFiles(ctx, systemPrompt, prompt, fileCtx.FileIDs)
|
||||
messages := buildReActMessages(systemPrompt, compressedContext, userInput, fileCtx.Summary, scratchpad)
|
||||
raw, err := o.generateWithOptionalFilesMessages(ctx, messages, fileCtx.FileIDs, appendFileIDText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react step=%d llm_raw=%q", traceLogPrefix, step, raw)
|
||||
}
|
||||
|
||||
// 解析 LLM 返回的 JSON 决策
|
||||
decision, err := parseDecision(raw)
|
||||
if err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("%s react step=%d parse failed err=%v, using raw as final answer", traceLogPrefix, step, err)
|
||||
}
|
||||
// 解析失败时,尝试将原始输出当作直接回答返回
|
||||
o.emitCapabilityGap(chatID, userID, userInput, "react_parse_failed")
|
||||
return strings.TrimSpace(raw), nil
|
||||
}
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react step=%d thought=%q action=%q is_final=%v",
|
||||
traceLogPrefix, step, decision.Thought, decision.Action, decision.IsFinalAnswer)
|
||||
}
|
||||
|
||||
// ========== 判定:是否为最终回答 ==========
|
||||
if decision.IsFinalAnswer {
|
||||
finalText := ""
|
||||
if decision.FinalAnswer != nil {
|
||||
@@ -389,40 +464,26 @@ func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, comp
|
||||
if finalText == "" {
|
||||
finalText = "已完成处理。"
|
||||
}
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react final at step=%d answer=%q", traceLogPrefix, step, finalText)
|
||||
}
|
||||
return finalText, nil
|
||||
}
|
||||
|
||||
// ========== 非最终回答:执行工具调用 ==========
|
||||
action := strings.ToLower(strings.TrimSpace(decision.Action))
|
||||
if action == "" || action == "none" {
|
||||
// LLM 说不是最终回答但也不指定工具,记录后让它再想一轮
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Thought: " + decision.Thought + "\n"
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Observation: 你没有指定要调用的工具,请重新决策:要么调用工具,要么给出最终回答。\n"
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Observation: 你没有指定要调用的工具,请重新决策。\n"
|
||||
continue
|
||||
}
|
||||
|
||||
actionInput := decision.GetActionInputString()
|
||||
|
||||
// 检查工具是否存在
|
||||
tool, ok := o.tools.Get(action)
|
||||
if !ok {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("%s react step=%d tool_not_found=%s", traceLogPrefix, step, action)
|
||||
}
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Thought: " + decision.Thought + "\n"
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Action: " + action + "\n"
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Observation: " + formatToolErrorObservation("TOOL_NOT_FOUND", action, "该工具不存在,可用工具请参阅 system prompt") + "\n"
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Observation: " + formatToolErrorObservation("TOOL_NOT_FOUND", action, "该工具不存在") + "\n"
|
||||
o.emitCapabilityGap(chatID, userID, userInput, "tool_not_found:"+action)
|
||||
continue
|
||||
}
|
||||
|
||||
// 调用工具
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react step=%d tool_call tool=%s input=%q", traceLogPrefix, step, action, actionInput)
|
||||
}
|
||||
toolOut, toolErr := tool.Call(ctx, actionInput)
|
||||
obs := strings.TrimSpace(toolOut)
|
||||
if obs == "" {
|
||||
@@ -432,37 +493,95 @@ func (o *Orchestrator) runUnifiedReAct(ctx context.Context, chatID, userID, comp
|
||||
obs = formatToolErrorObservation("TOOL_EXEC_ERROR", action, toolErr.Error()) + "\nOUTPUT:\n" + obs
|
||||
o.emitCapabilityGap(chatID, userID, userInput, "tool_call_failed:"+action)
|
||||
}
|
||||
// 限制观察值长度防止超出 LLM 上下文窗口
|
||||
if len(obs) > 4000 {
|
||||
obs = obs[:4000] + "\n...(truncated)"
|
||||
}
|
||||
|
||||
if o.log != nil {
|
||||
o.log.Infof("%s react step=%d observation_len=%d", traceLogPrefix, step, len(obs))
|
||||
}
|
||||
|
||||
// 将本轮的思考、行动、观察追加到 scratchpad
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Thought: " + decision.Thought + "\n"
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Action: " + action + "\n"
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " ActionInput: " + actionInput + "\n"
|
||||
scratchpad += "Step " + strconv.Itoa(step) + " Observation: " + obs + "\n"
|
||||
}
|
||||
|
||||
// 达到安全上限仍未得到最终回答
|
||||
o.emitCapabilityGap(chatID, userID, userInput, "react_step_exhausted")
|
||||
return "我尝试了多轮推理与工具调用,但仍未得到稳定结论。请给我更具体的约束或允许我继续尝试。", nil
|
||||
}
|
||||
|
||||
func composeRouteInput(userInput, fileSummary string) string {
|
||||
userInput = strings.TrimSpace(userInput)
|
||||
fileSummary = strings.TrimSpace(fileSummary)
|
||||
if userInput == "" {
|
||||
return fileSummary
|
||||
// buildLegacySystemPrompt 为不支持 tool_calls 的旧版 ReAct 链路构建 system prompt(含 JSON 输出格式约束)。
|
||||
func (o *Orchestrator) buildLegacySystemPrompt(userInput string) string {
|
||||
skillMetaDoc := o.formatSkillSummariesForPrompt()
|
||||
relevantSkillsDoc := o.formatSelectedSkillsForPrompt(userInput, nil)
|
||||
toolDoc := o.formatToolDoc()
|
||||
runtimeDoc := formatRuntimeContextForPrompt()
|
||||
|
||||
return strings.Join([]string{
|
||||
"你是一个个人自动化助手,必须遵循如下人格设定并保持一致:",
|
||||
o.soul,
|
||||
"",
|
||||
"===== 运行环境 =====",
|
||||
runtimeDoc,
|
||||
"",
|
||||
"===== 可用技能概览 =====",
|
||||
skillMetaDoc,
|
||||
"",
|
||||
"===== 本轮相关技能 =====",
|
||||
relevantSkillsDoc,
|
||||
"",
|
||||
"===== 可用工具 =====",
|
||||
toolDoc,
|
||||
"",
|
||||
"===== 输出格式约束 =====",
|
||||
"你必须使用 ReAct 模式进行决策。每次回复必须是且仅是一个 JSON 对象:",
|
||||
"{",
|
||||
" \"thought\": \"你的推理过程(必填)\",",
|
||||
" \"action\": \"要调用的工具名称(不调工具时填 none)\",",
|
||||
" \"action_input\": \"传给工具的输入\",",
|
||||
" \"is_final_answer\": true 或 false,",
|
||||
" \"final_answer\": \"当 is_final_answer=true 时填写给用户的最终回复\"",
|
||||
"}",
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
// buildToolDefinitions 将工具注册表转换为 OpenAI function calling 所需的 ToolDefinition 列表。
|
||||
func (o *Orchestrator) buildToolDefinitions() []llm.ToolDefinition {
|
||||
list := o.tools.List()
|
||||
defs := make([]llm.ToolDefinition, 0, len(list))
|
||||
defaultParams := json.RawMessage(`{"type":"object","properties":{"input":{"type":"string","description":"工具的输入命令或查询内容"}},"required":["input"]}`)
|
||||
|
||||
sort.Slice(list, func(i, j int) bool {
|
||||
return list[i].Name() < list[j].Name()
|
||||
})
|
||||
|
||||
for _, t := range list {
|
||||
defs = append(defs, llm.ToolDefinition{
|
||||
Type: "function",
|
||||
Function: llm.ToolFunctionDef{
|
||||
Name: t.Name(),
|
||||
Description: t.Description(),
|
||||
Parameters: defaultParams,
|
||||
},
|
||||
})
|
||||
}
|
||||
if fileSummary == "" {
|
||||
return userInput
|
||||
return defs
|
||||
}
|
||||
|
||||
// extractToolInput 从 LLM 的 function calling arguments JSON 中提取工具输入字符串。
|
||||
func extractToolInput(arguments string) string {
|
||||
arguments = strings.TrimSpace(arguments)
|
||||
if arguments == "" {
|
||||
return ""
|
||||
}
|
||||
return userInput + "\n\n" + fileSummary
|
||||
var args struct {
|
||||
Input string `json:"input"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(arguments), &args); err != nil {
|
||||
// 降级:直接将 arguments 作为输入
|
||||
return arguments
|
||||
}
|
||||
if args.Input != "" {
|
||||
return args.Input
|
||||
}
|
||||
return arguments
|
||||
}
|
||||
|
||||
func (o *Orchestrator) prepareFilePromptContext(ctx context.Context, files []llm.InputFile, pending []pendingFileRef) filePromptContext {
|
||||
@@ -535,16 +654,85 @@ func buildFileSummary(pending, uploaded []pendingFileRef) string {
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func (o *Orchestrator) generateWithOptionalFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string) (string, error) {
|
||||
func (o *Orchestrator) generateWithOptionalFilesMessages(ctx context.Context, messages []llm.PromptMessage, fileIDs []string, appendFileIDText bool) (string, error) {
|
||||
ids := nonEmptyIDs(fileIDs)
|
||||
if len(ids) == 0 {
|
||||
if client, ok := o.llm.(llm.MessageChatClient); ok {
|
||||
return client.GenerateMessages(ctx, messages)
|
||||
}
|
||||
systemPrompt, userPrompt := fallbackPromptsFromMessages(messages)
|
||||
return o.llm.Generate(ctx, systemPrompt, userPrompt)
|
||||
}
|
||||
if client, ok := o.llm.(llm.FileMessageChatClient); ok {
|
||||
return client.GenerateMessagesWithFiles(ctx, messages, ids, appendFileIDText)
|
||||
}
|
||||
client, ok := o.llm.(llm.FileChatClient)
|
||||
if !ok {
|
||||
systemPrompt, userPrompt := fallbackPromptsFromMessages(messages)
|
||||
return o.llm.Generate(ctx, systemPrompt, userPrompt)
|
||||
}
|
||||
return client.GenerateWithFiles(ctx, systemPrompt, userPrompt, ids)
|
||||
systemPrompt, userPrompt := fallbackPromptsFromMessages(messages)
|
||||
return client.GenerateWithFiles(ctx, systemPrompt, userPrompt, ids, appendFileIDText)
|
||||
}
|
||||
|
||||
func buildReActMessages(systemPrompt, compressedContext, userInput, fileSummary, scratchpad string) []llm.PromptMessage {
|
||||
msgs := make([]llm.PromptMessage, 0, 16)
|
||||
msgs = append(msgs, llm.PromptMessage{Role: "system", Content: systemPrompt})
|
||||
msgs = append(msgs, parseCompressedHistoryMessages(compressedContext)...)
|
||||
|
||||
if strings.TrimSpace(fileSummary) != "" {
|
||||
msgs = append(msgs, llm.PromptMessage{Role: "assistant", Content: "文件上下文摘要:\n" + strings.TrimSpace(fileSummary)})
|
||||
}
|
||||
if strings.TrimSpace(scratchpad) != "" {
|
||||
msgs = append(msgs, llm.PromptMessage{Role: "assistant", Content: "推理记录:\n" + strings.TrimSpace(scratchpad)})
|
||||
}
|
||||
msgs = append(msgs, llm.PromptMessage{Role: "user", Content: userInput})
|
||||
return msgs
|
||||
}
|
||||
|
||||
func parseCompressedHistoryMessages(compressed string) []llm.PromptMessage {
|
||||
compressed = strings.TrimSpace(compressed)
|
||||
if compressed == "" {
|
||||
return nil
|
||||
}
|
||||
lines := strings.Split(compressed, "\n")
|
||||
out := make([]llm.PromptMessage, 0, len(lines))
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
idx := strings.Index(line, ":")
|
||||
if idx <= 0 {
|
||||
out = append(out, llm.PromptMessage{Role: "assistant", Content: line})
|
||||
continue
|
||||
}
|
||||
role := strings.ToLower(strings.TrimSpace(line[:idx]))
|
||||
content := strings.TrimSpace(line[idx+1:])
|
||||
if role != "system" && role != "user" && role != "assistant" {
|
||||
role = "assistant"
|
||||
}
|
||||
out = append(out, llm.PromptMessage{Role: role, Content: content})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func fallbackPromptsFromMessages(messages []llm.PromptMessage) (string, string) {
|
||||
sysParts := make([]string, 0, 2)
|
||||
userParts := make([]string, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
role := strings.ToLower(strings.TrimSpace(m.Role))
|
||||
content := strings.TrimSpace(m.Content)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
if role == "system" {
|
||||
sysParts = append(sysParts, content)
|
||||
continue
|
||||
}
|
||||
userParts = append(userParts, role+": "+content)
|
||||
}
|
||||
return strings.Join(sysParts, "\n\n"), strings.Join(userParts, "\n")
|
||||
}
|
||||
|
||||
func (o *Orchestrator) buildFileUploadAck(ctx filePromptContext) string {
|
||||
@@ -670,180 +858,6 @@ func (o *Orchestrator) formatSelectedSkillsForPrompt(userInput string, selected
|
||||
return formatSkills(skills)
|
||||
}
|
||||
|
||||
func (o *Orchestrator) routeCapabilities(ctx context.Context, userInput string) capabilityRoutingResult {
|
||||
fallback := capabilityRoutingResult{
|
||||
NeedSkills: true,
|
||||
SelectedSkills: o.selectRelevantSkills(userInput, 4),
|
||||
Reason: "router fallback: keyword matching",
|
||||
UsedFallback: true,
|
||||
}
|
||||
|
||||
raw, err := o.llm.Generate(ctx, o.buildRouteSystemPrompt(), o.buildRouteUserPrompt(userInput))
|
||||
if err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("capability router llm call failed err=%v", err)
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
decision, err := parseCapabilityRoute(raw)
|
||||
if err != nil {
|
||||
if o.log != nil {
|
||||
o.log.Warnf("capability router parse failed err=%v raw=%q", err, raw)
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
resolvedTools := o.normalizeToolSelection(decision.SelectedTools)
|
||||
resolved := capabilityRoutingResult{
|
||||
NeedSkills: decision.NeedSkills,
|
||||
SelectedToolNames: resolvedTools,
|
||||
Reason: strings.TrimSpace(decision.Reason),
|
||||
}
|
||||
|
||||
if resolved.NeedSkills {
|
||||
skills := o.resolveSkillsByNames(decision.SelectedSkills, 4)
|
||||
if len(skills) == 0 {
|
||||
skills = o.selectRelevantSkills(userInput, 4)
|
||||
resolved.UsedFallback = true
|
||||
}
|
||||
resolved.SelectedSkills = skills
|
||||
}
|
||||
|
||||
return resolved
|
||||
}
|
||||
|
||||
func (o *Orchestrator) buildRouteSystemPrompt() string {
|
||||
return strings.Join([]string{
|
||||
"你是能力路由器(Router Agent)。",
|
||||
"你的任务是:在不加载技能全文的前提下,仅根据工具摘要和技能摘要,判断本请求是否可以仅靠原子工具能力完成,还是需要加载技能详细说明。",
|
||||
"输出必须且仅能是 JSON:",
|
||||
"{",
|
||||
" \"need_skills\": true 或 false,",
|
||||
" \"selected_tools\": [\"tool_name\", ...],",
|
||||
" \"selected_skills\": [\"skill_name\", ...],",
|
||||
" \"reason\": \"简短路由理由\"",
|
||||
"}",
|
||||
"规则:",
|
||||
"1) 优先原子工具能力。若可通过工具链路完成,need_skills=false。",
|
||||
"2) 只有当工具能力不足以覆盖业务约束时,need_skills=true 并选择少量最相关技能。",
|
||||
"3) selected_skills 仅填写技能名称(来自技能摘要)。",
|
||||
"4) selected_tools 仅填写可用工具名。",
|
||||
"5) 不要输出 JSON 之外内容。",
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
func (o *Orchestrator) buildRouteUserPrompt(userInput string) string {
|
||||
return strings.Join([]string{
|
||||
"当前运行环境:",
|
||||
formatRuntimeContextForPrompt(),
|
||||
"",
|
||||
"用户问题:",
|
||||
userInput,
|
||||
"",
|
||||
"可用工具摘要:",
|
||||
o.formatToolDoc(),
|
||||
"",
|
||||
"可用技能摘要:",
|
||||
o.formatSkillSummariesForPrompt(),
|
||||
"",
|
||||
"请给出路由 JSON。",
|
||||
}, "\n")
|
||||
}
|
||||
|
||||
func (o *Orchestrator) normalizeToolSelection(in []string) []string {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
allowed := map[string]struct{}{}
|
||||
for _, t := range o.tools.List() {
|
||||
allowed[strings.ToLower(strings.TrimSpace(t.Name()))] = struct{}{}
|
||||
}
|
||||
out := make([]string, 0, len(in))
|
||||
set := map[string]struct{}{}
|
||||
for _, name := range in {
|
||||
n := strings.ToLower(strings.TrimSpace(name))
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := allowed[n]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := set[n]; exists {
|
||||
continue
|
||||
}
|
||||
set[n] = struct{}{}
|
||||
out = append(out, n)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
func (o *Orchestrator) resolveSkillsByNames(names []string, maxCount int) []knowledge.Skill {
|
||||
if len(names) == 0 {
|
||||
return nil
|
||||
}
|
||||
if maxCount <= 0 {
|
||||
maxCount = 4
|
||||
}
|
||||
all := o.getSkillsSnapshot()
|
||||
idx := make(map[string]knowledge.Skill, len(all))
|
||||
for _, sk := range all {
|
||||
key := strings.ToLower(strings.TrimSpace(sk.Name))
|
||||
if key != "" {
|
||||
idx[key] = sk
|
||||
}
|
||||
}
|
||||
out := make([]knowledge.Skill, 0, maxCount)
|
||||
used := map[string]struct{}{}
|
||||
for _, name := range names {
|
||||
key := strings.ToLower(strings.TrimSpace(name))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
sk, ok := idx[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, exists := used[key]; exists {
|
||||
continue
|
||||
}
|
||||
used[key] = struct{}{}
|
||||
out = append(out, sk)
|
||||
if len(out) >= maxCount {
|
||||
break
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func formatRouteForPrompt(route capabilityRoutingResult) string {
|
||||
b := strings.Builder{}
|
||||
if route.UsedFallback {
|
||||
b.WriteString("router_status: fallback\n")
|
||||
} else {
|
||||
b.WriteString("router_status: ok\n")
|
||||
}
|
||||
b.WriteString("need_skills: ")
|
||||
b.WriteString(strconv.FormatBool(route.NeedSkills))
|
||||
b.WriteString("\n")
|
||||
b.WriteString("selected_tools: ")
|
||||
if len(route.SelectedToolNames) == 0 {
|
||||
b.WriteString("(none)")
|
||||
} else {
|
||||
b.WriteString(strings.Join(route.SelectedToolNames, ", "))
|
||||
}
|
||||
b.WriteString("\n")
|
||||
b.WriteString("selected_skill_count: ")
|
||||
b.WriteString(strconv.Itoa(len(route.SelectedSkills)))
|
||||
b.WriteString("\n")
|
||||
if strings.TrimSpace(route.Reason) != "" {
|
||||
b.WriteString("reason: ")
|
||||
b.WriteString(strings.TrimSpace(route.Reason))
|
||||
}
|
||||
return strings.TrimSpace(b.String())
|
||||
}
|
||||
|
||||
func (o *Orchestrator) selectRelevantSkills(userInput string, maxCount int) []knowledge.Skill {
|
||||
if maxCount <= 0 {
|
||||
maxCount = 4
|
||||
|
||||
@@ -5,28 +5,83 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"laodingbot/internal/config"
|
||||
"laodingbot/internal/logger"
|
||||
|
||||
openai "github.com/openai/openai-go" // imported as openai
|
||||
"github.com/openai/openai-go/option"
|
||||
"github.com/openai/openai-go/packages/param"
|
||||
"github.com/openai/openai-go/shared"
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
Generate(ctx context.Context, systemPrompt, userPrompt string) (string, error)
|
||||
}
|
||||
|
||||
type PromptMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
type MessageChatClient interface {
|
||||
GenerateMessages(ctx context.Context, messages []PromptMessage) (string, error)
|
||||
}
|
||||
|
||||
type FileChatClient interface {
|
||||
GenerateWithFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string) (string, error)
|
||||
GenerateWithFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string, appendFileIDText bool) (string, error)
|
||||
}
|
||||
|
||||
type FileMessageChatClient interface {
|
||||
GenerateMessagesWithFiles(ctx context.Context, messages []PromptMessage, fileIDs []string, appendFileIDText bool) (string, error)
|
||||
}
|
||||
|
||||
type FileUploader interface {
|
||||
UploadFile(ctx context.Context, file InputFile, purpose string) (string, error)
|
||||
}
|
||||
|
||||
// ToolCallChatClient 支持原生 function calling 的 LLM 客户端接口。
|
||||
type ToolCallChatClient interface {
|
||||
GenerateWithTools(ctx context.Context, messages []PromptMessage, tools []ToolDefinition, fileIDs []string, appendFileIDText bool) (*ChatCompletion, error)
|
||||
}
|
||||
|
||||
// ToolDefinition 描述一个可供 LLM 调用的工具函数定义。
|
||||
type ToolDefinition struct {
|
||||
Type string `json:"type"`
|
||||
Function ToolFunctionDef `json:"function"`
|
||||
}
|
||||
|
||||
// ToolFunctionDef 是工具函数的名称、描述和参数 JSON Schema。
|
||||
type ToolFunctionDef struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters json.RawMessage `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCall 是 LLM 在响应中返回的工具调用请求。
|
||||
type ToolCall struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Function ToolCallFunction `json:"function"`
|
||||
}
|
||||
|
||||
// ToolCallFunction 包含工具调用的函数名和参数。
|
||||
type ToolCallFunction struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
// ChatCompletion 是 LLM 响应的结构化表示,包含文本内容和可选的工具调用。
|
||||
type ChatCompletion struct {
|
||||
Content string
|
||||
ToolCalls []ToolCall
|
||||
}
|
||||
|
||||
type InputFile struct {
|
||||
FileName string
|
||||
MimeType string
|
||||
@@ -34,206 +89,312 @@ type InputFile struct {
|
||||
}
|
||||
|
||||
type OpenAICompatibleClient struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
client openai.Client
|
||||
model string
|
||||
fileModel string
|
||||
filePromptMode string
|
||||
http *http.Client
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func NewOpenAICompatibleClient(cfg config.LLMConfig, log *logger.Logger) *OpenAICompatibleClient {
|
||||
opts := []option.RequestOption{
|
||||
option.WithAPIKey(cfg.APIKey),
|
||||
option.WithRequestTimeout(60 * time.Second),
|
||||
}
|
||||
if strings.TrimSpace(cfg.BaseURL) != "" {
|
||||
opts = append(opts, option.WithBaseURL(cfg.BaseURL))
|
||||
}
|
||||
return &OpenAICompatibleClient{
|
||||
baseURL: cfg.BaseURL,
|
||||
apiKey: cfg.APIKey,
|
||||
client: openai.NewClient(opts...),
|
||||
model: cfg.Model,
|
||||
fileModel: cfg.FileModel,
|
||||
filePromptMode: cfg.FilePromptMode,
|
||||
http: &http.Client{Timeout: 60 * time.Second},
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
type chatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []chatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type chatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type chatContentPart struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
FileID string `json:"file_id,omitempty"`
|
||||
}
|
||||
|
||||
type chatResponse struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type fileUploadResponse struct {
|
||||
ID string `json:"id"`
|
||||
Bytes int64 `json:"bytes,omitempty"`
|
||||
CreatedAt int64 `json:"created_at,omitempty"`
|
||||
Filename string `json:"filename,omitempty"`
|
||||
Object string `json:"object,omitempty"`
|
||||
Purpose string `json:"purpose,omitempty"`
|
||||
Code int `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Status any `json:"status,omitempty"`
|
||||
StatusDetails any `json:"status_details,omitempty"`
|
||||
Data *struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data,omitempty"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) Generate(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
|
||||
return c.generateInternal(ctx, systemPrompt, userPrompt, nil)
|
||||
messages := []PromptMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userPrompt},
|
||||
}
|
||||
return c.generateWithMessagesInternal(ctx, messages, nil, false)
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) GenerateWithFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string) (string, error) {
|
||||
return c.generateInternal(ctx, systemPrompt, userPrompt, fileIDs)
|
||||
func (c *OpenAICompatibleClient) GenerateWithFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string, appendFileIDText bool) (string, error) {
|
||||
messages := []PromptMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userPrompt},
|
||||
}
|
||||
return c.generateWithMessagesInternal(ctx, messages, fileIDs, appendFileIDText)
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) generateInternal(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string) (string, error) {
|
||||
func (c *OpenAICompatibleClient) GenerateMessages(ctx context.Context, messages []PromptMessage) (string, error) {
|
||||
return c.generateWithMessagesInternal(ctx, messages, nil, false)
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) GenerateMessagesWithFiles(ctx context.Context, messages []PromptMessage, fileIDs []string, appendFileIDText bool) (string, error) {
|
||||
return c.generateWithMessagesInternal(ctx, messages, fileIDs, appendFileIDText)
|
||||
}
|
||||
|
||||
// GenerateWithTools 使用原生 function calling 发送请求,返回结构化的 ChatCompletion。
|
||||
func (c *OpenAICompatibleClient) GenerateWithTools(ctx context.Context, messages []PromptMessage, tools []ToolDefinition, fileIDs []string, appendFileIDText bool) (*ChatCompletion, error) {
|
||||
model := c.model
|
||||
ids := nonEmptyIDs(fileIDs)
|
||||
if len(ids) > 0 {
|
||||
if strings.TrimSpace(c.fileModel) != "" {
|
||||
model = c.fileModel
|
||||
}
|
||||
if len(ids) > 0 && strings.TrimSpace(c.fileModel) != "" {
|
||||
model = c.fileModel
|
||||
}
|
||||
|
||||
sdkMessages := buildSDKMessages(messages, ids, c.normalizedFilePromptMode(), appendFileIDText)
|
||||
sdkTools := toSDKTools(tools)
|
||||
|
||||
if c.log != nil {
|
||||
c.log.Debugf("llm request start model=%s system_len=%d user_len=%d file_count=%d file_prompt_mode=%s", model, len(systemPrompt), len(userPrompt), len(ids), c.normalizedFilePromptMode())
|
||||
c.log.Debugf("llm tool-call request start model=%s messages=%d tools=%d files=%d", model, len(sdkMessages), len(sdkTools), len(ids))
|
||||
}
|
||||
messages := buildMessages(systemPrompt, userPrompt, ids, c.normalizedFilePromptMode())
|
||||
body := chatRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
|
||||
params := openai.ChatCompletionNewParams{
|
||||
Model: shared.ChatModel(model),
|
||||
Messages: sdkMessages,
|
||||
}
|
||||
b, err := json.Marshal(body)
|
||||
if len(sdkTools) > 0 {
|
||||
params.Tools = sdkTools
|
||||
}
|
||||
|
||||
if c.log != nil {
|
||||
if b, err := json.Marshal(params); err == nil {
|
||||
c.log.Debugf("llm tool-call request params: %s", string(b))
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := c.client.Chat.Completions.New(ctx, params)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("llm tool-call request failed: %w", err)
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return nil, fmt.Errorf("llm returned empty choices")
|
||||
}
|
||||
|
||||
choice := resp.Choices[0]
|
||||
resultToolCalls := fromSDKToolCalls(choice.Message.ToolCalls)
|
||||
if c.log != nil {
|
||||
c.log.Infof("llm tool-call response success model=%s content_len=%d tool_calls=%d finish=%s",
|
||||
model, len(choice.Message.Content), len(resultToolCalls), choice.FinishReason)
|
||||
}
|
||||
|
||||
return &ChatCompletion{
|
||||
Content: choice.Message.Content,
|
||||
ToolCalls: resultToolCalls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) generateWithMessagesInternal(ctx context.Context, messages []PromptMessage, fileIDs []string, appendFileIDText bool) (string, error) {
|
||||
model := c.model
|
||||
ids := nonEmptyIDs(fileIDs)
|
||||
if len(ids) > 0 && strings.TrimSpace(c.fileModel) != "" {
|
||||
model = c.fileModel
|
||||
}
|
||||
|
||||
baseMessages := normalizePromptMessages(messages)
|
||||
if len(baseMessages) == 0 {
|
||||
baseMessages = []PromptMessage{{Role: "user", Content: ""}}
|
||||
}
|
||||
|
||||
systemLen, userLen := promptMessageLengths(baseMessages)
|
||||
if c.log != nil {
|
||||
c.log.Debugf("llm request start model=%s system_len=%d user_len=%d file_count=%d file_prompt_mode=%s", model, systemLen, userLen, len(ids), c.normalizedFilePromptMode())
|
||||
}
|
||||
|
||||
sdkMessages := buildSDKMessages(baseMessages, ids, c.normalizedFilePromptMode(), appendFileIDText)
|
||||
|
||||
params := openai.ChatCompletionNewParams{
|
||||
Model: shared.ChatModel(model),
|
||||
Messages: sdkMessages,
|
||||
}
|
||||
|
||||
resp, err := c.client.Chat.Completions.New(ctx, params)
|
||||
if err != nil {
|
||||
if c.log != nil {
|
||||
c.log.Errorf("marshal llm request failed err=%v", err)
|
||||
c.log.Errorf("llm request failed err=%v", err)
|
||||
}
|
||||
return "", err
|
||||
return "", fmt.Errorf("llm request failed: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(c.baseURL, "/") + "/chat/completions"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
if len(resp.Choices) == 0 {
|
||||
if c.log != nil {
|
||||
c.log.Errorf("build llm request failed err=%v", err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
if c.log != nil {
|
||||
c.log.Errorf("llm http request failed err=%v", err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
raw, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
if c.log != nil {
|
||||
c.log.Errorf("llm read response failed err=%v", err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
var out chatResponse
|
||||
if err := json.Unmarshal(raw, &out); err != nil {
|
||||
if c.log != nil {
|
||||
c.log.Errorf("llm response unmarshal failed status=%d err=%v", resp.StatusCode, err)
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
if c.log != nil {
|
||||
c.log.Errorf("llm bad status=%d", resp.StatusCode)
|
||||
}
|
||||
if out.Error != nil && out.Error.Message != "" {
|
||||
return "", fmt.Errorf("llm error: %s", out.Error.Message)
|
||||
}
|
||||
return "", fmt.Errorf("llm error status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
if len(out.Choices) == 0 {
|
||||
if c.log != nil {
|
||||
c.log.Errorf("llm returned empty choices status=%d", resp.StatusCode)
|
||||
c.log.Errorf("llm returned empty choices")
|
||||
}
|
||||
return "", fmt.Errorf("llm returned empty choices")
|
||||
}
|
||||
|
||||
content := resp.Choices[0].Message.Content
|
||||
if c.log != nil {
|
||||
c.log.Infof("llm response success model=%s output_len=%d", model, len(out.Choices[0].Message.Content))
|
||||
c.log.Infof("llm response success model=%s output_len=%d", model, len(content))
|
||||
}
|
||||
|
||||
return out.Choices[0].Message.Content, nil
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func buildMessages(systemPrompt, userPrompt string, fileIDs []string, mode string) []chatMessage {
|
||||
// buildSDKMessages 将 PromptMessage 列表转换为 openai SDK 的消息格式,并注入 file_id(如需要)。
|
||||
func buildSDKMessages(base []PromptMessage, fileIDs []string, mode string, appendFileIDText bool) []openai.ChatCompletionMessageParamUnion {
|
||||
mode = strings.ToLower(strings.TrimSpace(mode))
|
||||
if mode == "system_fileid_uri" {
|
||||
msgs := []chatMessage{{Role: "system", Content: systemPrompt}}
|
||||
for _, id := range fileIDs {
|
||||
if strings.TrimSpace(id) == "" {
|
||||
continue
|
||||
}
|
||||
msgs = append(msgs, chatMessage{Role: "system", Content: "fileid://" + strings.TrimSpace(id)})
|
||||
out := make([]openai.ChatCompletionMessageParamUnion, 0, len(base)+2)
|
||||
|
||||
for _, m := range base {
|
||||
role := normalizeRole(m.Role)
|
||||
if role == "" {
|
||||
continue
|
||||
}
|
||||
msgs = append(msgs, chatMessage{Role: "user", Content: userPrompt})
|
||||
return msgs
|
||||
out = append(out, toSDKMessage(m, role))
|
||||
}
|
||||
userContent := buildUserContent(userPrompt, fileIDs)
|
||||
return []chatMessage{
|
||||
{Role: "system", Content: systemPrompt},
|
||||
{Role: "user", Content: userContent},
|
||||
|
||||
if len(fileIDs) == 0 {
|
||||
return out
|
||||
}
|
||||
|
||||
if appendFileIDText {
|
||||
// WebUI 场景:将首个 fileID 作为 text part 追加到最后一个 user 消息。
|
||||
firstFileID := strings.TrimSpace(fileIDs[0])
|
||||
if firstFileID == "" {
|
||||
return out
|
||||
}
|
||||
for i := len(out) - 1; i >= 0; i-- {
|
||||
if r := out[i].GetRole(); r != nil && *r == "user" {
|
||||
out[i] = buildUserMessageWithFileIDText(out[i], firstFileID)
|
||||
return out
|
||||
}
|
||||
}
|
||||
out = append(out, buildUserMessageWithFileIDText(openai.UserMessage(""), firstFileID))
|
||||
return out
|
||||
}
|
||||
|
||||
// 非 WebUI 场景:保持原有 file content part 方式。
|
||||
for i := len(out) - 1; i >= 0; i-- {
|
||||
if r := out[i].GetRole(); r != nil && *r == "user" {
|
||||
out[i] = buildUserMessageWithFiles(out[i], fileIDs)
|
||||
return out
|
||||
}
|
||||
}
|
||||
out = append(out, buildUserMessageWithFiles(openai.UserMessage(""), fileIDs))
|
||||
return out
|
||||
}
|
||||
|
||||
// toSDKMessage 将单个 PromptMessage 转换为 openai SDK 消息类型。
|
||||
func toSDKMessage(m PromptMessage, role string) openai.ChatCompletionMessageParamUnion {
|
||||
switch role {
|
||||
case "system":
|
||||
return openai.SystemMessage(m.Content)
|
||||
case "user":
|
||||
return openai.UserMessage(m.Content)
|
||||
case "assistant":
|
||||
if len(m.ToolCalls) > 0 {
|
||||
sdkToolCalls := make([]openai.ChatCompletionMessageToolCallParam, 0, len(m.ToolCalls))
|
||||
for _, tc := range m.ToolCalls {
|
||||
sdkToolCalls = append(sdkToolCalls, openai.ChatCompletionMessageToolCallParam{
|
||||
ID: tc.ID,
|
||||
Function: openai.ChatCompletionMessageToolCallFunctionParam{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
msg := openai.AssistantMessage(m.Content)
|
||||
msg.OfAssistant.ToolCalls = sdkToolCalls
|
||||
return msg
|
||||
}
|
||||
return openai.AssistantMessage(m.Content)
|
||||
case "tool":
|
||||
return openai.ToolMessage(m.Content, m.ToolCallID)
|
||||
default:
|
||||
return openai.UserMessage(m.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func buildUserContent(userPrompt string, fileIDs []string) any {
|
||||
trimmedPrompt := strings.TrimSpace(userPrompt)
|
||||
if len(fileIDs) == 0 {
|
||||
return userPrompt
|
||||
// buildUserMessageWithFileIDText 为 user 消息追加一个 text part,内容为 fileID。
|
||||
func buildUserMessageWithFileIDText(msg openai.ChatCompletionMessageParamUnion, fileID string) openai.ChatCompletionMessageParamUnion {
|
||||
// 提取已有的文本内容
|
||||
text := ""
|
||||
if s, ok := msg.GetContent().AsAny().(*string); ok && s != nil {
|
||||
text = *s
|
||||
}
|
||||
fileID = strings.TrimSpace(fileID)
|
||||
if fileID == "" {
|
||||
return msg
|
||||
}
|
||||
|
||||
parts := make([]chatContentPart, 0, len(fileIDs)+1)
|
||||
if trimmedPrompt != "" {
|
||||
parts = append(parts, chatContentPart{Type: "text", Text: userPrompt})
|
||||
parts := make([]openai.ChatCompletionContentPartUnionParam, 0, 2)
|
||||
if strings.TrimSpace(text) != "" {
|
||||
parts = append(parts, openai.TextContentPart(text))
|
||||
}
|
||||
parts = append(parts, openai.TextContentPart(fileID))
|
||||
if len(parts) == 0 {
|
||||
return msg
|
||||
}
|
||||
return openai.UserMessage(parts)
|
||||
}
|
||||
|
||||
// buildUserMessageWithFiles 为 user 消息追加 file content parts。
|
||||
func buildUserMessageWithFiles(msg openai.ChatCompletionMessageParamUnion, fileIDs []string) openai.ChatCompletionMessageParamUnion {
|
||||
text := ""
|
||||
if s, ok := msg.GetContent().AsAny().(*string); ok && s != nil {
|
||||
text = *s
|
||||
}
|
||||
|
||||
parts := make([]openai.ChatCompletionContentPartUnionParam, 0, len(fileIDs)+1)
|
||||
if strings.TrimSpace(text) != "" {
|
||||
parts = append(parts, openai.TextContentPart(text))
|
||||
}
|
||||
for _, id := range fileIDs {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, chatContentPart{Type: "file", FileID: id})
|
||||
parts = append(parts, openai.FileContentPart(openai.ChatCompletionContentPartFileFileParam{FileID: param.NewOpt(id)}))
|
||||
}
|
||||
if len(parts) == 0 {
|
||||
return userPrompt
|
||||
return msg
|
||||
}
|
||||
return parts
|
||||
return openai.UserMessage(parts)
|
||||
}
|
||||
|
||||
// toSDKTools 将内部 ToolDefinition 列表转换为 openai SDK 的 ChatCompletionToolParam 列表。
|
||||
func toSDKTools(tools []ToolDefinition) []openai.ChatCompletionToolParam {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]openai.ChatCompletionToolParam, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
var params shared.FunctionParameters
|
||||
if len(t.Function.Parameters) > 0 {
|
||||
_ = json.Unmarshal(t.Function.Parameters, ¶ms)
|
||||
}
|
||||
out = append(out, openai.ChatCompletionToolParam{
|
||||
Function: shared.FunctionDefinitionParam{
|
||||
Name: t.Function.Name,
|
||||
Description: param.NewOpt(t.Function.Description),
|
||||
Parameters: params,
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// fromSDKToolCalls 将 openai SDK 响应中的 tool calls 转换为内部 ToolCall 类型。
|
||||
func fromSDKToolCalls(sdkCalls []openai.ChatCompletionMessageToolCall) []ToolCall {
|
||||
if len(sdkCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]ToolCall, 0, len(sdkCalls))
|
||||
for _, tc := range sdkCalls {
|
||||
out = append(out, ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: "function",
|
||||
Function: ToolCallFunction{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) UploadFile(ctx context.Context, file InputFile, purpose string) (string, error) {
|
||||
@@ -248,7 +409,6 @@ func (c *OpenAICompatibleClient) UploadFile(ctx context.Context, file InputFile,
|
||||
if purpose != "" {
|
||||
purposes = append(purposes, purpose)
|
||||
}
|
||||
// Provider compatibility fallback order.
|
||||
purposes = appendIfMissing(purposes, "file-extract")
|
||||
purposes = appendIfMissing(purposes, "batch")
|
||||
|
||||
@@ -270,77 +430,24 @@ func (c *OpenAICompatibleClient) UploadFile(ctx context.Context, file InputFile,
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) uploadFileOnce(ctx context.Context, file InputFile, purpose string) (string, error) {
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
if err := writer.WriteField("purpose", purpose); err != nil {
|
||||
return "", err
|
||||
}
|
||||
part, err := writer.CreateFormFile("file", file.FileName)
|
||||
resp, err := c.client.Files.New(ctx, openai.FileNewParams{
|
||||
File: bytes.NewReader(file.Content),
|
||||
Purpose: openai.FilePurpose(purpose),
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := part.Write(file.Content); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("llm file upload failed: %w", err)
|
||||
}
|
||||
|
||||
url := strings.TrimRight(c.baseURL, "/") + "/files"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
raw, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var out fileUploadResponse
|
||||
if err := json.Unmarshal(raw, &out); err != nil {
|
||||
return "", fmt.Errorf("llm file upload response decode failed: %w body=%s", err, clipForError(raw))
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
if strings.TrimSpace(out.Message) != "" {
|
||||
return "", fmt.Errorf("llm file upload error: %s", out.Message)
|
||||
}
|
||||
if out.Error != nil && out.Error.Message != "" {
|
||||
return "", fmt.Errorf("llm file upload error: %s", out.Error.Message)
|
||||
}
|
||||
return "", fmt.Errorf("llm file upload status: %d body=%s", resp.StatusCode, clipForError(raw))
|
||||
}
|
||||
fileID := strings.TrimSpace(out.ID)
|
||||
if fileID == "" && out.Data != nil {
|
||||
fileID = strings.TrimSpace(out.Data.ID)
|
||||
}
|
||||
fileID := strings.TrimSpace(resp.ID)
|
||||
if fileID == "" {
|
||||
return "", fmt.Errorf("llm file upload returned empty file id body=%s", clipForError(raw))
|
||||
return "", fmt.Errorf("llm file upload returned empty file id")
|
||||
}
|
||||
if c.log != nil {
|
||||
c.log.Infof("llm file uploaded name=%s size=%d file_id=%s purpose=%s status=%v", file.FileName, len(file.Content), fileID, purpose, out.Status)
|
||||
c.log.Infof("llm file uploaded name=%s size=%d file_id=%s purpose=%s", file.FileName, len(file.Content), fileID, purpose)
|
||||
}
|
||||
return fileID, nil
|
||||
}
|
||||
|
||||
func clipForError(raw []byte) string {
|
||||
s := strings.TrimSpace(string(raw))
|
||||
const max = 400
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max] + "...(truncated)"
|
||||
}
|
||||
|
||||
func appendIfMissing(items []string, value string) []string {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
@@ -374,6 +481,46 @@ func nonEmptyIDs(ids []string) []string {
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizePromptMessages(messages []PromptMessage) []PromptMessage {
|
||||
out := make([]PromptMessage, 0, len(messages))
|
||||
for _, m := range messages {
|
||||
role := normalizeRole(m.Role)
|
||||
if role == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, PromptMessage{
|
||||
Role: role,
|
||||
Content: m.Content,
|
||||
ToolCalls: m.ToolCalls,
|
||||
ToolCallID: m.ToolCallID,
|
||||
Name: m.Name,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeRole(role string) string {
|
||||
r := strings.ToLower(strings.TrimSpace(role))
|
||||
if r != "system" && r != "user" && r != "assistant" && r != "tool" {
|
||||
return ""
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func promptMessageLengths(messages []PromptMessage) (int, int) {
|
||||
systemLen := 0
|
||||
userLen := 0
|
||||
for _, m := range messages {
|
||||
switch normalizeRole(m.Role) {
|
||||
case "system":
|
||||
systemLen += len(m.Content)
|
||||
case "user":
|
||||
userLen += len(m.Content)
|
||||
}
|
||||
}
|
||||
return systemLen, userLen
|
||||
}
|
||||
|
||||
func (c *OpenAICompatibleClient) normalizedFilePromptMode() string {
|
||||
mode := strings.ToLower(strings.TrimSpace(c.filePromptMode))
|
||||
if mode == "system_fileid" || mode == "system_fileid_url" || mode == "system_fileid_uri" {
|
||||
|
||||
@@ -18,9 +18,10 @@ import (
|
||||
)
|
||||
|
||||
type IncomingMessage struct {
|
||||
ChatID string
|
||||
UserID string
|
||||
Text string
|
||||
ChatID string
|
||||
UserID string
|
||||
Text string
|
||||
FileIDs []string
|
||||
}
|
||||
|
||||
type ChatHandler func(context.Context, IncomingMessage) (string, error)
|
||||
@@ -37,9 +38,41 @@ type Bot struct {
|
||||
}
|
||||
|
||||
type chatRequest struct {
|
||||
Text string `json:"text"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Text string `json:"text"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
FileIDs []string `json:"file_ids"`
|
||||
}
|
||||
|
||||
func (r *chatRequest) UnmarshalJSON(data []byte) error {
|
||||
type rawChatRequest struct {
|
||||
Text string `json:"text"`
|
||||
SessionID string `json:"session_id"`
|
||||
SessionIDCamel string `json:"sessionId"`
|
||||
UserID string `json:"user_id"`
|
||||
UserIDCamel string `json:"userId"`
|
||||
FileIDs json.RawMessage `json:"file_ids"`
|
||||
FileIDsCamel json.RawMessage `json:"fileIds"`
|
||||
FileIDsFlat json.RawMessage `json:"fileids"`
|
||||
FileID json.RawMessage `json:"file_id"`
|
||||
}
|
||||
|
||||
var raw rawChatRequest
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Text = raw.Text
|
||||
r.SessionID = firstNonEmpty(raw.SessionID, raw.SessionIDCamel)
|
||||
r.UserID = firstNonEmpty(raw.UserID, raw.UserIDCamel)
|
||||
|
||||
rawIDs := firstNonEmptyRaw(raw.FileIDs, raw.FileIDsCamel, raw.FileIDsFlat, raw.FileID)
|
||||
ids, err := decodeStringList(rawIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.FileIDs = ids
|
||||
return nil
|
||||
}
|
||||
|
||||
type chatResponse struct {
|
||||
@@ -158,9 +191,10 @@ func (b *Bot) handleChat(w http.ResponseWriter, r *http.Request) {
|
||||
userID := b.resolveID(req.UserID, "user")
|
||||
|
||||
reply, err := b.chatHandler(r.Context(), IncomingMessage{
|
||||
ChatID: sessionID,
|
||||
UserID: userID,
|
||||
Text: req.Text,
|
||||
ChatID: sessionID,
|
||||
UserID: userID,
|
||||
Text: req.Text,
|
||||
FileIDs: req.FileIDs,
|
||||
})
|
||||
if err != nil {
|
||||
if b.log != nil {
|
||||
@@ -176,6 +210,65 @@ func (b *Bot) handleChat(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
func decodeStringList(raw json.RawMessage) ([]string, error) {
|
||||
if len(raw) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var list []string
|
||||
if err := json.Unmarshal(raw, &list); err == nil {
|
||||
return nonEmptyIDs(list), nil
|
||||
}
|
||||
|
||||
var single string
|
||||
if err := json.Unmarshal(raw, &single); err == nil {
|
||||
if strings.TrimSpace(single) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return nonEmptyIDs(strings.Split(single, ",")), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid file ids format")
|
||||
}
|
||||
|
||||
func firstNonEmptyRaw(vals ...json.RawMessage) json.RawMessage {
|
||||
for _, v := range vals {
|
||||
if len(v) > 0 {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func firstNonEmpty(vals ...string) string {
|
||||
for _, v := range vals {
|
||||
if strings.TrimSpace(v) != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func nonEmptyIDs(ids []string) []string {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(ids))
|
||||
seen := map[string]struct{}{}
|
||||
for _, id := range ids {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (b *Bot) handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"})
|
||||
|
||||
@@ -51,6 +51,66 @@ func TestHandleChatSuccess(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatWithFileIDs(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.chatHandler = func(_ context.Context, msg IncomingMessage) (string, error) {
|
||||
if msg.ChatID != "s1" || msg.UserID != "u1" || msg.Text != "hello" {
|
||||
t.Fatalf("unexpected message: %+v", msg)
|
||||
}
|
||||
if len(msg.FileIDs) != 2 || msg.FileIDs[0] != "file_a" || msg.FileIDs[1] != "file_b" {
|
||||
t.Fatalf("unexpected file ids: %+v", msg.FileIDs)
|
||||
}
|
||||
return "ok", nil
|
||||
}
|
||||
|
||||
body := strings.NewReader(`{"text":"hello","session_id":"s1","user_id":"u1","file_ids":["file_a","file_b"]}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChat(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatWithFileIDsAliases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{name: "camel array", body: `{"text":"hello","sessionId":"s1","userId":"u1","fileIds":["file_a","file_b"]}`},
|
||||
{name: "flat array", body: `{"text":"hello","session_id":"s1","user_id":"u1","fileids":["file_a","file_b"]}`},
|
||||
{name: "single key", body: `{"text":"hello","session_id":"s1","user_id":"u1","file_id":"file_a"}`},
|
||||
{name: "csv string", body: `{"text":"hello","session_id":"s1","user_id":"u1","file_ids":"file_a, file_b"}`},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.chatHandler = func(_ context.Context, msg IncomingMessage) (string, error) {
|
||||
if msg.ChatID != "s1" || msg.UserID != "u1" || msg.Text != "hello" {
|
||||
t.Fatalf("unexpected message: %+v", msg)
|
||||
}
|
||||
if len(msg.FileIDs) == 0 {
|
||||
t.Fatalf("expected file ids from alias payload, got empty")
|
||||
}
|
||||
return "ok", nil
|
||||
}
|
||||
|
||||
body := strings.NewReader(tt.body)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChat(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatMissingText(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.chatHandler = func(_ context.Context, _ IncomingMessage) (string, error) { return "", nil }
|
||||
|
||||
Reference in New Issue
Block a user