feat: add webui http channel for chat and file upload

This commit is contained in:
2026-03-10 10:23:53 +08:00
parent bd41f48971
commit 49f6297631
6 changed files with 527 additions and 2 deletions

View File

@@ -19,6 +19,7 @@ import (
"laodingbot/internal/tools" "laodingbot/internal/tools"
"laodingbot/internal/transport/feishu" "laodingbot/internal/transport/feishu"
"laodingbot/internal/transport/telegram" "laodingbot/internal/transport/telegram"
"laodingbot/internal/transport/webui"
) )
// main 是程序的入口点。它负责初始化环境、加载配置、注册工具并启动消息通道。 // main 是程序的入口点。它负责初始化环境、加载配置、注册工具并启动消息通道。
@@ -197,6 +198,21 @@ func runMessageChannel(ctx context.Context, cfg config.Config, engine *agent.Orc
} }
return engine.HandleMessage(ctx, msg.ChatID, msg.UserID, msg.Text) return engine.HandleMessage(ctx, msg.ChatID, msg.UserID, msg.Text)
}) })
case "webui":
wb, err := webui.NewBot(cfg.WebUI, lg.WithComponent("transport.webui"))
if err != nil {
return fmt.Errorf("init webui bot failed: %w", err)
}
lg.Infof("starting webui transport listen_addr=%s", cfg.WebUI.ListenAddr)
return wb.Run(
ctx,
func(ctx context.Context, msg webui.IncomingMessage) (string, error) {
return engine.HandleMessage(ctx, msg.ChatID, msg.UserID, msg.Text)
},
func(ctx context.Context, chatID, userID string, files []llm.InputFile) ([]string, error) {
return engine.UploadAndCacheFiles(ctx, chatID, userID, files)
},
)
default: default:
return fmt.Errorf("unsupported message channel: %s", cfg.MessageChannel) return fmt.Errorf("unsupported message channel: %s", cfg.MessageChannel)
} }

View File

@@ -17,6 +17,8 @@ TELEGRAM_POLL_TIMEOUT_SECONDS=30
FEISHU_APP_ID= FEISHU_APP_ID=
FEISHU_APP_SECRET= FEISHU_APP_SECRET=
FEISHU_VERIFY_TOKEN= FEISHU_VERIFY_TOKEN=
WEBUI_LISTEN_ADDR=:8090
WEBUI_MAX_UPLOAD_MB=20
LLM_BASE_URL=https://api.openai.com/v1 LLM_BASE_URL=https://api.openai.com/v1
LLM_API_KEY= LLM_API_KEY=

View File

@@ -117,6 +117,24 @@ func (o *Orchestrator) HandleMessageWithFiles(ctx context.Context, chatID, userI
return o.handleMessageInternal(ctx, chatID, userID, text, files) return o.handleMessageInternal(ctx, chatID, userID, text, files)
} }
// UploadAndCacheFiles 上传文件到 LLM 并缓存 file_id供后续同会话文本问答复用。
// 该方法不会写入 messages 表,仅更新内存中的 pending file 上下文。
func (o *Orchestrator) UploadAndCacheFiles(ctx context.Context, chatID, userID string, files []llm.InputFile) ([]string, error) {
if len(files) == 0 {
return nil, fmt.Errorf("no files provided")
}
uploadCtx := o.prepareFilePromptContext(ctx, files, nil)
if strings.TrimSpace(uploadCtx.FatalReason) != "" {
return nil, fmt.Errorf(uploadCtx.FatalReason)
}
ids := nonEmptyIDs(uploadCtx.FileIDs)
if len(ids) == 0 {
return nil, fmt.Errorf("file upload completed but no valid file_id returned")
}
o.appendPendingFiles(chatID, userID, uploadCtx.toPendingRefs())
return ids, nil
}
func (o *Orchestrator) handleMessageInternal(ctx context.Context, chatID, userID, text string, files []llm.InputFile) (string, error) { func (o *Orchestrator) handleMessageInternal(ctx context.Context, chatID, userID, text string, files []llm.InputFile) (string, error) {
// 为链路追踪设置唯一的 TraceID // 为链路追踪设置唯一的 TraceID
traceID := logger.NewTraceID() traceID := logger.NewTraceID()

View File

@@ -25,6 +25,7 @@ type Config struct {
Telegram TelegramConfig Telegram TelegramConfig
Feishu FeishuConfig Feishu FeishuConfig
WebUI WebUIConfig
LLM LLMConfig LLM LLMConfig
Security SecurityConfig Security SecurityConfig
WebSearch WebSearchConfig WebSearch WebSearchConfig
@@ -45,6 +46,11 @@ type FeishuConfig struct {
EventPath string EventPath string
} }
type WebUIConfig struct {
ListenAddr string
MaxUploadBytes int64
}
type LLMConfig struct { type LLMConfig struct {
BaseURL string BaseURL string
APIKey string APIKey string
@@ -95,6 +101,10 @@ func Load() (Config, error) {
ListenAddr: defaultIfEmpty(os.Getenv("FEISHU_LISTEN_ADDR"), ":8080"), ListenAddr: defaultIfEmpty(os.Getenv("FEISHU_LISTEN_ADDR"), ":8080"),
EventPath: defaultIfEmpty(os.Getenv("FEISHU_EVENT_PATH"), "/feishu/events"), EventPath: defaultIfEmpty(os.Getenv("FEISHU_EVENT_PATH"), "/feishu/events"),
}, },
WebUI: WebUIConfig{
ListenAddr: defaultIfEmpty(os.Getenv("WEBUI_LISTEN_ADDR"), ":8090"),
MaxUploadBytes: int64(intFromEnv("WEBUI_MAX_UPLOAD_MB", 20)) * 1024 * 1024,
},
LLM: LLMConfig{ LLM: LLMConfig{
BaseURL: strings.TrimRight(defaultIfEmpty(os.Getenv("LLM_BASE_URL"), "https://api.openai.com/v1"), "/"), BaseURL: strings.TrimRight(defaultIfEmpty(os.Getenv("LLM_BASE_URL"), "https://api.openai.com/v1"), "/"),
APIKey: strings.TrimSpace(os.Getenv("LLM_API_KEY")), APIKey: strings.TrimSpace(os.Getenv("LLM_API_KEY")),
@@ -116,8 +126,8 @@ func Load() (Config, error) {
cfg.MessageChannel = strings.ToLower(strings.TrimSpace(cfg.MessageChannel)) cfg.MessageChannel = strings.ToLower(strings.TrimSpace(cfg.MessageChannel))
cfg.LogLevel = strings.ToLower(strings.TrimSpace(cfg.LogLevel)) cfg.LogLevel = strings.ToLower(strings.TrimSpace(cfg.LogLevel))
if cfg.MessageChannel != "telegram" && cfg.MessageChannel != "feishu" { if cfg.MessageChannel != "telegram" && cfg.MessageChannel != "feishu" && cfg.MessageChannel != "webui" {
return Config{}, fmt.Errorf("MESSAGE_CHANNEL must be telegram or feishu") return Config{}, fmt.Errorf("MESSAGE_CHANNEL must be telegram, feishu, or webui")
} }
if cfg.LogLevel != "debug" && cfg.LogLevel != "info" && cfg.LogLevel != "warn" && cfg.LogLevel != "error" { if cfg.LogLevel != "debug" && cfg.LogLevel != "info" && cfg.LogLevel != "warn" && cfg.LogLevel != "error" {
return Config{}, fmt.Errorf("LOG_LEVEL must be debug, info, warn, or error") return Config{}, fmt.Errorf("LOG_LEVEL must be debug, info, warn, or error")
@@ -137,6 +147,9 @@ func Load() (Config, error) {
if cfg.GapClusterLookbackHours < 1 || cfg.GapClusterLookbackHours > 24*365 { if cfg.GapClusterLookbackHours < 1 || cfg.GapClusterLookbackHours > 24*365 {
return Config{}, fmt.Errorf("GAP_CLUSTER_LOOKBACK_HOURS must be between 1 and 8760") return Config{}, fmt.Errorf("GAP_CLUSTER_LOOKBACK_HOURS must be between 1 and 8760")
} }
if cfg.WebUI.MaxUploadBytes < 1024 || cfg.WebUI.MaxUploadBytes > 200*1024*1024 {
return Config{}, fmt.Errorf("WEBUI_MAX_UPLOAD_MB must be between 1 and 200")
}
if cfg.MessageChannel == "telegram" { if cfg.MessageChannel == "telegram" {
if cfg.Telegram.Token == "" { if cfg.Telegram.Token == "" {
@@ -156,6 +169,12 @@ func Load() (Config, error) {
} }
} }
if cfg.MessageChannel == "webui" {
if strings.TrimSpace(cfg.WebUI.ListenAddr) == "" {
return Config{}, fmt.Errorf("WEBUI_LISTEN_ADDR is required when MESSAGE_CHANNEL=webui")
}
}
if cfg.LLM.APIKey == "" { if cfg.LLM.APIKey == "" {
return Config{}, fmt.Errorf("LLM_API_KEY is required") return Config{}, fmt.Errorf("LLM_API_KEY is required")
} }

View File

@@ -0,0 +1,313 @@
package webui
import (
"context"
"encoding/json"
"fmt"
"io"
"mime"
"net/http"
"path/filepath"
"strings"
"sync/atomic"
"time"
"laodingbot/internal/config"
"laodingbot/internal/llm"
"laodingbot/internal/logger"
)
type IncomingMessage struct {
ChatID string
UserID string
Text string
}
type ChatHandler func(context.Context, IncomingMessage) (string, error)
type UploadHandler func(context.Context, string, string, []llm.InputFile) ([]string, error)
type Bot struct {
listenAddr string
maxUploadBytes int64
log *logger.Logger
chatHandler ChatHandler
uploadHandler UploadHandler
counter uint64
}
type chatRequest struct {
Text string `json:"text"`
SessionID string `json:"session_id"`
UserID string `json:"user_id"`
}
type chatResponse struct {
Reply string `json:"reply"`
SessionID string `json:"session_id"`
UserID string `json:"user_id"`
}
type uploadResponse struct {
FileID string `json:"file_id"`
FileIDs []string `json:"file_ids"`
FileName string `json:"file_name"`
MimeType string `json:"mime_type"`
SizeBytes int `json:"size_bytes"`
SessionID string `json:"session_id"`
UserID string `json:"user_id"`
}
type errorResponse struct {
Error string `json:"error"`
}
func NewBot(cfg config.WebUIConfig, log *logger.Logger) (*Bot, error) {
if strings.TrimSpace(cfg.ListenAddr) == "" {
return nil, fmt.Errorf("empty webui listen address")
}
if cfg.MaxUploadBytes <= 0 {
return nil, fmt.Errorf("invalid webui max upload bytes")
}
return &Bot{
listenAddr: strings.TrimSpace(cfg.ListenAddr),
maxUploadBytes: cfg.MaxUploadBytes,
log: log,
}, nil
}
func (b *Bot) Run(ctx context.Context, chatHandler ChatHandler, uploadHandler UploadHandler) error {
if chatHandler == nil {
return fmt.Errorf("nil webui chat handler")
}
if uploadHandler == nil {
return fmt.Errorf("nil webui upload handler")
}
b.chatHandler = chatHandler
b.uploadHandler = uploadHandler
mux := http.NewServeMux()
mux.HandleFunc("/api/chat", b.handleChat)
mux.HandleFunc("/api/upload", b.handleUpload)
srv := &http.Server{
Addr: b.listenAddr,
Handler: mux,
}
errCh := make(chan error, 1)
go func() {
err := srv.ListenAndServe()
if err != nil && err != http.ErrServerClosed {
errCh <- err
return
}
errCh <- nil
}()
if b.log != nil {
b.log.Infof("webui http transport started addr=%s", b.listenAddr)
}
select {
case <-ctx.Done():
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = srv.Shutdown(shutdownCtx)
err := <-errCh
if b.log != nil {
b.log.Infof("webui http transport stopped: %v", ctx.Err())
}
if err != nil {
return err
}
return ctx.Err()
case err := <-errCh:
if err != nil && b.log != nil {
b.log.Errorf("webui http transport failed err=%v", err)
}
return err
}
}
func (b *Bot) handleChat(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"})
return
}
if !strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))), "application/json") {
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "content-type must be application/json"})
return
}
if b.chatHandler == nil {
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "chat handler not ready"})
return
}
var req chatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid json body"})
return
}
req.Text = strings.TrimSpace(req.Text)
if req.Text == "" {
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "text is required"})
return
}
sessionID := b.resolveID(req.SessionID, "sess")
userID := b.resolveID(req.UserID, "user")
reply, err := b.chatHandler(r.Context(), IncomingMessage{
ChatID: sessionID,
UserID: userID,
Text: req.Text,
})
if err != nil {
if b.log != nil {
b.log.Errorf("webui chat handler failed session_id=%s user_id=%s err=%v", sessionID, userID, err)
}
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "chat failed"})
return
}
writeJSON(w, http.StatusOK, chatResponse{
Reply: reply,
SessionID: sessionID,
UserID: userID,
})
}
func (b *Bot) handleUpload(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"})
return
}
if b.uploadHandler == nil {
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload handler not ready"})
return
}
r.Body = http.MaxBytesReader(w, r.Body, b.maxUploadBytes)
if err := r.ParseMultipartForm(minInt64(b.maxUploadBytes, 32*1024*1024)); err != nil {
if strings.Contains(strings.ToLower(err.Error()), "request body too large") {
writeJSON(w, http.StatusRequestEntityTooLarge, errorResponse{Error: "file too large"})
return
}
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid multipart form"})
return
}
sessionID := b.resolveID(strings.TrimSpace(r.FormValue("session_id")), "sess")
userID := b.resolveID(strings.TrimSpace(r.FormValue("user_id")), "user")
file, header, err := r.FormFile("file")
if err != nil {
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "file is required"})
return
}
defer file.Close()
fileName := sanitizeFileName(header.Filename)
if fileName == "" {
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid file name"})
return
}
content, err := io.ReadAll(file)
if err != nil {
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "read file failed"})
return
}
if len(content) == 0 {
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "empty file"})
return
}
mimeType := strings.TrimSpace(header.Header.Get("Content-Type"))
if mimeType == "" {
mimeType = detectMimeByName(fileName)
}
ids, err := b.uploadHandler(r.Context(), sessionID, userID, []llm.InputFile{{
FileName: fileName,
MimeType: mimeType,
Content: content,
}})
if err != nil {
if b.log != nil {
b.log.Errorf("webui upload handler failed session_id=%s user_id=%s file=%s err=%v", sessionID, userID, fileName, err)
}
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload failed"})
return
}
if len(ids) == 0 || strings.TrimSpace(ids[0]) == "" {
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload succeeded but file_id is empty"})
return
}
writeJSON(w, http.StatusOK, uploadResponse{
FileID: strings.TrimSpace(ids[0]),
FileIDs: ids,
FileName: fileName,
MimeType: mimeType,
SizeBytes: len(content),
SessionID: sessionID,
UserID: userID,
})
}
func (b *Bot) resolveID(raw, prefix string) string {
raw = strings.TrimSpace(raw)
if raw != "" {
return raw
}
n := atomic.AddUint64(&b.counter, 1)
return fmt.Sprintf("%s_%d_%d", prefix, time.Now().UnixNano(), n)
}
func writeJSON(w http.ResponseWriter, status int, payload any) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(payload)
}
func minInt64(a, b int64) int64 {
if a < b {
return a
}
return b
}
func detectMimeByName(fileName string) string {
ext := strings.ToLower(strings.TrimSpace(filepath.Ext(fileName)))
if ext == "" {
return "application/octet-stream"
}
m := strings.TrimSpace(mime.TypeByExtension(ext))
if m == "" {
return "application/octet-stream"
}
return m
}
func sanitizeFileName(fileName string) string {
name := strings.TrimSpace(filepath.Base(fileName))
if name == "" || name == "." || name == ".." {
return ""
}
var b strings.Builder
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' {
b.WriteRune(r)
continue
}
b.WriteByte('_')
}
out := strings.TrimSpace(b.String())
if out == "" || out == "." || out == ".." {
return ""
}
if strings.HasPrefix(out, ".") {
out = "file" + out
}
return out
}

View File

@@ -0,0 +1,157 @@
package webui
import (
"bytes"
"context"
"encoding/json"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"laodingbot/internal/config"
"laodingbot/internal/llm"
)
func newTestBot(t *testing.T, maxUploadBytes int64) *Bot {
t.Helper()
b, err := NewBot(config.WebUIConfig{ListenAddr: ":8090", MaxUploadBytes: maxUploadBytes}, nil)
if err != nil {
t.Fatalf("NewBot failed: %v", err)
}
return b
}
func TestHandleChatSuccess(t *testing.T) {
b := newTestBot(t, 1024*1024)
b.chatHandler = func(_ context.Context, msg IncomingMessage) (string, error) {
if msg.ChatID != "s1" || msg.UserID != "u1" || msg.Text != "hello" {
t.Fatalf("unexpected message: %+v", msg)
}
return "ok", nil
}
body := strings.NewReader(`{"text":"hello","session_id":"s1","user_id":"u1"}`)
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
b.handleChat(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
}
var out chatResponse
if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v", err)
}
if out.Reply != "ok" || out.SessionID != "s1" || out.UserID != "u1" {
t.Fatalf("unexpected response: %+v", out)
}
}
func TestHandleChatMissingText(t *testing.T) {
b := newTestBot(t, 1024*1024)
b.chatHandler = func(_ context.Context, _ IncomingMessage) (string, error) { return "", nil }
body := strings.NewReader(`{"text":" "}`)
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
b.handleChat(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}
func TestHandleUploadSuccess(t *testing.T) {
b := newTestBot(t, 1024*1024)
b.uploadHandler = func(_ context.Context, chatID, userID string, files []llm.InputFile) ([]string, error) {
if chatID != "s1" || userID != "u1" {
t.Fatalf("unexpected ids chat=%s user=%s", chatID, userID)
}
if len(files) != 1 {
t.Fatalf("unexpected files len=%d", len(files))
}
if files[0].FileName != "doc.pdf" || len(files[0].Content) == 0 {
t.Fatalf("unexpected file payload: %+v", files[0])
}
return []string{"file_123"}, nil
}
var payload bytes.Buffer
writer := multipart.NewWriter(&payload)
_ = writer.WriteField("session_id", "s1")
_ = writer.WriteField("user_id", "u1")
fw, err := writer.CreateFormFile("file", "doc.pdf")
if err != nil {
t.Fatalf("CreateFormFile failed: %v", err)
}
_, _ = fw.Write([]byte("pdf-content"))
_ = writer.Close()
req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload)
req.Header.Set("Content-Type", writer.FormDataContentType())
w := httptest.NewRecorder()
b.handleUpload(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
}
var out uploadResponse
if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v", err)
}
if out.FileID != "file_123" || out.SessionID != "s1" || out.UserID != "u1" {
t.Fatalf("unexpected response: %+v", out)
}
}
func TestHandleUploadTooLarge(t *testing.T) {
b := newTestBot(t, 3)
b.uploadHandler = func(_ context.Context, _ string, _ string, _ []llm.InputFile) ([]string, error) {
return []string{"file_should_not_reach"}, nil
}
var payload bytes.Buffer
writer := multipart.NewWriter(&payload)
fw, err := writer.CreateFormFile("file", "a.txt")
if err != nil {
t.Fatalf("CreateFormFile failed: %v", err)
}
_, _ = fw.Write([]byte("12345"))
_ = writer.Close()
req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload)
req.Header.Set("Content-Type", writer.FormDataContentType())
w := httptest.NewRecorder()
b.handleUpload(w, req)
if w.Code != http.StatusRequestEntityTooLarge {
t.Fatalf("expected 413, got %d body=%s", w.Code, w.Body.String())
}
}
func TestHandleUploadMissingFile(t *testing.T) {
b := newTestBot(t, 1024*1024)
b.uploadHandler = func(_ context.Context, _ string, _ string, _ []llm.InputFile) ([]string, error) {
return []string{"file_should_not_reach"}, nil
}
var payload bytes.Buffer
writer := multipart.NewWriter(&payload)
_ = writer.WriteField("session_id", "s1")
_ = writer.Close()
req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload)
req.Header.Set("Content-Type", writer.FormDataContentType())
w := httptest.NewRecorder()
b.handleUpload(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}