feat: add webui http channel for chat and file upload
This commit is contained in:
313
internal/transport/webui/bot.go
Normal file
313
internal/transport/webui/bot.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package webui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"laodingbot/internal/config"
|
||||
"laodingbot/internal/llm"
|
||||
"laodingbot/internal/logger"
|
||||
)
|
||||
|
||||
type IncomingMessage struct {
|
||||
ChatID string
|
||||
UserID string
|
||||
Text string
|
||||
}
|
||||
|
||||
type ChatHandler func(context.Context, IncomingMessage) (string, error)
|
||||
type UploadHandler func(context.Context, string, string, []llm.InputFile) ([]string, error)
|
||||
|
||||
type Bot struct {
|
||||
listenAddr string
|
||||
maxUploadBytes int64
|
||||
log *logger.Logger
|
||||
|
||||
chatHandler ChatHandler
|
||||
uploadHandler UploadHandler
|
||||
counter uint64
|
||||
}
|
||||
|
||||
type chatRequest struct {
|
||||
Text string `json:"text"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
type chatResponse struct {
|
||||
Reply string `json:"reply"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
type uploadResponse struct {
|
||||
FileID string `json:"file_id"`
|
||||
FileIDs []string `json:"file_ids"`
|
||||
FileName string `json:"file_name"`
|
||||
MimeType string `json:"mime_type"`
|
||||
SizeBytes int `json:"size_bytes"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func NewBot(cfg config.WebUIConfig, log *logger.Logger) (*Bot, error) {
|
||||
if strings.TrimSpace(cfg.ListenAddr) == "" {
|
||||
return nil, fmt.Errorf("empty webui listen address")
|
||||
}
|
||||
if cfg.MaxUploadBytes <= 0 {
|
||||
return nil, fmt.Errorf("invalid webui max upload bytes")
|
||||
}
|
||||
return &Bot{
|
||||
listenAddr: strings.TrimSpace(cfg.ListenAddr),
|
||||
maxUploadBytes: cfg.MaxUploadBytes,
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *Bot) Run(ctx context.Context, chatHandler ChatHandler, uploadHandler UploadHandler) error {
|
||||
if chatHandler == nil {
|
||||
return fmt.Errorf("nil webui chat handler")
|
||||
}
|
||||
if uploadHandler == nil {
|
||||
return fmt.Errorf("nil webui upload handler")
|
||||
}
|
||||
b.chatHandler = chatHandler
|
||||
b.uploadHandler = uploadHandler
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/api/chat", b.handleChat)
|
||||
mux.HandleFunc("/api/upload", b.handleUpload)
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: b.listenAddr,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.ListenAndServe()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
}()
|
||||
|
||||
if b.log != nil {
|
||||
b.log.Infof("webui http transport started addr=%s", b.listenAddr)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = srv.Shutdown(shutdownCtx)
|
||||
err := <-errCh
|
||||
if b.log != nil {
|
||||
b.log.Infof("webui http transport stopped: %v", ctx.Err())
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ctx.Err()
|
||||
case err := <-errCh:
|
||||
if err != nil && b.log != nil {
|
||||
b.log.Errorf("webui http transport failed err=%v", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Bot) handleChat(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"})
|
||||
return
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))), "application/json") {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "content-type must be application/json"})
|
||||
return
|
||||
}
|
||||
if b.chatHandler == nil {
|
||||
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "chat handler not ready"})
|
||||
return
|
||||
}
|
||||
|
||||
var req chatRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid json body"})
|
||||
return
|
||||
}
|
||||
req.Text = strings.TrimSpace(req.Text)
|
||||
if req.Text == "" {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "text is required"})
|
||||
return
|
||||
}
|
||||
sessionID := b.resolveID(req.SessionID, "sess")
|
||||
userID := b.resolveID(req.UserID, "user")
|
||||
|
||||
reply, err := b.chatHandler(r.Context(), IncomingMessage{
|
||||
ChatID: sessionID,
|
||||
UserID: userID,
|
||||
Text: req.Text,
|
||||
})
|
||||
if err != nil {
|
||||
if b.log != nil {
|
||||
b.log.Errorf("webui chat handler failed session_id=%s user_id=%s err=%v", sessionID, userID, err)
|
||||
}
|
||||
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "chat failed"})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, chatResponse{
|
||||
Reply: reply,
|
||||
SessionID: sessionID,
|
||||
UserID: userID,
|
||||
})
|
||||
}
|
||||
|
||||
func (b *Bot) handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"})
|
||||
return
|
||||
}
|
||||
if b.uploadHandler == nil {
|
||||
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload handler not ready"})
|
||||
return
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, b.maxUploadBytes)
|
||||
if err := r.ParseMultipartForm(minInt64(b.maxUploadBytes, 32*1024*1024)); err != nil {
|
||||
if strings.Contains(strings.ToLower(err.Error()), "request body too large") {
|
||||
writeJSON(w, http.StatusRequestEntityTooLarge, errorResponse{Error: "file too large"})
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid multipart form"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := b.resolveID(strings.TrimSpace(r.FormValue("session_id")), "sess")
|
||||
userID := b.resolveID(strings.TrimSpace(r.FormValue("user_id")), "user")
|
||||
|
||||
file, header, err := r.FormFile("file")
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "file is required"})
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
fileName := sanitizeFileName(header.Filename)
|
||||
if fileName == "" {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid file name"})
|
||||
return
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "read file failed"})
|
||||
return
|
||||
}
|
||||
if len(content) == 0 {
|
||||
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "empty file"})
|
||||
return
|
||||
}
|
||||
|
||||
mimeType := strings.TrimSpace(header.Header.Get("Content-Type"))
|
||||
if mimeType == "" {
|
||||
mimeType = detectMimeByName(fileName)
|
||||
}
|
||||
|
||||
ids, err := b.uploadHandler(r.Context(), sessionID, userID, []llm.InputFile{{
|
||||
FileName: fileName,
|
||||
MimeType: mimeType,
|
||||
Content: content,
|
||||
}})
|
||||
if err != nil {
|
||||
if b.log != nil {
|
||||
b.log.Errorf("webui upload handler failed session_id=%s user_id=%s file=%s err=%v", sessionID, userID, fileName, err)
|
||||
}
|
||||
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload failed"})
|
||||
return
|
||||
}
|
||||
if len(ids) == 0 || strings.TrimSpace(ids[0]) == "" {
|
||||
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload succeeded but file_id is empty"})
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, uploadResponse{
|
||||
FileID: strings.TrimSpace(ids[0]),
|
||||
FileIDs: ids,
|
||||
FileName: fileName,
|
||||
MimeType: mimeType,
|
||||
SizeBytes: len(content),
|
||||
SessionID: sessionID,
|
||||
UserID: userID,
|
||||
})
|
||||
}
|
||||
|
||||
func (b *Bot) resolveID(raw, prefix string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw != "" {
|
||||
return raw
|
||||
}
|
||||
n := atomic.AddUint64(&b.counter, 1)
|
||||
return fmt.Sprintf("%s_%d_%d", prefix, time.Now().UnixNano(), n)
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, payload any) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
|
||||
func minInt64(a, b int64) int64 {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func detectMimeByName(fileName string) string {
|
||||
ext := strings.ToLower(strings.TrimSpace(filepath.Ext(fileName)))
|
||||
if ext == "" {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
m := strings.TrimSpace(mime.TypeByExtension(ext))
|
||||
if m == "" {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func sanitizeFileName(fileName string) string {
|
||||
name := strings.TrimSpace(filepath.Base(fileName))
|
||||
if name == "" || name == "." || name == ".." {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, r := range name {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' {
|
||||
b.WriteRune(r)
|
||||
continue
|
||||
}
|
||||
b.WriteByte('_')
|
||||
}
|
||||
out := strings.TrimSpace(b.String())
|
||||
if out == "" || out == "." || out == ".." {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(out, ".") {
|
||||
out = "file" + out
|
||||
}
|
||||
return out
|
||||
}
|
||||
157
internal/transport/webui/bot_test.go
Normal file
157
internal/transport/webui/bot_test.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package webui
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"laodingbot/internal/config"
|
||||
"laodingbot/internal/llm"
|
||||
)
|
||||
|
||||
func newTestBot(t *testing.T, maxUploadBytes int64) *Bot {
|
||||
t.Helper()
|
||||
b, err := NewBot(config.WebUIConfig{ListenAddr: ":8090", MaxUploadBytes: maxUploadBytes}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewBot failed: %v", err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func TestHandleChatSuccess(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.chatHandler = func(_ context.Context, msg IncomingMessage) (string, error) {
|
||||
if msg.ChatID != "s1" || msg.UserID != "u1" || msg.Text != "hello" {
|
||||
t.Fatalf("unexpected message: %+v", msg)
|
||||
}
|
||||
return "ok", nil
|
||||
}
|
||||
|
||||
body := strings.NewReader(`{"text":"hello","session_id":"s1","user_id":"u1"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChat(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var out chatResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
if out.Reply != "ok" || out.SessionID != "s1" || out.UserID != "u1" {
|
||||
t.Fatalf("unexpected response: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatMissingText(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.chatHandler = func(_ context.Context, _ IncomingMessage) (string, error) { return "", nil }
|
||||
|
||||
body := strings.NewReader(`{"text":" "}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChat(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleUploadSuccess(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.uploadHandler = func(_ context.Context, chatID, userID string, files []llm.InputFile) ([]string, error) {
|
||||
if chatID != "s1" || userID != "u1" {
|
||||
t.Fatalf("unexpected ids chat=%s user=%s", chatID, userID)
|
||||
}
|
||||
if len(files) != 1 {
|
||||
t.Fatalf("unexpected files len=%d", len(files))
|
||||
}
|
||||
if files[0].FileName != "doc.pdf" || len(files[0].Content) == 0 {
|
||||
t.Fatalf("unexpected file payload: %+v", files[0])
|
||||
}
|
||||
return []string{"file_123"}, nil
|
||||
}
|
||||
|
||||
var payload bytes.Buffer
|
||||
writer := multipart.NewWriter(&payload)
|
||||
_ = writer.WriteField("session_id", "s1")
|
||||
_ = writer.WriteField("user_id", "u1")
|
||||
fw, err := writer.CreateFormFile("file", "doc.pdf")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateFormFile failed: %v", err)
|
||||
}
|
||||
_, _ = fw.Write([]byte("pdf-content"))
|
||||
_ = writer.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleUpload(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var out uploadResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &out); err != nil {
|
||||
t.Fatalf("decode response failed: %v", err)
|
||||
}
|
||||
if out.FileID != "file_123" || out.SessionID != "s1" || out.UserID != "u1" {
|
||||
t.Fatalf("unexpected response: %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleUploadTooLarge(t *testing.T) {
|
||||
b := newTestBot(t, 3)
|
||||
b.uploadHandler = func(_ context.Context, _ string, _ string, _ []llm.InputFile) ([]string, error) {
|
||||
return []string{"file_should_not_reach"}, nil
|
||||
}
|
||||
|
||||
var payload bytes.Buffer
|
||||
writer := multipart.NewWriter(&payload)
|
||||
fw, err := writer.CreateFormFile("file", "a.txt")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateFormFile failed: %v", err)
|
||||
}
|
||||
_, _ = fw.Write([]byte("12345"))
|
||||
_ = writer.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleUpload(w, req)
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Fatalf("expected 413, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleUploadMissingFile(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.uploadHandler = func(_ context.Context, _ string, _ string, _ []llm.InputFile) ([]string, error) {
|
||||
return []string{"file_should_not_reach"}, nil
|
||||
}
|
||||
|
||||
var payload bytes.Buffer
|
||||
writer := multipart.NewWriter(&payload)
|
||||
_ = writer.WriteField("session_id", "s1")
|
||||
_ = writer.Close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/upload", &payload)
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleUpload(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user