package webui import ( "context" "encoding/json" "fmt" "io" "mime" "net/http" "path/filepath" "strings" "sync/atomic" "time" "laodingbot/internal/config" "laodingbot/internal/llm" "laodingbot/internal/logger" "laodingbot/internal/memory" "strconv" ) type IncomingMessage struct { ChatID string UserID string Text string } // StreamEventType 定义流式输出的事件类型 type StreamEventType string const ( StreamEventTypeThought StreamEventType = "thought" // LLM 思考过程 StreamEventTypeToolCall StreamEventType = "tool_call" // 工具调用请求 StreamEventTypeToolResult StreamEventType = "tool_result" // 工具执行结果 StreamEventTypeFinal StreamEventType = "final" // 最终答案 StreamEventTypeError StreamEventType = "error" // 错误信息 StreamEventTypeWorkspaceStart StreamEventType = "workspace_start" // 工具渲染开始 StreamEventTypeWorkspaceDelta StreamEventType = "workspace_delta" // 工具渲染增量内容 StreamEventTypeWorkspaceEnd StreamEventType = "workspace_end" // 工具渲染结束 ) // StreamEvent 代表流式输出中的一个事件 type StreamEvent struct { Type StreamEventType `json:"type"` Content string `json:"content"` Step int `json:"step,omitempty"` ToolName string `json:"tool_name,omitempty"` WorkspaceTitle string `json:"workspace_title,omitempty"` // 仅用于 workspace_start 类型 } type ChatHandler func(context.Context, IncomingMessage) (string, error) type StreamChatHandler func(context.Context, IncomingMessage, StreamEventCallback) (string, error) type StreamEventCallback func(event StreamEvent) error type UploadHandler func(context.Context, string, string, []llm.InputFile) ([]string, error) type HistoryHandler func(context.Context, string, int) ([]memory.Message, error) type Bot struct { listenAddr string maxUploadBytes int64 log *logger.Logger chatHandler ChatHandler streamChatHandler StreamChatHandler uploadHandler UploadHandler historyHandler HistoryHandler counter uint64 } type chatRequest struct { Text string `json:"text"` SessionID string `json:"session_id"` UserID string `json:"user_id"` } func (r *chatRequest) UnmarshalJSON(data []byte) error { type rawChatRequest struct { Text string `json:"text"` SessionID string `json:"session_id"` SessionIDCamel string `json:"sessionId"` UserID string `json:"user_id"` UserIDCamel string `json:"userId"` } var raw rawChatRequest if err := json.Unmarshal(data, &raw); err != nil { return err } r.Text = raw.Text r.SessionID = firstNonEmpty(raw.SessionID, raw.SessionIDCamel) r.UserID = firstNonEmpty(raw.UserID, raw.UserIDCamel) return nil } type chatResponse struct { Reply string `json:"reply"` SessionID string `json:"session_id"` UserID string `json:"user_id"` } type uploadResponse struct { FileID string `json:"file_id"` FileIDs []string `json:"file_ids"` FileName string `json:"file_name"` MimeType string `json:"mime_type"` SizeBytes int `json:"size_bytes"` SessionID string `json:"session_id"` UserID string `json:"user_id"` } type errorResponse struct { Error string `json:"error"` } func NewBot(cfg config.WebUIConfig, log *logger.Logger) (*Bot, error) { if strings.TrimSpace(cfg.ListenAddr) == "" { return nil, fmt.Errorf("empty webui listen address") } if cfg.MaxUploadBytes <= 0 { return nil, fmt.Errorf("invalid webui max upload bytes") } return &Bot{ listenAddr: strings.TrimSpace(cfg.ListenAddr), maxUploadBytes: cfg.MaxUploadBytes, log: log, }, nil } func (b *Bot) Run(ctx context.Context, chatHandler ChatHandler, streamChatHandler StreamChatHandler, uploadHandler UploadHandler, historyHandler HistoryHandler) error { if chatHandler == nil { return fmt.Errorf("nil webui chat handler") } if uploadHandler == nil { return fmt.Errorf("nil webui upload handler") } b.chatHandler = chatHandler b.streamChatHandler = streamChatHandler b.uploadHandler = uploadHandler b.historyHandler = historyHandler mux := http.NewServeMux() mux.HandleFunc("/api/chat", b.handleChat) mux.HandleFunc("/api/chat/stream", b.handleChatStream) mux.HandleFunc("/api/upload", b.handleUpload) mux.HandleFunc("/api/history", b.handleHistory) srv := &http.Server{ Addr: b.listenAddr, Handler: mux, } errCh := make(chan error, 1) go func() { err := srv.ListenAndServe() if err != nil && err != http.ErrServerClosed { errCh <- err return } errCh <- nil }() if b.log != nil { b.log.Infof("webui http transport started addr=%s", b.listenAddr) } select { case <-ctx.Done(): shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = srv.Shutdown(shutdownCtx) err := <-errCh if b.log != nil { b.log.Infof("webui http transport stopped: %v", ctx.Err()) } if err != nil { return err } return ctx.Err() case err := <-errCh: if err != nil && b.log != nil { b.log.Errorf("webui http transport failed err=%v", err) } return err } } func (b *Bot) handleChat(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"}) return } if !strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))), "application/json") { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "content-type must be application/json"}) return } if b.chatHandler == nil { writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "chat handler not ready"}) return } var req chatRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid json body"}) return } req.Text = strings.TrimSpace(req.Text) if req.Text == "" { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "text is required"}) return } sessionID := b.resolveID(req.SessionID, "sess") userID := b.resolveID(req.UserID, "user") reply, err := b.chatHandler(r.Context(), IncomingMessage{ ChatID: sessionID, UserID: userID, Text: req.Text, }) if err != nil { if b.log != nil { b.log.Errorf("webui chat handler failed session_id=%s user_id=%s err=%v", sessionID, userID, err) } writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "chat failed"}) return } writeJSON(w, http.StatusOK, chatResponse{ Reply: reply, SessionID: sessionID, UserID: userID, }) } func (b *Bot) handleHistory(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"}) return } if b.historyHandler == nil { writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "history handler not ready"}) return } sessionID := strings.TrimSpace(r.URL.Query().Get("session_id")) if sessionID == "" { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "session_id is required"}) return } limitStr := strings.TrimSpace(r.URL.Query().Get("limit")) limit := 20 if limitStr != "" { if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { limit = l } } history, err := b.historyHandler(r.Context(), sessionID, limit) if err != nil { if b.log != nil { b.log.Errorf("webui history handler failed session_id=%s err=%v", sessionID, err) } writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "load history failed"}) return } writeJSON(w, http.StatusOK, history) } func firstNonEmpty(vals ...string) string { for _, v := range vals { if strings.TrimSpace(v) != "" { return v } } return "" } func (b *Bot) handleChatStream(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"}) return } if !strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))), "application/json") { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "content-type must be application/json"}) return } if b.streamChatHandler == nil { writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "stream chat handler not ready"}) return } var req chatRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid json body"}) return } req.Text = strings.TrimSpace(req.Text) if req.Text == "" { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "text is required"}) return } sessionID := b.resolveID(req.SessionID, "sess") userID := b.resolveID(req.UserID, "user") // 设置 SSE 响应头 w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Headers", "Content-Type") w.WriteHeader(http.StatusOK) flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "streaming not supported", http.StatusInternalServerError) return } // 创建回调函数来推送 SSE 事件 callback := func(event StreamEvent) error { data, err := json.Marshal(event) if err != nil { return err } fmt.Fprintf(w, "data: %s\n\n", string(data)) flusher.Flush() return nil } // 调用流式处理器 reply, err := b.streamChatHandler(r.Context(), IncomingMessage{ ChatID: sessionID, UserID: userID, Text: req.Text, }, callback) if err != nil { if b.log != nil { b.log.Errorf("webui stream chat handler failed session_id=%s user_id=%s err=%v", sessionID, userID, err) } // 推送错误事件 errEvent := StreamEvent{ Type: StreamEventTypeError, Content: "stream error: " + err.Error(), } data, _ := json.Marshal(errEvent) fmt.Fprintf(w, "data: %s\n\n", string(data)) flusher.Flush() return } if b.log != nil { b.log.Infof("webui stream chat completed session_id=%s user_id=%s reply_len=%d", sessionID, userID, len(reply)) } } func (b *Bot) handleUpload(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"}) return } if b.uploadHandler == nil { writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload handler not ready"}) return } r.Body = http.MaxBytesReader(w, r.Body, b.maxUploadBytes) if err := r.ParseMultipartForm(minInt64(b.maxUploadBytes, 32*1024*1024)); err != nil { if strings.Contains(strings.ToLower(err.Error()), "request body too large") { writeJSON(w, http.StatusRequestEntityTooLarge, errorResponse{Error: "file too large"}) return } writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid multipart form"}) return } sessionID := b.resolveID(strings.TrimSpace(r.FormValue("session_id")), "sess") userID := b.resolveID(strings.TrimSpace(r.FormValue("user_id")), "user") file, header, err := r.FormFile("file") if err != nil { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "file is required"}) return } defer file.Close() fileName := sanitizeFileName(header.Filename) if fileName == "" { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid file name"}) return } content, err := io.ReadAll(file) if err != nil { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "read file failed"}) return } if len(content) == 0 { writeJSON(w, http.StatusBadRequest, errorResponse{Error: "empty file"}) return } mimeType := strings.TrimSpace(header.Header.Get("Content-Type")) if mimeType == "" { mimeType = detectMimeByName(fileName) } ids, err := b.uploadHandler(r.Context(), sessionID, userID, []llm.InputFile{{ FileName: fileName, MimeType: mimeType, Content: content, }}) if err != nil { if b.log != nil { b.log.Errorf("webui upload handler failed session_id=%s user_id=%s file=%s err=%v", sessionID, userID, fileName, err) } writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload failed"}) return } if len(ids) == 0 || strings.TrimSpace(ids[0]) == "" { writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload succeeded but file_id is empty"}) return } writeJSON(w, http.StatusOK, uploadResponse{ FileID: strings.TrimSpace(ids[0]), FileIDs: ids, FileName: fileName, MimeType: mimeType, SizeBytes: len(content), SessionID: sessionID, UserID: userID, }) } func (b *Bot) resolveID(raw, prefix string) string { raw = strings.TrimSpace(raw) if raw != "" { return raw } n := atomic.AddUint64(&b.counter, 1) return fmt.Sprintf("%s_%d_%d", prefix, time.Now().UnixNano(), n) } func writeJSON(w http.ResponseWriter, status int, payload any) { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(status) _ = json.NewEncoder(w).Encode(payload) } func minInt64(a, b int64) int64 { if a < b { return a } return b } func detectMimeByName(fileName string) string { ext := strings.ToLower(strings.TrimSpace(filepath.Ext(fileName))) if ext == "" { return "application/octet-stream" } m := strings.TrimSpace(mime.TypeByExtension(ext)) if m == "" { return "application/octet-stream" } return m } func sanitizeFileName(fileName string) string { name := strings.TrimSpace(filepath.Base(fileName)) if name == "" || name == "." || name == ".." { return "" } var b strings.Builder for _, r := range name { if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' { b.WriteRune(r) continue } b.WriteByte('_') } out := strings.TrimSpace(b.String()) if out == "" || out == "." || out == ".." { return "" } if strings.HasPrefix(out, ".") { out = "file" + out } return out }