209 lines
5.5 KiB
Go
209 lines
5.5 KiB
Go
|
|
package filedoc
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"laodingbot/internal/logger"
|
||
|
|
|
||
|
|
openai "github.com/openai/openai-go"
|
||
|
|
"github.com/openai/openai-go/option"
|
||
|
|
"github.com/openai/openai-go/shared"
|
||
|
|
)
|
||
|
|
|
||
|
|
type Config struct {
|
||
|
|
APIKey string
|
||
|
|
BaseURL string
|
||
|
|
Model string
|
||
|
|
Timeout time.Duration
|
||
|
|
}
|
||
|
|
|
||
|
|
type Tool struct {
|
||
|
|
client openai.Client
|
||
|
|
model string
|
||
|
|
maxOutputChars int
|
||
|
|
log *logger.Logger
|
||
|
|
}
|
||
|
|
|
||
|
|
func New(cfg Config, maxOutputChars int, log *logger.Logger) *Tool {
|
||
|
|
if strings.TrimSpace(cfg.Model) == "" {
|
||
|
|
cfg.Model = "gpt-4o-mini"
|
||
|
|
}
|
||
|
|
if cfg.Timeout <= 0 {
|
||
|
|
cfg.Timeout = 60 * time.Second
|
||
|
|
}
|
||
|
|
if maxOutputChars <= 0 {
|
||
|
|
maxOutputChars = 12000
|
||
|
|
}
|
||
|
|
|
||
|
|
opts := []option.RequestOption{
|
||
|
|
option.WithAPIKey(strings.TrimSpace(cfg.APIKey)),
|
||
|
|
option.WithRequestTimeout(cfg.Timeout),
|
||
|
|
}
|
||
|
|
if strings.TrimSpace(cfg.BaseURL) != "" {
|
||
|
|
opts = append(opts, option.WithBaseURL(strings.TrimSpace(cfg.BaseURL)))
|
||
|
|
}
|
||
|
|
|
||
|
|
return &Tool{
|
||
|
|
client: openai.NewClient(opts...),
|
||
|
|
model: cfg.Model,
|
||
|
|
maxOutputChars: maxOutputChars,
|
||
|
|
log: log,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *Tool) Name() string { return "extract_file_document" }
|
||
|
|
|
||
|
|
func (t *Tool) Description() string {
|
||
|
|
return "Extract full document details from a file ID via OpenAI. Input: file_id (supports plain ID, fileid://ID, or JSON {\"file_id\":\"...\"})."
|
||
|
|
}
|
||
|
|
|
||
|
|
func (t *Tool) Call(ctx context.Context, input string) (string, error) {
|
||
|
|
fileID, userFocus, err := parseInput(input)
|
||
|
|
if err != nil {
|
||
|
|
return "", err
|
||
|
|
}
|
||
|
|
|
||
|
|
prompt := buildExtractionPrompt(fileID, userFocus)
|
||
|
|
messages := []openai.ChatCompletionMessageParamUnion{
|
||
|
|
openai.SystemMessage("fileid://" + fileID),
|
||
|
|
openai.UserMessage([]openai.ChatCompletionContentPartUnionParam{
|
||
|
|
openai.TextContentPart(prompt),
|
||
|
|
}),
|
||
|
|
}
|
||
|
|
|
||
|
|
params := openai.ChatCompletionNewParams{
|
||
|
|
Model: shared.ChatModel(t.model),
|
||
|
|
Messages: messages,
|
||
|
|
}
|
||
|
|
|
||
|
|
if t.log != nil {
|
||
|
|
t.log.Infof("filedoc tool request model=%s file_id=%s", t.model, fileID)
|
||
|
|
}
|
||
|
|
|
||
|
|
resp, err := t.client.Chat.Completions.New(ctx, params)
|
||
|
|
if err != nil {
|
||
|
|
return "", fmt.Errorf("filedoc request failed: %w", err)
|
||
|
|
}
|
||
|
|
if len(resp.Choices) == 0 {
|
||
|
|
return "", fmt.Errorf("filedoc returned empty choices")
|
||
|
|
}
|
||
|
|
|
||
|
|
out := strings.TrimSpace(resp.Choices[0].Message.Content)
|
||
|
|
if out == "" {
|
||
|
|
out = "未提取到可读的文档内容。请确认 file_id 是否有效以及模型是否支持文件解析。"
|
||
|
|
}
|
||
|
|
if len(out) > t.maxOutputChars {
|
||
|
|
out = out[:t.maxOutputChars]
|
||
|
|
}
|
||
|
|
return out, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func buildExtractionPrompt(fileID, userFocus string) string {
|
||
|
|
focus := strings.TrimSpace(userFocus)
|
||
|
|
if focus == "" {
|
||
|
|
focus = "请输出完整文档信息,包括标题、主题、核心观点、结构大纲、关键术语、重要结论、风险点与后续建议。"
|
||
|
|
}
|
||
|
|
|
||
|
|
return strings.Join([]string{
|
||
|
|
"请基于所附文件输出完整文档信息。",
|
||
|
|
"file_id: " + fileID,
|
||
|
|
"",
|
||
|
|
"输出要求:",
|
||
|
|
"1) 文档基本信息:标题、文档类型、语言、可能作者/组织(若可判断)、时间线索(若可判断)。",
|
||
|
|
"2) 结构化摘要:按章节或逻辑段落给出要点,尽量保持原文顺序。",
|
||
|
|
"3) 关键数据与事实:列出关键数字、术语、专有名词、约束条件。",
|
||
|
|
"4) 风险与不确定性:明确哪些信息来源于文档,哪些是无法确认。",
|
||
|
|
"5) 面向执行的建议:给出可落地的后续行动项。",
|
||
|
|
"",
|
||
|
|
"补充关注点:",
|
||
|
|
focus,
|
||
|
|
}, "\n")
|
||
|
|
}
|
||
|
|
|
||
|
|
func parseInput(input string) (fileID string, userFocus string, err error) {
|
||
|
|
raw := strings.TrimSpace(input)
|
||
|
|
if raw == "" {
|
||
|
|
return "", "", fmt.Errorf("empty input: expected file_id")
|
||
|
|
}
|
||
|
|
|
||
|
|
if strings.HasPrefix(raw, "{") {
|
||
|
|
var payload map[string]any
|
||
|
|
if jsonErr := json.Unmarshal([]byte(raw), &payload); jsonErr == nil {
|
||
|
|
if id := firstNonEmptyString(payload, "file_id", "fileid", "id", "fileID"); id != "" {
|
||
|
|
return normalizeFileID(id), firstNonEmptyString(payload, "focus", "query", "instruction", "prompt"), nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
lines := strings.Split(raw, "\n")
|
||
|
|
for _, line := range lines {
|
||
|
|
candidate := extractFileIDToken(line)
|
||
|
|
if candidate != "" {
|
||
|
|
focus := strings.TrimSpace(strings.ReplaceAll(raw, line, ""))
|
||
|
|
return normalizeFileID(candidate), focus, nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
candidate := extractFileIDToken(raw)
|
||
|
|
if candidate == "" {
|
||
|
|
return "", "", fmt.Errorf("no file_id found in input")
|
||
|
|
}
|
||
|
|
return normalizeFileID(candidate), "", nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func extractFileIDToken(s string) string {
|
||
|
|
fields := strings.FieldsFunc(s, func(r rune) bool {
|
||
|
|
switch r {
|
||
|
|
case ' ', '\t', '\n', '\r', ',', ';', '|':
|
||
|
|
return true
|
||
|
|
default:
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
})
|
||
|
|
|
||
|
|
for _, f := range fields {
|
||
|
|
tok := strings.TrimSpace(strings.Trim(f, "\"'()[]{}"))
|
||
|
|
if tok == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
lower := strings.ToLower(tok)
|
||
|
|
if strings.HasPrefix(lower, "fileid://") {
|
||
|
|
return tok[len("fileid://"):]
|
||
|
|
}
|
||
|
|
if strings.HasPrefix(lower, "file_id=") || strings.HasPrefix(lower, "fileid=") {
|
||
|
|
idx := strings.Index(tok, "=")
|
||
|
|
if idx >= 0 && idx+1 < len(tok) {
|
||
|
|
return tok[idx+1:]
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if strings.HasPrefix(lower, "file_") || strings.HasPrefix(lower, "file-") {
|
||
|
|
return tok
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
|
||
|
|
func normalizeFileID(id string) string {
|
||
|
|
id = strings.TrimSpace(strings.Trim(id, "\"'"))
|
||
|
|
if strings.HasPrefix(strings.ToLower(id), "fileid://") {
|
||
|
|
return strings.TrimSpace(id[len("fileid://"):])
|
||
|
|
}
|
||
|
|
return id
|
||
|
|
}
|
||
|
|
|
||
|
|
func firstNonEmptyString(m map[string]any, keys ...string) string {
|
||
|
|
for _, k := range keys {
|
||
|
|
if v, ok := m[k]; ok {
|
||
|
|
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
|
||
|
|
return strings.TrimSpace(s)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return ""
|
||
|
|
}
|