Fix 法规对话
This commit is contained in:
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
from typing import AsyncGenerator, List, Optional
|
from typing import AsyncGenerator, List, Optional
|
||||||
|
|
||||||
@@ -20,6 +19,7 @@ from app.api.models import (
|
|||||||
SessionInfo,
|
SessionInfo,
|
||||||
)
|
)
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
from app.shared.async_utils import iter_in_thread
|
||||||
from app.shared.bootstrap import get_agent_conversation_service, get_conversation_store
|
from app.shared.bootstrap import get_agent_conversation_service, get_conversation_store
|
||||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||||
|
|
||||||
@@ -101,14 +101,13 @@ async def chat_stream_get(
|
|||||||
top_k=settings.rag_top_k,
|
top_k=settings.rag_top_k,
|
||||||
)
|
)
|
||||||
yield f"event: session\ndata: {json.dumps({'session_id': session_id_})}\n\n"
|
yield f"event: session\ndata: {json.dumps({'session_id': session_id_})}\n\n"
|
||||||
for event_data in event_stream:
|
async for event_data in iter_in_thread(event_stream):
|
||||||
event_type = event_data.get("event", "content")
|
event_type = event_data.get("event", "content")
|
||||||
data = event_data.get("data", "")
|
data = event_data.get("data", "")
|
||||||
if isinstance(data, (dict, list)):
|
if isinstance(data, (dict, list)):
|
||||||
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
yield f"event: {event_type}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||||
else:
|
else:
|
||||||
yield f"event: {event_type}\ndata: {data}\n\n"
|
yield f"event: {event_type}\ndata: {data}\n\n"
|
||||||
await asyncio.sleep(0)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
yield f"event: error\ndata: {str(exc)}\n\n"
|
yield f"event: error\ndata: {str(exc)}\n\n"
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
@@ -11,6 +10,7 @@ from fastapi.responses import StreamingResponse
|
|||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
|
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
|
||||||
|
from app.shared.async_utils import iter_in_thread
|
||||||
from app.shared.bootstrap import get_agent_conversation_service
|
from app.shared.bootstrap import get_agent_conversation_service
|
||||||
|
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ _DEFAULT_QUICK_QUESTIONS = [
|
|||||||
@router.post("/chat")
|
@router.post("/chat")
|
||||||
async def rag_chat(request: RagChatRequest):
|
async def rag_chat(request: RagChatRequest):
|
||||||
"""Stream RAG Q&A using the real agent service."""
|
"""Stream RAG Q&A using the real agent service."""
|
||||||
_, event_stream = get_agent_conversation_service().stream_chat(
|
session_id, event_stream = get_agent_conversation_service().stream_chat(
|
||||||
query=request.query,
|
query=request.query,
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
filters=request.filters,
|
filters=request.filters,
|
||||||
@@ -38,7 +38,11 @@ async def rag_chat(request: RagChatRequest):
|
|||||||
|
|
||||||
async def generate() -> AsyncGenerator[str, None]:
|
async def generate() -> AsyncGenerator[str, None]:
|
||||||
"""Translate agent SSE events to rag format."""
|
"""Translate agent SSE events to rag format."""
|
||||||
for event in event_stream:
|
yield (
|
||||||
|
"event: message\n"
|
||||||
|
f"data: {json.dumps({'type': 'session', 'session_id': session_id}, ensure_ascii=False)}\n\n"
|
||||||
|
)
|
||||||
|
async for event in iter_in_thread(event_stream):
|
||||||
event_type = event.get("event", "")
|
event_type = event.get("event", "")
|
||||||
data = event.get("data", "")
|
data = event.get("data", "")
|
||||||
if event_type == "sources":
|
if event_type == "sources":
|
||||||
@@ -69,14 +73,18 @@ async def rag_chat(request: RagChatRequest):
|
|||||||
elif event_type == "done":
|
elif event_type == "done":
|
||||||
yield (
|
yield (
|
||||||
"event: message\n"
|
"event: message\n"
|
||||||
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
f"data: {json.dumps({'type': 'done', 'session_id': session_id}, ensure_ascii=False)}\n\n"
|
||||||
)
|
)
|
||||||
elif event_type == "status":
|
elif event_type == "status":
|
||||||
yield (
|
yield (
|
||||||
"event: message\n"
|
"event: message\n"
|
||||||
f"data: {json.dumps({'type': 'status', 'text': data}, ensure_ascii=False)}\n\n"
|
f"data: {json.dumps({'type': 'status', 'text': data}, ensure_ascii=False)}\n\n"
|
||||||
)
|
)
|
||||||
await asyncio.sleep(0)
|
elif event_type == "error":
|
||||||
|
yield (
|
||||||
|
"event: message\n"
|
||||||
|
f"data: {json.dumps({'type': 'error', 'text': str(data)}, ensure_ascii=False)}\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generate(),
|
generate(),
|
||||||
|
|||||||
@@ -7,20 +7,75 @@ from app.domain.retrieval.ports import Reranker
|
|||||||
# Keep orchestration logic centralized so use-case flow stays easy to trace.
|
# Keep orchestration logic centralized so use-case flow stays easy to trace.
|
||||||
|
|
||||||
|
|
||||||
|
def _reciprocal_rank_fusion(
|
||||||
|
ranked_lists: list[list[RetrievedChunk]], k: int = 60
|
||||||
|
) -> list[RetrievedChunk]:
|
||||||
|
"""Merge multiple ranked lists with Reciprocal Rank Fusion.
|
||||||
|
|
||||||
|
Score for chunk c = sum over lists of 1 / (k + rank(c)).
|
||||||
|
A chunk appearing in multiple lists gets a higher fused score.
|
||||||
|
"""
|
||||||
|
scores: dict[str, float] = {}
|
||||||
|
chunk_map: dict[str, RetrievedChunk] = {}
|
||||||
|
|
||||||
|
for ranked in ranked_lists:
|
||||||
|
for rank, chunk in enumerate(ranked):
|
||||||
|
key = chunk.chunk_id
|
||||||
|
scores[key] = scores.get(key, 0.0) + 1.0 / (k + rank + 1)
|
||||||
|
chunk_map[key] = chunk
|
||||||
|
|
||||||
|
sorted_keys = sorted(scores, key=lambda ck: scores[ck], reverse=True)
|
||||||
|
return [
|
||||||
|
RetrievedChunk(
|
||||||
|
chunk_id=chunk_map[ck].chunk_id,
|
||||||
|
doc_id=chunk_map[ck].doc_id,
|
||||||
|
doc_name=chunk_map[ck].doc_name,
|
||||||
|
content=chunk_map[ck].content,
|
||||||
|
score=scores[ck],
|
||||||
|
section_title=chunk_map[ck].section_title,
|
||||||
|
page_number=chunk_map[ck].page_number,
|
||||||
|
metadata=chunk_map[ck].metadata,
|
||||||
|
)
|
||||||
|
for ck in sorted_keys
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeRetrievalService:
|
class KnowledgeRetrievalService:
|
||||||
"""Provide the Knowledge Retrieval Service service."""
|
"""Provide the Knowledge Retrieval Service service."""
|
||||||
|
|
||||||
def __init__(self, *, retriever: Retriever, reranker: Reranker | None = None, reranker_top_k: int = 5) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
retriever: Retriever,
|
||||||
|
bm25_retriever=None,
|
||||||
|
reranker: Reranker | None = None,
|
||||||
|
reranker_top_k: int = 5,
|
||||||
|
) -> None:
|
||||||
"""Initialize the Knowledge Retrieval Service instance."""
|
"""Initialize the Knowledge Retrieval Service instance."""
|
||||||
self.retriever = retriever
|
self.retriever = retriever
|
||||||
|
self.bm25_retriever = bm25_retriever
|
||||||
self.reranker = reranker
|
self.reranker = reranker
|
||||||
self.reranker_top_k = reranker_top_k
|
self.reranker_top_k = reranker_top_k
|
||||||
|
|
||||||
def retrieve(self, *, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
def retrieve(self, *, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||||
"""Retrieve and optionally rerank chunks for a query."""
|
"""Retrieve and optionally rerank chunks for a query.
|
||||||
candidate_k = top_k if self.reranker is None else max(top_k * 4, 20)
|
|
||||||
|
When a BM25 retriever is available, combines dense + sparse results
|
||||||
|
via Reciprocal Rank Fusion before optional reranking.
|
||||||
|
"""
|
||||||
|
use_hybrid = self.bm25_retriever is not None and getattr(self.bm25_retriever, "available", False)
|
||||||
|
candidate_k = max(top_k * 4, 20) if (self.reranker is not None or use_hybrid) else top_k
|
||||||
|
|
||||||
retrieval_query = RetrievalQuery(query=query, top_k=candidate_k, filters=filters)
|
retrieval_query = RetrievalQuery(query=query, top_k=candidate_k, filters=filters)
|
||||||
candidates = self.retriever.retrieve(retrieval_query)
|
dense_results = self.retriever.retrieve(retrieval_query)
|
||||||
|
|
||||||
|
if use_hybrid:
|
||||||
|
bm25_results = self.bm25_retriever.retrieve(query, top_k=candidate_k, filters=filters)
|
||||||
|
candidates = _reciprocal_rank_fusion([dense_results, bm25_results])
|
||||||
|
else:
|
||||||
|
candidates = dense_results
|
||||||
|
|
||||||
if self.reranker and candidates:
|
if self.reranker and candidates:
|
||||||
return self.reranker.rerank(query, candidates, top_k=self.reranker_top_k)
|
return self.reranker.rerank(query, candidates, top_k=self.reranker_top_k)
|
||||||
|
|
||||||
return candidates[:top_k]
|
return candidates[:top_k]
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from app.services.llm.llm_factory import get_llm_client
|
|||||||
# Keep adapter behavior explicit so integration details remain easy to audit.
|
# Keep adapter behavior explicit so integration details remain easy to audit.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PROMPT_TEMPLATES = {
|
PROMPT_TEMPLATES = {
|
||||||
"default": "你是法规知识问答助手。请仅依据提供的上下文回答;如果上下文不足,明确说明。",
|
"default": "你是法规知识问答助手。请仅依据提供的上下文回答;如果上下文不足,明确说明。",
|
||||||
"compliance_qa": "你是法规合规问答助手。优先引用给定法规原文,回答要准确、克制,并注明依据来源。",
|
"compliance_qa": "你是法规合规问答助手。优先引用给定法规原文,回答要准确、克制,并注明依据来源。",
|
||||||
@@ -21,6 +20,17 @@ PROMPT_TEMPLATES = {
|
|||||||
|
|
||||||
class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
||||||
"""Represent the Open A I Compatible Answer Generator type."""
|
"""Represent the Open A I Compatible Answer Generator type."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _estimate_tokens(text: str) -> int:
|
||||||
|
"""Estimate token count for mixed Chinese/English text.
|
||||||
|
|
||||||
|
Chinese chars are ~1.5 chars/token; ASCII is ~4 chars/token.
|
||||||
|
"""
|
||||||
|
chinese = sum(1 for c in text if "一" <= c <= "鿿")
|
||||||
|
other = len(text) - chinese
|
||||||
|
return int(chinese / 1.5 + other / 4) + 1
|
||||||
|
|
||||||
def _build_messages(
|
def _build_messages(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -40,9 +50,12 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
|||||||
f"页码: {chunk.page_number}\n"
|
f"页码: {chunk.page_number}\n"
|
||||||
f"内容: {chunk.content}"
|
f"内容: {chunk.content}"
|
||||||
)
|
)
|
||||||
context_tokens += len(block)
|
block_tokens = self._estimate_tokens(block)
|
||||||
|
if context_tokens + block_tokens > settings.rag_max_context_tokens:
|
||||||
|
break
|
||||||
|
context_tokens += block_tokens
|
||||||
context_blocks.append(block)
|
context_blocks.append(block)
|
||||||
context = "\n\n".join(context_blocks)[: settings.rag_max_context_tokens * 4]
|
context = "\n\n".join(context_blocks)
|
||||||
messages = [{"role": "system", "content": system_prompt}]
|
messages = [{"role": "system", "content": system_prompt}]
|
||||||
for item in history or []:
|
for item in history or []:
|
||||||
messages.append({"role": item["role"], "content": item["content"]})
|
messages.append({"role": item["role"], "content": item["content"]})
|
||||||
@@ -52,7 +65,7 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
|||||||
"content": f"问题:{query}\n\n参考上下文:\n{context}\n\n请在回答后给出简要引用编号。",
|
"content": f"问题:{query}\n\n参考上下文:\n{context}\n\n请在回答后给出简要引用编号。",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return messages, min(context_tokens, settings.rag_max_context_tokens)
|
return messages, context_tokens
|
||||||
|
|
||||||
def _sources(self, chunks: list[RetrievedChunk]) -> list[AnswerSource]:
|
def _sources(self, chunks: list[RetrievedChunk]) -> list[AnswerSource]:
|
||||||
"""Handle sources for this module for the Open A I Compatible Answer Generator instance."""
|
"""Handle sources for this module for the Open A I Compatible Answer Generator instance."""
|
||||||
@@ -98,7 +111,7 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
|||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
retrieved_count=len(retrieved_chunks),
|
retrieved_count=len(retrieved_chunks),
|
||||||
context_tokens=context_tokens,
|
context_tokens=context_tokens,
|
||||||
truncated=False,
|
truncated=len(retrieved_chunks) > len(messages),
|
||||||
error=response.error,
|
error=response.error,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -124,15 +137,18 @@ class OpenAICompatibleAnswerGenerator(AnswerGenerator):
|
|||||||
yield {"event": "sources", "data": sources}
|
yield {"event": "sources", "data": sources}
|
||||||
client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model)
|
client = get_llm_client(provider=provider or settings.llm_provider, model=model or settings.llm_model)
|
||||||
answer_parts: list[str] = []
|
answer_parts: list[str] = []
|
||||||
if hasattr(client, "stream_chat"):
|
try:
|
||||||
for chunk in client.stream_chat(messages):
|
if hasattr(client, "stream_chat"):
|
||||||
answer_parts.append(chunk)
|
for chunk in client.stream_chat(messages):
|
||||||
yield {"event": "content", "data": chunk}
|
answer_parts.append(chunk)
|
||||||
else:
|
yield {"event": "content", "data": chunk}
|
||||||
response = client.chat(messages)
|
else:
|
||||||
answer_parts.append(response.content)
|
response = client.chat(messages)
|
||||||
yield {"event": "content", "data": response.content}
|
answer_parts.append(response.content)
|
||||||
full_answer = "".join(answer_parts)
|
yield {"event": "content", "data": response.content}
|
||||||
|
except Exception as exc:
|
||||||
|
yield {"event": "error", "data": str(exc)}
|
||||||
|
return
|
||||||
yield {
|
yield {
|
||||||
"event": "done",
|
"event": "done",
|
||||||
"data": {
|
"data": {
|
||||||
|
|||||||
149
backend/app/infrastructure/vectorstore/bm25_retriever.py
Normal file
149
backend/app/infrastructure/vectorstore/bm25_retriever.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""BM25 sparse retriever backed by the existing Milvus collection.
|
||||||
|
|
||||||
|
Falls back to no-op if rank_bm25 or jieba is not installed.
|
||||||
|
The index is built lazily on first query and can be refreshed after
|
||||||
|
new documents are ingested by calling refresh().
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from app.domain.retrieval import RetrievedChunk
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from rank_bm25 import BM25Okapi
|
||||||
|
import jieba
|
||||||
|
|
||||||
|
jieba.setLogLevel(logging.WARNING)
|
||||||
|
_BM25_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
_BM25_AVAILABLE = False
|
||||||
|
logger.warning("rank_bm25 or jieba not installed – BM25 retrieval disabled. Run: pip install rank-bm25 jieba")
|
||||||
|
|
||||||
|
|
||||||
|
def _tokenize(text: str) -> list[str]:
|
||||||
|
"""Tokenize Chinese/mixed text with jieba."""
|
||||||
|
return list(jieba.cut(text))
|
||||||
|
|
||||||
|
|
||||||
|
class BM25Retriever:
|
||||||
|
"""Sparse BM25 retriever that indexes chunks stored in Milvus.
|
||||||
|
|
||||||
|
Thread-safe: the index is protected by a lock so concurrent requests
|
||||||
|
during the initial build do not race.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vector_index: "MilvusVectorIndex") -> None:
|
||||||
|
self._vector_index = vector_index
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._index: "BM25Okapi | None" = None
|
||||||
|
self._chunks: list[RetrievedChunk] = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available(self) -> bool:
|
||||||
|
return _BM25_AVAILABLE
|
||||||
|
|
||||||
|
def _load_all_chunks(self) -> list[RetrievedChunk]:
|
||||||
|
"""Fetch all chunk records from Milvus (no vectors needed)."""
|
||||||
|
try:
|
||||||
|
rows = self._vector_index.collection.query(
|
||||||
|
expr='doc_id != ""',
|
||||||
|
output_fields=["id", "doc_id", "doc_name", "content", "section_title", "page_number"],
|
||||||
|
limit=16384,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("BM25Retriever: failed to fetch chunks from Milvus")
|
||||||
|
return []
|
||||||
|
return [
|
||||||
|
RetrievedChunk(
|
||||||
|
chunk_id=str(row.get("id", "")),
|
||||||
|
doc_id=str(row.get("doc_id", "")),
|
||||||
|
doc_name=str(row.get("doc_name", "")),
|
||||||
|
content=str(row.get("content", "")),
|
||||||
|
score=0.0,
|
||||||
|
section_title=str(row.get("section_title", "")),
|
||||||
|
page_number=int(row.get("page_number") or 0),
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
if row.get("content")
|
||||||
|
]
|
||||||
|
|
||||||
|
def _ensure_built(self) -> None:
|
||||||
|
if self._index is not None:
|
||||||
|
return
|
||||||
|
with self._lock:
|
||||||
|
if self._index is not None:
|
||||||
|
return
|
||||||
|
self._build()
|
||||||
|
|
||||||
|
def _build(self) -> None:
|
||||||
|
logger.info("BM25Retriever: building index …")
|
||||||
|
chunks = self._load_all_chunks()
|
||||||
|
if not chunks:
|
||||||
|
logger.warning("BM25Retriever: no chunks found, index is empty")
|
||||||
|
self._chunks = []
|
||||||
|
self._index = BM25Okapi([[]])
|
||||||
|
return
|
||||||
|
tokenized = [_tokenize(c.content) for c in chunks]
|
||||||
|
self._chunks = chunks
|
||||||
|
self._index = BM25Okapi(tokenized)
|
||||||
|
logger.info("BM25Retriever: index built with %d chunks", len(chunks))
|
||||||
|
|
||||||
|
def refresh(self) -> None:
|
||||||
|
"""Rebuild the index (call after new documents are ingested)."""
|
||||||
|
with self._lock:
|
||||||
|
self._index = None
|
||||||
|
self._chunks = []
|
||||||
|
self._build()
|
||||||
|
|
||||||
|
def retrieve(self, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||||
|
"""Return top_k chunks ranked by BM25 score."""
|
||||||
|
if not _BM25_AVAILABLE:
|
||||||
|
return []
|
||||||
|
self._ensure_built()
|
||||||
|
if not self._chunks:
|
||||||
|
return []
|
||||||
|
|
||||||
|
tokens = _tokenize(query)
|
||||||
|
scores = self._index.get_scores(tokens) # type: ignore[union-attr]
|
||||||
|
|
||||||
|
# Pair and sort descending by score
|
||||||
|
ranked = sorted(
|
||||||
|
((float(scores[i]), self._chunks[i]) for i in range(len(self._chunks))),
|
||||||
|
key=lambda x: x[0],
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[RetrievedChunk] = []
|
||||||
|
for score, chunk in ranked[: top_k * 2]:
|
||||||
|
if score <= 0:
|
||||||
|
break
|
||||||
|
# Apply simple regulation_type filter if provided
|
||||||
|
if filters and chunk.metadata.get("regulation_type"):
|
||||||
|
types = [t.strip() for t in filters.split(",")]
|
||||||
|
if chunk.metadata.get("regulation_type") not in types:
|
||||||
|
continue
|
||||||
|
results.append(
|
||||||
|
RetrievedChunk(
|
||||||
|
chunk_id=chunk.chunk_id,
|
||||||
|
doc_id=chunk.doc_id,
|
||||||
|
doc_name=chunk.doc_name,
|
||||||
|
content=chunk.content,
|
||||||
|
score=score,
|
||||||
|
section_title=chunk.section_title,
|
||||||
|
page_number=chunk.page_number,
|
||||||
|
metadata=chunk.metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(results) >= top_k:
|
||||||
|
break
|
||||||
|
return results
|
||||||
38
backend/app/shared/async_utils.py
Normal file
38
backend/app/shared/async_utils.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""Async utility helpers for bridging sync generators to async FastAPI routes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
from typing import AsyncGenerator, Generator, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
|
async def iter_in_thread(sync_gen: Generator[T, None, None]) -> AsyncGenerator[T, None]:
|
||||||
|
"""Yield items from a synchronous generator without blocking the event loop.
|
||||||
|
|
||||||
|
Runs the generator in a daemon thread, forwarding items via an asyncio.Queue.
|
||||||
|
Use this to wrap blocking LLM streaming generators inside async FastAPI routes.
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
queue: asyncio.Queue = asyncio.Queue(maxsize=32)
|
||||||
|
|
||||||
|
def _drain() -> None:
|
||||||
|
try:
|
||||||
|
for item in sync_gen:
|
||||||
|
future = asyncio.run_coroutine_threadsafe(queue.put(item), loop)
|
||||||
|
future.result()
|
||||||
|
finally:
|
||||||
|
asyncio.run_coroutine_threadsafe(queue.put(_SENTINEL), loop).result()
|
||||||
|
|
||||||
|
thread = threading.Thread(target=_drain, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = await queue.get()
|
||||||
|
if item is _SENTINEL:
|
||||||
|
break
|
||||||
|
yield item # type: ignore[misc]
|
||||||
@@ -19,6 +19,7 @@ from app.infrastructure.storage.json_document_repository import JsonDocumentRepo
|
|||||||
from app.infrastructure.storage.minio_binary_store import MinioDocumentBinaryStore
|
from app.infrastructure.storage.minio_binary_store import MinioDocumentBinaryStore
|
||||||
from app.infrastructure.storage.postgres_document_repository import PostgresDocumentRepository
|
from app.infrastructure.storage.postgres_document_repository import PostgresDocumentRepository
|
||||||
from app.infrastructure.storage.postgres_parse_artifact_store import PostgresParseArtifactStore
|
from app.infrastructure.storage.postgres_parse_artifact_store import PostgresParseArtifactStore
|
||||||
|
from app.infrastructure.vectorstore.bm25_retriever import BM25Retriever
|
||||||
from app.infrastructure.vectorstore.dense_retriever import DenseRetriever
|
from app.infrastructure.vectorstore.dense_retriever import DenseRetriever
|
||||||
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
|
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
|
||||||
from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker
|
from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker
|
||||||
@@ -87,6 +88,13 @@ def get_reranker():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_bm25_retriever() -> BM25Retriever | None:
|
||||||
|
"""Return BM25 retriever if rank_bm25 + jieba are installed, else None."""
|
||||||
|
retriever = BM25Retriever(vector_index=get_vector_index())
|
||||||
|
return retriever if retriever.available else None
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def get_retrieval_service() -> KnowledgeRetrievalService:
|
def get_retrieval_service() -> KnowledgeRetrievalService:
|
||||||
"""Return retrieval service."""
|
"""Return retrieval service."""
|
||||||
@@ -96,6 +104,7 @@ def get_retrieval_service() -> KnowledgeRetrievalService:
|
|||||||
)
|
)
|
||||||
return KnowledgeRetrievalService(
|
return KnowledgeRetrievalService(
|
||||||
retriever=retriever,
|
retriever=retriever,
|
||||||
|
bm25_retriever=get_bm25_retriever(),
|
||||||
reranker=get_reranker(),
|
reranker=get_reranker(),
|
||||||
reranker_top_k=settings.reranker_top_k,
|
reranker_top_k=settings.reranker_top_k,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ export interface SSEMessage {
|
|||||||
type: string;
|
type: string;
|
||||||
text?: string;
|
text?: string;
|
||||||
docs?: RetrievedDoc[];
|
docs?: RetrievedDoc[];
|
||||||
|
session_id?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function streamSSE<TMessage extends SSEMessage>(
|
export async function streamSSE<TMessage extends SSEMessage>(
|
||||||
|
|||||||
@@ -38,7 +38,14 @@ function parseSSEChunk(raw: string, onMessage: (data: SSEMessage) => void) {
|
|||||||
const joined = dataLines.join('\n');
|
const joined = dataLines.join('\n');
|
||||||
if (!joined) continue;
|
if (!joined) continue;
|
||||||
|
|
||||||
if (eventName === 'sources') {
|
// /agent/chat/stream uses named events (sources, content, done, error, session, status)
|
||||||
|
// /rag/chat wraps everything in event:message with type in JSON body — handle both
|
||||||
|
if (eventName === 'session') {
|
||||||
|
try {
|
||||||
|
const payload = JSON.parse(joined) as Record<string, unknown>;
|
||||||
|
onMessage({ type: 'session', session_id: String(payload.session_id ?? '') });
|
||||||
|
} catch { /* ignore */ }
|
||||||
|
} else if (eventName === 'sources') {
|
||||||
try {
|
try {
|
||||||
const docs = JSON.parse(joined) as Array<Record<string, unknown>>;
|
const docs = JSON.parse(joined) as Array<Record<string, unknown>>;
|
||||||
onMessage({
|
onMessage({
|
||||||
@@ -53,17 +60,26 @@ function parseSSEChunk(raw: string, onMessage: (data: SSEMessage) => void) {
|
|||||||
download_url: doc.doc_id ? `${AGENT_API_BASE}/documents/download/${String(doc.doc_id)}` : undefined,
|
download_url: doc.doc_id ? `${AGENT_API_BASE}/documents/download/${String(doc.doc_id)}` : undefined,
|
||||||
})),
|
})),
|
||||||
});
|
});
|
||||||
} catch {
|
} catch { /* ignore */ }
|
||||||
// Ignore malformed source payloads.
|
|
||||||
}
|
|
||||||
} else if (eventName === 'content') {
|
} else if (eventName === 'content') {
|
||||||
onMessage({ type: 'chunk', text: joined });
|
onMessage({ type: 'chunk', text: joined });
|
||||||
} else if (eventName === 'done') {
|
} else if (eventName === 'done') {
|
||||||
onMessage({ type: 'done', text: joined });
|
try {
|
||||||
|
const payload = JSON.parse(joined) as Record<string, unknown>;
|
||||||
|
onMessage({ type: 'done', session_id: payload.session_id ? String(payload.session_id) : undefined });
|
||||||
|
} catch {
|
||||||
|
onMessage({ type: 'done' });
|
||||||
|
}
|
||||||
} else if (eventName === 'error') {
|
} else if (eventName === 'error') {
|
||||||
onMessage({ type: 'error', text: joined });
|
onMessage({ type: 'error', text: joined });
|
||||||
} else if (eventName === 'status') {
|
} else if (eventName === 'status') {
|
||||||
onMessage({ type: 'status', text: joined });
|
onMessage({ type: 'status', text: joined });
|
||||||
|
} else if (eventName === 'message') {
|
||||||
|
// /rag/chat format: event:message + JSON body with type field
|
||||||
|
try {
|
||||||
|
const payload = JSON.parse(joined) as SSEMessage;
|
||||||
|
onMessage(payload);
|
||||||
|
} catch { /* ignore */ }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -74,7 +90,9 @@ export async function ragChat(
|
|||||||
onMessage: (data: SSEMessage) => void,
|
onMessage: (data: SSEMessage) => void,
|
||||||
onError?: (error: Error) => void,
|
onError?: (error: Error) => void,
|
||||||
onComplete?: () => void,
|
onComplete?: () => void,
|
||||||
filters?: string
|
filters?: string,
|
||||||
|
sessionId?: string,
|
||||||
|
signal?: AbortSignal,
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`${AGENT_API_BASE}/agent/chat/stream`, {
|
const response = await fetch(`${AGENT_API_BASE}/agent/chat/stream`, {
|
||||||
@@ -83,7 +101,13 @@ export async function ragChat(
|
|||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
Accept: 'text/event-stream',
|
Accept: 'text/event-stream',
|
||||||
},
|
},
|
||||||
body: JSON.stringify({ query, top_k: topK, ...(filters ? { filters } : {}) }),
|
body: JSON.stringify({
|
||||||
|
query,
|
||||||
|
top_k: topK,
|
||||||
|
...(filters ? { filters } : {}),
|
||||||
|
...(sessionId ? { session_id: sessionId } : {}),
|
||||||
|
}),
|
||||||
|
signal,
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!response.ok || !response.body) {
|
if (!response.ok || !response.body) {
|
||||||
@@ -112,6 +136,7 @@ export async function ragChat(
|
|||||||
onComplete();
|
onComplete();
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
if (error instanceof DOMException && error.name === 'AbortError') return;
|
||||||
if (onError) {
|
if (onError) {
|
||||||
onError(error instanceof Error ? error : new Error(String(error)));
|
onError(error instanceof Error ? error : new Error(String(error)));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import React, { useEffect, useRef, useState } from 'react';
|
import React, { useCallback, useEffect, useRef, useState } from 'react';
|
||||||
import { useTheme } from '../../contexts';
|
import { useTheme } from '../../contexts';
|
||||||
import type { ChatMessage, RetrievalData } from '../../types';
|
import type { ChatMessage, RetrievalData } from '../../types';
|
||||||
import { getQuickQuestions, ragChat } from '../../api/rag';
|
import { getQuickQuestions, ragChat } from '../../api/rag';
|
||||||
@@ -19,6 +19,7 @@ export const RagChatPage: React.FC = () => {
|
|||||||
const { theme } = useTheme();
|
const { theme } = useTheme();
|
||||||
const nextMessageIdRef = useRef(1);
|
const nextMessageIdRef = useRef(1);
|
||||||
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
||||||
|
// retrievals: right-panel shows sources of the most recent assistant reply
|
||||||
const [retrievals, setRetrievals] = useState<RetrievalData[]>([]);
|
const [retrievals, setRetrievals] = useState<RetrievalData[]>([]);
|
||||||
const [input, setInput] = useState<string>('');
|
const [input, setInput] = useState<string>('');
|
||||||
const [loading, setLoading] = useState<boolean>(false);
|
const [loading, setLoading] = useState<boolean>(false);
|
||||||
@@ -27,62 +28,60 @@ export const RagChatPage: React.FC = () => {
|
|||||||
const [quickQuestions, setQuickQuestions] = useState<string[]>(ragQuickQuestionsDefault);
|
const [quickQuestions, setQuickQuestions] = useState<string[]>(ragQuickQuestionsDefault);
|
||||||
const [filterRegulationType, setFilterRegulationType] = useState<string>('');
|
const [filterRegulationType, setFilterRegulationType] = useState<string>('');
|
||||||
const [highlightedSourceIdx, setHighlightedSourceIdx] = useState<number | null>(null);
|
const [highlightedSourceIdx, setHighlightedSourceIdx] = useState<number | null>(null);
|
||||||
|
const [sessionId, setSessionId] = useState<string | undefined>();
|
||||||
|
|
||||||
|
// Auto-scroll ref
|
||||||
|
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||||
|
// AbortController for cancelling in-flight requests
|
||||||
|
const abortRef = useRef<AbortController | null>(null);
|
||||||
|
|
||||||
function nextMessageId() {
|
function nextMessageId() {
|
||||||
const currentId = nextMessageIdRef.current;
|
const id = nextMessageIdRef.current;
|
||||||
nextMessageIdRef.current += 1;
|
nextMessageIdRef.current += 1;
|
||||||
return currentId;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Scroll to bottom whenever messages change
|
||||||
|
useEffect(() => {
|
||||||
|
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' });
|
||||||
|
}, [messages]);
|
||||||
|
|
||||||
async function loadQuickQuestions() {
|
async function loadQuickQuestions() {
|
||||||
try {
|
try {
|
||||||
const response = await getQuickQuestions();
|
const response = await getQuickQuestions();
|
||||||
setQuickQuestions(response.questions.map(q => q.question));
|
setQuickQuestions(response.questions.map(q => q.question));
|
||||||
} catch (error) {
|
} catch {
|
||||||
console.error('Failed to load quick questions:', error);
|
// keep defaults
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const timerId = window.setTimeout(() => {
|
const timerId = window.setTimeout(() => { void loadQuickQuestions(); }, 0);
|
||||||
void loadQuickQuestions();
|
|
||||||
}, 0);
|
|
||||||
return () => window.clearTimeout(timerId);
|
return () => window.clearTimeout(timerId);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const sendMessage = (text: string) => {
|
/**
|
||||||
if (!text.trim()) return;
|
* Core query executor — shared by sendMessage and regenerateLastAnswer.
|
||||||
|
* Manages session_id, AbortController, SSE parsing, and state updates.
|
||||||
|
*/
|
||||||
|
const executeQuery = useCallback((text: string) => {
|
||||||
|
// Cancel any in-flight request
|
||||||
|
abortRef.current?.abort();
|
||||||
|
abortRef.current = new AbortController();
|
||||||
|
|
||||||
const userMsg = { id: nextMessageId(), role: 'user' as const, content: text };
|
|
||||||
setMessages((prev) => [...prev, userMsg]);
|
|
||||||
setInput('');
|
|
||||||
setLoading(true);
|
|
||||||
setRetrievals([]);
|
|
||||||
setHighlightedSourceIdx(null);
|
|
||||||
|
|
||||||
let currentResponse = '';
|
|
||||||
const activeFilters = filterRegulationType.trim() || undefined;
|
const activeFilters = filterRegulationType.trim() || undefined;
|
||||||
|
let currentResponse = '';
|
||||||
|
// Capture the assistant message id so we can attach sources later
|
||||||
|
let assistantMsgId: number | null = null;
|
||||||
|
|
||||||
void ragChat(
|
void ragChat(
|
||||||
text,
|
text,
|
||||||
5,
|
5,
|
||||||
(data: unknown) => {
|
(data) => {
|
||||||
const sseData = data as {
|
if (data.type === 'session' && data.session_id) {
|
||||||
type: string;
|
setSessionId(data.session_id);
|
||||||
text?: string;
|
} else if (data.type === 'retrieved' && data.docs) {
|
||||||
docs?: Array<{
|
const docs: RetrievalData[] = data.docs.map(d => ({
|
||||||
id: string;
|
|
||||||
score: number;
|
|
||||||
preview: string;
|
|
||||||
doc_name: string;
|
|
||||||
clause: string;
|
|
||||||
doc_id?: string;
|
|
||||||
download_url?: string;
|
|
||||||
}>;
|
|
||||||
};
|
|
||||||
|
|
||||||
if (sseData.type === 'retrieved' && sseData.docs) {
|
|
||||||
const retrievedDocs: RetrievalData[] = sseData.docs.map(d => ({
|
|
||||||
id: parseInt(d.id.replace('chunk-', ''), 10) || 1,
|
id: parseInt(d.id.replace('chunk-', ''), 10) || 1,
|
||||||
file: d.doc_name,
|
file: d.doc_name,
|
||||||
clause: d.clause,
|
clause: d.clause,
|
||||||
@@ -91,125 +90,98 @@ export const RagChatPage: React.FC = () => {
|
|||||||
docId: d.doc_id,
|
docId: d.doc_id,
|
||||||
downloadUrl: d.download_url,
|
downloadUrl: d.download_url,
|
||||||
}));
|
}));
|
||||||
setRetrievals(retrievedDocs);
|
setRetrievals(docs);
|
||||||
} else if (sseData.type === 'chunk' && sseData.text) {
|
// Attach sources to the assistant message once we know its id
|
||||||
currentResponse += sseData.text;
|
if (assistantMsgId !== null) {
|
||||||
setMessages((prev) => {
|
setMessages(prev => prev.map(m =>
|
||||||
const lastMsg = prev[prev.length - 1];
|
m.id === assistantMsgId ? { ...m, sources: docs } : m
|
||||||
if (lastMsg?.role === 'assistant') {
|
));
|
||||||
return [...prev.slice(0, -1), { ...lastMsg, content: currentResponse }];
|
}
|
||||||
|
} else if (data.type === 'chunk' && data.text) {
|
||||||
|
currentResponse += data.text;
|
||||||
|
setMessages(prev => {
|
||||||
|
const last = prev[prev.length - 1];
|
||||||
|
if (last?.role === 'assistant' && last.id === assistantMsgId) {
|
||||||
|
return [...prev.slice(0, -1), { ...last, content: currentResponse }];
|
||||||
}
|
}
|
||||||
return [...prev, { id: nextMessageId(), role: 'assistant' as const, content: currentResponse }];
|
// First chunk: create assistant message
|
||||||
|
const newId = nextMessageId();
|
||||||
|
assistantMsgId = newId;
|
||||||
|
return [...prev, { id: newId, role: 'assistant' as const, content: currentResponse }];
|
||||||
});
|
});
|
||||||
} else if (sseData.type === 'done') {
|
} else if (data.type === 'done') {
|
||||||
|
if (data.session_id) setSessionId(data.session_id);
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
} else if (sseData.type === 'error') {
|
} else if (data.type === 'error') {
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
|
setMessages(prev => [
|
||||||
|
...prev,
|
||||||
|
{ id: nextMessageId(), role: 'assistant' as const, content: '抱歉,生成回答时出错,请稍后再试。' },
|
||||||
|
]);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
(error: Error) => {
|
(error) => {
|
||||||
console.error('RAG chat error:', error);
|
console.error('RAG chat error:', error);
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
setMessages((prev) => [
|
setMessages(prev => [
|
||||||
...prev,
|
...prev,
|
||||||
{ id: nextMessageId(), role: 'assistant' as const, content: '抱歉,连接服务器时出错,请稍后再试。' }
|
{ id: nextMessageId(), role: 'assistant' as const, content: '抱歉,连接服务器时出错,请稍后再试。' },
|
||||||
]);
|
]);
|
||||||
},
|
},
|
||||||
() => {
|
() => { setLoading(false); },
|
||||||
setLoading(false);
|
activeFilters,
|
||||||
},
|
sessionId,
|
||||||
activeFilters
|
abortRef.current.signal,
|
||||||
);
|
);
|
||||||
};
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
|
}, [filterRegulationType, sessionId]);
|
||||||
|
|
||||||
const clearMessages = () => {
|
const sendMessage = (text: string) => {
|
||||||
setMessages([]);
|
if (!text.trim() || loading) return;
|
||||||
|
setMessages(prev => [...prev, { id: nextMessageId(), role: 'user' as const, content: text }]);
|
||||||
|
setInput('');
|
||||||
|
setLoading(true);
|
||||||
setRetrievals([]);
|
setRetrievals([]);
|
||||||
setShowClearConfirm(false);
|
setHighlightedSourceIdx(null);
|
||||||
|
executeQuery(text);
|
||||||
};
|
};
|
||||||
|
|
||||||
const regenerateLastAnswer = () => {
|
const regenerateLastAnswer = () => {
|
||||||
if (messages.length < 2) return;
|
if (loading) return;
|
||||||
const lastUserMsg = messages.filter((m) => m.role === 'user').pop();
|
const lastUserMsg = [...messages].reverse().find(m => m.role === 'user');
|
||||||
if (!lastUserMsg) return;
|
if (!lastUserMsg) return;
|
||||||
|
// Remove the last assistant message
|
||||||
|
setMessages(prev => {
|
||||||
|
const lastAssistantIdx = [...prev].reverse().findIndex(m => m.role === 'assistant');
|
||||||
|
if (lastAssistantIdx === -1) return prev;
|
||||||
|
const idx = prev.length - 1 - lastAssistantIdx;
|
||||||
|
return [...prev.slice(0, idx)];
|
||||||
|
});
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
setMessages((prev) => [...prev.slice(0, -1)]);
|
|
||||||
setRetrievals([]);
|
setRetrievals([]);
|
||||||
setHighlightedSourceIdx(null);
|
setHighlightedSourceIdx(null);
|
||||||
|
executeQuery(lastUserMsg.content);
|
||||||
|
};
|
||||||
|
|
||||||
let currentResponse = '';
|
const clearMessages = () => {
|
||||||
const activeFilters = filterRegulationType.trim() || undefined;
|
abortRef.current?.abort();
|
||||||
|
setMessages([]);
|
||||||
void ragChat(
|
setRetrievals([]);
|
||||||
lastUserMsg.content,
|
setSessionId(undefined);
|
||||||
5,
|
setShowClearConfirm(false);
|
||||||
(data: unknown) => {
|
setLoading(false);
|
||||||
const sseData = data as {
|
|
||||||
type: string;
|
|
||||||
text?: string;
|
|
||||||
docs?: Array<{
|
|
||||||
id: string;
|
|
||||||
score: number;
|
|
||||||
preview: string;
|
|
||||||
doc_name: string;
|
|
||||||
clause: string;
|
|
||||||
doc_id?: string;
|
|
||||||
download_url?: string;
|
|
||||||
}>;
|
|
||||||
};
|
|
||||||
|
|
||||||
if (sseData.type === 'retrieved' && sseData.docs) {
|
|
||||||
const retrievedDocs: RetrievalData[] = sseData.docs.map(d => ({
|
|
||||||
id: parseInt(d.id.replace('chunk-', ''), 10) || 1,
|
|
||||||
file: d.doc_name,
|
|
||||||
clause: d.clause,
|
|
||||||
score: d.score,
|
|
||||||
content: d.preview,
|
|
||||||
docId: d.doc_id,
|
|
||||||
downloadUrl: d.download_url,
|
|
||||||
}));
|
|
||||||
setRetrievals(retrievedDocs);
|
|
||||||
} else if (sseData.type === 'chunk' && sseData.text) {
|
|
||||||
currentResponse += sseData.text;
|
|
||||||
setMessages((prev) => {
|
|
||||||
const lastMsg = prev[prev.length - 1];
|
|
||||||
if (lastMsg?.role === 'assistant') {
|
|
||||||
return [...prev.slice(0, -1), { ...lastMsg, content: currentResponse }];
|
|
||||||
}
|
|
||||||
return [...prev, { id: nextMessageId(), role: 'assistant' as const, content: currentResponse }];
|
|
||||||
});
|
|
||||||
} else if (sseData.type === 'done') {
|
|
||||||
setLoading(false);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
(error: Error) => {
|
|
||||||
console.error('RAG chat error:', error);
|
|
||||||
setLoading(false);
|
|
||||||
setMessages((prev) => [
|
|
||||||
...prev,
|
|
||||||
{ id: nextMessageId(), role: 'assistant' as const, content: '抱歉,连接服务器时出错,请稍后再试。' }
|
|
||||||
]);
|
|
||||||
},
|
|
||||||
() => {
|
|
||||||
setLoading(false);
|
|
||||||
},
|
|
||||||
activeFilters
|
|
||||||
);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div style={{
|
<div style={{ flex: 1, display: 'flex', height: 'calc(100vh - 128px)' }}>
|
||||||
flex: 1,
|
{/* ── Left: chat panel ─────────────────────────────────── */}
|
||||||
display: 'flex',
|
|
||||||
height: 'calc(100vh - 128px)',
|
|
||||||
}}>
|
|
||||||
<div style={{
|
<div style={{
|
||||||
flex: '0 0 60%',
|
flex: '0 0 60%',
|
||||||
display: 'flex',
|
display: 'flex',
|
||||||
flexDirection: 'column',
|
flexDirection: 'column',
|
||||||
borderRight: `1px solid ${theme.border}`,
|
borderRight: `1px solid ${theme.border}`,
|
||||||
}}>
|
}}>
|
||||||
|
{/* Message list */}
|
||||||
<div style={{
|
<div style={{
|
||||||
flex: 1,
|
flex: 1,
|
||||||
overflowY: 'auto',
|
overflowY: 'auto',
|
||||||
@@ -219,20 +191,11 @@ export const RagChatPage: React.FC = () => {
|
|||||||
gap: 20,
|
gap: 20,
|
||||||
}}>
|
}}>
|
||||||
{messages.length === 0 ? (
|
{messages.length === 0 ? (
|
||||||
<div style={{
|
<div style={{ textAlign: 'center', padding: 60, color: theme.text3 }}>
|
||||||
textAlign: 'center',
|
|
||||||
padding: 60,
|
|
||||||
color: theme.text3,
|
|
||||||
}}>
|
|
||||||
<div style={{
|
<div style={{
|
||||||
width: 72,
|
width: 72, height: 72, borderRadius: 16,
|
||||||
height: 72,
|
background: theme.bgCard, display: 'flex', alignItems: 'center',
|
||||||
borderRadius: 16,
|
justifyContent: 'center', margin: '0 auto 20px',
|
||||||
background: theme.bgCard,
|
|
||||||
display: 'flex',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
margin: '0 auto 20px',
|
|
||||||
border: `1px solid ${theme.border}`,
|
border: `1px solid ${theme.border}`,
|
||||||
}}>
|
}}>
|
||||||
<svg width="28" height="28" viewBox="0 0 24 24" fill="none">
|
<svg width="28" height="28" viewBox="0 0 24 24" fill="none">
|
||||||
@@ -245,20 +208,14 @@ export const RagChatPage: React.FC = () => {
|
|||||||
) : (
|
) : (
|
||||||
messages.map(msg => (
|
messages.map(msg => (
|
||||||
<div key={msg.id} style={{
|
<div key={msg.id} style={{
|
||||||
display: 'flex',
|
display: 'flex', gap: 12,
|
||||||
gap: 12,
|
|
||||||
flexDirection: msg.role === 'user' ? 'row-reverse' : 'row',
|
flexDirection: msg.role === 'user' ? 'row-reverse' : 'row',
|
||||||
}}>
|
}}>
|
||||||
{msg.role === 'assistant' && (
|
{msg.role === 'assistant' && (
|
||||||
<div style={{
|
<div style={{
|
||||||
width: 32,
|
width: 32, height: 32, borderRadius: 8,
|
||||||
height: 32,
|
background: theme.gradientAccent, display: 'flex',
|
||||||
borderRadius: 8,
|
alignItems: 'center', justifyContent: 'center', flexShrink: 0,
|
||||||
background: theme.gradientAccent,
|
|
||||||
display: 'flex',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
flexShrink: 0,
|
|
||||||
}}>
|
}}>
|
||||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none">
|
<svg width="16" height="16" viewBox="0 0 24 24" fill="none">
|
||||||
<circle cx="12" cy="12" r="6" fill="#fff"/>
|
<circle cx="12" cy="12" r="6" fill="#fff"/>
|
||||||
@@ -271,16 +228,16 @@ export const RagChatPage: React.FC = () => {
|
|||||||
background: msg.role === 'user' ? theme.gradientAccent : theme.bgCard,
|
background: msg.role === 'user' ? theme.gradientAccent : theme.bgCard,
|
||||||
borderRadius: 12,
|
borderRadius: 12,
|
||||||
color: msg.role === 'user' ? '#fff' : theme.text,
|
color: msg.role === 'user' ? '#fff' : theme.text,
|
||||||
fontSize: 14,
|
fontSize: 14, lineHeight: 1.6, whiteSpace: 'pre-wrap',
|
||||||
lineHeight: 1.6,
|
|
||||||
whiteSpace: 'pre-wrap',
|
|
||||||
border: msg.role === 'assistant' ? `1px solid ${theme.border}` : 'none',
|
border: msg.role === 'assistant' ? `1px solid ${theme.border}` : 'none',
|
||||||
}}>
|
}}>
|
||||||
{msg.role === 'assistant' ? (
|
{msg.role === 'assistant' ? (
|
||||||
<CitedAnswer
|
<CitedAnswer
|
||||||
text={msg.content}
|
text={msg.content}
|
||||||
sources={retrievals}
|
sources={msg.sources ?? retrievals}
|
||||||
onCiteClick={(idx) => {
|
onCiteClick={(idx) => {
|
||||||
|
const msgSources = msg.sources ?? retrievals;
|
||||||
|
setRetrievals(msgSources);
|
||||||
setHighlightedSourceIdx(idx);
|
setHighlightedSourceIdx(idx);
|
||||||
const el = document.getElementById(`source-${idx}`);
|
const el = document.getElementById(`source-${idx}`);
|
||||||
if (el) el.scrollIntoView({ behavior: 'smooth', block: 'center' });
|
if (el) el.scrollIntoView({ behavior: 'smooth', block: 'center' });
|
||||||
@@ -289,12 +246,9 @@ export const RagChatPage: React.FC = () => {
|
|||||||
) : msg.content}
|
) : msg.content}
|
||||||
{msg.role === 'assistant' && msg.retrievalIds && msg.retrievalIds.length > 0 && (
|
{msg.role === 'assistant' && msg.retrievalIds && msg.retrievalIds.length > 0 && (
|
||||||
<div style={{
|
<div style={{
|
||||||
marginTop: 10,
|
marginTop: 10, paddingTop: 10,
|
||||||
paddingTop: 10,
|
|
||||||
borderTop: `1px solid ${theme.border}`,
|
borderTop: `1px solid ${theme.border}`,
|
||||||
display: 'flex',
|
display: 'flex', alignItems: 'center', gap: 6,
|
||||||
alignItems: 'center',
|
|
||||||
gap: 6,
|
|
||||||
}}>
|
}}>
|
||||||
<svg width="12" height="12" viewBox="0 0 24 24" fill="none">
|
<svg width="12" height="12" viewBox="0 0 24 24" fill="none">
|
||||||
<path d="M14 2H6C5 2 4 3 4 4V20C4 21 5 22 6 22H18C19 22 20 21 20 20V8L14 2Z" stroke={theme.accent} strokeWidth="1.5"/>
|
<path d="M14 2H6C5 2 4 3 4 4V20C4 21 5 22 6 22H18C19 22 20 21 20 20V8L14 2Z" stroke={theme.accent} strokeWidth="1.5"/>
|
||||||
@@ -311,150 +265,115 @@ export const RagChatPage: React.FC = () => {
|
|||||||
{loading && (
|
{loading && (
|
||||||
<div style={{ display: 'flex', gap: 12 }}>
|
<div style={{ display: 'flex', gap: 12 }}>
|
||||||
<div style={{
|
<div style={{
|
||||||
width: 32,
|
width: 32, height: 32, borderRadius: 8,
|
||||||
height: 32,
|
background: theme.gradientAccent, display: 'flex',
|
||||||
borderRadius: 8,
|
alignItems: 'center', justifyContent: 'center',
|
||||||
background: theme.gradientAccent,
|
|
||||||
display: 'flex',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
}}>
|
}}>
|
||||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none">
|
<svg width="16" height="16" viewBox="0 0 24 24" fill="none">
|
||||||
<circle cx="12" cy="12" r="6" fill="#fff"/>
|
<circle cx="12" cy="12" r="6" fill="#fff"/>
|
||||||
</svg>
|
</svg>
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
padding: '14px 18px',
|
padding: '14px 18px', background: theme.bgCard,
|
||||||
background: theme.bgCard,
|
borderRadius: 12, border: `1px solid ${theme.border}`,
|
||||||
borderRadius: 12,
|
display: 'flex', alignItems: 'center', gap: 8,
|
||||||
border: `1px solid ${theme.border}`,
|
|
||||||
display: 'flex',
|
|
||||||
alignItems: 'center',
|
|
||||||
gap: 8,
|
|
||||||
}}>
|
}}>
|
||||||
<div style={{
|
<div style={{
|
||||||
width: 6,
|
width: 6, height: 6, borderRadius: '50%',
|
||||||
height: 6,
|
background: theme.accent, animation: 'pulse 1s infinite',
|
||||||
borderRadius: '50%',
|
|
||||||
background: theme.accent,
|
|
||||||
animation: 'pulse 1s infinite',
|
|
||||||
}} />
|
}} />
|
||||||
<span style={{ fontSize: 13, color: theme.text2 }}>检索中...</span>
|
<span style={{ fontSize: 13, color: theme.text2 }}>检索中...</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
{/* Scroll anchor */}
|
||||||
|
<div ref={messagesEndRef} />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div style={{
|
{/* Input area */}
|
||||||
padding: '16px 32px 20px',
|
<div style={{ padding: '16px 32px 20px', background: theme.bg, borderTop: `1px solid ${theme.border}` }}>
|
||||||
background: theme.bg,
|
{/* Filter row */}
|
||||||
borderTop: `1px solid ${theme.border}`,
|
<div style={{ display: 'flex', gap: 8, marginBottom: 10, alignItems: 'center' }}>
|
||||||
}}>
|
|
||||||
<div style={{
|
|
||||||
display: 'flex',
|
|
||||||
gap: 8,
|
|
||||||
marginBottom: 10,
|
|
||||||
alignItems: 'center',
|
|
||||||
}}>
|
|
||||||
<span className="mono" style={{ fontSize: 11, color: theme.text3, whiteSpace: 'nowrap' }}>法规类型</span>
|
<span className="mono" style={{ fontSize: 11, color: theme.text3, whiteSpace: 'nowrap' }}>法规类型</span>
|
||||||
<input
|
<input
|
||||||
value={filterRegulationType}
|
value={filterRegulationType}
|
||||||
onChange={(e) => setFilterRegulationType(e.target.value)}
|
onChange={e => setFilterRegulationType(e.target.value)}
|
||||||
placeholder="如: GB / UN-ECE / IATF(留空不过滤)"
|
placeholder="如: GB / UN-ECE / IATF(留空不过滤)"
|
||||||
style={{
|
style={{
|
||||||
flex: 1,
|
flex: 1, maxWidth: 280, padding: '5px 10px', fontSize: 12,
|
||||||
maxWidth: 280,
|
background: theme.bgHover, border: `1px solid ${theme.border}`,
|
||||||
padding: '5px 10px',
|
borderRadius: 6, color: theme.text, outline: 'none',
|
||||||
fontSize: 12,
|
|
||||||
background: theme.bgHover,
|
|
||||||
border: `1px solid ${theme.border}`,
|
|
||||||
borderRadius: 6,
|
|
||||||
color: theme.text,
|
|
||||||
outline: 'none',
|
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div style={{
|
{/* Quick questions */}
|
||||||
display: 'flex',
|
<div style={{ display: 'flex', gap: 8, marginBottom: 12, flexWrap: 'wrap' }}>
|
||||||
gap: 8,
|
|
||||||
marginBottom: 12,
|
|
||||||
flexWrap: 'wrap',
|
|
||||||
}}>
|
|
||||||
{quickQuestions.map(q => (
|
{quickQuestions.map(q => (
|
||||||
<button
|
<button
|
||||||
key={q}
|
key={q}
|
||||||
onClick={() => sendMessage(q)}
|
onClick={() => sendMessage(q)}
|
||||||
|
disabled={loading}
|
||||||
style={{
|
style={{
|
||||||
padding: '6px 14px',
|
padding: '6px 14px', fontSize: 12, background: theme.bgCard,
|
||||||
fontSize: 12,
|
border: `1px solid ${theme.border}`, borderRadius: 6,
|
||||||
background: theme.bgCard,
|
color: theme.text2, cursor: loading ? 'not-allowed' : 'pointer',
|
||||||
border: `1px solid ${theme.border}`,
|
|
||||||
borderRadius: 6,
|
|
||||||
color: theme.text2,
|
|
||||||
cursor: 'pointer',
|
|
||||||
}}
|
}}
|
||||||
>{q}</button>
|
>{q}</button>
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* Send row */}
|
||||||
<div style={{ display: 'flex', gap: 10 }}>
|
<div style={{ display: 'flex', gap: 10 }}>
|
||||||
<input
|
<input
|
||||||
value={input}
|
value={input}
|
||||||
onChange={(e) => setInput(e.target.value)}
|
onChange={e => setInput(e.target.value)}
|
||||||
onKeyDown={(e) => e.key === 'Enter' && sendMessage(input)}
|
onKeyDown={e => e.key === 'Enter' && !e.shiftKey && sendMessage(input)}
|
||||||
placeholder="输入法规问题..."
|
placeholder="输入法规问题..."
|
||||||
style={{
|
style={{
|
||||||
flex: 1,
|
flex: 1, padding: 12, fontSize: 14,
|
||||||
padding: 12,
|
background: theme.bgCard, border: `1px solid ${theme.border}`,
|
||||||
fontSize: 14,
|
borderRadius: 8, color: theme.text, outline: 'none',
|
||||||
background: theme.bgCard,
|
|
||||||
border: `1px solid ${theme.border}`,
|
|
||||||
borderRadius: 8,
|
|
||||||
color: theme.text,
|
|
||||||
outline: 'none',
|
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
<button
|
<button
|
||||||
onClick={() => sendMessage(input)}
|
onClick={() => sendMessage(input)}
|
||||||
disabled={loading || !input.trim()}
|
disabled={loading || !input.trim()}
|
||||||
style={{
|
style={{
|
||||||
padding: '12px 24px',
|
padding: '12px 24px', fontSize: 14, fontWeight: 600,
|
||||||
fontSize: 14,
|
|
||||||
fontWeight: 600,
|
|
||||||
background: loading || !input.trim() ? theme.bgHover : theme.gradientAccent,
|
background: loading || !input.trim() ? theme.bgHover : theme.gradientAccent,
|
||||||
color: loading || !input.trim() ? theme.text3 : '#fff',
|
color: loading || !input.trim() ? theme.text3 : '#fff',
|
||||||
border: 'none',
|
border: 'none', borderRadius: 8,
|
||||||
borderRadius: 8,
|
|
||||||
cursor: loading || !input.trim() ? 'not-allowed' : 'pointer',
|
cursor: loading || !input.trim() ? 'not-allowed' : 'pointer',
|
||||||
}}
|
}}
|
||||||
>发送</button>
|
>发送</button>
|
||||||
{messages.length > 0 && (
|
{loading && (
|
||||||
|
<button
|
||||||
|
onClick={() => { abortRef.current?.abort(); setLoading(false); }}
|
||||||
|
style={{
|
||||||
|
padding: '12px 16px', fontSize: 13, background: theme.bgCard,
|
||||||
|
border: `1px solid ${theme.border}`, borderRadius: 8,
|
||||||
|
color: theme.text2, cursor: 'pointer',
|
||||||
|
}}
|
||||||
|
>停止</button>
|
||||||
|
)}
|
||||||
|
{!loading && messages.length > 0 && (
|
||||||
<button
|
<button
|
||||||
onClick={() => setShowClearConfirm(true)}
|
onClick={() => setShowClearConfirm(true)}
|
||||||
style={{
|
style={{
|
||||||
padding: '12px 16px',
|
padding: '12px 16px', fontSize: 13, background: theme.bgCard,
|
||||||
fontSize: 13,
|
border: `1px solid ${theme.border}`, borderRadius: 8,
|
||||||
background: theme.bgCard,
|
color: theme.text2, cursor: 'pointer',
|
||||||
border: `1px solid ${theme.border}`,
|
|
||||||
borderRadius: 8,
|
|
||||||
color: theme.text2,
|
|
||||||
cursor: 'pointer',
|
|
||||||
}}
|
}}
|
||||||
>清空</button>
|
>清空</button>
|
||||||
)}
|
)}
|
||||||
{messages.filter(m => m.role === 'assistant').length > 0 && (
|
{!loading && messages.filter(m => m.role === 'assistant').length > 0 && (
|
||||||
<button
|
<button
|
||||||
onClick={regenerateLastAnswer}
|
onClick={regenerateLastAnswer}
|
||||||
disabled={loading}
|
|
||||||
style={{
|
style={{
|
||||||
padding: '12px 16px',
|
padding: '12px 16px', fontSize: 13, background: theme.bgCard,
|
||||||
fontSize: 13,
|
border: `1px solid ${theme.border}`, borderRadius: 8,
|
||||||
background: theme.bgCard,
|
color: theme.text2, cursor: 'pointer',
|
||||||
border: `1px solid ${theme.border}`,
|
|
||||||
borderRadius: 8,
|
|
||||||
color: theme.text2,
|
|
||||||
cursor: loading ? 'not-allowed' : 'pointer',
|
|
||||||
}}
|
}}
|
||||||
>重生成</button>
|
>重生成</button>
|
||||||
)}
|
)}
|
||||||
@@ -462,27 +381,16 @@ export const RagChatPage: React.FC = () => {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div style={{
|
{/* ── Right: retrieved sources panel ───────────────────── */}
|
||||||
flex: '0 0 40%',
|
<div style={{ flex: '0 0 40%', display: 'flex', flexDirection: 'column', background: theme.bgCard }}>
|
||||||
display: 'flex',
|
|
||||||
flexDirection: 'column',
|
|
||||||
background: theme.bgCard,
|
|
||||||
}}>
|
|
||||||
<div style={{
|
<div style={{
|
||||||
padding: '20px 24px',
|
padding: '20px 24px', borderBottom: `1px solid ${theme.border}`,
|
||||||
borderBottom: `1px solid ${theme.border}`,
|
display: 'flex', alignItems: 'center', gap: 10,
|
||||||
display: 'flex',
|
|
||||||
alignItems: 'center',
|
|
||||||
gap: 10,
|
|
||||||
}}>
|
}}>
|
||||||
<div style={{
|
<div style={{
|
||||||
width: 28,
|
width: 28, height: 28, borderRadius: 6,
|
||||||
height: 28,
|
background: theme.gradientAccent, display: 'flex',
|
||||||
borderRadius: 6,
|
alignItems: 'center', justifyContent: 'center',
|
||||||
background: theme.gradientAccent,
|
|
||||||
display: 'flex',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
}}>
|
}}>
|
||||||
<svg width="14" height="14" viewBox="0 0 24 24" fill="none">
|
<svg width="14" height="14" viewBox="0 0 24 24" fill="none">
|
||||||
<path d="M14 2H6C5 2 4 3 4 4V20C4 21 5 22 6 22H18C19 22 20 21 20 20V8L14 2Z" stroke="#fff" strokeWidth="1.5"/>
|
<path d="M14 2H6C5 2 4 3 4 4V20C4 21 5 22 6 22H18C19 22 20 21 20 20V8L14 2Z" stroke="#fff" strokeWidth="1.5"/>
|
||||||
@@ -493,20 +401,13 @@ export const RagChatPage: React.FC = () => {
|
|||||||
</span>
|
</span>
|
||||||
{retrievals.length > 0 && (
|
{retrievals.length > 0 && (
|
||||||
<span className="mono" style={{
|
<span className="mono" style={{
|
||||||
fontSize: 11,
|
fontSize: 11, padding: '4px 10px',
|
||||||
padding: '4px 10px',
|
background: theme.bgHover, borderRadius: 4, color: theme.text3,
|
||||||
background: theme.bgHover,
|
|
||||||
borderRadius: 4,
|
|
||||||
color: theme.text3,
|
|
||||||
}}>{retrievals.length}</span>
|
}}>{retrievals.length}</span>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div style={{
|
<div style={{ flex: 1, overflowY: 'auto', padding: '16px 24px' }}>
|
||||||
flex: 1,
|
|
||||||
overflowY: 'auto',
|
|
||||||
padding: '16px 24px',
|
|
||||||
}}>
|
|
||||||
{retrievals.length > 0 ? (
|
{retrievals.length > 0 ? (
|
||||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}>
|
<div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}>
|
||||||
{retrievals.map((r, i) => (
|
{retrievals.map((r, i) => (
|
||||||
@@ -515,30 +416,20 @@ export const RagChatPage: React.FC = () => {
|
|||||||
id={`source-${i + 1}`}
|
id={`source-${i + 1}`}
|
||||||
onClick={() => setSelectedRetrieval(r)}
|
onClick={() => setSelectedRetrieval(r)}
|
||||||
style={{
|
style={{
|
||||||
padding: 16,
|
padding: 16, background: highlightedSourceIdx === i + 1 ? theme.bgElevated : theme.bgHover,
|
||||||
background: highlightedSourceIdx === i + 1 ? theme.bgElevated : theme.bgHover,
|
borderRadius: 10, border: `1px solid ${highlightedSourceIdx === i + 1 ? theme.accent : theme.border}`,
|
||||||
borderRadius: 10,
|
cursor: 'pointer', position: 'relative',
|
||||||
border: `1px solid ${highlightedSourceIdx === i + 1 ? theme.accent : theme.border}`,
|
|
||||||
cursor: 'pointer',
|
|
||||||
position: 'relative',
|
|
||||||
transition: 'border-color 0.2s, background 0.2s',
|
transition: 'border-color 0.2s, background 0.2s',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<div style={{
|
<div style={{
|
||||||
position: 'absolute',
|
position: 'absolute', left: 0, top: 16, bottom: 16,
|
||||||
left: 0,
|
width: 3, background: theme.gradientAccent, borderRadius: 2,
|
||||||
top: 16,
|
|
||||||
bottom: 16,
|
|
||||||
width: 3,
|
|
||||||
background: theme.gradientAccent,
|
|
||||||
borderRadius: 2,
|
|
||||||
}} />
|
}} />
|
||||||
<div style={{ paddingLeft: 8 }}>
|
<div style={{ paddingLeft: 8 }}>
|
||||||
<div style={{
|
<div style={{
|
||||||
display: 'flex',
|
display: 'flex', alignItems: 'center',
|
||||||
alignItems: 'center',
|
justifyContent: 'space-between', marginBottom: 8,
|
||||||
justifyContent: 'space-between',
|
|
||||||
marginBottom: 8,
|
|
||||||
}}>
|
}}>
|
||||||
<span className="mono" style={{ fontSize: 11, fontWeight: 700, color: theme.accent }}>#{i + 1}</span>
|
<span className="mono" style={{ fontSize: 11, fontWeight: 700, color: theme.accent }}>#{i + 1}</span>
|
||||||
<div style={{ display: 'flex', alignItems: 'center', gap: 10 }}>
|
<div style={{ display: 'flex', alignItems: 'center', gap: 10 }}>
|
||||||
@@ -547,17 +438,13 @@ export const RagChatPage: React.FC = () => {
|
|||||||
href={r.downloadUrl}
|
href={r.downloadUrl}
|
||||||
target="_blank"
|
target="_blank"
|
||||||
rel="noreferrer"
|
rel="noreferrer"
|
||||||
onClick={(event) => event.stopPropagation()}
|
onClick={e => e.stopPropagation()}
|
||||||
style={{ fontSize: 11, color: theme.accent, textDecoration: 'none' }}
|
style={{ fontSize: 11, color: theme.accent, textDecoration: 'none' }}
|
||||||
>
|
>下载文档</a>
|
||||||
下载文档
|
|
||||||
</a>
|
|
||||||
)}
|
)}
|
||||||
<span className="mono" style={{
|
<span className="mono" style={{ fontSize: 11, fontWeight: 600, color: theme.accent }}>
|
||||||
fontSize: 11,
|
{(r.score * 100).toFixed(0)}%
|
||||||
fontWeight: 600,
|
</span>
|
||||||
color: theme.accent,
|
|
||||||
}}>{(r.score * 100).toFixed(0)}%</span>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div style={{ fontSize: 13, fontWeight: 500, marginBottom: 4, color: theme.text }}>{r.file}</div>
|
<div style={{ fontSize: 13, fontWeight: 500, marginBottom: 4, color: theme.text }}>{r.file}</div>
|
||||||
@@ -565,32 +452,20 @@ export const RagChatPage: React.FC = () => {
|
|||||||
{r.clause}{r.docId ? ` · ${r.docId}` : ''}
|
{r.clause}{r.docId ? ` · ${r.docId}` : ''}
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
fontSize: 12,
|
fontSize: 12, color: theme.text2, lineHeight: 1.5,
|
||||||
color: theme.text2,
|
display: '-webkit-box', WebkitLineClamp: 3,
|
||||||
lineHeight: 1.5,
|
WebkitBoxOrient: 'vertical', overflow: 'hidden',
|
||||||
overflow: 'hidden',
|
|
||||||
textOverflow: 'ellipsis',
|
|
||||||
whiteSpace: 'nowrap',
|
|
||||||
}}>{r.content}</div>
|
}}>{r.content}</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<div style={{
|
<div style={{ textAlign: 'center', padding: 40, color: theme.text3 }}>
|
||||||
textAlign: 'center',
|
|
||||||
padding: 40,
|
|
||||||
color: theme.text3,
|
|
||||||
}}>
|
|
||||||
<div style={{
|
<div style={{
|
||||||
width: 48,
|
width: 48, height: 48, borderRadius: 10,
|
||||||
height: 48,
|
background: theme.bgHover, display: 'flex', alignItems: 'center',
|
||||||
borderRadius: 10,
|
justifyContent: 'center', margin: '0 auto 16px',
|
||||||
background: theme.bgHover,
|
|
||||||
display: 'flex',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
margin: '0 auto 16px',
|
|
||||||
}}>
|
}}>
|
||||||
<svg width="20" height="20" viewBox="0 0 24 24" fill="none">
|
<svg width="20" height="20" viewBox="0 0 24 24" fill="none">
|
||||||
<path d="M14 2H6C5 2 4 3 4 4V20C4 21 5 22 6 22H18C19 22 20 21 20 20V8L14 2Z" stroke={theme.text3} strokeWidth="1.5"/>
|
<path d="M14 2H6C5 2 4 3 4 4V20C4 21 5 22 6 22H18C19 22 20 21 20 20V8L14 2Z" stroke={theme.text3} strokeWidth="1.5"/>
|
||||||
@@ -602,52 +477,33 @@ export const RagChatPage: React.FC = () => {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* ── Clear confirm modal ───────────────────────────────── */}
|
||||||
{showClearConfirm && (
|
{showClearConfirm && (
|
||||||
<div style={{
|
<div style={{
|
||||||
position: 'fixed',
|
position: 'fixed', top: 0, left: 0, right: 0, bottom: 0,
|
||||||
top: 0,
|
background: 'rgba(0,0,0,0.6)', display: 'flex',
|
||||||
left: 0,
|
alignItems: 'center', justifyContent: 'center', zIndex: 1000,
|
||||||
right: 0,
|
|
||||||
bottom: 0,
|
|
||||||
background: 'rgba(0,0,0,0.6)',
|
|
||||||
display: 'flex',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
zIndex: 1000,
|
|
||||||
}}>
|
}}>
|
||||||
<div style={{
|
<div style={{
|
||||||
padding: 24,
|
padding: 24, background: theme.bgCard, borderRadius: 16,
|
||||||
background: theme.bgCard,
|
maxWidth: 400, border: `1px solid ${theme.border}`,
|
||||||
borderRadius: 16,
|
|
||||||
maxWidth: 400,
|
|
||||||
border: `1px solid ${theme.border}`,
|
|
||||||
}}>
|
}}>
|
||||||
<div style={{ fontSize: 15, fontWeight: 600, marginBottom: 12, color: theme.text }}>确定清空对话?</div>
|
<div style={{ fontSize: 15, fontWeight: 600, marginBottom: 12, color: theme.text }}>确定清空对话?</div>
|
||||||
<div style={{ fontSize: 13, color: theme.text2, marginBottom: 20 }}>此操作不可恢复</div>
|
<div style={{ fontSize: 13, color: theme.text2, marginBottom: 20 }}>此操作不可恢复,会话历史将被重置</div>
|
||||||
<div style={{ display: 'flex', gap: 10, justifyContent: 'flex-end' }}>
|
<div style={{ display: 'flex', gap: 10, justifyContent: 'flex-end' }}>
|
||||||
<button
|
<button
|
||||||
onClick={() => setShowClearConfirm(false)}
|
onClick={() => setShowClearConfirm(false)}
|
||||||
style={{
|
style={{
|
||||||
padding: '10px 18px',
|
padding: '10px 18px', fontSize: 13, background: theme.bgHover,
|
||||||
fontSize: 13,
|
border: 'none', borderRadius: 8, color: theme.text2, cursor: 'pointer',
|
||||||
background: theme.bgHover,
|
|
||||||
border: 'none',
|
|
||||||
borderRadius: 8,
|
|
||||||
color: theme.text2,
|
|
||||||
cursor: 'pointer',
|
|
||||||
}}
|
}}
|
||||||
>取消</button>
|
>取消</button>
|
||||||
<button
|
<button
|
||||||
onClick={clearMessages}
|
onClick={clearMessages}
|
||||||
style={{
|
style={{
|
||||||
padding: '10px 18px',
|
padding: '10px 18px', fontSize: 13, fontWeight: 600,
|
||||||
fontSize: 13,
|
background: theme.accent, border: 'none', borderRadius: 8,
|
||||||
fontWeight: 600,
|
color: '#fff', cursor: 'pointer',
|
||||||
background: theme.accent,
|
|
||||||
border: 'none',
|
|
||||||
borderRadius: 8,
|
|
||||||
color: '#fff',
|
|
||||||
cursor: 'pointer',
|
|
||||||
}}
|
}}
|
||||||
>确认</button>
|
>确认</button>
|
||||||
</div>
|
</div>
|
||||||
@@ -655,47 +511,31 @@ export const RagChatPage: React.FC = () => {
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{/* ── Source detail modal ───────────────────────────────── */}
|
||||||
{selectedRetrieval && (
|
{selectedRetrieval && (
|
||||||
<div
|
<div
|
||||||
onClick={() => setSelectedRetrieval(null)}
|
onClick={() => setSelectedRetrieval(null)}
|
||||||
style={{
|
style={{
|
||||||
position: 'fixed',
|
position: 'fixed', top: 0, left: 0, right: 0, bottom: 0,
|
||||||
top: 0,
|
background: 'rgba(0,0,0,0.6)', display: 'flex',
|
||||||
left: 0,
|
alignItems: 'center', justifyContent: 'center', zIndex: 1000,
|
||||||
right: 0,
|
|
||||||
bottom: 0,
|
|
||||||
background: 'rgba(0,0,0,0.6)',
|
|
||||||
display: 'flex',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
zIndex: 1000,
|
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<div
|
<div
|
||||||
onClick={(e) => e.stopPropagation()}
|
onClick={e => e.stopPropagation()}
|
||||||
style={{
|
style={{
|
||||||
width: 520,
|
width: 520, maxWidth: '90%', maxHeight: '80%',
|
||||||
maxWidth: '90%',
|
overflowY: 'auto', padding: 24, background: theme.bgCard,
|
||||||
maxHeight: '80%',
|
borderRadius: 16, border: `1px solid ${theme.accent}`,
|
||||||
padding: 24,
|
|
||||||
background: theme.bgCard,
|
|
||||||
borderRadius: 16,
|
|
||||||
border: `1px solid ${theme.accent}`,
|
|
||||||
boxShadow: '0 8px 32px rgba(0,0,0,0.3)',
|
boxShadow: '0 8px 32px rgba(0,0,0,0.3)',
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<div style={{
|
<div style={{
|
||||||
display: 'flex',
|
display: 'flex', alignItems: 'center',
|
||||||
alignItems: 'center',
|
justifyContent: 'space-between', marginBottom: 16,
|
||||||
justifyContent: 'space-between',
|
|
||||||
marginBottom: 16,
|
|
||||||
}}>
|
}}>
|
||||||
<div style={{ display: 'flex', alignItems: 'center', gap: 10, flexWrap: 'wrap' }}>
|
<div style={{ display: 'flex', alignItems: 'center', gap: 10, flexWrap: 'wrap' }}>
|
||||||
<div style={{
|
<div style={{ padding: '4px 10px', background: theme.gradientAccent, borderRadius: 6 }}>
|
||||||
padding: '4px 10px',
|
|
||||||
background: theme.gradientAccent,
|
|
||||||
borderRadius: 6,
|
|
||||||
}}>
|
|
||||||
<span className="mono" style={{ fontSize: 11, fontWeight: 600, color: '#fff' }}>
|
<span className="mono" style={{ fontSize: 11, fontWeight: 600, color: '#fff' }}>
|
||||||
{(selectedRetrieval.score * 100).toFixed(0)}%
|
{(selectedRetrieval.score * 100).toFixed(0)}%
|
||||||
</span>
|
</span>
|
||||||
@@ -707,23 +547,15 @@ export const RagChatPage: React.FC = () => {
|
|||||||
target="_blank"
|
target="_blank"
|
||||||
rel="noreferrer"
|
rel="noreferrer"
|
||||||
style={{ fontSize: 12, color: theme.accent, textDecoration: 'none' }}
|
style={{ fontSize: 12, color: theme.accent, textDecoration: 'none' }}
|
||||||
>
|
>下载关联文档</a>
|
||||||
下载关联文档
|
|
||||||
</a>
|
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
<button
|
<button
|
||||||
onClick={() => setSelectedRetrieval(null)}
|
onClick={() => setSelectedRetrieval(null)}
|
||||||
style={{
|
style={{
|
||||||
width: 28,
|
width: 28, height: 28, background: theme.bgHover,
|
||||||
height: 28,
|
border: 'none', borderRadius: 6, cursor: 'pointer',
|
||||||
background: theme.bgHover,
|
display: 'flex', alignItems: 'center', justifyContent: 'center',
|
||||||
border: 'none',
|
|
||||||
borderRadius: 6,
|
|
||||||
cursor: 'pointer',
|
|
||||||
display: 'flex',
|
|
||||||
alignItems: 'center',
|
|
||||||
justifyContent: 'center',
|
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<svg width="14" height="14" viewBox="0 0 24 24" fill="none">
|
<svg width="14" height="14" viewBox="0 0 24 24" fill="none">
|
||||||
@@ -731,23 +563,15 @@ export const RagChatPage: React.FC = () => {
|
|||||||
</svg>
|
</svg>
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{ padding: '10px 14px', background: theme.bgHover, borderRadius: 8, marginBottom: 16 }}>
|
||||||
padding: '10px 14px',
|
|
||||||
background: theme.bgHover,
|
|
||||||
borderRadius: 8,
|
|
||||||
marginBottom: 16,
|
|
||||||
}}>
|
|
||||||
<span className="mono" style={{ fontSize: 12, color: theme.accent }}>{selectedRetrieval.clause}</span>
|
<span className="mono" style={{ fontSize: 12, color: theme.accent }}>{selectedRetrieval.clause}</span>
|
||||||
{selectedRetrieval.docId && (
|
{selectedRetrieval.docId && (
|
||||||
<span className="mono" style={{ fontSize: 11, color: theme.text3, marginLeft: 8 }}>{selectedRetrieval.docId}</span>
|
<span className="mono" style={{ fontSize: 11, color: theme.text3, marginLeft: 8 }}>{selectedRetrieval.docId}</span>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{ fontSize: 14, lineHeight: 1.7, color: theme.text2, whiteSpace: 'pre-wrap' }}>
|
||||||
fontSize: 14,
|
{selectedRetrieval.content}
|
||||||
lineHeight: 1.7,
|
</div>
|
||||||
color: theme.text2,
|
|
||||||
whiteSpace: 'pre-wrap',
|
|
||||||
}}>{selectedRetrieval.content}</div>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ export interface ChatMessage {
|
|||||||
role: 'user' | 'assistant';
|
role: 'user' | 'assistant';
|
||||||
content: string;
|
content: string;
|
||||||
retrievalIds?: number[];
|
retrievalIds?: number[];
|
||||||
|
sources?: RetrievalData[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface RetrievalData {
|
export interface RetrievalData {
|
||||||
|
|||||||
Reference in New Issue
Block a user