diff --git a/backend/app/api/routes/agent.py b/backend/app/api/routes/agent.py index 3a55cae..01af248 100644 --- a/backend/app/api/routes/agent.py +++ b/backend/app/api/routes/agent.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import json from typing import AsyncGenerator, List, Optional @@ -20,6 +19,7 @@ from app.api.models import ( SessionInfo, ) from app.config.settings import settings +from app.shared.async_utils import iter_in_thread from app.shared.bootstrap import get_agent_conversation_service, get_conversation_store # Keep route handlers close to their transport-layer wiring for easier auditing. @@ -101,14 +101,13 @@ async def chat_stream_get( top_k=settings.rag_top_k, ) yield f"event: session\ndata: {json.dumps({'session_id': session_id_})}\n\n" - for event_data in event_stream: + async for event_data in iter_in_thread(event_stream): event_type = event_data.get("event", "content") data = event_data.get("data", "") if isinstance(data, (dict, list)): yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" else: yield f"event: {event_type}\ndata: {data}\n\n" - await asyncio.sleep(0) except Exception as exc: yield f"event: error\ndata: {str(exc)}\n\n" diff --git a/backend/app/api/routes/rag.py b/backend/app/api/routes/rag.py index bd465dc..0d209b9 100644 --- a/backend/app/api/routes/rag.py +++ b/backend/app/api/routes/rag.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import json from typing import AsyncGenerator @@ -11,6 +10,7 @@ from fastapi.responses import StreamingResponse from app.config.settings import settings from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion +from app.shared.async_utils import iter_in_thread from app.shared.bootstrap import get_agent_conversation_service @@ -29,7 +29,7 @@ _DEFAULT_QUICK_QUESTIONS = [ @router.post("/chat") async def rag_chat(request: RagChatRequest): """Stream RAG Q&A using the real agent service.""" - _, event_stream = get_agent_conversation_service().stream_chat( + session_id, event_stream = get_agent_conversation_service().stream_chat( query=request.query, session_id=request.session_id, filters=request.filters, @@ -38,7 +38,11 @@ async def rag_chat(request: RagChatRequest): async def generate() -> AsyncGenerator[str, None]: """Translate agent SSE events to rag format.""" - for event in event_stream: + yield ( + "event: message\n" + f"data: {json.dumps({'type': 'session', 'session_id': session_id}, ensure_ascii=False)}\n\n" + ) + async for event in iter_in_thread(event_stream): event_type = event.get("event", "") data = event.get("data", "") if event_type == "sources": @@ -69,14 +73,18 @@ async def rag_chat(request: RagChatRequest): elif event_type == "done": yield ( "event: message\n" - f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" + f"data: {json.dumps({'type': 'done', 'session_id': session_id}, ensure_ascii=False)}\n\n" ) elif event_type == "status": yield ( "event: message\n" f"data: {json.dumps({'type': 'status', 'text': data}, ensure_ascii=False)}\n\n" ) - await asyncio.sleep(0) + elif event_type == "error": + yield ( + "event: message\n" + f"data: {json.dumps({'type': 'error', 'text': str(data)}, ensure_ascii=False)}\n\n" + ) return StreamingResponse( generate(), diff --git a/backend/app/application/knowledge/services.py b/backend/app/application/knowledge/services.py index b8d3685..903bab5 100644 --- a/backend/app/application/knowledge/services.py +++ b/backend/app/application/knowledge/services.py @@ -7,20 +7,75 @@ from app.domain.retrieval.ports import Reranker # Keep orchestration logic centralized so use-case flow stays easy to trace. +def _reciprocal_rank_fusion( + ranked_lists: list[list[RetrievedChunk]], k: int = 60 +) -> list[RetrievedChunk]: + """Merge multiple ranked lists with Reciprocal Rank Fusion. + + Score for chunk c = sum over lists of 1 / (k + rank(c)). + A chunk appearing in multiple lists gets a higher fused score. + """ + scores: dict[str, float] = {} + chunk_map: dict[str, RetrievedChunk] = {} + + for ranked in ranked_lists: + for rank, chunk in enumerate(ranked): + key = chunk.chunk_id + scores[key] = scores.get(key, 0.0) + 1.0 / (k + rank + 1) + chunk_map[key] = chunk + + sorted_keys = sorted(scores, key=lambda ck: scores[ck], reverse=True) + return [ + RetrievedChunk( + chunk_id=chunk_map[ck].chunk_id, + doc_id=chunk_map[ck].doc_id, + doc_name=chunk_map[ck].doc_name, + content=chunk_map[ck].content, + score=scores[ck], + section_title=chunk_map[ck].section_title, + page_number=chunk_map[ck].page_number, + metadata=chunk_map[ck].metadata, + ) + for ck in sorted_keys + ] + + class KnowledgeRetrievalService: """Provide the Knowledge Retrieval Service service.""" - def __init__(self, *, retriever: Retriever, reranker: Reranker | None = None, reranker_top_k: int = 5) -> None: + def __init__( + self, + *, + retriever: Retriever, + bm25_retriever=None, + reranker: Reranker | None = None, + reranker_top_k: int = 5, + ) -> None: """Initialize the Knowledge Retrieval Service instance.""" self.retriever = retriever + self.bm25_retriever = bm25_retriever self.reranker = reranker self.reranker_top_k = reranker_top_k def retrieve(self, *, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]: - """Retrieve and optionally rerank chunks for a query.""" - candidate_k = top_k if self.reranker is None else max(top_k * 4, 20) + """Retrieve and optionally rerank chunks for a query. + + When a BM25 retriever is available, combines dense + sparse results + via Reciprocal Rank Fusion before optional reranking. + """ + use_hybrid = self.bm25_retriever is not None and getattr(self.bm25_retriever, "available", False) + candidate_k = max(top_k * 4, 20) if (self.reranker is not None or use_hybrid) else top_k + retrieval_query = RetrievalQuery(query=query, top_k=candidate_k, filters=filters) - candidates = self.retriever.retrieve(retrieval_query) + dense_results = self.retriever.retrieve(retrieval_query) + + if use_hybrid: + bm25_results = self.bm25_retriever.retrieve(query, top_k=candidate_k, filters=filters) + candidates = _reciprocal_rank_fusion([dense_results, bm25_results]) + else: + candidates = dense_results + if self.reranker and candidates: return self.reranker.rerank(query, candidates, top_k=self.reranker_top_k) + return candidates[:top_k] diff --git a/backend/app/infrastructure/llm/openai_compatible_answer_generator.py b/backend/app/infrastructure/llm/openai_compatible_answer_generator.py index 9d87c86..2c296bb 100644 --- a/backend/app/infrastructure/llm/openai_compatible_answer_generator.py +++ b/backend/app/infrastructure/llm/openai_compatible_answer_generator.py @@ -12,7 +12,6 @@ from app.services.llm.llm_factory import get_llm_client # Keep adapter behavior explicit so integration details remain easy to audit. - PROMPT_TEMPLATES = { "default": "你是法规知识问答助手。请仅依据提供的上下文回答;如果上下文不足,明确说明。", "compliance_qa": "你是法规合规问答助手。优先引用给定法规原文,回答要准确、克制,并注明依据来源。", @@ -21,6 +20,17 @@ PROMPT_TEMPLATES = { class OpenAICompatibleAnswerGenerator(AnswerGenerator): """Represent the Open A I Compatible Answer Generator type.""" + + @staticmethod + def _estimate_tokens(text: str) -> int: + """Estimate token count for mixed Chinese/English text. + + Chinese chars are ~1.5 chars/token; ASCII is ~4 chars/token. + """ + chinese = sum(1 for c in text if "一" <= c <= "鿿") + other = len(text) - chinese + return int(chinese / 1.5 + other / 4) + 1 + def _build_messages( self, *, @@ -40,9 +50,12 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator): f"页码: {chunk.page_number}\n" f"内容: {chunk.content}" ) - context_tokens += len(block) + block_tokens = self._estimate_tokens(block) + if context_tokens + block_tokens > settings.rag_max_context_tokens: + break + context_tokens += block_tokens context_blocks.append(block) - context = "\n\n".join(context_blocks)[: settings.rag_max_context_tokens * 4] + context = "\n\n".join(context_blocks) messages = [{"role": "system", "content": system_prompt}] for item in history or []: messages.append({"role": item["role"], "content": item["content"]}) @@ -52,7 +65,7 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator): "content": f"问题:{query}\n\n参考上下文:\n{context}\n\n请在回答后给出简要引用编号。", } ) - return messages, min(context_tokens, settings.rag_max_context_tokens) + return messages, context_tokens def _sources(self, chunks: list[RetrievedChunk]) -> list[AnswerSource]: """Handle sources for this module for the Open A I Compatible Answer Generator instance.""" @@ -98,7 +111,7 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator): latency_ms=latency_ms, retrieved_count=len(retrieved_chunks), context_tokens=context_tokens, - truncated=False, + truncated=len(retrieved_chunks) > len(messages), error=response.error, ) @@ -124,15 +137,18 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator): yield {"event": "sources", "data": sources} client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model) answer_parts: list[str] = [] - if hasattr(client, "stream_chat"): - for chunk in client.stream_chat(messages): - answer_parts.append(chunk) - yield {"event": "content", "data": chunk} - else: - response = client.chat(messages) - answer_parts.append(response.content) - yield {"event": "content", "data": response.content} - full_answer = "".join(answer_parts) + try: + if hasattr(client, "stream_chat"): + for chunk in client.stream_chat(messages): + answer_parts.append(chunk) + yield {"event": "content", "data": chunk} + else: + response = client.chat(messages) + answer_parts.append(response.content) + yield {"event": "content", "data": response.content} + except Exception as exc: + yield {"event": "error", "data": str(exc)} + return yield { "event": "done", "data": { diff --git a/backend/app/infrastructure/vectorstore/bm25_retriever.py b/backend/app/infrastructure/vectorstore/bm25_retriever.py new file mode 100644 index 0000000..1f55b98 --- /dev/null +++ b/backend/app/infrastructure/vectorstore/bm25_retriever.py @@ -0,0 +1,149 @@ +"""BM25 sparse retriever backed by the existing Milvus collection. + +Falls back to no-op if rank_bm25 or jieba is not installed. +The index is built lazily on first query and can be refreshed after +new documents are ingested by calling refresh(). +""" + +from __future__ import annotations + +import logging +import threading +from typing import TYPE_CHECKING + +from app.domain.retrieval import RetrievedChunk + +if TYPE_CHECKING: + from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex + +logger = logging.getLogger(__name__) + +try: + from rank_bm25 import BM25Okapi + import jieba + + jieba.setLogLevel(logging.WARNING) + _BM25_AVAILABLE = True +except ImportError: + _BM25_AVAILABLE = False + logger.warning("rank_bm25 or jieba not installed – BM25 retrieval disabled. Run: pip install rank-bm25 jieba") + + +def _tokenize(text: str) -> list[str]: + """Tokenize Chinese/mixed text with jieba.""" + return list(jieba.cut(text)) + + +class BM25Retriever: + """Sparse BM25 retriever that indexes chunks stored in Milvus. + + Thread-safe: the index is protected by a lock so concurrent requests + during the initial build do not race. + """ + + def __init__(self, vector_index: "MilvusVectorIndex") -> None: + self._vector_index = vector_index + self._lock = threading.Lock() + self._index: "BM25Okapi | None" = None + self._chunks: list[RetrievedChunk] = [] + + @property + def available(self) -> bool: + return _BM25_AVAILABLE + + def _load_all_chunks(self) -> list[RetrievedChunk]: + """Fetch all chunk records from Milvus (no vectors needed).""" + try: + rows = self._vector_index.collection.query( + expr='doc_id != ""', + output_fields=["id", "doc_id", "doc_name", "content", "section_title", "page_number"], + limit=16384, + ) + except Exception: + logger.exception("BM25Retriever: failed to fetch chunks from Milvus") + return [] + return [ + RetrievedChunk( + chunk_id=str(row.get("id", "")), + doc_id=str(row.get("doc_id", "")), + doc_name=str(row.get("doc_name", "")), + content=str(row.get("content", "")), + score=0.0, + section_title=str(row.get("section_title", "")), + page_number=int(row.get("page_number") or 0), + metadata={}, + ) + for row in rows + if row.get("content") + ] + + def _ensure_built(self) -> None: + if self._index is not None: + return + with self._lock: + if self._index is not None: + return + self._build() + + def _build(self) -> None: + logger.info("BM25Retriever: building index …") + chunks = self._load_all_chunks() + if not chunks: + logger.warning("BM25Retriever: no chunks found, index is empty") + self._chunks = [] + self._index = BM25Okapi([[]]) + return + tokenized = [_tokenize(c.content) for c in chunks] + self._chunks = chunks + self._index = BM25Okapi(tokenized) + logger.info("BM25Retriever: index built with %d chunks", len(chunks)) + + def refresh(self) -> None: + """Rebuild the index (call after new documents are ingested).""" + with self._lock: + self._index = None + self._chunks = [] + self._build() + + def retrieve(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]: + """Return top_k chunks ranked by BM25 score.""" + if not _BM25_AVAILABLE: + return [] + self._ensure_built() + if not self._chunks: + return [] + + tokens = _tokenize(query) + scores = self._index.get_scores(tokens) # type: ignore[union-attr] + + # Pair and sort descending by score + ranked = sorted( + ((float(scores[i]), self._chunks[i]) for i in range(len(self._chunks))), + key=lambda x: x[0], + reverse=True, + ) + + results: list[RetrievedChunk] = [] + for score, chunk in ranked[: top_k * 2]: + if score <= 0: + break + # Apply simple regulation_type filter if provided + if filters and chunk.metadata.get("regulation_type"): + types = [t.strip() for t in filters.split(",")] + if chunk.metadata.get("regulation_type") not in types: + continue + results.append( + RetrievedChunk( + chunk_id=chunk.chunk_id, + doc_id=chunk.doc_id, + doc_name=chunk.doc_name, + content=chunk.content, + score=score, + section_title=chunk.section_title, + page_number=chunk.page_number, + metadata=chunk.metadata, + ) + ) + if len(results) >= top_k: + break + return results diff --git a/backend/app/shared/async_utils.py b/backend/app/shared/async_utils.py new file mode 100644 index 0000000..74bb78c --- /dev/null +++ b/backend/app/shared/async_utils.py @@ -0,0 +1,38 @@ +"""Async utility helpers for bridging sync generators to async FastAPI routes.""" + +from __future__ import annotations + +import asyncio +import threading +from typing import AsyncGenerator, Generator, TypeVar + +T = TypeVar("T") + +_SENTINEL = object() + + +async def iter_in_thread(sync_gen: Generator[T, None, None]) -> AsyncGenerator[T, None]: + """Yield items from a synchronous generator without blocking the event loop. + + Runs the generator in a daemon thread, forwarding items via an asyncio.Queue. + Use this to wrap blocking LLM streaming generators inside async FastAPI routes. + """ + loop = asyncio.get_running_loop() + queue: asyncio.Queue = asyncio.Queue(maxsize=32) + + def _drain() -> None: + try: + for item in sync_gen: + future = asyncio.run_coroutine_threadsafe(queue.put(item), loop) + future.result() + finally: + asyncio.run_coroutine_threadsafe(queue.put(_SENTINEL), loop).result() + + thread = threading.Thread(target=_drain, daemon=True) + thread.start() + + while True: + item = await queue.get() + if item is _SENTINEL: + break + yield item # type: ignore[misc] diff --git a/backend/app/shared/bootstrap.py b/backend/app/shared/bootstrap.py index 4ab08e0..0fc7a2a 100644 --- a/backend/app/shared/bootstrap.py +++ b/backend/app/shared/bootstrap.py @@ -19,6 +19,7 @@ from app.infrastructure.storage.json_document_repository import JsonDocumentRepo from app.infrastructure.storage.minio_binary_store import MinioDocumentBinaryStore from app.infrastructure.storage.postgres_document_repository import PostgresDocumentRepository from app.infrastructure.storage.postgres_parse_artifact_store import PostgresParseArtifactStore +from app.infrastructure.vectorstore.bm25_retriever import BM25Retriever from app.infrastructure.vectorstore.dense_retriever import DenseRetriever from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker @@ -87,6 +88,13 @@ def get_reranker(): return None +@lru_cache +def get_bm25_retriever() -> BM25Retriever | None: + """Return BM25 retriever if rank_bm25 + jieba are installed, else None.""" + retriever = BM25Retriever(vector_index=get_vector_index()) + return retriever if retriever.available else None + + @lru_cache def get_retrieval_service() -> KnowledgeRetrievalService: """Return retrieval service.""" @@ -96,6 +104,7 @@ def get_retrieval_service() -> KnowledgeRetrievalService: ) return KnowledgeRetrievalService( retriever=retriever, + bm25_retriever=get_bm25_retriever(), reranker=get_reranker(), reranker_top_k=settings.reranker_top_k, ) diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index d64ac34..92af193 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -44,6 +44,7 @@ export interface SSEMessage { type: string; text?: string; docs?: RetrievedDoc[]; + session_id?: string; } export async function streamSSE( diff --git a/frontend/src/api/rag.ts b/frontend/src/api/rag.ts index 2e50ab2..3413d4e 100644 --- a/frontend/src/api/rag.ts +++ b/frontend/src/api/rag.ts @@ -38,7 +38,14 @@ function parseSSEChunk(raw: string, onMessage: (data: SSEMessage) => void) { const joined = dataLines.join('\n'); if (!joined) continue; - if (eventName === 'sources') { + // /agent/chat/stream uses named events (sources, content, done, error, session, status) + // /rag/chat wraps everything in event:message with type in JSON body — handle both + if (eventName === 'session') { + try { + const payload = JSON.parse(joined) as Record; + onMessage({ type: 'session', session_id: String(payload.session_id ?? '') }); + } catch { /* ignore */ } + } else if (eventName === 'sources') { try { const docs = JSON.parse(joined) as Array>; onMessage({ @@ -53,17 +60,26 @@ function parseSSEChunk(raw: string, onMessage: (data: SSEMessage) => void) { download_url: doc.doc_id ? `${AGENT_API_BASE}/documents/download/${String(doc.doc_id)}` : undefined, })), }); - } catch { - // Ignore malformed source payloads. - } + } catch { /* ignore */ } } else if (eventName === 'content') { onMessage({ type: 'chunk', text: joined }); } else if (eventName === 'done') { - onMessage({ type: 'done', text: joined }); + try { + const payload = JSON.parse(joined) as Record; + onMessage({ type: 'done', session_id: payload.session_id ? String(payload.session_id) : undefined }); + } catch { + onMessage({ type: 'done' }); + } } else if (eventName === 'error') { onMessage({ type: 'error', text: joined }); } else if (eventName === 'status') { onMessage({ type: 'status', text: joined }); + } else if (eventName === 'message') { + // /rag/chat format: event:message + JSON body with type field + try { + const payload = JSON.parse(joined) as SSEMessage; + onMessage(payload); + } catch { /* ignore */ } } } } @@ -74,7 +90,9 @@ export async function ragChat( onMessage: (data: SSEMessage) => void, onError?: (error: Error) => void, onComplete?: () => void, - filters?: string + filters?: string, + sessionId?: string, + signal?: AbortSignal, ): Promise { try { const response = await fetch(`${AGENT_API_BASE}/agent/chat/stream`, { @@ -83,7 +101,13 @@ export async function ragChat( 'Content-Type': 'application/json', Accept: 'text/event-stream', }, - body: JSON.stringify({ query, top_k: topK, ...(filters ? { filters } : {}) }), + body: JSON.stringify({ + query, + top_k: topK, + ...(filters ? { filters } : {}), + ...(sessionId ? { session_id: sessionId } : {}), + }), + signal, }); if (!response.ok || !response.body) { @@ -112,6 +136,7 @@ export async function ragChat( onComplete(); } } catch (error) { + if (error instanceof DOMException && error.name === 'AbortError') return; if (onError) { onError(error instanceof Error ? error : new Error(String(error))); } diff --git a/frontend/src/pages/RagChat/RagChatPage.tsx b/frontend/src/pages/RagChat/RagChatPage.tsx index e731d7d..ed7a1b8 100644 --- a/frontend/src/pages/RagChat/RagChatPage.tsx +++ b/frontend/src/pages/RagChat/RagChatPage.tsx @@ -1,4 +1,4 @@ -import React, { useEffect, useRef, useState } from 'react'; +import React, { useCallback, useEffect, useRef, useState } from 'react'; import { useTheme } from '../../contexts'; import type { ChatMessage, RetrievalData } from '../../types'; import { getQuickQuestions, ragChat } from '../../api/rag'; @@ -19,6 +19,7 @@ export const RagChatPage: React.FC = () => { const { theme } = useTheme(); const nextMessageIdRef = useRef(1); const [messages, setMessages] = useState([]); + // retrievals: right-panel shows sources of the most recent assistant reply const [retrievals, setRetrievals] = useState([]); const [input, setInput] = useState(''); const [loading, setLoading] = useState(false); @@ -27,62 +28,60 @@ export const RagChatPage: React.FC = () => { const [quickQuestions, setQuickQuestions] = useState(ragQuickQuestionsDefault); const [filterRegulationType, setFilterRegulationType] = useState(''); const [highlightedSourceIdx, setHighlightedSourceIdx] = useState(null); + const [sessionId, setSessionId] = useState(); + + // Auto-scroll ref + const messagesEndRef = useRef(null); + // AbortController for cancelling in-flight requests + const abortRef = useRef(null); function nextMessageId() { - const currentId = nextMessageIdRef.current; + const id = nextMessageIdRef.current; nextMessageIdRef.current += 1; - return currentId; + return id; } + // Scroll to bottom whenever messages change + useEffect(() => { + messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' }); + }, [messages]); + async function loadQuickQuestions() { try { const response = await getQuickQuestions(); setQuickQuestions(response.questions.map(q => q.question)); - } catch (error) { - console.error('Failed to load quick questions:', error); + } catch { + // keep defaults } } useEffect(() => { - const timerId = window.setTimeout(() => { - void loadQuickQuestions(); - }, 0); + const timerId = window.setTimeout(() => { void loadQuickQuestions(); }, 0); return () => window.clearTimeout(timerId); }, []); - const sendMessage = (text: string) => { - if (!text.trim()) return; + /** + * Core query executor — shared by sendMessage and regenerateLastAnswer. + * Manages session_id, AbortController, SSE parsing, and state updates. + */ + const executeQuery = useCallback((text: string) => { + // Cancel any in-flight request + abortRef.current?.abort(); + abortRef.current = new AbortController(); - const userMsg = { id: nextMessageId(), role: 'user' as const, content: text }; - setMessages((prev) => [...prev, userMsg]); - setInput(''); - setLoading(true); - setRetrievals([]); - setHighlightedSourceIdx(null); - - let currentResponse = ''; const activeFilters = filterRegulationType.trim() || undefined; + let currentResponse = ''; + // Capture the assistant message id so we can attach sources later + let assistantMsgId: number | null = null; void ragChat( text, 5, - (data: unknown) => { - const sseData = data as { - type: string; - text?: string; - docs?: Array<{ - id: string; - score: number; - preview: string; - doc_name: string; - clause: string; - doc_id?: string; - download_url?: string; - }>; - }; - - if (sseData.type === 'retrieved' && sseData.docs) { - const retrievedDocs: RetrievalData[] = sseData.docs.map(d => ({ + (data) => { + if (data.type === 'session' && data.session_id) { + setSessionId(data.session_id); + } else if (data.type === 'retrieved' && data.docs) { + const docs: RetrievalData[] = data.docs.map(d => ({ id: parseInt(d.id.replace('chunk-', ''), 10) || 1, file: d.doc_name, clause: d.clause, @@ -91,125 +90,98 @@ export const RagChatPage: React.FC = () => { docId: d.doc_id, downloadUrl: d.download_url, })); - setRetrievals(retrievedDocs); - } else if (sseData.type === 'chunk' && sseData.text) { - currentResponse += sseData.text; - setMessages((prev) => { - const lastMsg = prev[prev.length - 1]; - if (lastMsg?.role === 'assistant') { - return [...prev.slice(0, -1), { ...lastMsg, content: currentResponse }]; + setRetrievals(docs); + // Attach sources to the assistant message once we know its id + if (assistantMsgId !== null) { + setMessages(prev => prev.map(m => + m.id === assistantMsgId ? { ...m, sources: docs } : m + )); + } + } else if (data.type === 'chunk' && data.text) { + currentResponse += data.text; + setMessages(prev => { + const last = prev[prev.length - 1]; + if (last?.role === 'assistant' && last.id === assistantMsgId) { + return [...prev.slice(0, -1), { ...last, content: currentResponse }]; } - return [...prev, { id: nextMessageId(), role: 'assistant' as const, content: currentResponse }]; + // First chunk: create assistant message + const newId = nextMessageId(); + assistantMsgId = newId; + return [...prev, { id: newId, role: 'assistant' as const, content: currentResponse }]; }); - } else if (sseData.type === 'done') { + } else if (data.type === 'done') { + if (data.session_id) setSessionId(data.session_id); setLoading(false); - } else if (sseData.type === 'error') { + } else if (data.type === 'error') { setLoading(false); + setMessages(prev => [ + ...prev, + { id: nextMessageId(), role: 'assistant' as const, content: '抱歉,生成回答时出错,请稍后再试。' }, + ]); } }, - (error: Error) => { + (error) => { console.error('RAG chat error:', error); setLoading(false); - setMessages((prev) => [ + setMessages(prev => [ ...prev, - { id: nextMessageId(), role: 'assistant' as const, content: '抱歉,连接服务器时出错,请稍后再试。' } + { id: nextMessageId(), role: 'assistant' as const, content: '抱歉,连接服务器时出错,请稍后再试。' }, ]); }, - () => { - setLoading(false); - }, - activeFilters + () => { setLoading(false); }, + activeFilters, + sessionId, + abortRef.current.signal, ); - }; + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [filterRegulationType, sessionId]); - const clearMessages = () => { - setMessages([]); + const sendMessage = (text: string) => { + if (!text.trim() || loading) return; + setMessages(prev => [...prev, { id: nextMessageId(), role: 'user' as const, content: text }]); + setInput(''); + setLoading(true); setRetrievals([]); - setShowClearConfirm(false); + setHighlightedSourceIdx(null); + executeQuery(text); }; const regenerateLastAnswer = () => { - if (messages.length < 2) return; - const lastUserMsg = messages.filter((m) => m.role === 'user').pop(); + if (loading) return; + const lastUserMsg = [...messages].reverse().find(m => m.role === 'user'); if (!lastUserMsg) return; - + // Remove the last assistant message + setMessages(prev => { + const lastAssistantIdx = [...prev].reverse().findIndex(m => m.role === 'assistant'); + if (lastAssistantIdx === -1) return prev; + const idx = prev.length - 1 - lastAssistantIdx; + return [...prev.slice(0, idx)]; + }); setLoading(true); - setMessages((prev) => [...prev.slice(0, -1)]); setRetrievals([]); setHighlightedSourceIdx(null); + executeQuery(lastUserMsg.content); + }; - let currentResponse = ''; - const activeFilters = filterRegulationType.trim() || undefined; - - void ragChat( - lastUserMsg.content, - 5, - (data: unknown) => { - const sseData = data as { - type: string; - text?: string; - docs?: Array<{ - id: string; - score: number; - preview: string; - doc_name: string; - clause: string; - doc_id?: string; - download_url?: string; - }>; - }; - - if (sseData.type === 'retrieved' && sseData.docs) { - const retrievedDocs: RetrievalData[] = sseData.docs.map(d => ({ - id: parseInt(d.id.replace('chunk-', ''), 10) || 1, - file: d.doc_name, - clause: d.clause, - score: d.score, - content: d.preview, - docId: d.doc_id, - downloadUrl: d.download_url, - })); - setRetrievals(retrievedDocs); - } else if (sseData.type === 'chunk' && sseData.text) { - currentResponse += sseData.text; - setMessages((prev) => { - const lastMsg = prev[prev.length - 1]; - if (lastMsg?.role === 'assistant') { - return [...prev.slice(0, -1), { ...lastMsg, content: currentResponse }]; - } - return [...prev, { id: nextMessageId(), role: 'assistant' as const, content: currentResponse }]; - }); - } else if (sseData.type === 'done') { - setLoading(false); - } - }, - (error: Error) => { - console.error('RAG chat error:', error); - setLoading(false); - setMessages((prev) => [ - ...prev, - { id: nextMessageId(), role: 'assistant' as const, content: '抱歉,连接服务器时出错,请稍后再试。' } - ]); - }, - () => { - setLoading(false); - }, - activeFilters - ); + const clearMessages = () => { + abortRef.current?.abort(); + setMessages([]); + setRetrievals([]); + setSessionId(undefined); + setShowClearConfirm(false); + setLoading(false); }; return ( -
+
+ {/* ── Left: chat panel ─────────────────────────────────── */}
+ {/* Message list */}
{ gap: 20, }}> {messages.length === 0 ? ( -
+
@@ -245,20 +208,14 @@ export const RagChatPage: React.FC = () => { ) : ( messages.map(msg => (
{msg.role === 'assistant' && (
@@ -271,16 +228,16 @@ export const RagChatPage: React.FC = () => { background: msg.role === 'user' ? theme.gradientAccent : theme.bgCard, borderRadius: 12, color: msg.role === 'user' ? '#fff' : theme.text, - fontSize: 14, - lineHeight: 1.6, - whiteSpace: 'pre-wrap', + fontSize: 14, lineHeight: 1.6, whiteSpace: 'pre-wrap', border: msg.role === 'assistant' ? `1px solid ${theme.border}` : 'none', }}> {msg.role === 'assistant' ? ( { + const msgSources = msg.sources ?? retrievals; + setRetrievals(msgSources); setHighlightedSourceIdx(idx); const el = document.getElementById(`source-${idx}`); if (el) el.scrollIntoView({ behavior: 'smooth', block: 'center' }); @@ -289,12 +246,9 @@ export const RagChatPage: React.FC = () => { ) : msg.content} {msg.role === 'assistant' && msg.retrievalIds && msg.retrievalIds.length > 0 && (
@@ -311,150 +265,115 @@ export const RagChatPage: React.FC = () => { {loading && (
检索中...
)} + {/* Scroll anchor */} +
-
-
+ {/* Input area */} +
+ {/* Filter row */} +
法规类型 setFilterRegulationType(e.target.value)} + onChange={e => setFilterRegulationType(e.target.value)} placeholder="如: GB / UN-ECE / IATF(留空不过滤)" style={{ - flex: 1, - maxWidth: 280, - padding: '5px 10px', - fontSize: 12, - background: theme.bgHover, - border: `1px solid ${theme.border}`, - borderRadius: 6, - color: theme.text, - outline: 'none', + flex: 1, maxWidth: 280, padding: '5px 10px', fontSize: 12, + background: theme.bgHover, border: `1px solid ${theme.border}`, + borderRadius: 6, color: theme.text, outline: 'none', }} />
-
+ {/* Quick questions */} +
{quickQuestions.map(q => ( ))}
+ {/* Send row */}
setInput(e.target.value)} - onKeyDown={(e) => e.key === 'Enter' && sendMessage(input)} + onChange={e => setInput(e.target.value)} + onKeyDown={e => e.key === 'Enter' && !e.shiftKey && sendMessage(input)} placeholder="输入法规问题..." style={{ - flex: 1, - padding: 12, - fontSize: 14, - background: theme.bgCard, - border: `1px solid ${theme.border}`, - borderRadius: 8, - color: theme.text, - outline: 'none', + flex: 1, padding: 12, fontSize: 14, + background: theme.bgCard, border: `1px solid ${theme.border}`, + borderRadius: 8, color: theme.text, outline: 'none', }} /> - {messages.length > 0 && ( + {loading && ( + + )} + {!loading && messages.length > 0 && ( )} - {messages.filter(m => m.role === 'assistant').length > 0 && ( + {!loading && messages.filter(m => m.role === 'assistant').length > 0 && ( )} @@ -462,27 +381,16 @@ export const RagChatPage: React.FC = () => {
-
+ {/* ── Right: retrieved sources panel ───────────────────── */} +
@@ -493,20 +401,13 @@ export const RagChatPage: React.FC = () => { {retrievals.length > 0 && ( {retrievals.length} )}
-
+
{retrievals.length > 0 ? (
{retrievals.map((r, i) => ( @@ -515,30 +416,20 @@ export const RagChatPage: React.FC = () => { id={`source-${i + 1}`} onClick={() => setSelectedRetrieval(r)} style={{ - padding: 16, - background: highlightedSourceIdx === i + 1 ? theme.bgElevated : theme.bgHover, - borderRadius: 10, - border: `1px solid ${highlightedSourceIdx === i + 1 ? theme.accent : theme.border}`, - cursor: 'pointer', - position: 'relative', + padding: 16, background: highlightedSourceIdx === i + 1 ? theme.bgElevated : theme.bgHover, + borderRadius: 10, border: `1px solid ${highlightedSourceIdx === i + 1 ? theme.accent : theme.border}`, + cursor: 'pointer', position: 'relative', transition: 'border-color 0.2s, background 0.2s', }} >
#{i + 1}
@@ -547,17 +438,13 @@ export const RagChatPage: React.FC = () => { href={r.downloadUrl} target="_blank" rel="noreferrer" - onClick={(event) => event.stopPropagation()} + onClick={e => e.stopPropagation()} style={{ fontSize: 11, color: theme.accent, textDecoration: 'none' }} - > - 下载文档 - + >下载文档 )} - {(r.score * 100).toFixed(0)}% + + {(r.score * 100).toFixed(0)}% +
{r.file}
@@ -565,32 +452,20 @@ export const RagChatPage: React.FC = () => { {r.clause}{r.docId ? ` · ${r.docId}` : ''}
{r.content}
))}
) : ( -
+
@@ -602,52 +477,33 @@ export const RagChatPage: React.FC = () => {
+ {/* ── Clear confirm modal ───────────────────────────────── */} {showClearConfirm && (
确定清空对话?
-
此操作不可恢复
+
此操作不可恢复,会话历史将被重置
@@ -655,47 +511,31 @@ export const RagChatPage: React.FC = () => {
)} + {/* ── Source detail modal ───────────────────────────────── */} {selectedRetrieval && (
setSelectedRetrieval(null)} style={{ - position: 'fixed', - top: 0, - left: 0, - right: 0, - bottom: 0, - background: 'rgba(0,0,0,0.6)', - display: 'flex', - alignItems: 'center', - justifyContent: 'center', - zIndex: 1000, + position: 'fixed', top: 0, left: 0, right: 0, bottom: 0, + background: 'rgba(0,0,0,0.6)', display: 'flex', + alignItems: 'center', justifyContent: 'center', zIndex: 1000, }} >
e.stopPropagation()} + onClick={e => e.stopPropagation()} style={{ - width: 520, - maxWidth: '90%', - maxHeight: '80%', - padding: 24, - background: theme.bgCard, - borderRadius: 16, - border: `1px solid ${theme.accent}`, + width: 520, maxWidth: '90%', maxHeight: '80%', + overflowY: 'auto', padding: 24, background: theme.bgCard, + borderRadius: 16, border: `1px solid ${theme.accent}`, boxShadow: '0 8px 32px rgba(0,0,0,0.3)', }} >
-
+
{(selectedRetrieval.score * 100).toFixed(0)}% @@ -707,23 +547,15 @@ export const RagChatPage: React.FC = () => { target="_blank" rel="noreferrer" style={{ fontSize: 12, color: theme.accent, textDecoration: 'none' }} - > - 下载关联文档 - + >下载关联文档 )}
-
+
{selectedRetrieval.clause} {selectedRetrieval.docId && ( {selectedRetrieval.docId} )}
-
{selectedRetrieval.content}
+
+ {selectedRetrieval.content} +
)} diff --git a/frontend/src/types/doc.ts b/frontend/src/types/doc.ts index 7d30511..5c11cd9 100644 --- a/frontend/src/types/doc.ts +++ b/frontend/src/types/doc.ts @@ -25,6 +25,7 @@ export interface ChatMessage { role: 'user' | 'assistant'; content: string; retrievalIds?: number[]; + sources?: RetrievalData[]; } export interface RetrievalData {