feat: add webui http channel for chat and file upload
This commit is contained in:
@@ -117,6 +117,24 @@ func (o *Orchestrator) HandleMessageWithFiles(ctx context.Context, chatID, userI
|
||||
return o.handleMessageInternal(ctx, chatID, userID, text, files)
|
||||
}
|
||||
|
||||
// UploadAndCacheFiles 上传文件到 LLM 并缓存 file_id,供后续同会话文本问答复用。
|
||||
// 该方法不会写入 messages 表,仅更新内存中的 pending file 上下文。
|
||||
func (o *Orchestrator) UploadAndCacheFiles(ctx context.Context, chatID, userID string, files []llm.InputFile) ([]string, error) {
|
||||
if len(files) == 0 {
|
||||
return nil, fmt.Errorf("no files provided")
|
||||
}
|
||||
uploadCtx := o.prepareFilePromptContext(ctx, files, nil)
|
||||
if strings.TrimSpace(uploadCtx.FatalReason) != "" {
|
||||
return nil, fmt.Errorf(uploadCtx.FatalReason)
|
||||
}
|
||||
ids := nonEmptyIDs(uploadCtx.FileIDs)
|
||||
if len(ids) == 0 {
|
||||
return nil, fmt.Errorf("file upload completed but no valid file_id returned")
|
||||
}
|
||||
o.appendPendingFiles(chatID, userID, uploadCtx.toPendingRefs())
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (o *Orchestrator) handleMessageInternal(ctx context.Context, chatID, userID, text string, files []llm.InputFile) (string, error) {
|
||||
// 为链路追踪设置唯一的 TraceID
|
||||
traceID := logger.NewTraceID()
|
||||
|
||||
@@ -25,6 +25,7 @@ type Config struct {
|
||||
|
||||
Telegram TelegramConfig
|
||||
Feishu FeishuConfig
|
||||
WebUI WebUIConfig
|
||||
LLM LLMConfig
|
||||
Security SecurityConfig
|
||||
WebSearch WebSearchConfig
|
||||
@@ -45,6 +46,11 @@ type FeishuConfig struct {
|
||||
EventPath string
|
||||
}
|
||||
|
||||
type WebUIConfig struct {
|
||||
ListenAddr string
|
||||
MaxUploadBytes int64
|
||||
}
|
||||
|
||||
type LLMConfig struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
@@ -95,6 +101,10 @@ func Load() (Config, error) {
|
||||
ListenAddr: defaultIfEmpty(os.Getenv("FEISHU_LISTEN_ADDR"), ":8080"),
|
||||
EventPath: defaultIfEmpty(os.Getenv("FEISHU_EVENT_PATH"), "/feishu/events"),
|
||||
},
|
||||
WebUI: WebUIConfig{
|
||||
ListenAddr: defaultIfEmpty(os.Getenv("WEBUI_LISTEN_ADDR"), ":8090"),
|
||||
MaxUploadBytes: int64(intFromEnv("WEBUI_MAX_UPLOAD_MB", 20)) * 1024 * 1024,
|
||||
},
|
||||
LLM: LLMConfig{
|
||||
BaseURL: strings.TrimRight(defaultIfEmpty(os.Getenv("LLM_BASE_URL"), "https://api.openai.com/v1"), "/"),
|
||||
APIKey: strings.TrimSpace(os.Getenv("LLM_API_KEY")),
|
||||
@@ -116,8 +126,8 @@ func Load() (Config, error) {
|
||||
|
||||
cfg.MessageChannel = strings.ToLower(strings.TrimSpace(cfg.MessageChannel))
|
||||
cfg.LogLevel = strings.ToLower(strings.TrimSpace(cfg.LogLevel))
|
||||
if cfg.MessageChannel != "telegram" && cfg.MessageChannel != "feishu" {
|
||||
return Config{}, fmt.Errorf("MESSAGE_CHANNEL must be telegram or feishu")
|
||||
if cfg.MessageChannel != "telegram" && cfg.MessageChannel != "feishu" && cfg.MessageChannel != "webui" {
|
||||
return Config{}, fmt.Errorf("MESSAGE_CHANNEL must be telegram, feishu, or webui")
|
||||
}
|
||||
if cfg.LogLevel != "debug" && cfg.LogLevel != "info" && cfg.LogLevel != "warn" && cfg.LogLevel != "error" {
|
||||
return Config{}, fmt.Errorf("LOG_LEVEL must be debug, info, warn, or error")
|
||||
@@ -137,6 +147,9 @@ func Load() (Config, error) {
|
||||
if cfg.GapClusterLookbackHours < 1 || cfg.GapClusterLookbackHours > 24*365 {
|
||||
return Config{}, fmt.Errorf("GAP_CLUSTER_LOOKBACK_HOURS must be between 1 and 8760")
|
||||
}
|
||||
if cfg.WebUI.MaxUploadBytes < 1024 || cfg.WebUI.MaxUploadBytes > 200*1024*1024 {
|
||||
return Config{}, fmt.Errorf("WEBUI_MAX_UPLOAD_MB must be between 1 and 200")
|
||||
}
|
||||
|
||||
if cfg.MessageChannel == "telegram" {
|
||||
if cfg.Telegram.Token == "" {
|
||||
@@ -156,6 +169,12 @@ func Load() (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.MessageChannel == "webui" {
|
||||
if strings.TrimSpace(cfg.WebUI.ListenAddr) == "" {
|
||||
return Config{}, fmt.Errorf("WEBUI_LISTEN_ADDR is required when MESSAGE_CHANNEL=webui")
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.LLM.APIKey == "" {
|
||||
return Config{}, fmt.Errorf("LLM_API_KEY is required")
|
||||
}
|
||||
|
||||
313
internal/transport/webui/bot.go
Normal file
313
internal/transport/webui/bot.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package webui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"laodingbot/internal/config"
|
||||
"laodingbot/internal/llm"
|
||||
"laodingbot/internal/logger"
|
||||
)
|
||||
|
||||
type IncomingMessage struct {
|
||||
ChatID string
|
||||
UserID string
|
||||
Text string
|
||||
}
|
||||
|
||||
type ChatHandler func(context.Context, IncomingMessage) (string, error)
|
||||
type UploadHandler func(context.Context, string, string, []llm.InputFile) ([]string, error)
|
||||
|
||||
type Bot struct {
|
||||
listenAddr string
|
||||
maxUploadBytes int64
|
||||
log *logger.Logger
|
||||
|
||||
chatHandler ChatHandler
|
||||
uploadHandler UploadHandler
|
||||
counter uint64
|
||||
}
|
||||
|
||||
type chatRequest struct {
|
||||
Text string `json:"text"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
type chatResponse struct {
|
||||
Reply string `json:"reply"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
type uploadResponse struct {
|
||||
FileID string `json:"file_id"`
|
||||
FileIDs []string `json:"file_ids"`
|
||||
FileName string `json:"file_name"`
|
||||
MimeType string `json:"mime_type"`
|
||||
SizeBytes int `json:"size_bytes"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func NewBot(cfg config.WebUIConfig, log *logger.Logger) (*Bot, error) {
|
||||
if strings.TrimSpace(cfg.ListenAddr) == "" {
|
||||
return nil, fmt.Errorf("empty webui listen address")
|
||||
}
|
||||
if cfg.MaxUploadBytes <= 0 {
|
||||
return nil, fmt.Errorf("invalid webui max upload bytes")
|
||||
}
|
||||
return &Bot{
|
||||
listenAddr: strings.TrimSpace(cfg.ListenAddr),
|
||||
maxUploadBytes: cfg.MaxUploadBytes,
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *Bot) Run(ctx context.Context, chatHandler ChatHandler, uploadHandler UploadHandler) error {
|
||||
if chatHandler == nil {
|
||||
return fmt.Errorf("nil webui chat handler")
|
||||
}
|
||||
if uploadHandler == nil {
|
||||
return fmt.Errorf("nil webui upload handler")
|
||||
}
|
||||
b.chatHandler = chatHandler
|
||||
b.uploadHandler = uploadHandler
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/api/chat", b.handleChat)
|
||||
mux.HandleFunc("/api/upload", b.handleUpload)
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: b.listenAddr,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.ListenAndServe()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
}()
|
||||
|
||||
if b.log != nil {
|
||||
b.log.Infof("webui http transport started addr=%s", b.listenAddr)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = srv.Shutdown(shutdownCtx)
|
||||
err := <-errCh
|
||||
if b.log != nil {
|
||||
b.log.Infof("webui http transport stopped: %v", ctx.Err())
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ctx.Err()
|
||||
case err := <-errCh:
|
||||
if err != nil && b.log != nil {
|
||||
b.log.Errorf("webui http transport failed err=%v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bot) handleChat(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"})
|
||||
return
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))), "application/json") {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "content-type must be application/json"})
|
||||
return
|
||||
}
|
||||
if b.chatHandler == nil {
|
||||
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "chat handler not ready"})
|
||||
return
|
||||
}
|
||||
|
||||
var req chatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid json body"})
|
||||
return
|
||||
}
|
||||
req.Text = strings.TrimSpace(req.Text)
|
||||
if req.Text == "" {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "text is required"})
|
||||
return
|
||||
}
|
||||
sessionID := b.resolveID(req.SessionID, "sess")
|
||||
userID := b.resolveID(req.UserID, "user")
|
||||
|
||||
reply, err := b.chatHandler(r.Context(), IncomingMessage{
|
||||
ChatID: sessionID,
|
||||
UserID: userID,
|
||||
Text: req.Text,
|
||||
})
|
||||
if err != nil {
|
||||
if b.log != nil {
|
||||
b.log.Errorf("webui chat handler failed session_id=%s user_id=%s err=%v", sessionID, userID, err)
|
||||
}
|
||||
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "chat failed"})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, chatResponse{
|
||||
Reply: reply,
|
||||
SessionID: sessionID,
|
||||
UserID: userID,
|
||||
})
|
||||
}
|
||||
|
||||
func (b *Bot) handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"})
|
||||
return
|
||||
}
|
||||
if b.uploadHandler == nil {
|
||||
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload handler not ready"})
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, b.maxUploadBytes)
|
||||
if err := r.ParseMultipartForm(minInt64(b.maxUploadBytes, 32*1024*1024)); err != nil {
|
||||
if strings.Contains(strings.ToLower(err.Error()), "request body too large") {
|
||||
writeJSON(w, http.StatusRequestEntityTooLarge, errorResponse{Error: "file too large"})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid multipart form"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := b.resolveID(strings.TrimSpace(r.FormValue("session_id")), "sess")
|
||||
userID := b.resolveID(strings.TrimSpace(r.FormValue("user_id")), "user")
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "file is required"})
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
fileName := sanitizeFileName(header.Filename)
|
||||
if fileName == "" {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid file name"})
|
||||
return
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "read file failed"})
|
||||
return
|
||||
}
|
||||
if len(content) == 0 {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "empty file"})
|
||||
return
|
||||
}
|
||||
|
||||
mimeType := strings.TrimSpace(header.Header.Get("Content-Type"))
|
||||
if mimeType == "" {
|
||||
mimeType = detectMimeByName(fileName)
|
||||
}
|
||||
|
||||
ids, err := b.uploadHandler(r.Context(), sessionID, userID, []llm.InputFile{{
|
||||
FileName: fileName,
|
||||
MimeType: mimeType,
|
||||
Content: content,
|
||||
}})
|
||||
if err != nil {
|
||||
if b.log != nil {
|
||||
b.log.Errorf("webui upload handler failed session_id=%s user_id=%s file=%s err=%v", sessionID, userID, fileName, err)
|
||||
}
|
||||
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload failed"})
|
||||
return
|
||||
}
|
||||
if len(ids) == 0 || strings.TrimSpace(ids[0]) == "" {
|
||||
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload succeeded but file_id is empty"})
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, uploadResponse{
|
||||
FileID: strings.TrimSpace(ids[0]),
|
||||
FileIDs: ids,
|
||||
FileName: fileName,
|
||||
MimeType: mimeType,
|
||||
SizeBytes: len(content),
|
||||
SessionID: sessionID,
|
||||
UserID: userID,
|
||||
})
|
||||
}
|
||||
|
||||
func (b *Bot) resolveID(raw, prefix string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw != "" {
|
||||
return raw
|
||||
}
|
||||
n := atomic.AddUint64(&b.counter, 1)
|
||||
return fmt.Sprintf("%s_%d_%d", prefix, time.Now().UnixNano(), n)
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
|
||||
func minInt64(a, b int64) int64 {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func detectMimeByName(fileName string) string {
|
||||
ext := strings.ToLower(strings.TrimSpace(filepath.Ext(fileName)))
|
||||
if ext == "" {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
m := strings.TrimSpace(mime.TypeByExtension(ext))
|
||||
if m == "" {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func sanitizeFileName(fileName string) string {
|
||||
name := strings.TrimSpace(filepath.Base(fileName))
|
||||
if name == "" || name == "." || name == ".." {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, r := range name {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' {
|
||||
b.WriteRune(r)
|
||||
continue
|
||||
}
|
||||
b.WriteByte('_')
|
||||
}
|
||||
out := strings.TrimSpace(b.String())
|
||||
if out == "" || out == "." || out == ".." {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(out, ".") {
|
||||
out = "file" + out
|
||||
}
|
||||
return out
|
||||
}
|
||||
157
internal/transport/webui/bot_test.go
Normal file
157
internal/transport/webui/bot_test.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package webui
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"laodingbot/internal/config"
|
||||
"laodingbot/internal/llm"
|
||||
)
|
||||
|
||||
func newTestBot(t *testing.T, maxUploadBytes int64) *Bot {
|
||||
t.Helper()
|
||||
b, err := NewBot(config.WebUIConfig{ListenAddr: ":8090", MaxUploadBytes: maxUploadBytes}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewBot failed: %v", err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func TestHandleChatSuccess(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.chatHandler = func(_ context.Context, msg IncomingMessage) (string, error) {
|
||||
if msg.ChatID != "s1" || msg.UserID != "u1" || msg.Text != "hello" {
|
||||
t.Fatalf("unexpected message: %+v", msg)
|
||||
}
|
||||
return "ok", nil
|
||||
}
|
||||
|
||||
body := strings.NewReader(`{"text":"hello","session_id":"s1","user_id":"u1"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChat(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var out chatResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
if out.Reply != "ok" || out.SessionID != "s1" || out.UserID != "u1" {
|
||||
t.Fatalf("unexpected response: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatMissingText(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.chatHandler = func(_ context.Context, _ IncomingMessage) (string, error) { return "", nil }
|
||||
|
||||
body := strings.NewReader(`{"text":" "}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChat(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleUploadSuccess(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.uploadHandler = func(_ context.Context, chatID, userID string, files []llm.InputFile) ([]string, error) {
|
||||
if chatID != "s1" || userID != "u1" {
|
||||
t.Fatalf("unexpected ids chat=%s user=%s", chatID, userID)
|
||||
}
|
||||
if len(files) != 1 {
|
||||
t.Fatalf("unexpected files len=%d", len(files))
|
||||
}
|
||||
if files[0].FileName != "doc.pdf" || len(files[0].Content) == 0 {
|
||||
t.Fatalf("unexpected file payload: %+v", files[0])
|
||||
}
|
||||
return []string{"file_123"}, nil
|
||||
}
|
||||
|
||||
var payload bytes.Buffer
|
||||
writer := multipart.NewWriter(&payload)
|
||||
_ = writer.WriteField("session_id", "s1")
|
||||
_ = writer.WriteField("user_id", "u1")
|
||||
fw, err := writer.CreateFormFile("file", "doc.pdf")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateFormFile failed: %v", err)
|
||||
}
|
||||
_, _ = fw.Write([]byte("pdf-content"))
|
||||
_ = writer.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleUpload(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var out uploadResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
if out.FileID != "file_123" || out.SessionID != "s1" || out.UserID != "u1" {
|
||||
t.Fatalf("unexpected response: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleUploadTooLarge(t *testing.T) {
|
||||
b := newTestBot(t, 3)
|
||||
b.uploadHandler = func(_ context.Context, _ string, _ string, _ []llm.InputFile) ([]string, error) {
|
||||
return []string{"file_should_not_reach"}, nil
|
||||
}
|
||||
|
||||
var payload bytes.Buffer
|
||||
writer := multipart.NewWriter(&payload)
|
||||
fw, err := writer.CreateFormFile("file", "a.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateFormFile failed: %v", err)
|
||||
}
|
||||
_, _ = fw.Write([]byte("12345"))
|
||||
_ = writer.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleUpload(w, req)
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Fatalf("expected 413, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleUploadMissingFile(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.uploadHandler = func(_ context.Context, _ string, _ string, _ []llm.InputFile) ([]string, error) {
|
||||
return []string{"file_should_not_reach"}, nil
|
||||
}
|
||||
|
||||
var payload bytes.Buffer
|
||||
writer := multipart.NewWriter(&payload)
|
||||
_ = writer.WriteField("session_id", "s1")
|
||||
_ = writer.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleUpload(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user