package webui import ( "context" "encoding/json" "fmt" "io" "mime" "net/http" "path/filepath" "strings" "sync/atomic" "time" "laodingbot/internal/config" "laodingbot/internal/llm" "laodingbot/internal/logger" ) type IncomingMessage struct { ChatID string UserID string Text string FileIDs []string } type ChatHandler func(context.Context, IncomingMessage) (string, error) type UploadHandler func(context.Context, string, string, []llm.InputFile) ([]string, error) type Bot struct { listenAddr string maxUploadBytes int64 log *logger.Logger chatHandler ChatHandler uploadHandler UploadHandler counter uint64 } type chatRequest struct { Text string `json:"text"` SessionID string `json:"session_id"` UserID string `json:"user_id"` FileIDs []string `json:"file_ids"` } func (r *chatRequest) UnmarshalJSON(data []byte) error { type rawChatRequest struct { Text string `json:"text"` SessionID string `json:"session_id"` SessionIDCamel string `json:"sessionId"` UserID string `json:"user_id"` UserIDCamel string `json:"userId"` FileIDs json.RawMessage `json:"file_ids"` FileIDsCamel json.RawMessage `json:"fileIds"` FileIDsFlat json.RawMessage `json:"fileids"` FileID json.RawMessage `json:"file_id"` } var raw rawChatRequest if err := json.Unmarshal(data, &raw); err != nil { return err } r.Text = raw.Text r.SessionID = firstNonEmpty(raw.SessionID, raw.SessionIDCamel) r.UserID = firstNonEmpty(raw.UserID, raw.UserIDCamel) rawIDs := firstNonEmptyRaw(raw.FileIDs, raw.FileIDsCamel, raw.FileIDsFlat, raw.FileID) ids, err := decodeStringList(rawIDs) if err != nil { return err } r.FileIDs = ids return nil } type chatResponse struct { Reply string `json:"reply"` SessionID string `json:"session_id"` UserID string `json:"user_id"` } type uploadResponse struct { FileID string `json:"file_id"` FileIDs []string `json:"file_ids"` FileName string `json:"file_name"` MimeType string `json:"mime_type"` SizeBytes int `json:"size_bytes"` SessionID string `json:"session_id"` UserID string `json:"user_id"` } type errorResponse struct { Error string `json:"error"` } func NewBot(cfg config.WebUIConfig, log *logger.Logger) (*Bot, error) { if strings.TrimSpace(cfg.ListenAddr) == "" { return nil, fmt.Errorf("empty webui listen address") } if cfg.MaxUploadBytes <= 0 { return nil, fmt.Errorf("invalid webui max upload bytes") } return &Bot{ listenAddr: strings.TrimSpace(cfg.ListenAddr), maxUploadBytes: cfg.MaxUploadBytes, log: log, }, nil } func (b *Bot) Run(ctx context.Context, chatHandler ChatHandler, uploadHandler UploadHandler) error { if chatHandler == nil { return fmt.Errorf("nil webui chat handler") } if uploadHandler == nil { return fmt.Errorf("nil webui upload handler") } b.chatHandler = chatHandler b.uploadHandler = uploadHandler mux := http.NewServeMux() mux.HandleFunc("/api/chat", b.handleChat) mux.HandleFunc("/api/upload", b.handleUpload) srv := &http.Server{ Addr: b.listenAddr, Handler: mux, } errCh := make(chan error, 1) go func() { err := srv.ListenAndServe() if err != nil && err != http.ErrServerClosed { errCh <- err return } errCh <- nil }() if b.log != nil { b.log.Infof("webui http transport started addr=%s", b.listenAddr) } select { case <-ctx.Done(): shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = srv.Shutdown(shutdownCtx) err := <-errCh if b.log != nil { b.log.Infof("webui http transport stopped: %v", ctx.Err()) } if err != nil { return err } return ctx.Err() case err := <-errCh: if err != nil && b.log != nil { b.log.Errorf("webui http transport failed err=%v", err) } return err } } func (b *Bot) handleChat(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"}) return } if !strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))), "application/json") { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "content-type must be application/json"}) return } if b.chatHandler == nil { writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "chat handler not ready"}) return } var req chatRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid json body"}) return } req.Text = strings.TrimSpace(req.Text) if req.Text == "" { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "text is required"}) return } sessionID := b.resolveID(req.SessionID, "sess") userID := b.resolveID(req.UserID, "user") reply, err := b.chatHandler(r.Context(), IncomingMessage{ ChatID: sessionID, UserID: userID, Text: req.Text, FileIDs: req.FileIDs, }) if err != nil { if b.log != nil { b.log.Errorf("webui chat handler failed session_id=%s user_id=%s err=%v", sessionID, userID, err) } writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "chat failed"}) return } writeJSON(w, http.StatusOK, chatResponse{ Reply: reply, SessionID: sessionID, UserID: userID, }) } func decodeStringList(raw json.RawMessage) ([]string, error) { if len(raw) == 0 { return nil, nil } var list []string if err := json.Unmarshal(raw, &list); err == nil { return nonEmptyIDs(list), nil } var single string if err := json.Unmarshal(raw, &single); err == nil { if strings.TrimSpace(single) == "" { return nil, nil } return nonEmptyIDs(strings.Split(single, ",")), nil } return nil, fmt.Errorf("invalid file ids format") } func firstNonEmptyRaw(vals ...json.RawMessage) json.RawMessage { for _, v := range vals { if len(v) > 0 { return v } } return nil } func firstNonEmpty(vals ...string) string { for _, v := range vals { if strings.TrimSpace(v) != "" { return v } } return "" } func nonEmptyIDs(ids []string) []string { if len(ids) == 0 { return nil } out := make([]string, 0, len(ids)) seen := map[string]struct{}{} for _, id := range ids { id = strings.TrimSpace(id) if id == "" { continue } if _, ok := seen[id]; ok { continue } seen[id] = struct{}{} out = append(out, id) } return out } func (b *Bot) handleUpload(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"}) return } if b.uploadHandler == nil { writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload handler not ready"}) return } r.Body = http.MaxBytesReader(w, r.Body, b.maxUploadBytes) if err := r.ParseMultipartForm(minInt64(b.maxUploadBytes, 32*1024*1024)); err != nil { if strings.Contains(strings.ToLower(err.Error()), "request body too large") { writeJSON(w, http.StatusRequestEntityTooLarge, errorResponse{Error: "file too large"}) return } writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid multipart form"}) return } sessionID := b.resolveID(strings.TrimSpace(r.FormValue("session_id")), "sess") userID := b.resolveID(strings.TrimSpace(r.FormValue("user_id")), "user") file, header, err := r.FormFile("file") if err != nil { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "file is required"}) return } defer file.Close() fileName := sanitizeFileName(header.Filename) if fileName == "" { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid file name"}) return } content, err := io.ReadAll(file) if err != nil { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "read file failed"}) return } if len(content) == 0 { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "empty file"}) return } mimeType := strings.TrimSpace(header.Header.Get("Content-Type")) if mimeType == "" { mimeType = detectMimeByName(fileName) } ids, err := b.uploadHandler(r.Context(), sessionID, userID, []llm.InputFile{{ FileName: fileName, MimeType: mimeType, Content: content, }}) if err != nil { if b.log != nil { b.log.Errorf("webui upload handler failed session_id=%s user_id=%s file=%s err=%v", sessionID, userID, fileName, err) } writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload failed"}) return } if len(ids) == 0 || strings.TrimSpace(ids[0]) == "" { writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload succeeded but file_id is empty"}) return } writeJSON(w, http.StatusOK, uploadResponse{ FileID: strings.TrimSpace(ids[0]), FileIDs: ids, FileName: fileName, MimeType: mimeType, SizeBytes: len(content), SessionID: sessionID, UserID: userID, }) } func (b *Bot) resolveID(raw, prefix string) string { raw = strings.TrimSpace(raw) if raw != "" { return raw } n := atomic.AddUint64(&b.counter, 1) return fmt.Sprintf("%s_%d_%d", prefix, time.Now().UnixNano(), n) } func writeJSON(w http.ResponseWriter, status int, payload any) { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(status) _ = json.NewEncoder(w).Encode(payload) } func minInt64(a, b int64) int64 { if a < b { return a } return b } func detectMimeByName(fileName string) string { ext := strings.ToLower(strings.TrimSpace(filepath.Ext(fileName))) if ext == "" { return "application/octet-stream" } m := strings.TrimSpace(mime.TypeByExtension(ext)) if m == "" { return "application/octet-stream" } return m } func sanitizeFileName(fileName string) string { name := strings.TrimSpace(filepath.Base(fileName)) if name == "" || name == "." || name == ".." { return "" } var b strings.Builder for _, r := range name { if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' { b.WriteRune(r) continue } b.WriteByte('_') } out := strings.TrimSpace(b.String()) if out == "" || out == "." || out == ".." { return "" } if strings.HasPrefix(out, ".") { out = "file" + out } return out }