Migrate LLM client to OpenAI SDK and implement WebUI-specific fileID handling
This commit is contained in:
@@ -18,9 +18,10 @@ import (
|
||||
)
|
||||
|
||||
type IncomingMessage struct {
|
||||
ChatID string
|
||||
UserID string
|
||||
Text string
|
||||
ChatID string
|
||||
UserID string
|
||||
Text string
|
||||
FileIDs []string
|
||||
}
|
||||
|
||||
type ChatHandler func(context.Context, IncomingMessage) (string, error)
|
||||
@@ -37,9 +38,41 @@ type Bot struct {
|
||||
}
|
||||
|
||||
type chatRequest struct {
|
||||
Text string `json:"text"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Text string `json:"text"`
|
||||
SessionID string `json:"session_id"`
|
||||
UserID string `json:"user_id"`
|
||||
FileIDs []string `json:"file_ids"`
|
||||
}
|
||||
|
||||
func (r *chatRequest) UnmarshalJSON(data []byte) error {
|
||||
type rawChatRequest struct {
|
||||
Text string `json:"text"`
|
||||
SessionID string `json:"session_id"`
|
||||
SessionIDCamel string `json:"sessionId"`
|
||||
UserID string `json:"user_id"`
|
||||
UserIDCamel string `json:"userId"`
|
||||
FileIDs json.RawMessage `json:"file_ids"`
|
||||
FileIDsCamel json.RawMessage `json:"fileIds"`
|
||||
FileIDsFlat json.RawMessage `json:"fileids"`
|
||||
FileID json.RawMessage `json:"file_id"`
|
||||
}
|
||||
|
||||
var raw rawChatRequest
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Text = raw.Text
|
||||
r.SessionID = firstNonEmpty(raw.SessionID, raw.SessionIDCamel)
|
||||
r.UserID = firstNonEmpty(raw.UserID, raw.UserIDCamel)
|
||||
|
||||
rawIDs := firstNonEmptyRaw(raw.FileIDs, raw.FileIDsCamel, raw.FileIDsFlat, raw.FileID)
|
||||
ids, err := decodeStringList(rawIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.FileIDs = ids
|
||||
return nil
|
||||
}
|
||||
|
||||
type chatResponse struct {
|
||||
@@ -158,9 +191,10 @@ func (b *Bot) handleChat(w http.ResponseWriter, r *http.Request) {
|
||||
userID := b.resolveID(req.UserID, "user")
|
||||
|
||||
reply, err := b.chatHandler(r.Context(), IncomingMessage{
|
||||
ChatID: sessionID,
|
||||
UserID: userID,
|
||||
Text: req.Text,
|
||||
ChatID: sessionID,
|
||||
UserID: userID,
|
||||
Text: req.Text,
|
||||
FileIDs: req.FileIDs,
|
||||
})
|
||||
if err != nil {
|
||||
if b.log != nil {
|
||||
@@ -176,6 +210,65 @@ func (b *Bot) handleChat(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
func decodeStringList(raw json.RawMessage) ([]string, error) {
|
||||
if len(raw) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var list []string
|
||||
if err := json.Unmarshal(raw, &list); err == nil {
|
||||
return nonEmptyIDs(list), nil
|
||||
}
|
||||
|
||||
var single string
|
||||
if err := json.Unmarshal(raw, &single); err == nil {
|
||||
if strings.TrimSpace(single) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return nonEmptyIDs(strings.Split(single, ",")), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid file ids format")
|
||||
}
|
||||
|
||||
func firstNonEmptyRaw(vals ...json.RawMessage) json.RawMessage {
|
||||
for _, v := range vals {
|
||||
if len(v) > 0 {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func firstNonEmpty(vals ...string) string {
|
||||
for _, v := range vals {
|
||||
if strings.TrimSpace(v) != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func nonEmptyIDs(ids []string) []string {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(ids))
|
||||
seen := map[string]struct{}{}
|
||||
for _, id := range ids {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (b *Bot) handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, errorResponse{Error: "method not allowed"})
|
||||
|
||||
@@ -51,6 +51,66 @@ func TestHandleChatSuccess(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatWithFileIDs(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.chatHandler = func(_ context.Context, msg IncomingMessage) (string, error) {
|
||||
if msg.ChatID != "s1" || msg.UserID != "u1" || msg.Text != "hello" {
|
||||
t.Fatalf("unexpected message: %+v", msg)
|
||||
}
|
||||
if len(msg.FileIDs) != 2 || msg.FileIDs[0] != "file_a" || msg.FileIDs[1] != "file_b" {
|
||||
t.Fatalf("unexpected file ids: %+v", msg.FileIDs)
|
||||
}
|
||||
return "ok", nil
|
||||
}
|
||||
|
||||
body := strings.NewReader(`{"text":"hello","session_id":"s1","user_id":"u1","file_ids":["file_a","file_b"]}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChat(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatWithFileIDsAliases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{name: "camel array", body: `{"text":"hello","sessionId":"s1","userId":"u1","fileIds":["file_a","file_b"]}`},
|
||||
{name: "flat array", body: `{"text":"hello","session_id":"s1","user_id":"u1","fileids":["file_a","file_b"]}`},
|
||||
{name: "single key", body: `{"text":"hello","session_id":"s1","user_id":"u1","file_id":"file_a"}`},
|
||||
{name: "csv string", body: `{"text":"hello","session_id":"s1","user_id":"u1","file_ids":"file_a, file_b"}`},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.chatHandler = func(_ context.Context, msg IncomingMessage) (string, error) {
|
||||
if msg.ChatID != "s1" || msg.UserID != "u1" || msg.Text != "hello" {
|
||||
t.Fatalf("unexpected message: %+v", msg)
|
||||
}
|
||||
if len(msg.FileIDs) == 0 {
|
||||
t.Fatalf("expected file ids from alias payload, got empty")
|
||||
}
|
||||
return "ok", nil
|
||||
}
|
||||
|
||||
body := strings.NewReader(tt.body)
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/chat", body)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
b.handleChat(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleChatMissingText(t *testing.T) {
|
||||
b := newTestBot(t, 1024*1024)
|
||||
b.chatHandler = func(_ context.Context, _ IncomingMessage) (string, error) { return "", nil }
|
||||
|
||||
Reference in New Issue
Block a user