407 lines
10 KiB
Go
407 lines
10 KiB
Go
package webui
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"mime"
|
|
"net/http"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"laodingbot/internal/config"
|
|
"laodingbot/internal/llm"
|
|
"laodingbot/internal/logger"
|
|
)
|
|
|
|
type IncomingMessage struct {
|
|
ChatID string
|
|
UserID string
|
|
Text string
|
|
FileIDs []string
|
|
}
|
|
|
|
type ChatHandler func(context.Context, IncomingMessage) (string, error)
|
|
type UploadHandler func(context.Context, string, string, []llm.InputFile) ([]string, error)
|
|
|
|
type Bot struct {
|
|
listenAddr string
|
|
maxUploadBytes int64
|
|
log *logger.Logger
|
|
|
|
chatHandler ChatHandler
|
|
uploadHandler UploadHandler
|
|
counter uint64
|
|
}
|
|
|
|
type chatRequest struct {
|
|
Text string `json:"text"`
|
|
SessionID string `json:"session_id"`
|
|
UserID string `json:"user_id"`
|
|
FileIDs []string `json:"file_ids"`
|
|
}
|
|
|
|
func (r *chatRequest) UnmarshalJSON(data []byte) error {
|
|
type rawChatRequest struct {
|
|
Text string `json:"text"`
|
|
SessionID string `json:"session_id"`
|
|
SessionIDCamel string `json:"sessionId"`
|
|
UserID string `json:"user_id"`
|
|
UserIDCamel string `json:"userId"`
|
|
FileIDs json.RawMessage `json:"file_ids"`
|
|
FileIDsCamel json.RawMessage `json:"fileIds"`
|
|
FileIDsFlat json.RawMessage `json:"fileids"`
|
|
FileID json.RawMessage `json:"file_id"`
|
|
}
|
|
|
|
var raw rawChatRequest
|
|
if err := json.Unmarshal(data, &raw); err != nil {
|
|
return err
|
|
}
|
|
|
|
r.Text = raw.Text
|
|
r.SessionID = firstNonEmpty(raw.SessionID, raw.SessionIDCamel)
|
|
r.UserID = firstNonEmpty(raw.UserID, raw.UserIDCamel)
|
|
|
|
rawIDs := firstNonEmptyRaw(raw.FileIDs, raw.FileIDsCamel, raw.FileIDsFlat, raw.FileID)
|
|
ids, err := decodeStringList(rawIDs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
r.FileIDs = ids
|
|
return nil
|
|
}
|
|
|
|
type chatResponse struct {
|
|
Reply string `json:"reply"`
|
|
SessionID string `json:"session_id"`
|
|
UserID string `json:"user_id"`
|
|
}
|
|
|
|
type uploadResponse struct {
|
|
FileID string `json:"file_id"`
|
|
FileIDs []string `json:"file_ids"`
|
|
FileName string `json:"file_name"`
|
|
MimeType string `json:"mime_type"`
|
|
SizeBytes int `json:"size_bytes"`
|
|
SessionID string `json:"session_id"`
|
|
UserID string `json:"user_id"`
|
|
}
|
|
|
|
type errorResponse struct {
|
|
Error string `json:"error"`
|
|
}
|
|
|
|
func NewBot(cfg config.WebUIConfig, log *logger.Logger) (*Bot, error) {
|
|
if strings.TrimSpace(cfg.ListenAddr) == "" {
|
|
return nil, fmt.Errorf("empty webui listen address")
|
|
}
|
|
if cfg.MaxUploadBytes <= 0 {
|
|
return nil, fmt.Errorf("invalid webui max upload bytes")
|
|
}
|
|
return &Bot{
|
|
listenAddr: strings.TrimSpace(cfg.ListenAddr),
|
|
maxUploadBytes: cfg.MaxUploadBytes,
|
|
log: log,
|
|
}, nil
|
|
}
|
|
|
|
func (b *Bot) Run(ctx context.Context, chatHandler ChatHandler, uploadHandler UploadHandler) error {
|
|
if chatHandler == nil {
|
|
return fmt.Errorf("nil webui chat handler")
|
|
}
|
|
if uploadHandler == nil {
|
|
return fmt.Errorf("nil webui upload handler")
|
|
}
|
|
b.chatHandler = chatHandler
|
|
b.uploadHandler = uploadHandler
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/api/chat", b.handleChat)
|
|
mux.HandleFunc("/api/upload", b.handleUpload)
|
|
|
|
srv := &http.Server{
|
|
Addr: b.listenAddr,
|
|
Handler: mux,
|
|
}
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
err := srv.ListenAndServe()
|
|
if err != nil && err != http.ErrServerClosed {
|
|
errCh <- err
|
|
return
|
|
}
|
|
errCh <- nil
|
|
}()
|
|
|
|
if b.log != nil {
|
|
b.log.Infof("webui http transport started addr=%s", b.listenAddr)
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
_ = srv.Shutdown(shutdownCtx)
|
|
err := <-errCh
|
|
if b.log != nil {
|
|
b.log.Infof("webui http transport stopped: %v", ctx.Err())
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return ctx.Err()
|
|
case err := <-errCh:
|
|
if err != nil && b.log != nil {
|
|
b.log.Errorf("webui http transport failed err=%v", err)
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
func (b *Bot) handleChat(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"})
|
|
return
|
|
}
|
|
if !strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Content-Type"))), "application/json") {
|
|
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "content-type must be application/json"})
|
|
return
|
|
}
|
|
if b.chatHandler == nil {
|
|
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "chat handler not ready"})
|
|
return
|
|
}
|
|
|
|
var req chatRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid json body"})
|
|
return
|
|
}
|
|
req.Text = strings.TrimSpace(req.Text)
|
|
if req.Text == "" {
|
|
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "text is required"})
|
|
return
|
|
}
|
|
sessionID := b.resolveID(req.SessionID, "sess")
|
|
userID := b.resolveID(req.UserID, "user")
|
|
|
|
reply, err := b.chatHandler(r.Context(), IncomingMessage{
|
|
ChatID: sessionID,
|
|
UserID: userID,
|
|
Text: req.Text,
|
|
FileIDs: req.FileIDs,
|
|
})
|
|
if err != nil {
|
|
if b.log != nil {
|
|
b.log.Errorf("webui chat handler failed session_id=%s user_id=%s err=%v", sessionID, userID, err)
|
|
}
|
|
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "chat failed"})
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusOK, chatResponse{
|
|
Reply: reply,
|
|
SessionID: sessionID,
|
|
UserID: userID,
|
|
})
|
|
}
|
|
|
|
func decodeStringList(raw json.RawMessage) ([]string, error) {
|
|
if len(raw) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
var list []string
|
|
if err := json.Unmarshal(raw, &list); err == nil {
|
|
return nonEmptyIDs(list), nil
|
|
}
|
|
|
|
var single string
|
|
if err := json.Unmarshal(raw, &single); err == nil {
|
|
if strings.TrimSpace(single) == "" {
|
|
return nil, nil
|
|
}
|
|
return nonEmptyIDs(strings.Split(single, ",")), nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("invalid file ids format")
|
|
}
|
|
|
|
func firstNonEmptyRaw(vals ...json.RawMessage) json.RawMessage {
|
|
for _, v := range vals {
|
|
if len(v) > 0 {
|
|
return v
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func firstNonEmpty(vals ...string) string {
|
|
for _, v := range vals {
|
|
if strings.TrimSpace(v) != "" {
|
|
return v
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func nonEmptyIDs(ids []string) []string {
|
|
if len(ids) == 0 {
|
|
return nil
|
|
}
|
|
out := make([]string, 0, len(ids))
|
|
seen := map[string]struct{}{}
|
|
for _, id := range ids {
|
|
id = strings.TrimSpace(id)
|
|
if id == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[id]; ok {
|
|
continue
|
|
}
|
|
seen[id] = struct{}{}
|
|
out = append(out, id)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (b *Bot) handleUpload(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"})
|
|
return
|
|
}
|
|
if b.uploadHandler == nil {
|
|
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload handler not ready"})
|
|
return
|
|
}
|
|
|
|
r.Body = http.MaxBytesReader(w, r.Body, b.maxUploadBytes)
|
|
if err := r.ParseMultipartForm(minInt64(b.maxUploadBytes, 32*1024*1024)); err != nil {
|
|
if strings.Contains(strings.ToLower(err.Error()), "request body too large") {
|
|
writeJSON(w, http.StatusRequestEntityTooLarge, errorResponse{Error: "file too large"})
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid multipart form"})
|
|
return
|
|
}
|
|
|
|
sessionID := b.resolveID(strings.TrimSpace(r.FormValue("session_id")), "sess")
|
|
userID := b.resolveID(strings.TrimSpace(r.FormValue("user_id")), "user")
|
|
|
|
file, header, err := r.FormFile("file")
|
|
if err != nil {
|
|
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "file is required"})
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
fileName := sanitizeFileName(header.Filename)
|
|
if fileName == "" {
|
|
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "invalid file name"})
|
|
return
|
|
}
|
|
|
|
content, err := io.ReadAll(file)
|
|
if err != nil {
|
|
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "read file failed"})
|
|
return
|
|
}
|
|
if len(content) == 0 {
|
|
writeJSON(w, http.StatusBadRequest, errorResponse{Error: "empty file"})
|
|
return
|
|
}
|
|
|
|
mimeType := strings.TrimSpace(header.Header.Get("Content-Type"))
|
|
if mimeType == "" {
|
|
mimeType = detectMimeByName(fileName)
|
|
}
|
|
|
|
ids, err := b.uploadHandler(r.Context(), sessionID, userID, []llm.InputFile{{
|
|
FileName: fileName,
|
|
MimeType: mimeType,
|
|
Content: content,
|
|
}})
|
|
if err != nil {
|
|
if b.log != nil {
|
|
b.log.Errorf("webui upload handler failed session_id=%s user_id=%s file=%s err=%v", sessionID, userID, fileName, err)
|
|
}
|
|
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload failed"})
|
|
return
|
|
}
|
|
if len(ids) == 0 || strings.TrimSpace(ids[0]) == "" {
|
|
writeJSON(w, http.StatusInternalServerError, errorResponse{Error: "upload succeeded but file_id is empty"})
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, uploadResponse{
|
|
FileID: strings.TrimSpace(ids[0]),
|
|
FileIDs: ids,
|
|
FileName: fileName,
|
|
MimeType: mimeType,
|
|
SizeBytes: len(content),
|
|
SessionID: sessionID,
|
|
UserID: userID,
|
|
})
|
|
}
|
|
|
|
func (b *Bot) resolveID(raw, prefix string) string {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw != "" {
|
|
return raw
|
|
}
|
|
n := atomic.AddUint64(&b.counter, 1)
|
|
return fmt.Sprintf("%s_%d_%d", prefix, time.Now().UnixNano(), n)
|
|
}
|
|
|
|
func writeJSON(w http.ResponseWriter, status int, payload any) {
|
|
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
w.WriteHeader(status)
|
|
_ = json.NewEncoder(w).Encode(payload)
|
|
}
|
|
|
|
func minInt64(a, b int64) int64 {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
func detectMimeByName(fileName string) string {
|
|
ext := strings.ToLower(strings.TrimSpace(filepath.Ext(fileName)))
|
|
if ext == "" {
|
|
return "application/octet-stream"
|
|
}
|
|
m := strings.TrimSpace(mime.TypeByExtension(ext))
|
|
if m == "" {
|
|
return "application/octet-stream"
|
|
}
|
|
return m
|
|
}
|
|
|
|
func sanitizeFileName(fileName string) string {
|
|
name := strings.TrimSpace(filepath.Base(fileName))
|
|
if name == "" || name == "." || name == ".." {
|
|
return ""
|
|
}
|
|
var b strings.Builder
|
|
for _, r := range name {
|
|
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '.' || r == '_' || r == '-' {
|
|
b.WriteRune(r)
|
|
continue
|
|
}
|
|
b.WriteByte('_')
|
|
}
|
|
out := strings.TrimSpace(b.String())
|
|
if out == "" || out == "." || out == ".." {
|
|
return ""
|
|
}
|
|
if strings.HasPrefix(out, ".") {
|
|
out = "file" + out
|
|
}
|
|
return out
|
|
}
|