From 49f62976312d344a45be60116b39f6870062ba01 Mon Sep 17 00:00:00 2001 From: "Ding, Shuo" Date: Tue, 10 Mar 2026 10:23:53 +0800 Subject: [PATCH] feat: add webui http channel for chat and file upload --- cmd/bot/main.go | 16 ++ configs/env.sample | 2 + internal/agent/orchestrator.go | 18 ++ internal/config/config.go | 23 +- internal/transport/webui/bot.go | 313 +++++++++++++++++++++++++++ internal/transport/webui/bot_test.go | 157 ++++++++++++++ 6 files changed, 527 insertions(+), 2 deletions(-) create mode 100644 internal/transport/webui/bot.go create mode 100644 internal/transport/webui/bot_test.go diff --git a/cmd/bot/main.go b/cmd/bot/main.go index a905a4c..18149a1 100644 --- a/cmd/bot/main.go +++ b/cmd/bot/main.go @@ -19,6 +19,7 @@ import ( "laodingbot/internal/tools" "laodingbot/internal/transport/feishu" "laodingbot/internal/transport/telegram" + "laodingbot/internal/transport/webui" ) // main 是程序的入口点。它负责初始化环境、加载配置、注册工具并启动消息通道。 @@ -197,6 +198,21 @@ func runMessageChannel(ctx context.Context, cfg config.Config, engine *agent.Orc } return engine.HandleMessage(ctx, msg.ChatID, msg.UserID, msg.Text) }) + case "webui": + wb, err := webui.NewBot(cfg.WebUI, lg.WithComponent("transport.webui")) + if err != nil { + return fmt.Errorf("init webui bot failed: %w", err) + } + lg.Infof("starting webui transport listen_addr=%s", cfg.WebUI.ListenAddr) + return wb.Run( + ctx, + func(ctx context.Context, msg webui.IncomingMessage) (string, error) { + return engine.HandleMessage(ctx, msg.ChatID, msg.UserID, msg.Text) + }, + func(ctx context.Context, chatID, userID string, files []llm.InputFile) ([]string, error) { + return engine.UploadAndCacheFiles(ctx, chatID, userID, files) + }, + ) default: return fmt.Errorf("unsupported message channel: %s", cfg.MessageChannel) } diff --git a/configs/env.sample b/configs/env.sample index f962301..e8f24ce 100644 --- a/configs/env.sample +++ b/configs/env.sample @@ -17,6 +17,8 @@ TELEGRAM_POLL_TIMEOUT_SECONDS=30 FEISHU_APP_ID= FEISHU_APP_SECRET= FEISHU_VERIFY_TOKEN= +WEBUI_LISTEN_ADDR=:8090 +WEBUI_MAX_UPLOAD_MB=20 LLM_BASE_URL=https://api.openai.com/v1 LLM_API_KEY= diff --git a/internal/agent/orchestrator.go b/internal/agent/orchestrator.go index 074a31e..0277043 100644 --- a/internal/agent/orchestrator.go +++ b/internal/agent/orchestrator.go @@ -117,6 +117,24 @@ func (o *Orchestrator) HandleMessageWithFiles(ctx context.Context, chatID, userI return o.handleMessageInternal(ctx, chatID, userID, text, files) } +// UploadAndCacheFiles 上传文件到 LLM 并缓存 file_id,供后续同会话文本问答复用。 +// 该方法不会写入 messages 表,仅更新内存中的 pending file 上下文。 +func (o *Orchestrator) UploadAndCacheFiles(ctx context.Context, chatID, userID string, files []llm.InputFile) ([]string, error) { + if len(files) == 0 { + return nil, fmt.Errorf("no files provided") + } + uploadCtx := o.prepareFilePromptContext(ctx, files, nil) + if strings.TrimSpace(uploadCtx.FatalReason) != "" { + return nil, fmt.Errorf(uploadCtx.FatalReason) + } + ids := nonEmptyIDs(uploadCtx.FileIDs) + if len(ids) == 0 { + return nil, fmt.Errorf("file upload completed but no valid file_id returned") + } + o.appendPendingFiles(chatID, userID, uploadCtx.toPendingRefs()) + return ids, nil +} + func (o *Orchestrator) handleMessageInternal(ctx context.Context, chatID, userID, text string, files []llm.InputFile) (string, error) { // 为链路追踪设置唯一的 TraceID traceID := logger.NewTraceID() diff --git a/internal/config/config.go b/internal/config/config.go index d5b7a7d..25b1768 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -25,6 +25,7 @@ type Config struct { Telegram TelegramConfig Feishu FeishuConfig + WebUI WebUIConfig LLM LLMConfig Security SecurityConfig WebSearch WebSearchConfig @@ -45,6 +46,11 @@ type FeishuConfig struct { EventPath string } +type WebUIConfig struct { + ListenAddr string + MaxUploadBytes int64 +} + type LLMConfig struct { BaseURL string APIKey string @@ -95,6 +101,10 @@ func Load() (Config, error) { ListenAddr: defaultIfEmpty(os.Getenv("FEISHU_LISTEN_ADDR"), ":8080"), EventPath: defaultIfEmpty(os.Getenv("FEISHU_EVENT_PATH"), "/feishu/events"), }, + WebUI: WebUIConfig{ + ListenAddr: defaultIfEmpty(os.Getenv("WEBUI_LISTEN_ADDR"), ":8090"), + MaxUploadBytes: int64(intFromEnv("WEBUI_MAX_UPLOAD_MB", 20)) * 1024 * 1024, + }, LLM: LLMConfig{ BaseURL: strings.TrimRight(defaultIfEmpty(os.Getenv("LLM_BASE_URL"), "https://api.openai.com/v1"), "/"), APIKey: strings.TrimSpace(os.Getenv("LLM_API_KEY")), @@ -116,8 +126,8 @@ func Load() (Config, error) { cfg.MessageChannel = strings.ToLower(strings.TrimSpace(cfg.MessageChannel)) cfg.LogLevel = strings.ToLower(strings.TrimSpace(cfg.LogLevel)) - if cfg.MessageChannel != "telegram" && cfg.MessageChannel != "feishu" { - return Config{}, fmt.Errorf("MESSAGE_CHANNEL must be telegram or feishu") + if cfg.MessageChannel != "telegram" && cfg.MessageChannel != "feishu" && cfg.MessageChannel != "webui" { + return Config{}, fmt.Errorf("MESSAGE_CHANNEL must be telegram, feishu, or webui") } if cfg.LogLevel != "debug" && cfg.LogLevel != "info" && cfg.LogLevel != "warn" && cfg.LogLevel != "error" { return Config{}, fmt.Errorf("LOG_LEVEL must be debug, info, warn, or error") @@ -137,6 +147,9 @@ func Load() (Config, error) { if cfg.GapClusterLookbackHours < 1 || cfg.GapClusterLookbackHours > 24*365 { return Config{}, fmt.Errorf("GAP_CLUSTER_LOOKBACK_HOURS must be between 1 and 8760") } + if cfg.WebUI.MaxUploadBytes < 1024 || cfg.WebUI.MaxUploadBytes > 200*1024*1024 { + return Config{}, fmt.Errorf("WEBUI_MAX_UPLOAD_MB must be between 1 and 200") + } if cfg.MessageChannel == "telegram" { if cfg.Telegram.Token == "" { @@ -156,6 +169,12 @@ func Load() (Config, error) { } } + if cfg.MessageChannel == "webui" { + if strings.TrimSpace(cfg.WebUI.ListenAddr) == "" { + return Config{}, fmt.Errorf("WEBUI_LISTEN_ADDR is required when MESSAGE_CHANNEL=webui") + } + } + if cfg.LLM.APIKey == "" { return Config{}, fmt.Errorf("LLM_API_KEY is required") } diff --git a/internal/transport/webui/bot.go b/internal/transport/webui/bot.go new file mode 100644 index 0000000..6bfe87d --- /dev/null +++ b/internal/transport/webui/bot.go @@ -0,0 +1,313 @@ +package webui + +import ( + "context" + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "path/filepath" + "strings" + "sync/atomic" + "time" + + "laodingbot/internal/config" + "laodingbot/internal/llm" + "laodingbot/internal/logger" +) + +type IncomingMessage struct { + ChatID string + UserID string + Text string +} + +type ChatHandler func(context.Context, IncomingMessage) (string, error) +type UploadHandler func(context.Context, string, string, []llm.InputFile) ([]string, error) + +type Bot struct { + listenAddr string + maxUploadBytes int64 + log *logger.Logger + + chatHandler ChatHandler + uploadHandler UploadHandler + counter uint64 +} + +type chatRequest struct { + Text string `json:"text"` + SessionID string `json:"session_id"` + UserID string `json:"user_id"` +} + +type chatResponse struct { + Reply string `json:"reply"` + SessionID string `json:"session_id"` + UserID string `json:"user_id"` +} + +type uploadResponse struct { + FileID string `json:"file_id"` + FileIDs []string `json:"file_ids"` + FileName string `json:"file_name"` + MimeType string `json:"mime_type"` + SizeBytes int `json:"size_bytes"` + SessionID string `json:"session_id"` + UserID string `json:"user_id"` +} + +type errorResponse struct { + Error string `json:"error"` +} + +func NewBot(cfg config.WebUIConfig, log *logger.Logger) (*Bot, error) { + if strings.TrimSpace(cfg.ListenAddr) == "" { + return nil, fmt.Errorf("empty webui listen address") + } + if cfg.MaxUploadBytes <= 0 { + return nil, fmt.Errorf("invalid webui max upload bytes") + } + return &Bot{ + listenAddr: strings.TrimSpace(cfg.ListenAddr), + maxUploadBytes: cfg.MaxUploadBytes, + log: log, + }, nil +} + +func (b *Bot) Run(ctx context.Context, chatHandler ChatHandler, uploadHandler UploadHandler) error { + if chatHandler == nil { + return fmt.Errorf("nil webui chat handler") + } + if uploadHandler == nil { + return fmt.Errorf("nil webui upload handler") + } + b.chatHandler = chatHandler + b.uploadHandler = uploadHandler + + mux := http.NewServeMux() + mux.HandleFunc("/api/chat", b.handleChat) + mux.HandleFunc("/api/upload", b.handleUpload) + + srv := &http.Server{ + Addr: b.listenAddr, + Handler: mux, + } + + errCh := make(chan error, 1) + go func() { + err := srv.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + errCh <- err + return + } + errCh <- nil + }() + + if b.log != nil { + b.log.Infof("webui http transport started addr=%s", b.listenAddr) + } + + select { + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = srv.Shutdown(shutdownCtx) + err := <-errCh + if b.log != nil { + b.log.Infof("webui http transport stopped: %v", ctx.Err()) + } + if err != nil { + return err + } + return ctx.Err() + case err := <-errCh: + if err != nil && b.log != nil { + b.log.Errorf("webui http transport failed err=%v", err) + } + return err + } +} + +func (b *Bot) handleChat(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"}) + return + } + if !strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))), "application/json") { + writeJSON(w, http.StatusBadRequest, errorResponse{Error: "content-type must be application/json"}) + return + } + if b.chatHandler == nil { + writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "chat handler not ready"}) + return + } + + var req chatRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid json body"}) + return + } + req.Text = strings.TrimSpace(req.Text) + if req.Text == "" { + writeJSON(w, http.StatusBadRequest, errorResponse{Error: "text is required"}) + return + } + sessionID := b.resolveID(req.SessionID, "sess") + userID := b.resolveID(req.UserID, "user") + + reply, err := b.chatHandler(r.Context(), IncomingMessage{ + ChatID: sessionID, + UserID: userID, + Text: req.Text, + }) + if err != nil { + if b.log != nil { + b.log.Errorf("webui chat handler failed session_id=%s user_id=%s err=%v", sessionID, userID, err) + } + writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "chat failed"}) + return + } + writeJSON(w, http.StatusOK, chatResponse{ + Reply: reply, + SessionID: sessionID, + UserID: userID, + }) +} + +func (b *Bot) handleUpload(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"}) + return + } + if b.uploadHandler == nil { + writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload handler not ready"}) + return + } + + r.Body = http.MaxBytesReader(w, r.Body, b.maxUploadBytes) + if err := r.ParseMultipartForm(minInt64(b.maxUploadBytes, 32*1024*1024)); err != nil { + if strings.Contains(strings.ToLower(err.Error()), "request body too large") { + writeJSON(w, http.StatusRequestEntityTooLarge, errorResponse{Error: "file too large"}) + return + } + writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid multipart form"}) + return + } + + sessionID := b.resolveID(strings.TrimSpace(r.FormValue("session_id")), "sess") + userID := b.resolveID(strings.TrimSpace(r.FormValue("user_id")), "user") + + file, header, err := r.FormFile("file") + if err != nil { + writeJSON(w, http.StatusBadRequest, errorResponse{Error: "file is required"}) + return + } + defer file.Close() + + fileName := sanitizeFileName(header.Filename) + if fileName == "" { + writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid file name"}) + return + } + + content, err := io.ReadAll(file) + if err != nil { + writeJSON(w, http.StatusBadRequest, errorResponse{Error: "read file failed"}) + return + } + if len(content) == 0 { + writeJSON(w, http.StatusBadRequest, errorResponse{Error: "empty file"}) + return + } + + mimeType := strings.TrimSpace(header.Header.Get("Content-Type")) + if mimeType == "" { + mimeType = detectMimeByName(fileName) + } + + ids, err := b.uploadHandler(r.Context(), sessionID, userID, []llm.InputFile{{ + FileName: fileName, + MimeType: mimeType, + Content: content, + }}) + if err != nil { + if b.log != nil { + b.log.Errorf("webui upload handler failed session_id=%s user_id=%s file=%s err=%v", sessionID, userID, fileName, err) + } + writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload failed"}) + return + } + if len(ids) == 0 || strings.TrimSpace(ids[0]) == "" { + writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload succeeded but file_id is empty"}) + return + } + + writeJSON(w, http.StatusOK, uploadResponse{ + FileID: strings.TrimSpace(ids[0]), + FileIDs: ids, + FileName: fileName, + MimeType: mimeType, + SizeBytes: len(content), + SessionID: sessionID, + UserID: userID, + }) +} + +func (b *Bot) resolveID(raw, prefix string) string { + raw = strings.TrimSpace(raw) + if raw != "" { + return raw + } + n := atomic.AddUint64(&b.counter, 1) + return fmt.Sprintf("%s_%d_%d", prefix, time.Now().UnixNano(), n) +} + +func writeJSON(w http.ResponseWriter, status int, payload any) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(payload) +} + +func minInt64(a, b int64) int64 { + if a < b { + return a + } + return b +} + +func detectMimeByName(fileName string) string { + ext := strings.ToLower(strings.TrimSpace(filepath.Ext(fileName))) + if ext == "" { + return "application/octet-stream" + } + m := strings.TrimSpace(mime.TypeByExtension(ext)) + if m == "" { + return "application/octet-stream" + } + return m +} + +func sanitizeFileName(fileName string) string { + name := strings.TrimSpace(filepath.Base(fileName)) + if name == "" || name == "." || name == ".." { + return "" + } + var b strings.Builder + for _, r := range name { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' { + b.WriteRune(r) + continue + } + b.WriteByte('_') + } + out := strings.TrimSpace(b.String()) + if out == "" || out == "." || out == ".." { + return "" + } + if strings.HasPrefix(out, ".") { + out = "file" + out + } + return out +} diff --git a/internal/transport/webui/bot_test.go b/internal/transport/webui/bot_test.go new file mode 100644 index 0000000..bd01469 --- /dev/null +++ b/internal/transport/webui/bot_test.go @@ -0,0 +1,157 @@ +package webui + +import ( + "bytes" + "context" + "encoding/json" + "mime/multipart" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "laodingbot/internal/config" + "laodingbot/internal/llm" +) + +func newTestBot(t *testing.T, maxUploadBytes int64) *Bot { + t.Helper() + b, err := NewBot(config.WebUIConfig{ListenAddr: ":8090", MaxUploadBytes: maxUploadBytes}, nil) + if err != nil { + t.Fatalf("NewBot failed: %v", err) + } + return b +} + +func TestHandleChatSuccess(t *testing.T) { + b := newTestBot(t, 1024*1024) + b.chatHandler = func(_ context.Context, msg IncomingMessage) (string, error) { + if msg.ChatID != "s1" || msg.UserID != "u1" || msg.Text != "hello" { + t.Fatalf("unexpected message: %+v", msg) + } + return "ok", nil + } + + body := strings.NewReader(`{"text":"hello","session_id":"s1","user_id":"u1"}`) + req := httptest.NewRequest(http.MethodPost, "/api/chat", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + b.handleChat(w, req) + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String()) + } + + var out chatResponse + if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + if out.Reply != "ok" || out.SessionID != "s1" || out.UserID != "u1" { + t.Fatalf("unexpected response: %+v", out) + } +} + +func TestHandleChatMissingText(t *testing.T) { + b := newTestBot(t, 1024*1024) + b.chatHandler = func(_ context.Context, _ IncomingMessage) (string, error) { return "", nil } + + body := strings.NewReader(`{"text":" "}`) + req := httptest.NewRequest(http.MethodPost, "/api/chat", body) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + b.handleChat(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } +} + +func TestHandleUploadSuccess(t *testing.T) { + b := newTestBot(t, 1024*1024) + b.uploadHandler = func(_ context.Context, chatID, userID string, files []llm.InputFile) ([]string, error) { + if chatID != "s1" || userID != "u1" { + t.Fatalf("unexpected ids chat=%s user=%s", chatID, userID) + } + if len(files) != 1 { + t.Fatalf("unexpected files len=%d", len(files)) + } + if files[0].FileName != "doc.pdf" || len(files[0].Content) == 0 { + t.Fatalf("unexpected file payload: %+v", files[0]) + } + return []string{"file_123"}, nil + } + + var payload bytes.Buffer + writer := multipart.NewWriter(&payload) + _ = writer.WriteField("session_id", "s1") + _ = writer.WriteField("user_id", "u1") + fw, err := writer.CreateFormFile("file", "doc.pdf") + if err != nil { + t.Fatalf("CreateFormFile failed: %v", err) + } + _, _ = fw.Write([]byte("pdf-content")) + _ = writer.Close() + + req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload) + req.Header.Set("Content-Type", writer.FormDataContentType()) + w := httptest.NewRecorder() + + b.handleUpload(w, req) + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String()) + } + + var out uploadResponse + if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil { + t.Fatalf("decode response failed: %v", err) + } + if out.FileID != "file_123" || out.SessionID != "s1" || out.UserID != "u1" { + t.Fatalf("unexpected response: %+v", out) + } +} + +func TestHandleUploadTooLarge(t *testing.T) { + b := newTestBot(t, 3) + b.uploadHandler = func(_ context.Context, _ string, _ string, _ []llm.InputFile) ([]string, error) { + return []string{"file_should_not_reach"}, nil + } + + var payload bytes.Buffer + writer := multipart.NewWriter(&payload) + fw, err := writer.CreateFormFile("file", "a.txt") + if err != nil { + t.Fatalf("CreateFormFile failed: %v", err) + } + _, _ = fw.Write([]byte("12345")) + _ = writer.Close() + + req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload) + req.Header.Set("Content-Type", writer.FormDataContentType()) + w := httptest.NewRecorder() + + b.handleUpload(w, req) + if w.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("expected 413, got %d body=%s", w.Code, w.Body.String()) + } +} + +func TestHandleUploadMissingFile(t *testing.T) { + b := newTestBot(t, 1024*1024) + b.uploadHandler = func(_ context.Context, _ string, _ string, _ []llm.InputFile) ([]string, error) { + return []string{"file_should_not_reach"}, nil + } + + var payload bytes.Buffer + writer := multipart.NewWriter(&payload) + _ = writer.WriteField("session_id", "s1") + _ = writer.Close() + + req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload) + req.Header.Set("Content-Type", writer.FormDataContentType()) + w := httptest.NewRecorder() + + b.handleUpload(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", w.Code) + } +}