Files
LaodingBot/internal/transport/webui/bot_test.go

218 lines
6.6 KiB
Go
Raw Normal View History

package webui
import (
"bytes"
"context"
"encoding/json"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"laodingbot/internal/config"
"laodingbot/internal/llm"
)
func newTestBot(t *testing.T, maxUploadBytes int64) *Bot {
t.Helper()
b, err := NewBot(config.WebUIConfig{ListenAddr: ":8090", MaxUploadBytes: maxUploadBytes}, nil)
if err != nil {
t.Fatalf("NewBot failed: %v", err)
}
return b
}
func TestHandleChatSuccess(t *testing.T) {
b := newTestBot(t, 1024*1024)
b.chatHandler = func(_ context.Context, msg IncomingMessage) (string, error) {
if msg.ChatID != "s1" || msg.UserID != "u1" || msg.Text != "hello" {
t.Fatalf("unexpected message: %+v", msg)
}
return "ok", nil
}
body := strings.NewReader(`{"text":"hello","session_id":"s1","user_id":"u1"}`)
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
b.handleChat(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
}
var out chatResponse
if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v", err)
}
if out.Reply != "ok" || out.SessionID != "s1" || out.UserID != "u1" {
t.Fatalf("unexpected response: %+v", out)
}
}
func TestHandleChatWithFileIDs(t *testing.T) {
b := newTestBot(t, 1024*1024)
b.chatHandler = func(_ context.Context, msg IncomingMessage) (string, error) {
if msg.ChatID != "s1" || msg.UserID != "u1" || msg.Text != "hello" {
t.Fatalf("unexpected message: %+v", msg)
}
if len(msg.FileIDs) != 2 || msg.FileIDs[0] != "file_a" || msg.FileIDs[1] != "file_b" {
t.Fatalf("unexpected file ids: %+v", msg.FileIDs)
}
return "ok", nil
}
body := strings.NewReader(`{"text":"hello","session_id":"s1","user_id":"u1","file_ids":["file_a","file_b"]}`)
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
b.handleChat(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
}
}
func TestHandleChatWithFileIDsAliases(t *testing.T) {
tests := []struct {
name string
body string
}{
{name: "camel array", body: `{"text":"hello","sessionId":"s1","userId":"u1","fileIds":["file_a","file_b"]}`},
{name: "flat array", body: `{"text":"hello","session_id":"s1","user_id":"u1","fileids":["file_a","file_b"]}`},
{name: "single key", body: `{"text":"hello","session_id":"s1","user_id":"u1","file_id":"file_a"}`},
{name: "csv string", body: `{"text":"hello","session_id":"s1","user_id":"u1","file_ids":"file_a, file_b"}`},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b := newTestBot(t, 1024*1024)
b.chatHandler = func(_ context.Context, msg IncomingMessage) (string, error) {
if msg.ChatID != "s1" || msg.UserID != "u1" || msg.Text != "hello" {
t.Fatalf("unexpected message: %+v", msg)
}
if len(msg.FileIDs) == 0 {
t.Fatalf("expected file ids from alias payload, got empty")
}
return "ok", nil
}
body := strings.NewReader(tt.body)
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
b.handleChat(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
}
})
}
}
func TestHandleChatMissingText(t *testing.T) {
b := newTestBot(t, 1024*1024)
b.chatHandler = func(_ context.Context, _ IncomingMessage) (string, error) { return "", nil }
body := strings.NewReader(`{"text":" "}`)
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
b.handleChat(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}
func TestHandleUploadSuccess(t *testing.T) {
b := newTestBot(t, 1024*1024)
b.uploadHandler = func(_ context.Context, chatID, userID string, files []llm.InputFile) ([]string, error) {
if chatID != "s1" || userID != "u1" {
t.Fatalf("unexpected ids chat=%s user=%s", chatID, userID)
}
if len(files) != 1 {
t.Fatalf("unexpected files len=%d", len(files))
}
if files[0].FileName != "doc.pdf" || len(files[0].Content) == 0 {
t.Fatalf("unexpected file payload: %+v", files[0])
}
return []string{"file_123"}, nil
}
var payload bytes.Buffer
writer := multipart.NewWriter(&payload)
_ = writer.WriteField("session_id", "s1")
_ = writer.WriteField("user_id", "u1")
fw, err := writer.CreateFormFile("file", "doc.pdf")
if err != nil {
t.Fatalf("CreateFormFile failed: %v", err)
}
_, _ = fw.Write([]byte("pdf-content"))
_ = writer.Close()
req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload)
req.Header.Set("Content-Type", writer.FormDataContentType())
w := httptest.NewRecorder()
b.handleUpload(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
}
var out uploadResponse
if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil {
t.Fatalf("decode response failed: %v", err)
}
if out.FileID != "file_123" || out.SessionID != "s1" || out.UserID != "u1" {
t.Fatalf("unexpected response: %+v", out)
}
}
func TestHandleUploadTooLarge(t *testing.T) {
b := newTestBot(t, 3)
b.uploadHandler = func(_ context.Context, _ string, _ string, _ []llm.InputFile) ([]string, error) {
return []string{"file_should_not_reach"}, nil
}
var payload bytes.Buffer
writer := multipart.NewWriter(&payload)
fw, err := writer.CreateFormFile("file", "a.txt")
if err != nil {
t.Fatalf("CreateFormFile failed: %v", err)
}
_, _ = fw.Write([]byte("12345"))
_ = writer.Close()
req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload)
req.Header.Set("Content-Type", writer.FormDataContentType())
w := httptest.NewRecorder()
b.handleUpload(w, req)
if w.Code != http.StatusRequestEntityTooLarge {
t.Fatalf("expected 413, got %d body=%s", w.Code, w.Body.String())
}
}
func TestHandleUploadMissingFile(t *testing.T) {
b := newTestBot(t, 1024*1024)
b.uploadHandler = func(_ context.Context, _ string, _ string, _ []llm.InputFile) ([]string, error) {
return []string{"file_should_not_reach"}, nil
}
var payload bytes.Buffer
writer := multipart.NewWriter(&payload)
_ = writer.WriteField("session_id", "s1")
_ = writer.Close()
req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload)
req.Header.Set("Content-Type", writer.FormDataContentType())
w := httptest.NewRecorder()
b.handleUpload(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", w.Code)
}
}