Migrate LLM client to OpenAI SDK and implement WebUI-specific fileID handling

This commit is contained in:
2026-03-10 17:54:50 +08:00
parent 49f6297631
commit 0e1a800646
23 changed files with 1162 additions and 8201 deletions

View File

@@ -5,28 +5,83 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"strings"
"time"
"laodingbot/internal/config"
"laodingbot/internal/logger"
openai "github.com/openai/openai-go" // imported as openai
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
"github.com/openai/openai-go/shared"
)
type Client interface {
Generate(ctx context.Context, systemPrompt, userPrompt string) (string, error)
}
type PromptMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
}
type MessageChatClient interface {
GenerateMessages(ctx context.Context, messages []PromptMessage) (string, error)
}
type FileChatClient interface {
GenerateWithFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string) (string, error)
GenerateWithFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string, appendFileIDText bool) (string, error)
}
type FileMessageChatClient interface {
GenerateMessagesWithFiles(ctx context.Context, messages []PromptMessage, fileIDs []string, appendFileIDText bool) (string, error)
}
type FileUploader interface {
UploadFile(ctx context.Context, file InputFile, purpose string) (string, error)
}
// ToolCallChatClient 支持原生 function calling 的 LLM 客户端接口。
type ToolCallChatClient interface {
GenerateWithTools(ctx context.Context, messages []PromptMessage, tools []ToolDefinition, fileIDs []string, appendFileIDText bool) (*ChatCompletion, error)
}
// ToolDefinition 描述一个可供 LLM 调用的工具函数定义。
type ToolDefinition struct {
Type string `json:"type"`
Function ToolFunctionDef `json:"function"`
}
// ToolFunctionDef 是工具函数的名称、描述和参数 JSON Schema。
type ToolFunctionDef struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters json.RawMessage `json:"parameters,omitempty"`
}
// ToolCall 是 LLM 在响应中返回的工具调用请求。
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function ToolCallFunction `json:"function"`
}
// ToolCallFunction 包含工具调用的函数名和参数。
type ToolCallFunction struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
// ChatCompletion 是 LLM 响应的结构化表示,包含文本内容和可选的工具调用。
type ChatCompletion struct {
Content string
ToolCalls []ToolCall
}
type InputFile struct {
FileName string
MimeType string
@@ -34,206 +89,312 @@ type InputFile struct {
}
type OpenAICompatibleClient struct {
baseURL string
apiKey string
client openai.Client
model string
fileModel string
filePromptMode string
http *http.Client
log *logger.Logger
}
func NewOpenAICompatibleClient(cfg config.LLMConfig, log *logger.Logger) *OpenAICompatibleClient {
opts := []option.RequestOption{
option.WithAPIKey(cfg.APIKey),
option.WithRequestTimeout(60 * time.Second),
}
if strings.TrimSpace(cfg.BaseURL) != "" {
opts = append(opts, option.WithBaseURL(cfg.BaseURL))
}
return &OpenAICompatibleClient{
baseURL: cfg.BaseURL,
apiKey: cfg.APIKey,
client: openai.NewClient(opts...),
model: cfg.Model,
fileModel: cfg.FileModel,
filePromptMode: cfg.FilePromptMode,
http: &http.Client{Timeout: 60 * time.Second},
log: log,
}
}
type chatRequest struct {
Model string `json:"model"`
Messages []chatMessage `json:"messages"`
}
type chatMessage struct {
Role string `json:"role"`
Content any `json:"content"`
}
type chatContentPart struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
FileID string `json:"file_id,omitempty"`
}
type chatResponse struct {
Choices []struct {
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
type fileUploadResponse struct {
ID string `json:"id"`
Bytes int64 `json:"bytes,omitempty"`
CreatedAt int64 `json:"created_at,omitempty"`
Filename string `json:"filename,omitempty"`
Object string `json:"object,omitempty"`
Purpose string `json:"purpose,omitempty"`
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Status any `json:"status,omitempty"`
StatusDetails any `json:"status_details,omitempty"`
Data *struct {
ID string `json:"id"`
} `json:"data,omitempty"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
func (c *OpenAICompatibleClient) Generate(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
return c.generateInternal(ctx, systemPrompt, userPrompt, nil)
messages := []PromptMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
}
return c.generateWithMessagesInternal(ctx, messages, nil, false)
}
func (c *OpenAICompatibleClient) GenerateWithFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string) (string, error) {
return c.generateInternal(ctx, systemPrompt, userPrompt, fileIDs)
func (c *OpenAICompatibleClient) GenerateWithFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string, appendFileIDText bool) (string, error) {
messages := []PromptMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
}
return c.generateWithMessagesInternal(ctx, messages, fileIDs, appendFileIDText)
}
func (c *OpenAICompatibleClient) generateInternal(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string) (string, error) {
func (c *OpenAICompatibleClient) GenerateMessages(ctx context.Context, messages []PromptMessage) (string, error) {
return c.generateWithMessagesInternal(ctx, messages, nil, false)
}
func (c *OpenAICompatibleClient) GenerateMessagesWithFiles(ctx context.Context, messages []PromptMessage, fileIDs []string, appendFileIDText bool) (string, error) {
return c.generateWithMessagesInternal(ctx, messages, fileIDs, appendFileIDText)
}
// GenerateWithTools 使用原生 function calling 发送请求,返回结构化的 ChatCompletion。
func (c *OpenAICompatibleClient) GenerateWithTools(ctx context.Context, messages []PromptMessage, tools []ToolDefinition, fileIDs []string, appendFileIDText bool) (*ChatCompletion, error) {
model := c.model
ids := nonEmptyIDs(fileIDs)
if len(ids) > 0 {
if strings.TrimSpace(c.fileModel) != "" {
model = c.fileModel
}
if len(ids) > 0 && strings.TrimSpace(c.fileModel) != "" {
model = c.fileModel
}
sdkMessages := buildSDKMessages(messages, ids, c.normalizedFilePromptMode(), appendFileIDText)
sdkTools := toSDKTools(tools)
if c.log != nil {
c.log.Debugf("llm request start model=%s system_len=%d user_len=%d file_count=%d file_prompt_mode=%s", model, len(systemPrompt), len(userPrompt), len(ids), c.normalizedFilePromptMode())
c.log.Debugf("llm tool-call request start model=%s messages=%d tools=%d files=%d", model, len(sdkMessages), len(sdkTools), len(ids))
}
messages := buildMessages(systemPrompt, userPrompt, ids, c.normalizedFilePromptMode())
body := chatRequest{
Model: model,
Messages: messages,
params := openai.ChatCompletionNewParams{
Model: shared.ChatModel(model),
Messages: sdkMessages,
}
b, err := json.Marshal(body)
if len(sdkTools) > 0 {
params.Tools = sdkTools
}
if c.log != nil {
if b, err := json.Marshal(params); err == nil {
c.log.Debugf("llm tool-call request params: %s", string(b))
}
}
resp, err := c.client.Chat.Completions.New(ctx, params)
if err != nil {
return nil, fmt.Errorf("llm tool-call request failed: %w", err)
}
if len(resp.Choices) == 0 {
return nil, fmt.Errorf("llm returned empty choices")
}
choice := resp.Choices[0]
resultToolCalls := fromSDKToolCalls(choice.Message.ToolCalls)
if c.log != nil {
c.log.Infof("llm tool-call response success model=%s content_len=%d tool_calls=%d finish=%s",
model, len(choice.Message.Content), len(resultToolCalls), choice.FinishReason)
}
return &ChatCompletion{
Content: choice.Message.Content,
ToolCalls: resultToolCalls,
}, nil
}
func (c *OpenAICompatibleClient) generateWithMessagesInternal(ctx context.Context, messages []PromptMessage, fileIDs []string, appendFileIDText bool) (string, error) {
model := c.model
ids := nonEmptyIDs(fileIDs)
if len(ids) > 0 && strings.TrimSpace(c.fileModel) != "" {
model = c.fileModel
}
baseMessages := normalizePromptMessages(messages)
if len(baseMessages) == 0 {
baseMessages = []PromptMessage{{Role: "user", Content: ""}}
}
systemLen, userLen := promptMessageLengths(baseMessages)
if c.log != nil {
c.log.Debugf("llm request start model=%s system_len=%d user_len=%d file_count=%d file_prompt_mode=%s", model, systemLen, userLen, len(ids), c.normalizedFilePromptMode())
}
sdkMessages := buildSDKMessages(baseMessages, ids, c.normalizedFilePromptMode(), appendFileIDText)
params := openai.ChatCompletionNewParams{
Model: shared.ChatModel(model),
Messages: sdkMessages,
}
resp, err := c.client.Chat.Completions.New(ctx, params)
if err != nil {
if c.log != nil {
c.log.Errorf("marshal llm request failed err=%v", err)
c.log.Errorf("llm request failed err=%v", err)
}
return "", err
return "", fmt.Errorf("llm request failed: %w", err)
}
url := strings.TrimRight(c.baseURL, "/") + "/chat/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
if err != nil {
if len(resp.Choices) == 0 {
if c.log != nil {
c.log.Errorf("build llm request failed err=%v", err)
}
return "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.apiKey)
resp, err := c.http.Do(req)
if err != nil {
if c.log != nil {
c.log.Errorf("llm http request failed err=%v", err)
}
return "", err
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
if c.log != nil {
c.log.Errorf("llm read response failed err=%v", err)
}
return "", err
}
var out chatResponse
if err := json.Unmarshal(raw, &out); err != nil {
if c.log != nil {
c.log.Errorf("llm response unmarshal failed status=%d err=%v", resp.StatusCode, err)
}
return "", err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
if c.log != nil {
c.log.Errorf("llm bad status=%d", resp.StatusCode)
}
if out.Error != nil && out.Error.Message != "" {
return "", fmt.Errorf("llm error: %s", out.Error.Message)
}
return "", fmt.Errorf("llm error status: %d", resp.StatusCode)
}
if len(out.Choices) == 0 {
if c.log != nil {
c.log.Errorf("llm returned empty choices status=%d", resp.StatusCode)
c.log.Errorf("llm returned empty choices")
}
return "", fmt.Errorf("llm returned empty choices")
}
content := resp.Choices[0].Message.Content
if c.log != nil {
c.log.Infof("llm response success model=%s output_len=%d", model, len(out.Choices[0].Message.Content))
c.log.Infof("llm response success model=%s output_len=%d", model, len(content))
}
return out.Choices[0].Message.Content, nil
return content, nil
}
func buildMessages(systemPrompt, userPrompt string, fileIDs []string, mode string) []chatMessage {
// buildSDKMessages 将 PromptMessage 列表转换为 openai SDK 的消息格式,并注入 file_id如需要
func buildSDKMessages(base []PromptMessage, fileIDs []string, mode string, appendFileIDText bool) []openai.ChatCompletionMessageParamUnion {
mode = strings.ToLower(strings.TrimSpace(mode))
if mode == "system_fileid_uri" {
msgs := []chatMessage{{Role: "system", Content: systemPrompt}}
for _, id := range fileIDs {
if strings.TrimSpace(id) == "" {
continue
}
msgs = append(msgs, chatMessage{Role: "system", Content: "fileid://" + strings.TrimSpace(id)})
out := make([]openai.ChatCompletionMessageParamUnion, 0, len(base)+2)
for _, m := range base {
role := normalizeRole(m.Role)
if role == "" {
continue
}
msgs = append(msgs, chatMessage{Role: "user", Content: userPrompt})
return msgs
out = append(out, toSDKMessage(m, role))
}
userContent := buildUserContent(userPrompt, fileIDs)
return []chatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userContent},
if len(fileIDs) == 0 {
return out
}
if appendFileIDText {
// WebUI 场景:将首个 fileID 作为 text part 追加到最后一个 user 消息。
firstFileID := strings.TrimSpace(fileIDs[0])
if firstFileID == "" {
return out
}
for i := len(out) - 1; i >= 0; i-- {
if r := out[i].GetRole(); r != nil && *r == "user" {
out[i] = buildUserMessageWithFileIDText(out[i], firstFileID)
return out
}
}
out = append(out, buildUserMessageWithFileIDText(openai.UserMessage(""), firstFileID))
return out
}
// 非 WebUI 场景:保持原有 file content part 方式。
for i := len(out) - 1; i >= 0; i-- {
if r := out[i].GetRole(); r != nil && *r == "user" {
out[i] = buildUserMessageWithFiles(out[i], fileIDs)
return out
}
}
out = append(out, buildUserMessageWithFiles(openai.UserMessage(""), fileIDs))
return out
}
// toSDKMessage 将单个 PromptMessage 转换为 openai SDK 消息类型。
func toSDKMessage(m PromptMessage, role string) openai.ChatCompletionMessageParamUnion {
switch role {
case "system":
return openai.SystemMessage(m.Content)
case "user":
return openai.UserMessage(m.Content)
case "assistant":
if len(m.ToolCalls) > 0 {
sdkToolCalls := make([]openai.ChatCompletionMessageToolCallParam, 0, len(m.ToolCalls))
for _, tc := range m.ToolCalls {
sdkToolCalls = append(sdkToolCalls, openai.ChatCompletionMessageToolCallParam{
ID: tc.ID,
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
})
}
msg := openai.AssistantMessage(m.Content)
msg.OfAssistant.ToolCalls = sdkToolCalls
return msg
}
return openai.AssistantMessage(m.Content)
case "tool":
return openai.ToolMessage(m.Content, m.ToolCallID)
default:
return openai.UserMessage(m.Content)
}
}
func buildUserContent(userPrompt string, fileIDs []string) any {
trimmedPrompt := strings.TrimSpace(userPrompt)
if len(fileIDs) == 0 {
return userPrompt
// buildUserMessageWithFileIDText 为 user 消息追加一个 text part内容为 fileID。
func buildUserMessageWithFileIDText(msg openai.ChatCompletionMessageParamUnion, fileID string) openai.ChatCompletionMessageParamUnion {
// 提取已有的文本内容
text := ""
if s, ok := msg.GetContent().AsAny().(*string); ok && s != nil {
text = *s
}
fileID = strings.TrimSpace(fileID)
if fileID == "" {
return msg
}
parts := make([]chatContentPart, 0, len(fileIDs)+1)
if trimmedPrompt != "" {
parts = append(parts, chatContentPart{Type: "text", Text: userPrompt})
parts := make([]openai.ChatCompletionContentPartUnionParam, 0, 2)
if strings.TrimSpace(text) != "" {
parts = append(parts, openai.TextContentPart(text))
}
parts = append(parts, openai.TextContentPart(fileID))
if len(parts) == 0 {
return msg
}
return openai.UserMessage(parts)
}
// buildUserMessageWithFiles 为 user 消息追加 file content parts。
func buildUserMessageWithFiles(msg openai.ChatCompletionMessageParamUnion, fileIDs []string) openai.ChatCompletionMessageParamUnion {
text := ""
if s, ok := msg.GetContent().AsAny().(*string); ok && s != nil {
text = *s
}
parts := make([]openai.ChatCompletionContentPartUnionParam, 0, len(fileIDs)+1)
if strings.TrimSpace(text) != "" {
parts = append(parts, openai.TextContentPart(text))
}
for _, id := range fileIDs {
id = strings.TrimSpace(id)
if id == "" {
continue
}
parts = append(parts, chatContentPart{Type: "file", FileID: id})
parts = append(parts, openai.FileContentPart(openai.ChatCompletionContentPartFileFileParam{FileID: param.NewOpt(id)}))
}
if len(parts) == 0 {
return userPrompt
return msg
}
return parts
return openai.UserMessage(parts)
}
// toSDKTools 将内部 ToolDefinition 列表转换为 openai SDK 的 ChatCompletionToolParam 列表。
func toSDKTools(tools []ToolDefinition) []openai.ChatCompletionToolParam {
if len(tools) == 0 {
return nil
}
out := make([]openai.ChatCompletionToolParam, 0, len(tools))
for _, t := range tools {
var params shared.FunctionParameters
if len(t.Function.Parameters) > 0 {
_ = json.Unmarshal(t.Function.Parameters, &params)
}
out = append(out, openai.ChatCompletionToolParam{
Function: shared.FunctionDefinitionParam{
Name: t.Function.Name,
Description: param.NewOpt(t.Function.Description),
Parameters: params,
},
})
}
return out
}
// fromSDKToolCalls 将 openai SDK 响应中的 tool calls 转换为内部 ToolCall 类型。
func fromSDKToolCalls(sdkCalls []openai.ChatCompletionMessageToolCall) []ToolCall {
if len(sdkCalls) == 0 {
return nil
}
out := make([]ToolCall, 0, len(sdkCalls))
for _, tc := range sdkCalls {
out = append(out, ToolCall{
ID: tc.ID,
Type: "function",
Function: ToolCallFunction{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
})
}
return out
}
func (c *OpenAICompatibleClient) UploadFile(ctx context.Context, file InputFile, purpose string) (string, error) {
@@ -248,7 +409,6 @@ func (c *OpenAICompatibleClient) UploadFile(ctx context.Context, file InputFile,
if purpose != "" {
purposes = append(purposes, purpose)
}
// Provider compatibility fallback order.
purposes = appendIfMissing(purposes, "file-extract")
purposes = appendIfMissing(purposes, "batch")
@@ -270,77 +430,24 @@ func (c *OpenAICompatibleClient) UploadFile(ctx context.Context, file InputFile,
}
func (c *OpenAICompatibleClient) uploadFileOnce(ctx context.Context, file InputFile, purpose string) (string, error) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
if err := writer.WriteField("purpose", purpose); err != nil {
return "", err
}
part, err := writer.CreateFormFile("file", file.FileName)
resp, err := c.client.Files.New(ctx, openai.FileNewParams{
File: bytes.NewReader(file.Content),
Purpose: openai.FilePurpose(purpose),
})
if err != nil {
return "", err
}
if _, err := part.Write(file.Content); err != nil {
return "", err
}
if err := writer.Close(); err != nil {
return "", err
return "", fmt.Errorf("llm file upload failed: %w", err)
}
url := strings.TrimRight(c.baseURL, "/") + "/files"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
if err != nil {
return "", err
}
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", "Bearer "+c.apiKey)
resp, err := c.http.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
var out fileUploadResponse
if err := json.Unmarshal(raw, &out); err != nil {
return "", fmt.Errorf("llm file upload response decode failed: %w body=%s", err, clipForError(raw))
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
if strings.TrimSpace(out.Message) != "" {
return "", fmt.Errorf("llm file upload error: %s", out.Message)
}
if out.Error != nil && out.Error.Message != "" {
return "", fmt.Errorf("llm file upload error: %s", out.Error.Message)
}
return "", fmt.Errorf("llm file upload status: %d body=%s", resp.StatusCode, clipForError(raw))
}
fileID := strings.TrimSpace(out.ID)
if fileID == "" && out.Data != nil {
fileID = strings.TrimSpace(out.Data.ID)
}
fileID := strings.TrimSpace(resp.ID)
if fileID == "" {
return "", fmt.Errorf("llm file upload returned empty file id body=%s", clipForError(raw))
return "", fmt.Errorf("llm file upload returned empty file id")
}
if c.log != nil {
c.log.Infof("llm file uploaded name=%s size=%d file_id=%s purpose=%s status=%v", file.FileName, len(file.Content), fileID, purpose, out.Status)
c.log.Infof("llm file uploaded name=%s size=%d file_id=%s purpose=%s", file.FileName, len(file.Content), fileID, purpose)
}
return fileID, nil
}
func clipForError(raw []byte) string {
s := strings.TrimSpace(string(raw))
const max = 400
if len(s) <= max {
return s
}
return s[:max] + "...(truncated)"
}
func appendIfMissing(items []string, value string) []string {
value = strings.TrimSpace(value)
if value == "" {
@@ -374,6 +481,46 @@ func nonEmptyIDs(ids []string) []string {
return out
}
func normalizePromptMessages(messages []PromptMessage) []PromptMessage {
out := make([]PromptMessage, 0, len(messages))
for _, m := range messages {
role := normalizeRole(m.Role)
if role == "" {
continue
}
out = append(out, PromptMessage{
Role: role,
Content: m.Content,
ToolCalls: m.ToolCalls,
ToolCallID: m.ToolCallID,
Name: m.Name,
})
}
return out
}
func normalizeRole(role string) string {
r := strings.ToLower(strings.TrimSpace(role))
if r != "system" && r != "user" && r != "assistant" && r != "tool" {
return ""
}
return r
}
func promptMessageLengths(messages []PromptMessage) (int, int) {
systemLen := 0
userLen := 0
for _, m := range messages {
switch normalizeRole(m.Role) {
case "system":
systemLen += len(m.Content)
case "user":
userLen += len(m.Content)
}
}
return systemLen, userLen
}
func (c *OpenAICompatibleClient) normalizedFilePromptMode() string {
mode := strings.ToLower(strings.TrimSpace(c.filePromptMode))
if mode == "system_fileid" || mode == "system_fileid_url" || mode == "system_fileid_uri" {