Files
LaodingBot/internal/llm/client.go

531 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package llm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"strings"
"time"
"laodingbot/internal/config"
"laodingbot/internal/logger"
openai "github.com/openai/openai-go" // imported as openai
"github.com/openai/openai-go/option"
"github.com/openai/openai-go/packages/param"
"github.com/openai/openai-go/shared"
)
type Client interface {
Generate(ctx context.Context, systemPrompt, userPrompt string) (string, error)
}
type PromptMessage struct {
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Name string `json:"name,omitempty"`
}
type MessageChatClient interface {
GenerateMessages(ctx context.Context, messages []PromptMessage) (string, error)
}
type FileChatClient interface {
GenerateWithFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string, appendFileIDText bool) (string, error)
}
type FileMessageChatClient interface {
GenerateMessagesWithFiles(ctx context.Context, messages []PromptMessage, fileIDs []string, appendFileIDText bool) (string, error)
}
type FileUploader interface {
UploadFile(ctx context.Context, file InputFile, purpose string) (string, error)
}
// ToolCallChatClient 支持原生 function calling 的 LLM 客户端接口。
type ToolCallChatClient interface {
GenerateWithTools(ctx context.Context, messages []PromptMessage, tools []ToolDefinition, fileIDs []string, appendFileIDText bool) (*ChatCompletion, error)
}
// ToolDefinition 描述一个可供 LLM 调用的工具函数定义。
type ToolDefinition struct {
Type string `json:"type"`
Function ToolFunctionDef `json:"function"`
}
// ToolFunctionDef 是工具函数的名称、描述和参数 JSON Schema。
type ToolFunctionDef struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters json.RawMessage `json:"parameters,omitempty"`
}
// ToolCall 是 LLM 在响应中返回的工具调用请求。
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function ToolCallFunction `json:"function"`
}
// ToolCallFunction 包含工具调用的函数名和参数。
type ToolCallFunction struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
// ChatCompletion 是 LLM 响应的结构化表示,包含文本内容和可选的工具调用。
type ChatCompletion struct {
Content string
ToolCalls []ToolCall
}
type InputFile struct {
FileName string
MimeType string
Content []byte
}
type OpenAICompatibleClient struct {
client openai.Client
model string
fileModel string
filePromptMode string
log *logger.Logger
}
func NewOpenAICompatibleClient(cfg config.LLMConfig, log *logger.Logger) *OpenAICompatibleClient {
opts := []option.RequestOption{
option.WithAPIKey(cfg.APIKey),
option.WithRequestTimeout(60 * time.Second),
}
if strings.TrimSpace(cfg.BaseURL) != "" {
opts = append(opts, option.WithBaseURL(cfg.BaseURL))
}
return &OpenAICompatibleClient{
client: openai.NewClient(opts...),
model: cfg.Model,
fileModel: cfg.FileModel,
filePromptMode: cfg.FilePromptMode,
log: log,
}
}
func (c *OpenAICompatibleClient) Generate(ctx context.Context, systemPrompt, userPrompt string) (string, error) {
messages := []PromptMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
}
return c.generateWithMessagesInternal(ctx, messages, nil, false)
}
func (c *OpenAICompatibleClient) GenerateWithFiles(ctx context.Context, systemPrompt, userPrompt string, fileIDs []string, appendFileIDText bool) (string, error) {
messages := []PromptMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
}
return c.generateWithMessagesInternal(ctx, messages, fileIDs, appendFileIDText)
}
func (c *OpenAICompatibleClient) GenerateMessages(ctx context.Context, messages []PromptMessage) (string, error) {
return c.generateWithMessagesInternal(ctx, messages, nil, false)
}
func (c *OpenAICompatibleClient) GenerateMessagesWithFiles(ctx context.Context, messages []PromptMessage, fileIDs []string, appendFileIDText bool) (string, error) {
return c.generateWithMessagesInternal(ctx, messages, fileIDs, appendFileIDText)
}
// GenerateWithTools 使用原生 function calling 发送请求,返回结构化的 ChatCompletion。
func (c *OpenAICompatibleClient) GenerateWithTools(ctx context.Context, messages []PromptMessage, tools []ToolDefinition, fileIDs []string, appendFileIDText bool) (*ChatCompletion, error) {
model := c.model
ids := nonEmptyIDs(fileIDs)
if len(ids) > 0 && strings.TrimSpace(c.fileModel) != "" {
model = c.fileModel
}
sdkMessages := buildSDKMessages(messages, ids, c.normalizedFilePromptMode(), appendFileIDText)
sdkTools := toSDKTools(tools)
if c.log != nil {
c.log.Debugf("llm tool-call request start model=%s messages=%d tools=%d files=%d", model, len(sdkMessages), len(sdkTools), len(ids))
}
params := openai.ChatCompletionNewParams{
Model: shared.ChatModel(model),
Messages: sdkMessages,
}
if len(sdkTools) > 0 {
params.Tools = sdkTools
}
if c.log != nil {
if b, err := json.Marshal(params); err == nil {
c.log.Debugf("llm tool-call request params: %s", string(b))
}
}
resp, err := c.client.Chat.Completions.New(ctx, params)
if err != nil {
return nil, fmt.Errorf("llm tool-call request failed: %w", err)
}
if len(resp.Choices) == 0 {
return nil, fmt.Errorf("llm returned empty choices")
}
choice := resp.Choices[0]
resultToolCalls := fromSDKToolCalls(choice.Message.ToolCalls)
if c.log != nil {
c.log.Infof("llm tool-call response success model=%s content_len=%d tool_calls=%d finish=%s",
model, len(choice.Message.Content), len(resultToolCalls), choice.FinishReason)
}
return &ChatCompletion{
Content: choice.Message.Content,
ToolCalls: resultToolCalls,
}, nil
}
func (c *OpenAICompatibleClient) generateWithMessagesInternal(ctx context.Context, messages []PromptMessage, fileIDs []string, appendFileIDText bool) (string, error) {
model := c.model
ids := nonEmptyIDs(fileIDs)
if len(ids) > 0 && strings.TrimSpace(c.fileModel) != "" {
model = c.fileModel
}
baseMessages := normalizePromptMessages(messages)
if len(baseMessages) == 0 {
baseMessages = []PromptMessage{{Role: "user", Content: ""}}
}
systemLen, userLen := promptMessageLengths(baseMessages)
if c.log != nil {
c.log.Debugf("llm request start model=%s system_len=%d user_len=%d file_count=%d file_prompt_mode=%s", model, systemLen, userLen, len(ids), c.normalizedFilePromptMode())
}
sdkMessages := buildSDKMessages(baseMessages, ids, c.normalizedFilePromptMode(), appendFileIDText)
params := openai.ChatCompletionNewParams{
Model: shared.ChatModel(model),
Messages: sdkMessages,
}
resp, err := c.client.Chat.Completions.New(ctx, params)
if err != nil {
if c.log != nil {
c.log.Errorf("llm request failed err=%v", err)
}
return "", fmt.Errorf("llm request failed: %w", err)
}
if len(resp.Choices) == 0 {
if c.log != nil {
c.log.Errorf("llm returned empty choices")
}
return "", fmt.Errorf("llm returned empty choices")
}
content := resp.Choices[0].Message.Content
if c.log != nil {
c.log.Infof("llm response success model=%s output_len=%d", model, len(content))
}
return content, nil
}
// buildSDKMessages 将 PromptMessage 列表转换为 openai SDK 的消息格式,并注入 file_id如需要
func buildSDKMessages(base []PromptMessage, fileIDs []string, mode string, appendFileIDText bool) []openai.ChatCompletionMessageParamUnion {
mode = strings.ToLower(strings.TrimSpace(mode))
out := make([]openai.ChatCompletionMessageParamUnion, 0, len(base)+2)
for _, m := range base {
role := normalizeRole(m.Role)
if role == "" {
continue
}
out = append(out, toSDKMessage(m, role))
}
if len(fileIDs) == 0 {
return out
}
if appendFileIDText {
// WebUI 场景:将首个 fileID 作为 text part 追加到最后一个 user 消息。
firstFileID := strings.TrimSpace(fileIDs[0])
if firstFileID == "" {
return out
}
for i := len(out) - 1; i >= 0; i-- {
if r := out[i].GetRole(); r != nil && *r == "user" {
out[i] = buildUserMessageWithFileIDText(out[i], firstFileID)
return out
}
}
out = append(out, buildUserMessageWithFileIDText(openai.UserMessage(""), firstFileID))
return out
}
// 非 WebUI 场景:保持原有 file content part 方式。
for i := len(out) - 1; i >= 0; i-- {
if r := out[i].GetRole(); r != nil && *r == "user" {
out[i] = buildUserMessageWithFiles(out[i], fileIDs)
return out
}
}
out = append(out, buildUserMessageWithFiles(openai.UserMessage(""), fileIDs))
return out
}
// toSDKMessage 将单个 PromptMessage 转换为 openai SDK 消息类型。
func toSDKMessage(m PromptMessage, role string) openai.ChatCompletionMessageParamUnion {
switch role {
case "system":
return openai.SystemMessage(m.Content)
case "user":
return openai.UserMessage(m.Content)
case "assistant":
if len(m.ToolCalls) > 0 {
sdkToolCalls := make([]openai.ChatCompletionMessageToolCallParam, 0, len(m.ToolCalls))
for _, tc := range m.ToolCalls {
sdkToolCalls = append(sdkToolCalls, openai.ChatCompletionMessageToolCallParam{
ID: tc.ID,
Function: openai.ChatCompletionMessageToolCallFunctionParam{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
})
}
msg := openai.AssistantMessage(m.Content)
msg.OfAssistant.ToolCalls = sdkToolCalls
return msg
}
return openai.AssistantMessage(m.Content)
case "tool":
return openai.ToolMessage(m.Content, m.ToolCallID)
default:
return openai.UserMessage(m.Content)
}
}
// buildUserMessageWithFileIDText 为 user 消息追加一个 text part内容为 fileID。
func buildUserMessageWithFileIDText(msg openai.ChatCompletionMessageParamUnion, fileID string) openai.ChatCompletionMessageParamUnion {
// 提取已有的文本内容
text := ""
if s, ok := msg.GetContent().AsAny().(*string); ok && s != nil {
text = *s
}
fileID = strings.TrimSpace(fileID)
if fileID == "" {
return msg
}
parts := make([]openai.ChatCompletionContentPartUnionParam, 0, 2)
if strings.TrimSpace(text) != "" {
parts = append(parts, openai.TextContentPart(text))
}
parts = append(parts, openai.TextContentPart(fileID))
if len(parts) == 0 {
return msg
}
return openai.UserMessage(parts)
}
// buildUserMessageWithFiles 为 user 消息追加 file content parts。
func buildUserMessageWithFiles(msg openai.ChatCompletionMessageParamUnion, fileIDs []string) openai.ChatCompletionMessageParamUnion {
text := ""
if s, ok := msg.GetContent().AsAny().(*string); ok && s != nil {
text = *s
}
parts := make([]openai.ChatCompletionContentPartUnionParam, 0, len(fileIDs)+1)
if strings.TrimSpace(text) != "" {
parts = append(parts, openai.TextContentPart(text))
}
for _, id := range fileIDs {
id = strings.TrimSpace(id)
if id == "" {
continue
}
parts = append(parts, openai.FileContentPart(openai.ChatCompletionContentPartFileFileParam{FileID: param.NewOpt(id)}))
}
if len(parts) == 0 {
return msg
}
return openai.UserMessage(parts)
}
// toSDKTools 将内部 ToolDefinition 列表转换为 openai SDK 的 ChatCompletionToolParam 列表。
func toSDKTools(tools []ToolDefinition) []openai.ChatCompletionToolParam {
if len(tools) == 0 {
return nil
}
out := make([]openai.ChatCompletionToolParam, 0, len(tools))
for _, t := range tools {
var params shared.FunctionParameters
if len(t.Function.Parameters) > 0 {
_ = json.Unmarshal(t.Function.Parameters, &params)
}
out = append(out, openai.ChatCompletionToolParam{
Function: shared.FunctionDefinitionParam{
Name: t.Function.Name,
Description: param.NewOpt(t.Function.Description),
Parameters: params,
},
})
}
return out
}
// fromSDKToolCalls 将 openai SDK 响应中的 tool calls 转换为内部 ToolCall 类型。
func fromSDKToolCalls(sdkCalls []openai.ChatCompletionMessageToolCall) []ToolCall {
if len(sdkCalls) == 0 {
return nil
}
out := make([]ToolCall, 0, len(sdkCalls))
for _, tc := range sdkCalls {
out = append(out, ToolCall{
ID: tc.ID,
Type: "function",
Function: ToolCallFunction{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
})
}
return out
}
func (c *OpenAICompatibleClient) UploadFile(ctx context.Context, file InputFile, purpose string) (string, error) {
if strings.TrimSpace(file.FileName) == "" {
return "", fmt.Errorf("empty file name")
}
if len(file.Content) == 0 {
return "", fmt.Errorf("empty file content")
}
purpose = strings.TrimSpace(purpose)
purposes := []string{}
if purpose != "" {
purposes = append(purposes, purpose)
}
purposes = appendIfMissing(purposes, "file-extract")
purposes = appendIfMissing(purposes, "batch")
var lastErr error
for _, p := range purposes {
fileID, err := c.uploadFileOnce(ctx, file, p)
if err == nil {
return fileID, nil
}
lastErr = err
if c.log != nil {
c.log.Warnf("llm file upload failed purpose=%s err=%v", p, err)
}
}
if lastErr == nil {
lastErr = fmt.Errorf("llm file upload failed: no purpose tried")
}
return "", lastErr
}
func (c *OpenAICompatibleClient) uploadFileOnce(ctx context.Context, file InputFile, purpose string) (string, error) {
resp, err := c.client.Files.New(ctx, openai.FileNewParams{
File: bytes.NewReader(file.Content),
Purpose: openai.FilePurpose(purpose),
})
if err != nil {
return "", fmt.Errorf("llm file upload failed: %w", err)
}
fileID := strings.TrimSpace(resp.ID)
if fileID == "" {
return "", fmt.Errorf("llm file upload returned empty file id")
}
if c.log != nil {
c.log.Infof("llm file uploaded name=%s size=%d file_id=%s purpose=%s", file.FileName, len(file.Content), fileID, purpose)
}
return fileID, nil
}
func appendIfMissing(items []string, value string) []string {
value = strings.TrimSpace(value)
if value == "" {
return items
}
for _, it := range items {
if strings.EqualFold(strings.TrimSpace(it), value) {
return items
}
}
return append(items, value)
}
func nonEmptyIDs(ids []string) []string {
if len(ids) == 0 {
return nil
}
out := make([]string, 0, len(ids))
seen := map[string]struct{}{}
for _, id := range ids {
id = strings.TrimSpace(id)
if id == "" {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
return out
}
func normalizePromptMessages(messages []PromptMessage) []PromptMessage {
out := make([]PromptMessage, 0, len(messages))
for _, m := range messages {
role := normalizeRole(m.Role)
if role == "" {
continue
}
out = append(out, PromptMessage{
Role: role,
Content: m.Content,
ToolCalls: m.ToolCalls,
ToolCallID: m.ToolCallID,
Name: m.Name,
})
}
return out
}
func normalizeRole(role string) string {
r := strings.ToLower(strings.TrimSpace(role))
if r != "system" && r != "user" && r != "assistant" && r != "tool" {
return ""
}
return r
}
func promptMessageLengths(messages []PromptMessage) (int, int) {
systemLen := 0
userLen := 0
for _, m := range messages {
switch normalizeRole(m.Role) {
case "system":
systemLen += len(m.Content)
case "user":
userLen += len(m.Content)
}
}
return systemLen, userLen
}
func (c *OpenAICompatibleClient) normalizedFilePromptMode() string {
mode := strings.ToLower(strings.TrimSpace(c.filePromptMode))
if mode == "system_fileid" || mode == "system_fileid_url" || mode == "system_fileid_uri" {
return "system_fileid_uri"
}
return "user_content_file_parts"
}