fix 文档管理模块 & 法规对话模块
This commit is contained in:
15
.env
15
.env
@@ -7,26 +7,29 @@ APP_VERSION=0.1.0
|
||||
DEBUG=false
|
||||
|
||||
# ===== Milvus向量数据库配置(已有)=====
|
||||
MILVUS_HOST=localhost
|
||||
MILVUS_HOST=6.86.80.8
|
||||
MILVUS_PORT=19530
|
||||
MILVUS_COLLECTION=regulations_dense_1024_v1
|
||||
MILVUS_DB_NAME=default
|
||||
MILVUS_INDEX_TYPE=IVF_FLAT
|
||||
MILVUS_NLIST=128
|
||||
MILVUS_NPROBE=16
|
||||
|
||||
# ===== MinIO对象存储配置(已有)=====
|
||||
MINIO_ENDPOINT=localhost:9000
|
||||
MINIO_ENDPOINT=6.86.80.8:9000
|
||||
MINIO_ACCESS_KEY=minioadmin
|
||||
MINIO_SECRET_KEY=minioadmin
|
||||
MINIO_BUCKET=compliance-docs
|
||||
MINIO_SECURE=false
|
||||
|
||||
# ===== Redis配置(已有)=====
|
||||
REDIS_HOST=localhost
|
||||
REDIS_HOST=6.86.80.8
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=redis@123
|
||||
REDIS_DB=0
|
||||
|
||||
# ===== PostgreSQL配置(已有)=====
|
||||
POSTGRES_HOST=localhost
|
||||
POSTGRES_HOST=6.86.80.8
|
||||
POSTGRES_PORT=5432
|
||||
POSTGRES_USER=postgresql
|
||||
POSTGRES_PASSWORD=postgresql123456
|
||||
@@ -43,6 +46,10 @@ EMBEDDING_TIMEOUT_SECONDS=120
|
||||
CHUNK_SIZE=512
|
||||
CHUNK_OVERLAP=50
|
||||
MAX_FILE_SIZE_MB=100
|
||||
PARSER_BACKEND=aliyun
|
||||
CHUNK_BACKEND=aliyun
|
||||
# 文档元数据存储后端:json(默认)或 postgres
|
||||
DOCUMENT_REPOSITORY_BACKEND=json
|
||||
|
||||
# ===== API配置 =====
|
||||
API_HOST=0.0.0.0
|
||||
|
||||
@@ -6,6 +6,9 @@ MILVUS_HOST=6.86.80.8
|
||||
MILVUS_PORT=19530
|
||||
MILVUS_COLLECTION=regulations_dense_1024_v1
|
||||
MILVUS_DB_NAME=default
|
||||
MILVUS_INDEX_TYPE=IVF_FLAT
|
||||
MILVUS_NLIST=128
|
||||
MILVUS_NPROBE=16
|
||||
|
||||
# ===== MinIO对象存储配置(已有)=====
|
||||
MINIO_ENDPOINT=6.86.80.8:9000
|
||||
@@ -26,3 +29,7 @@ POSTGRES_PORT=5432
|
||||
POSTGRES_USER=postgresql
|
||||
POSTGRES_PASSWORD=postgresql123456
|
||||
POSTGRES_DB=compliance_db
|
||||
|
||||
# ===== 文档元数据后端 =====
|
||||
# 改为 postgres 以启用 PG 持久化(structure_nodes + semantic_blocks 入库)
|
||||
DOCUMENT_REPOSITORY_BACKEND=json
|
||||
|
||||
22
.env.example
22
.env.example
@@ -7,10 +7,13 @@ APP_VERSION=0.1.0
|
||||
DEBUG=false
|
||||
|
||||
# ===== Milvus向量数据库配置 =====
|
||||
MILVUS_HOST=localhost
|
||||
MILVUS_HOST=6.86.80.8
|
||||
MILVUS_PORT=19530
|
||||
MILVUS_COLLECTION=regulations_dense_1024_v1
|
||||
MILVUS_DB_NAME=default
|
||||
MILVUS_INDEX_TYPE=IVF_FLAT
|
||||
MILVUS_NLIST=128
|
||||
MILVUS_NPROBE=16
|
||||
|
||||
# ===== 嵌入模型配置 =====
|
||||
EMBEDDING_MODEL=text-embedding-v3
|
||||
@@ -20,20 +23,20 @@ EMBEDDING_BASE_URL=http://6.86.80.4:30080/v1
|
||||
EMBEDDING_TIMEOUT_SECONDS=120
|
||||
|
||||
# ===== MinIO对象存储配置 =====
|
||||
MINIO_ENDPOINT=localhost:9000
|
||||
MINIO_ENDPOINT=6.86.80.8:9000
|
||||
MINIO_ACCESS_KEY=minioadmin
|
||||
MINIO_SECRET_KEY=minioadmin123
|
||||
MINIO_BUCKET=compliance-docs
|
||||
MINIO_SECURE=false
|
||||
|
||||
# ===== Redis配置 =====
|
||||
REDIS_HOST=localhost
|
||||
REDIS_HOST=6.86.80.8
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=
|
||||
REDIS_DB=0
|
||||
|
||||
# ===== PostgreSQL配置 =====
|
||||
POSTGRES_HOST=localhost
|
||||
POSTGRES_HOST=6.86.80.8
|
||||
POSTGRES_PORT=5432
|
||||
POSTGRES_USER=compliance
|
||||
POSTGRES_PASSWORD=compliance123
|
||||
@@ -46,6 +49,8 @@ MAX_FILE_SIZE_MB=100
|
||||
DOCUMENT_METADATA_PATH=backend/data/documents.json
|
||||
PARSER_BACKEND=aliyun
|
||||
CHUNK_BACKEND=aliyun
|
||||
# 文档元数据存储后端:json(默认,无需数据库)或 postgres(启用 PG 持久化)
|
||||
DOCUMENT_REPOSITORY_BACKEND=json
|
||||
|
||||
# ===== 阿里云文档解析 =====
|
||||
ALIBABA_ACCESS_KEY_ID=your_aliyun_access_key_id
|
||||
@@ -88,9 +93,18 @@ DEEPSEEK_MODEL=deepseek-v4-flash
|
||||
|
||||
# ===== RAG配置 =====
|
||||
RAG_TOP_K=10
|
||||
RAG_RETRIEVAL_TOP_K=20
|
||||
RAG_MAX_CONTEXT_TOKENS=4000
|
||||
RAG_SUMMARY_MAX_TOKENS=1024
|
||||
|
||||
# ===== Reranker配置(Cross-Encoder精排,默认关闭)=====
|
||||
# 设置 RERANKER_ENABLED=true 并配置 RERANKER_BASE_URL 以启用精排
|
||||
RERANKER_ENABLED=false
|
||||
RERANKER_BASE_URL=
|
||||
RERANKER_MODEL=BAAI/bge-reranker-v2-m3
|
||||
RERANKER_API_KEY=
|
||||
RERANKER_TOP_K=5
|
||||
|
||||
# ===== 会话配置 =====
|
||||
SESSION_MAX_SESSIONS=100
|
||||
SESSION_TIMEOUT_MINUTES=30
|
||||
|
||||
@@ -23,6 +23,8 @@ class DocumentUploadResponse(BaseModel):
|
||||
num_chunks: int = Field(default=0, description="分块数量")
|
||||
summary: str = Field(default="", description="LLM生成的文档摘要")
|
||||
summary_latency_ms: int = Field(default=0, description="摘要生成耗时(ms)")
|
||||
regulation_type: str = Field(default="", description="法规类型")
|
||||
version: str = Field(default="", description="文档版本号")
|
||||
timestamp: datetime = Field(default_factory=datetime.now, description="时间戳")
|
||||
|
||||
|
||||
|
||||
@@ -14,30 +14,22 @@ from app.schemas.compliance import (
|
||||
AnalyzeResponse,
|
||||
ComplianceChatRequest,
|
||||
)
|
||||
from app.services.mock_data import (
|
||||
generate_task_id,
|
||||
get_mock_compliance_result,
|
||||
get_mock_compliance_chat_response,
|
||||
)
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
from app.services.mock_data import generate_task_id, get_mock_compliance_result
|
||||
from app.shared.bootstrap import get_agent_conversation_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/compliance", tags=["合规分析"])
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store: dict[str, dict] = {}
|
||||
|
||||
# Store uploaded compliance files inside the local backend data directory.
|
||||
RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw"
|
||||
|
||||
|
||||
@router.post("/analyze", response_model=AnalyzeResponse)
|
||||
async def analyze_document(file: UploadFile = File(...)):
|
||||
"""Handle analyze document."""
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
task_id = generate_task_id()
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}"
|
||||
|
||||
@@ -45,7 +37,6 @@ async def analyze_document(file: UploadFile = File(...)):
|
||||
with file_path.open("wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store[task_id] = {
|
||||
"task_id": task_id,
|
||||
"file_path": str(file_path),
|
||||
@@ -53,8 +44,6 @@ async def analyze_document(file: UploadFile = File(...)):
|
||||
"result": None,
|
||||
}
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store[task_id]["status"] = "completed"
|
||||
tasks_store[task_id]["result"] = get_mock_compliance_result(task_id)
|
||||
|
||||
@@ -65,47 +54,44 @@ async def analyze_document(file: UploadFile = File(...)):
|
||||
async def get_result(task_id: str):
|
||||
"""Return result."""
|
||||
if task_id not in tasks_store:
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
return get_mock_compliance_result(task_id)
|
||||
|
||||
task = tasks_store[task_id]
|
||||
|
||||
if task["status"] == "processing":
|
||||
return {"status": "processing", "message": "分析进行中"}
|
||||
|
||||
return task["result"]
|
||||
|
||||
|
||||
@router.post("/chat/{segment_id}")
|
||||
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
||||
"""Handle compliance chat."""
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
intent_map = {
|
||||
1: "车身结构设计",
|
||||
2: "动力系统配置",
|
||||
3: "安全配置设计",
|
||||
}
|
||||
intent = intent_map.get(segment_id, "车身结构设计")
|
||||
"""Stream compliance Q&A grounded in real vector retrieval."""
|
||||
query = request.query
|
||||
if request.segment_context:
|
||||
query = f"[段落分析上下文]\n{request.segment_context}\n\n用户问题:{request.query}"
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
"""Handle generate."""
|
||||
response = get_mock_compliance_chat_response(intent, request.query)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
sentences = response.split("\n\n")
|
||||
for sentence in sentences:
|
||||
if sentence.strip():
|
||||
chunks = sentence.split("\n")
|
||||
for chunk in chunks:
|
||||
if chunk.strip():
|
||||
await asyncio.sleep(0.05)
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
|
||||
_, event_stream = get_agent_conversation_service().stream_chat(
|
||||
query=query,
|
||||
top_k=5,
|
||||
prompt_template="compliance_qa",
|
||||
)
|
||||
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
"""Translate agent SSE events to compliance chunk/done format."""
|
||||
for event in event_stream:
|
||||
event_type = event.get("event", "")
|
||||
if event_type == "content":
|
||||
text = event.get("data", "")
|
||||
if text:
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': text}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
elif event_type == "done":
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
|
||||
@@ -80,6 +80,8 @@ async def get_document_status(doc_id: str):
|
||||
num_chunks=document.chunk_count,
|
||||
summary=document.summary,
|
||||
summary_latency_ms=document.summary_latency_ms,
|
||||
regulation_type=document.regulation_type,
|
||||
version=document.version,
|
||||
)
|
||||
|
||||
|
||||
@@ -123,7 +125,7 @@ async def list_documents():
|
||||
@router.get("/management-list")
|
||||
async def get_document_management_list():
|
||||
"""Return document management list."""
|
||||
documents = get_document_query_service().list_documents(limit=10)
|
||||
documents = get_document_query_service().list_documents()
|
||||
return {
|
||||
"documents": [
|
||||
{
|
||||
@@ -131,10 +133,37 @@ async def get_document_management_list():
|
||||
"doc_name": item.doc_name,
|
||||
"status": item.status.value,
|
||||
"chunk_count": item.chunk_count,
|
||||
"size_bytes": item.size_bytes,
|
||||
"summary": item.summary,
|
||||
"updated_at": item.updated_at.isoformat(),
|
||||
"regulation_type": item.regulation_type,
|
||||
"version": item.version,
|
||||
}
|
||||
for item in documents
|
||||
],
|
||||
"total": len(documents),
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
"""Delete a document and its associated data."""
|
||||
deleted = get_document_command_service().delete(doc_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="文档不存在")
|
||||
return {"doc_id": doc_id, "deleted": True}
|
||||
|
||||
|
||||
@router.post("/{doc_id}/retry", response_model=DocumentUploadResponse)
|
||||
async def retry_document(doc_id: str):
|
||||
"""Re-process a failed document."""
|
||||
try:
|
||||
result = get_document_command_service().retry(doc_id)
|
||||
if result.status == "failed":
|
||||
raise HTTPException(status_code=500, detail=result.message)
|
||||
return _document_response(result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("文档重试失败")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
@@ -9,73 +9,74 @@ from typing import AsyncGenerator
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
|
||||
from app.services.mock_data import (
|
||||
get_mock_quick_questions,
|
||||
get_mock_retrieval,
|
||||
get_mock_rag_answer,
|
||||
)
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
from app.shared.bootstrap import get_agent_conversation_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/rag", tags=["RAG问答"])
|
||||
|
||||
_DEFAULT_QUICK_QUESTIONS = [
|
||||
{"id": "1", "question": "请总结最新入库法规对电池安全的核心要求", "category": "法规解读"},
|
||||
{"id": "2", "question": "我上传的制度文档与新能源法规有哪些潜在冲突?", "category": "差距分析"},
|
||||
{"id": "3", "question": "请给出法规依据,并按条款列出整改建议", "category": "整改建议"},
|
||||
{"id": "4", "question": "请解释 UN-ECE 与 GB 标准在网络安全方面的差异", "category": "标准对比"},
|
||||
{"id": "5", "question": "IATF 16949 对供应商质量管理有哪些强制要求?", "category": "法规解读"},
|
||||
{"id": "6", "question": "ISO 45001 与 AQ 标准在职业健康安全方面的主要差异是什么?", "category": "标准对比"},
|
||||
]
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def rag_chat(request: RagChatRequest):
|
||||
"""Handle rag chat."""
|
||||
"""Stream RAG Q&A using the real agent service."""
|
||||
_, event_stream = get_agent_conversation_service().stream_chat(
|
||||
query=request.query,
|
||||
session_id=request.session_id,
|
||||
filters=request.filters,
|
||||
top_k=request.top_k or settings.rag_top_k,
|
||||
)
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
"""Handle generate."""
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'retrieving'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
docs = get_mock_retrieval(request.query, top_k=request.top_k)
|
||||
|
||||
retrieved_data = [
|
||||
"""Translate agent SSE events to rag format."""
|
||||
for event in event_stream:
|
||||
event_type = event.get("event", "")
|
||||
data = event.get("data", "")
|
||||
if event_type == "sources":
|
||||
docs = [
|
||||
{
|
||||
"id": d["id"],
|
||||
"score": d["score"],
|
||||
"preview": d["preview"],
|
||||
"doc_name": d.get("doc_name", ""),
|
||||
"clause": d.get("clause", ""),
|
||||
"id": str(s.get("chunk_id") or s.get("doc_id") or idx + 1),
|
||||
"score": s.get("score", 0),
|
||||
"preview": s.get("content", "")[:200],
|
||||
"doc_name": s.get("doc_name", ""),
|
||||
"clause": s.get("section_title", "法规片段"),
|
||||
"doc_id": s.get("doc_id"),
|
||||
"download_url": (
|
||||
f"/api/v1/documents/download/{s['doc_id']}" if s.get("doc_id") else None
|
||||
),
|
||||
}
|
||||
for d in docs
|
||||
for idx, s in enumerate(data if isinstance(data, list) else [])
|
||||
]
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'retrieved', 'docs': retrieved_data}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield (
|
||||
f"event: message\ndata: "
|
||||
f"{json.dumps({'type': 'generating', 'text': '正在生成答案...'}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
answer = get_mock_rag_answer(request.query)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
sentences = answer.split("\n\n")
|
||||
for sentence in sentences:
|
||||
if sentence.strip():
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
chunks = sentence.split("\n")
|
||||
for chunk in chunks:
|
||||
if chunk.strip():
|
||||
await asyncio.sleep(0.05) # Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
|
||||
f"data: {json.dumps({'type': 'retrieved', 'docs': docs}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
elif event_type == "content":
|
||||
if data:
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': data}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
elif event_type == "done":
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
elif event_type == "status":
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'status', 'text': data}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
@@ -86,9 +87,13 @@ async def rag_chat(request: RagChatRequest):
|
||||
|
||||
@router.get("/quick-questions", response_model=QuickQuestionsResponse)
|
||||
async def get_quick_questions():
|
||||
"""Return quick questions."""
|
||||
"""Return configurable quick questions from settings or defaults."""
|
||||
raw = getattr(settings, "rag_quick_questions", None)
|
||||
if raw and isinstance(raw, list):
|
||||
questions = [
|
||||
QuickQuestion(id=q["id"], question=q["question"], category=q["category"])
|
||||
for q in get_mock_quick_questions()
|
||||
QuickQuestion(id=str(i + 1), question=q if isinstance(q, str) else q.get("question", ""), category=q.get("category", "法规问答") if isinstance(q, dict) else "法规问答")
|
||||
for i, q in enumerate(raw)
|
||||
]
|
||||
else:
|
||||
questions = [QuickQuestion(**q) for q in _DEFAULT_QUICK_QUESTIONS]
|
||||
return QuickQuestionsResponse(questions=questions)
|
||||
|
||||
@@ -17,6 +17,7 @@ from app.domain.documents import (
|
||||
DocumentParser,
|
||||
DocumentRepository,
|
||||
DocumentStatus,
|
||||
ParseArtifactStore,
|
||||
ParsedDocument,
|
||||
)
|
||||
from app.domain.retrieval import EmbeddingProvider, VectorIndex
|
||||
@@ -47,6 +48,7 @@ class DocumentCommandService:
|
||||
chunk_builder: ChunkBuilder,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
vector_index: VectorIndex,
|
||||
parse_artifact_store: ParseArtifactStore | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Document Command Service instance."""
|
||||
self.document_repository = document_repository
|
||||
@@ -55,6 +57,7 @@ class DocumentCommandService:
|
||||
self.chunk_builder = chunk_builder
|
||||
self.embedding_provider = embedding_provider
|
||||
self.vector_index = vector_index
|
||||
self.parse_artifact_store = parse_artifact_store
|
||||
|
||||
def _save_parse_artifacts(self, *, doc_id: str, parsed_document: ParsedDocument) -> dict[str, str]:
|
||||
"""Persist parse artifacts so troubleshooting does not depend on provider retention windows."""
|
||||
@@ -143,6 +146,15 @@ class DocumentCommandService:
|
||||
"processing_stage": "parsed",
|
||||
},
|
||||
)
|
||||
if self.parse_artifact_store:
|
||||
try:
|
||||
self.parse_artifact_store.save(
|
||||
doc_id,
|
||||
parsed_document.structure_nodes,
|
||||
parsed_document.semantic_blocks,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id)
|
||||
|
||||
chunks = self.chunk_builder.build(
|
||||
parsed_document=parsed_document,
|
||||
@@ -205,20 +217,120 @@ class DocumentCommandService:
|
||||
logger.warning("临时文件清理失败: {}", temp_path)
|
||||
|
||||
|
||||
def delete(self, doc_id: str) -> bool:
|
||||
"""Delete document record, binary file, and vector chunks."""
|
||||
document = self.document_repository.get(doc_id)
|
||||
if not document:
|
||||
return False
|
||||
try:
|
||||
self.binary_store.delete(document.object_name)
|
||||
except Exception:
|
||||
logger.warning("Binary delete failed for doc_id={}", doc_id)
|
||||
try:
|
||||
self.vector_index.delete_by_document(doc_id)
|
||||
except Exception:
|
||||
logger.warning("Vector delete failed for doc_id={}", doc_id)
|
||||
if self.parse_artifact_store:
|
||||
try:
|
||||
self.parse_artifact_store.delete(doc_id)
|
||||
except Exception:
|
||||
logger.warning("ParseArtifactStore delete failed for doc_id={}", doc_id)
|
||||
self.document_repository.delete(doc_id)
|
||||
return True
|
||||
|
||||
def retry(self, doc_id: str) -> DocumentProcessResult:
|
||||
"""Re-process a failed document from its stored binary."""
|
||||
document = self.document_repository.get(doc_id)
|
||||
if not document:
|
||||
return DocumentProcessResult(doc_id=doc_id, doc_name="", status="failed", message="文档不存在")
|
||||
content = self.binary_store.read(document.object_name)
|
||||
return self.upload_and_process(
|
||||
doc_id=doc_id,
|
||||
file_name=document.file_name,
|
||||
content=content,
|
||||
content_type=document.content_type,
|
||||
doc_name=document.doc_name,
|
||||
regulation_type=document.regulation_type,
|
||||
version=document.version,
|
||||
generate_summary=bool(document.metadata.get("generate_summary", False)),
|
||||
)
|
||||
|
||||
|
||||
class DocumentQueryService:
|
||||
"""Provide the Document Query Service service."""
|
||||
def __init__(self, *, document_repository: DocumentRepository, binary_store: DocumentBinaryStore) -> None:
|
||||
def __init__(self, *, document_repository: DocumentRepository, binary_store: DocumentBinaryStore, vector_index: VectorIndex) -> None:
|
||||
"""Initialize the Document Query Service instance."""
|
||||
self.document_repository = document_repository
|
||||
self.binary_store = binary_store
|
||||
self.vector_index = vector_index
|
||||
|
||||
def get(self, doc_id: str) -> Document | None:
|
||||
"""Handle get for the Document Query Service instance."""
|
||||
return self.document_repository.get(doc_id)
|
||||
|
||||
def list_documents(self, limit: int | None = None) -> list[Document]:
|
||||
"""List documents for the Document Query Service instance."""
|
||||
return self.document_repository.list(limit=limit)
|
||||
"""Return documents with real-time state from Milvus as the authoritative source.
|
||||
|
||||
Algorithm:
|
||||
1. Query Milvus for all doc metadata (doc_id, doc_name, chunk_count, …).
|
||||
2. Load JSON/PG metadata records and index them by doc_id.
|
||||
3. Merge: Milvus-present docs get status=INDEXED and live chunk_count;
|
||||
metadata-only docs with status=INDEXED are demoted to FAILED.
|
||||
4. Milvus-only docs (no metadata record) are surfaced as synthetic INDEXED
|
||||
entries so they are never invisible to the management list.
|
||||
"""
|
||||
# Fetch live Milvus state first.
|
||||
try:
|
||||
milvus_rows = self.vector_index.list_document_metadata()
|
||||
except Exception:
|
||||
milvus_rows = []
|
||||
|
||||
milvus_by_id: dict[str, dict] = {r["doc_id"]: r for r in milvus_rows}
|
||||
|
||||
# Load metadata store records.
|
||||
meta_docs = self.document_repository.list(limit=limit)
|
||||
meta_by_id: dict[str, Document] = {d.doc_id: d for d in meta_docs}
|
||||
|
||||
result: list[Document] = []
|
||||
|
||||
# Reconcile metadata records against Milvus.
|
||||
for doc in meta_docs:
|
||||
if doc.doc_id in milvus_by_id:
|
||||
row = milvus_by_id[doc.doc_id]
|
||||
doc.chunk_count = row["chunk_count"]
|
||||
doc.status = DocumentStatus.INDEXED
|
||||
# Backfill fields that may be missing from older JSON records.
|
||||
if not doc.doc_name and row.get("doc_name"):
|
||||
doc.doc_name = row["doc_name"]
|
||||
if not doc.regulation_type and row.get("regulation_type"):
|
||||
doc.regulation_type = row["regulation_type"]
|
||||
if not doc.version and row.get("version"):
|
||||
doc.version = row["version"]
|
||||
elif doc.status == DocumentStatus.INDEXED:
|
||||
# Metadata says indexed but Milvus has no chunks.
|
||||
doc.status = DocumentStatus.FAILED
|
||||
doc.error_message = "向量数据库中未找到对应数据"
|
||||
result.append(doc)
|
||||
|
||||
# Surface Milvus-only docs that have no metadata record at all.
|
||||
for doc_id, row in milvus_by_id.items():
|
||||
if doc_id not in meta_by_id:
|
||||
synthetic = Document(
|
||||
doc_id=doc_id,
|
||||
doc_name=row.get("doc_name", doc_id),
|
||||
file_name=row.get("doc_name", doc_id),
|
||||
object_name="",
|
||||
content_type="",
|
||||
size_bytes=0,
|
||||
status=DocumentStatus.INDEXED,
|
||||
regulation_type=row.get("regulation_type", ""),
|
||||
version=row.get("version", ""),
|
||||
chunk_count=row["chunk_count"],
|
||||
)
|
||||
result.append(synthetic)
|
||||
|
||||
result.sort(key=lambda d: d.updated_at, reverse=True)
|
||||
return result[:limit] if limit is not None else result
|
||||
|
||||
def download(self, doc_id: str) -> tuple[Document, bytes]:
|
||||
"""Handle download for the Document Query Service instance."""
|
||||
|
||||
@@ -3,17 +3,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.domain.retrieval import RetrievalQuery, Retriever, RetrievedChunk
|
||||
from app.domain.retrieval.ports import Reranker
|
||||
# Keep orchestration logic centralized so use-case flow stays easy to trace.
|
||||
|
||||
|
||||
|
||||
class KnowledgeRetrievalService:
|
||||
"""Provide the Knowledge Retrieval Service service."""
|
||||
def __init__(self, *, retriever: Retriever) -> None:
|
||||
|
||||
def __init__(self, *, retriever: Retriever, reranker: Reranker | None = None, reranker_top_k: int = 5) -> None:
|
||||
"""Initialize the Knowledge Retrieval Service instance."""
|
||||
self.retriever = retriever
|
||||
self.reranker = reranker
|
||||
self.reranker_top_k = reranker_top_k
|
||||
|
||||
def retrieve(self, *, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||
"""Handle retrieve for the Knowledge Retrieval Service instance."""
|
||||
retrieval_query = RetrievalQuery(query=query, top_k=top_k, filters=filters)
|
||||
return self.retriever.retrieve(retrieval_query)
|
||||
"""Retrieve and optionally rerank chunks for a query."""
|
||||
candidate_k = top_k if self.reranker is None else max(top_k * 4, 20)
|
||||
retrieval_query = RetrievalQuery(query=query, top_k=candidate_k, filters=filters)
|
||||
candidates = self.retriever.retrieve(retrieval_query)
|
||||
if self.reranker and candidates:
|
||||
return self.reranker.rerank(query, candidates, top_k=self.reranker_top_k)
|
||||
return candidates[:top_k]
|
||||
|
||||
@@ -31,7 +31,7 @@ class Settings(BaseSettings):
|
||||
debug: bool = Field(default=False, description="调试模式")
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
milvus_host: str = Field(default="localhost", description="Milvus服务地址")
|
||||
milvus_host: str = Field(default="6.86.80.8", description="Milvus服务地址")
|
||||
milvus_port: int = Field(default=19530, description="Milvus服务端口")
|
||||
milvus_collection: str = Field(default="regulations_dense_1024_v1", description="法规向量集合名称")
|
||||
milvus_db_name: str = Field(default="default", description="Milvus数据库名称")
|
||||
@@ -54,20 +54,20 @@ class Settings(BaseSettings):
|
||||
parser_failure_mode: str = Field(default="fail", description="解析失败策略")
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
minio_endpoint: str = Field(default="localhost:9000", description="MinIO服务地址")
|
||||
minio_endpoint: str = Field(default="6.86.80.8:9000", description="MinIO服务地址")
|
||||
minio_access_key: str = Field(default="minioadmin", description="MinIO访问密钥")
|
||||
minio_secret_key: str = Field(default="minioadmin123", description="MinIO秘密密钥")
|
||||
minio_bucket: str = Field(default="upload-files", description="文档存储桶名称")
|
||||
minio_secure: bool = Field(default=False, description="是否使用HTTPS")
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
redis_host: str = Field(default="localhost", description="Redis服务地址")
|
||||
redis_host: str = Field(default="6.86.80.8", description="Redis服务地址")
|
||||
redis_port: int = Field(default=6379, description="Redis服务端口")
|
||||
redis_password: str = Field(default="", description="Redis密码")
|
||||
redis_db: int = Field(default=0, description="Redis数据库编号")
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
postgres_host: str = Field(default="localhost", description="PostgreSQL服务地址")
|
||||
postgres_host: str = Field(default="6.86.80.8", description="PostgreSQL服务地址")
|
||||
postgres_port: int = Field(default=5432, description="PostgreSQL服务端口")
|
||||
postgres_user: str = Field(default="compliance", description="PostgreSQL用户名")
|
||||
postgres_password: str = Field(default="compliance123", description="PostgreSQL密码")
|
||||
@@ -80,6 +80,7 @@ class Settings(BaseSettings):
|
||||
document_metadata_path: str = Field(default="backend/data/documents.json", description="文档元数据存储路径")
|
||||
parser_backend: str = Field(default="aliyun", description="解析后端(local/aliyun)")
|
||||
chunk_backend: str = Field(default="aliyun", description="分块后端(local/aliyun)")
|
||||
document_repository_backend: str = Field(default="json", description="文档元数据存储后端 (json/postgres)")
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
api_host: str = Field(default="0.0.0.0", description="API服务地址")
|
||||
@@ -104,9 +105,16 @@ class Settings(BaseSettings):
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
rag_top_k: int = Field(default=5, description="检索召回数量")
|
||||
rag_retrieval_top_k: int = Field(default=20, description="精排前召回候选数量(reranker 启用时生效)")
|
||||
rag_max_context_tokens: int = Field(default=2000, description="RAG最大上下文token数")
|
||||
rag_summary_max_tokens: int = Field(default=10240, description="文档摘要最大token数")
|
||||
|
||||
reranker_enabled: bool = Field(default=False, description="是否启用 Cross-Encoder 精排")
|
||||
reranker_base_url: str = Field(default="", description="Reranker API 地址")
|
||||
reranker_model: str = Field(default="BAAI/bge-reranker-v2-m3", description="Reranker 模型名称")
|
||||
reranker_api_key: str = Field(default="", description="Reranker API 密钥")
|
||||
reranker_top_k: int = Field(default=5, description="精排后保留的最终结果数量")
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
milvus_index_type: str = Field(default="IVF_FLAT", description="Milvus索引类型")
|
||||
milvus_nlist: int = Field(default=128, description="Milvus nlist参数")
|
||||
|
||||
@@ -25,7 +25,7 @@ class Settings(BaseSettings):
|
||||
dashscope_api_key: str = ""
|
||||
|
||||
# Milvus
|
||||
milvus_host: str = "localhost"
|
||||
milvus_host: str = "6.86.80.8"
|
||||
milvus_port: int = 19530
|
||||
milvus_collection: str = "regulations_dense_1024_v1"
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Initialize the app.domain.documents package."""
|
||||
|
||||
from .models import Chunk, Document, DocumentStatus, ParsedDocument
|
||||
from .ports import ChunkBuilder, DocumentBinaryStore, DocumentParser, DocumentRepository
|
||||
from .ports import ChunkBuilder, DocumentBinaryStore, DocumentParser, DocumentRepository, ParseArtifactStore
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
@@ -14,4 +14,5 @@ __all__ = [
|
||||
"DocumentBinaryStore",
|
||||
"DocumentParser",
|
||||
"DocumentRepository",
|
||||
"ParseArtifactStore",
|
||||
]
|
||||
|
||||
@@ -31,6 +31,11 @@ class DocumentRepository(ABC):
|
||||
"""Handle list for the Document Repository instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, doc_id: str) -> bool:
|
||||
"""Delete a document record. Returns True if deleted, False if not found."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_status(
|
||||
self,
|
||||
@@ -94,3 +99,32 @@ class ChunkBuilder(ABC):
|
||||
) -> list[Chunk]:
|
||||
"""Handle build for the Chunk Builder instance."""
|
||||
pass
|
||||
|
||||
|
||||
class ParseArtifactStore(ABC):
|
||||
"""Persist parse artifacts (structure nodes and semantic blocks) for relational queries."""
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
doc_id: str,
|
||||
structure_nodes: list[dict],
|
||||
semantic_blocks: list[dict],
|
||||
) -> None:
|
||||
"""Persist structure nodes and semantic blocks for a document."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, doc_id: str) -> None:
|
||||
"""Remove all parse artifacts for a document."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_semantic_blocks(self, doc_id: str) -> list[dict]:
|
||||
"""Return all semantic blocks for a document."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_structure_nodes(self, doc_id: str) -> list[dict]:
|
||||
"""Return all structure nodes for a document."""
|
||||
pass
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Initialize the app.domain.retrieval package."""
|
||||
|
||||
from .models import RetrievalQuery, RetrievedChunk
|
||||
from .ports import EmbeddingProvider, Retriever, VectorIndex
|
||||
from .ports import EmbeddingProvider, Reranker, Retriever, VectorIndex
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = ["RetrievalQuery", "RetrievedChunk", "EmbeddingProvider", "Retriever", "VectorIndex"]
|
||||
__all__ = ["RetrievalQuery", "RetrievedChunk", "EmbeddingProvider", "Reranker", "Retriever", "VectorIndex"]
|
||||
|
||||
@@ -10,7 +10,6 @@ from .models import RetrievalQuery, RetrievedChunk
|
||||
# Keep domain contracts explicit so adapters can swap implementations cleanly.
|
||||
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Provide the Embedding Provider provider."""
|
||||
@abstractmethod
|
||||
@@ -41,12 +40,35 @@ class VectorIndex(ABC):
|
||||
"""Handle search for the Vector Index instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def count_by_document(self) -> dict[str, int]:
|
||||
"""Return a mapping of doc_id -> chunk count from the vector store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_document_metadata(self) -> list[dict]:
|
||||
"""Return per-document metadata rows from the vector store.
|
||||
|
||||
Each row contains at minimum: doc_id, doc_name, chunk_count.
|
||||
Optional fields: regulation_type, version.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def health(self) -> dict:
|
||||
"""Handle health for the Vector Index instance."""
|
||||
pass
|
||||
|
||||
|
||||
class Reranker(ABC):
|
||||
"""Re-score and re-order a candidate list using a cross-encoder model."""
|
||||
|
||||
@abstractmethod
|
||||
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
|
||||
"""Return top_k chunks sorted by cross-encoder score (descending)."""
|
||||
pass
|
||||
|
||||
|
||||
class Retriever(ABC):
|
||||
"""Provide the Retriever retriever."""
|
||||
@abstractmethod
|
||||
|
||||
@@ -289,7 +289,7 @@ def build_vector_chunks(
|
||||
{
|
||||
"doc_id": doc_id,
|
||||
"doc_title": doc_title,
|
||||
"chunk_id": f"chunk-{chunk_index}",
|
||||
"chunk_id": f"{doc_id}-chunk-{chunk_index}",
|
||||
"chunk_index": chunk_index,
|
||||
"semantic_id": block["semantic_id"],
|
||||
"chunk_type": block["block_type"],
|
||||
|
||||
@@ -75,6 +75,15 @@ class JsonDocumentRepository(DocumentRepository):
|
||||
documents.sort(key=lambda item: item.updated_at, reverse=True)
|
||||
return documents[:limit] if limit is not None else documents
|
||||
|
||||
def delete(self, doc_id: str) -> bool:
|
||||
"""Delete a document record."""
|
||||
payload = self._load()
|
||||
if doc_id not in payload:
|
||||
return False
|
||||
del payload[doc_id]
|
||||
self._save(payload)
|
||||
return True
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
doc_id: str,
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
"""Implement infrastructure support for postgres document repository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.documents import Document, DocumentRepository, DocumentStatus
|
||||
|
||||
_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
doc_id VARCHAR(128) PRIMARY KEY,
|
||||
doc_name VARCHAR(512) NOT NULL DEFAULT '',
|
||||
file_name VARCHAR(512) NOT NULL DEFAULT '',
|
||||
object_name VARCHAR(1024) NOT NULL DEFAULT '',
|
||||
content_type VARCHAR(128) NOT NULL DEFAULT '',
|
||||
size_bytes BIGINT NOT NULL DEFAULT 0,
|
||||
status VARCHAR(32) NOT NULL DEFAULT 'pending',
|
||||
regulation_type VARCHAR(128) NOT NULL DEFAULT '',
|
||||
version VARCHAR(64) NOT NULL DEFAULT '',
|
||||
summary TEXT NOT NULL DEFAULT '',
|
||||
summary_latency_ms INTEGER NOT NULL DEFAULT 0,
|
||||
chunk_count INTEGER NOT NULL DEFAULT 0,
|
||||
parser_name VARCHAR(128) NOT NULL DEFAULT '',
|
||||
index_name VARCHAR(128) NOT NULL DEFAULT '',
|
||||
error_message TEXT NOT NULL DEFAULT '',
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
|
||||
_COLUMNS = (
|
||||
"doc_id", "doc_name", "file_name", "object_name", "content_type",
|
||||
"size_bytes", "status", "regulation_type", "version", "summary",
|
||||
"summary_latency_ms", "chunk_count", "parser_name", "index_name",
|
||||
"error_message", "metadata", "created_at", "updated_at",
|
||||
)
|
||||
|
||||
|
||||
class PostgresDocumentRepository(DocumentRepository):
|
||||
"""DocumentRepository implementation backed by PostgreSQL."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pool = ThreadedConnectionPool(
|
||||
minconn=1,
|
||||
maxconn=5,
|
||||
host=settings.postgres_host,
|
||||
port=settings.postgres_port,
|
||||
user=settings.postgres_user,
|
||||
password=settings.postgres_password,
|
||||
dbname=settings.postgres_db,
|
||||
)
|
||||
self._ensure_schema()
|
||||
|
||||
def _ensure_schema(self) -> None:
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(_CREATE_TABLE)
|
||||
conn.commit()
|
||||
|
||||
@contextmanager
|
||||
def _conn(self):
|
||||
conn = self._pool.getconn()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
self._pool.putconn(conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serialization helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _row_to_document(self, row: dict[str, Any]) -> Document:
|
||||
return Document(
|
||||
doc_id=row["doc_id"],
|
||||
doc_name=row["doc_name"],
|
||||
file_name=row["file_name"],
|
||||
object_name=row["object_name"],
|
||||
content_type=row["content_type"],
|
||||
size_bytes=row["size_bytes"],
|
||||
status=DocumentStatus(row["status"]),
|
||||
regulation_type=row["regulation_type"],
|
||||
version=row["version"],
|
||||
summary=row["summary"],
|
||||
summary_latency_ms=row["summary_latency_ms"],
|
||||
chunk_count=row["chunk_count"],
|
||||
parser_name=row["parser_name"],
|
||||
index_name=row["index_name"],
|
||||
error_message=row["error_message"],
|
||||
metadata=row["metadata"] if isinstance(row["metadata"], dict) else json.loads(row["metadata"] or "{}"),
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# DocumentRepository interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create(self, document: Document) -> Document:
|
||||
sql = """
|
||||
INSERT INTO documents
|
||||
(doc_id, doc_name, file_name, object_name, content_type, size_bytes,
|
||||
status, regulation_type, version, summary, summary_latency_ms,
|
||||
chunk_count, parser_name, index_name, error_message, metadata,
|
||||
created_at, updated_at)
|
||||
VALUES
|
||||
(%(doc_id)s, %(doc_name)s, %(file_name)s, %(object_name)s, %(content_type)s,
|
||||
%(size_bytes)s, %(status)s, %(regulation_type)s, %(version)s, %(summary)s,
|
||||
%(summary_latency_ms)s, %(chunk_count)s, %(parser_name)s, %(index_name)s,
|
||||
%(error_message)s, %(metadata)s, %(created_at)s, %(updated_at)s)
|
||||
ON CONFLICT (doc_id) DO NOTHING
|
||||
"""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, self._to_params(document))
|
||||
conn.commit()
|
||||
return document
|
||||
|
||||
def update(self, document: Document) -> Document:
|
||||
document.updated_at = datetime.now(UTC)
|
||||
sql = """
|
||||
UPDATE documents SET
|
||||
doc_name=%(doc_name)s, file_name=%(file_name)s, object_name=%(object_name)s,
|
||||
content_type=%(content_type)s, size_bytes=%(size_bytes)s, status=%(status)s,
|
||||
regulation_type=%(regulation_type)s, version=%(version)s, summary=%(summary)s,
|
||||
summary_latency_ms=%(summary_latency_ms)s, chunk_count=%(chunk_count)s,
|
||||
parser_name=%(parser_name)s, index_name=%(index_name)s,
|
||||
error_message=%(error_message)s, metadata=%(metadata)s, updated_at=%(updated_at)s
|
||||
WHERE doc_id=%(doc_id)s
|
||||
"""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, self._to_params(document))
|
||||
conn.commit()
|
||||
return document
|
||||
|
||||
def get(self, doc_id: str) -> Document | None:
|
||||
sql = "SELECT * FROM documents WHERE doc_id = %s"
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql, (doc_id,))
|
||||
row = cur.fetchone()
|
||||
return self._row_to_document(dict(row)) if row else None
|
||||
|
||||
def list(self, limit: int | None = None) -> list[Document]:
|
||||
sql = "SELECT * FROM documents ORDER BY updated_at DESC"
|
||||
if limit is not None:
|
||||
sql += f" LIMIT {int(limit)}"
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql)
|
||||
rows = cur.fetchall()
|
||||
return [self._row_to_document(dict(r)) for r in rows]
|
||||
|
||||
def delete(self, doc_id: str) -> bool:
|
||||
sql = "DELETE FROM documents WHERE doc_id = %s"
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (doc_id,))
|
||||
deleted = cur.rowcount > 0
|
||||
conn.commit()
|
||||
return deleted
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
doc_id: str,
|
||||
status: DocumentStatus,
|
||||
*,
|
||||
error_message: str = "",
|
||||
chunk_count: int | None = None,
|
||||
summary: str | None = None,
|
||||
summary_latency_ms: int | None = None,
|
||||
parser_name: str | None = None,
|
||||
index_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> Document | None:
|
||||
document = self.get(doc_id)
|
||||
if not document:
|
||||
return None
|
||||
document.status = status
|
||||
document.error_message = error_message
|
||||
if chunk_count is not None:
|
||||
document.chunk_count = chunk_count
|
||||
if summary is not None:
|
||||
document.summary = summary
|
||||
if summary_latency_ms is not None:
|
||||
document.summary_latency_ms = summary_latency_ms
|
||||
if parser_name is not None:
|
||||
document.parser_name = parser_name
|
||||
if index_name is not None:
|
||||
document.index_name = index_name
|
||||
if metadata:
|
||||
document.metadata.update(metadata)
|
||||
return self.update(document)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _to_params(self, document: Document) -> dict[str, Any]:
|
||||
return {
|
||||
"doc_id": document.doc_id,
|
||||
"doc_name": document.doc_name,
|
||||
"file_name": document.file_name,
|
||||
"object_name": document.object_name,
|
||||
"content_type": document.content_type,
|
||||
"size_bytes": document.size_bytes,
|
||||
"status": document.status.value,
|
||||
"regulation_type": document.regulation_type,
|
||||
"version": document.version,
|
||||
"summary": document.summary,
|
||||
"summary_latency_ms": document.summary_latency_ms,
|
||||
"chunk_count": document.chunk_count,
|
||||
"parser_name": document.parser_name,
|
||||
"index_name": document.index_name,
|
||||
"error_message": document.error_message,
|
||||
"metadata": json.dumps(document.metadata, ensure_ascii=False),
|
||||
"created_at": document.created_at,
|
||||
"updated_at": document.updated_at,
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
"""Implement infrastructure support for postgres parse artifact store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.documents import ParseArtifactStore
|
||||
|
||||
_CREATE_STRUCTURE_NODES = """
|
||||
CREATE TABLE IF NOT EXISTS structure_nodes (
|
||||
id SERIAL PRIMARY KEY,
|
||||
doc_id VARCHAR(128) NOT NULL,
|
||||
unique_id VARCHAR(128),
|
||||
page INTEGER NOT NULL DEFAULT 0,
|
||||
idx INTEGER NOT NULL DEFAULT 0,
|
||||
level INTEGER NOT NULL DEFAULT 0,
|
||||
title TEXT NOT NULL DEFAULT '',
|
||||
type VARCHAR(64),
|
||||
sub_type VARCHAR(64),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT fk_sn_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_structure_nodes_doc_id ON structure_nodes(doc_id);
|
||||
"""
|
||||
|
||||
_CREATE_SEMANTIC_BLOCKS = """
|
||||
CREATE TABLE IF NOT EXISTS semantic_blocks (
|
||||
id SERIAL PRIMARY KEY,
|
||||
doc_id VARCHAR(128) NOT NULL,
|
||||
semantic_id VARCHAR(128) NOT NULL,
|
||||
block_type VARCHAR(64) NOT NULL DEFAULT '',
|
||||
page_start INTEGER NOT NULL DEFAULT 0,
|
||||
page_end INTEGER NOT NULL DEFAULT 0,
|
||||
section_path JSONB NOT NULL DEFAULT '[]',
|
||||
section_level INTEGER NOT NULL DEFAULT 0,
|
||||
section_title VARCHAR(512) NOT NULL DEFAULT '',
|
||||
source_ids JSONB NOT NULL DEFAULT '[]',
|
||||
text TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT fk_sb_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE,
|
||||
CONSTRAINT uq_semantic_blocks UNIQUE (doc_id, semantic_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_semantic_blocks_doc_id ON semantic_blocks(doc_id);
|
||||
"""
|
||||
|
||||
|
||||
class PostgresParseArtifactStore(ParseArtifactStore):
|
||||
"""ParseArtifactStore implementation backed by PostgreSQL.
|
||||
|
||||
Requires the `documents` table to exist first (created by PostgresDocumentRepository).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pool = ThreadedConnectionPool(
|
||||
minconn=1,
|
||||
maxconn=5,
|
||||
host=settings.postgres_host,
|
||||
port=settings.postgres_port,
|
||||
user=settings.postgres_user,
|
||||
password=settings.postgres_password,
|
||||
dbname=settings.postgres_db,
|
||||
)
|
||||
self._ensure_schema()
|
||||
|
||||
def _ensure_schema(self) -> None:
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(_CREATE_STRUCTURE_NODES)
|
||||
cur.execute(_CREATE_SEMANTIC_BLOCKS)
|
||||
conn.commit()
|
||||
|
||||
@contextmanager
|
||||
def _conn(self):
|
||||
conn = self._pool.getconn()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
self._pool.putconn(conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# ParseArtifactStore interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save(
|
||||
self,
|
||||
doc_id: str,
|
||||
structure_nodes: list[dict],
|
||||
semantic_blocks: list[dict],
|
||||
) -> None:
|
||||
"""Persist structure nodes and semantic blocks, replacing any existing records."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
# Delete existing records first to keep save idempotent.
|
||||
cur.execute("DELETE FROM structure_nodes WHERE doc_id = %s", (doc_id,))
|
||||
cur.execute("DELETE FROM semantic_blocks WHERE doc_id = %s", (doc_id,))
|
||||
|
||||
if structure_nodes:
|
||||
psycopg2.extras.execute_values(
|
||||
cur,
|
||||
"""
|
||||
INSERT INTO structure_nodes
|
||||
(doc_id, unique_id, page, idx, level, title, type, sub_type)
|
||||
VALUES %s
|
||||
""",
|
||||
[
|
||||
(
|
||||
doc_id,
|
||||
node.get("unique_id"),
|
||||
int(node.get("page", 0) or 0),
|
||||
int(node.get("index", 0) or 0),
|
||||
int(node.get("level", 0) or 0),
|
||||
str(node.get("title", "")),
|
||||
node.get("type"),
|
||||
node.get("sub_type"),
|
||||
)
|
||||
for node in structure_nodes
|
||||
],
|
||||
)
|
||||
|
||||
if semantic_blocks:
|
||||
psycopg2.extras.execute_values(
|
||||
cur,
|
||||
"""
|
||||
INSERT INTO semantic_blocks
|
||||
(doc_id, semantic_id, block_type, page_start, page_end,
|
||||
section_path, section_level, section_title, source_ids, text)
|
||||
VALUES %s
|
||||
""",
|
||||
[
|
||||
(
|
||||
doc_id,
|
||||
block.get("semantic_id", ""),
|
||||
block.get("block_type", ""),
|
||||
int(block.get("page_start", 0) or 0),
|
||||
int(block.get("page_end", 0) or 0),
|
||||
json.dumps(block.get("section_path", []), ensure_ascii=False),
|
||||
int(block.get("section_level", 0) or 0),
|
||||
str(block.get("section_title", "")),
|
||||
json.dumps(block.get("source_ids", []), ensure_ascii=False),
|
||||
str(block.get("text", "")),
|
||||
)
|
||||
for block in semantic_blocks
|
||||
],
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def delete(self, doc_id: str) -> None:
|
||||
"""Remove all parse artifacts for a document (ON DELETE CASCADE handles child rows)."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("DELETE FROM structure_nodes WHERE doc_id = %s", (doc_id,))
|
||||
cur.execute("DELETE FROM semantic_blocks WHERE doc_id = %s", (doc_id,))
|
||||
conn.commit()
|
||||
|
||||
def get_semantic_blocks(self, doc_id: str) -> list[dict[str, Any]]:
|
||||
"""Return all semantic blocks for a document ordered by id."""
|
||||
sql = """
|
||||
SELECT semantic_id, block_type, page_start, page_end,
|
||||
section_path, section_level, section_title, source_ids, text
|
||||
FROM semantic_blocks
|
||||
WHERE doc_id = %s
|
||||
ORDER BY id
|
||||
"""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql, (doc_id,))
|
||||
rows = cur.fetchall()
|
||||
results = []
|
||||
for row in rows:
|
||||
item = dict(row)
|
||||
for key in ("section_path", "source_ids"):
|
||||
if isinstance(item[key], str):
|
||||
item[key] = json.loads(item[key])
|
||||
results.append(item)
|
||||
return results
|
||||
|
||||
def get_structure_nodes(self, doc_id: str) -> list[dict[str, Any]]:
|
||||
"""Return all structure nodes for a document ordered by idx."""
|
||||
sql = """
|
||||
SELECT unique_id, page, idx, level, title, type, sub_type
|
||||
FROM structure_nodes
|
||||
WHERE doc_id = %s
|
||||
ORDER BY idx
|
||||
"""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql, (doc_id,))
|
||||
rows = cur.fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Implement cross-encoder reranking via an OpenAI-compatible reranker API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.retrieval import Reranker, RetrievedChunk
|
||||
|
||||
|
||||
class OpenAICompatibleReranker(Reranker):
|
||||
"""Call a TEI / Cohere-style reranker endpoint to re-score retrieved chunks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
api_key: str | None = None,
|
||||
timeout: int = 30,
|
||||
) -> None:
|
||||
self._base_url = (base_url or settings.reranker_base_url).rstrip("/")
|
||||
self._model = model or settings.reranker_model
|
||||
self._api_key = api_key or settings.reranker_api_key
|
||||
self._timeout = timeout
|
||||
|
||||
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
|
||||
"""Return up to top_k chunks re-sorted by cross-encoder score."""
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
texts = [chunk.content for chunk in chunks]
|
||||
start = time.time()
|
||||
try:
|
||||
scores = self._call_reranker(query, texts)
|
||||
except Exception as exc:
|
||||
logger.warning("Reranker call failed ({}), falling back to original order: {}", type(exc).__name__, exc)
|
||||
return chunks[:top_k]
|
||||
|
||||
elapsed_ms = int((time.time() - start) * 1000)
|
||||
logger.debug("Reranker scored {} chunks in {}ms", len(chunks), elapsed_ms)
|
||||
|
||||
ranked = sorted(
|
||||
[(score, chunk) for score, chunk in zip(scores, chunks)],
|
||||
key=lambda x: x[0],
|
||||
reverse=True,
|
||||
)
|
||||
result = []
|
||||
for score, chunk in ranked[:top_k]:
|
||||
chunk.score = float(score)
|
||||
result.append(chunk)
|
||||
return result
|
||||
|
||||
def _call_reranker(self, query: str, texts: list[str]) -> list[float]:
|
||||
"""Call the reranker API and return a score per text."""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
|
||||
# Try TEI format first: POST /rerank
|
||||
payload = {"query": query, "texts": texts, "raw_scores": False, "return_text": False}
|
||||
url = f"{self._base_url}/rerank"
|
||||
resp = requests.post(url, json=payload, headers=headers, timeout=self._timeout)
|
||||
|
||||
if resp.status_code == 404:
|
||||
# Fall back to Cohere / OpenAI-style: POST /v1/rerank
|
||||
payload_v1 = {"model": self._model, "query": query, "documents": texts}
|
||||
url = f"{self._base_url}/v1/rerank"
|
||||
resp = requests.post(url, json=payload_v1, headers=headers, timeout=self._timeout)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# TEI response: list of {"index": N, "score": F}
|
||||
if isinstance(data, list):
|
||||
ordered = sorted(data, key=lambda x: x["index"])
|
||||
return [float(item["score"]) for item in ordered]
|
||||
|
||||
# Cohere/OpenAI response: {"results": [{"index": N, "relevance_score": F}]}
|
||||
results = data.get("results", [])
|
||||
ordered = sorted(results, key=lambda x: x["index"])
|
||||
return [float(item.get("relevance_score", item.get("score", 0))) for item in ordered]
|
||||
@@ -100,14 +100,42 @@ class MilvusVectorIndex(VectorIndex):
|
||||
result = self.collection.delete(f'doc_id == "{doc_id}"')
|
||||
return len(result.primary_keys)
|
||||
|
||||
def _parse_filters(self, filters: str | None) -> str | None:
|
||||
"""Parse filter string into Milvus expression."""
|
||||
if not filters or not filters.strip():
|
||||
return None
|
||||
|
||||
filters = filters.strip()
|
||||
|
||||
# Check if already a Milvus expression (contains operators)
|
||||
if any(op in filters for op in ["==", "!=", "in", "not in", ">", "<", ">=", "<=", "and", "or"]):
|
||||
return filters
|
||||
|
||||
# Parse simple regulation_type filter
|
||||
# Support: "GB" or "GB,UN-ECE" or "GB, UN-ECE"
|
||||
types = [t.strip() for t in filters.split(",") if t.strip()]
|
||||
|
||||
if not types:
|
||||
return None
|
||||
|
||||
if len(types) == 1:
|
||||
# Single value: regulation_type == "GB"
|
||||
return f'regulation_type == "{types[0]}"'
|
||||
else:
|
||||
# Multiple values: regulation_type in ["GB", "UN-ECE"]
|
||||
quoted_types = [f'"{t}"' for t in types]
|
||||
return f'regulation_type in [{", ".join(quoted_types)}]'
|
||||
|
||||
def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||
"""Handle search for the Milvus Vector Index instance."""
|
||||
milvus_expr = self._parse_filters(filters)
|
||||
|
||||
results = self.collection.search(
|
||||
data=[query_vector],
|
||||
anns_field="embedding",
|
||||
param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}},
|
||||
limit=top_k,
|
||||
filter=filters,
|
||||
expr=milvus_expr,
|
||||
output_fields=[
|
||||
"doc_id",
|
||||
"doc_name",
|
||||
@@ -145,6 +173,49 @@ class MilvusVectorIndex(VectorIndex):
|
||||
)
|
||||
return payload
|
||||
|
||||
def count_by_document(self) -> dict[str, int]:
|
||||
"""Return doc_id -> chunk count from Milvus."""
|
||||
try:
|
||||
rows = self.collection.query(expr="doc_id != \"\"", output_fields=["doc_id"])
|
||||
except Exception:
|
||||
return {}
|
||||
counts: dict[str, int] = {}
|
||||
for row in rows:
|
||||
doc_id = row.get("doc_id", "")
|
||||
if doc_id:
|
||||
counts[doc_id] = counts.get(doc_id, 0) + 1
|
||||
return counts
|
||||
|
||||
def list_document_metadata(self) -> list[dict]:
|
||||
"""Return one metadata row per document from Milvus (single query, no embeddings)."""
|
||||
try:
|
||||
rows = self.collection.query(
|
||||
expr="doc_id != \"\"",
|
||||
output_fields=["doc_id", "doc_name", "regulation_type", "version"],
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
seen: dict[str, dict] = {}
|
||||
counts: dict[str, int] = {}
|
||||
for row in rows:
|
||||
doc_id = row.get("doc_id", "")
|
||||
if not doc_id:
|
||||
continue
|
||||
counts[doc_id] = counts.get(doc_id, 0) + 1
|
||||
if doc_id not in seen:
|
||||
seen[doc_id] = {
|
||||
"doc_id": doc_id,
|
||||
"doc_name": row.get("doc_name", ""),
|
||||
"regulation_type": row.get("regulation_type", ""),
|
||||
"version": row.get("version", ""),
|
||||
}
|
||||
|
||||
return [
|
||||
{**meta, "chunk_count": counts[meta["doc_id"]]}
|
||||
for meta in seen.values()
|
||||
]
|
||||
|
||||
def health(self) -> dict:
|
||||
"""Handle health for the Milvus Vector Index instance."""
|
||||
return {
|
||||
|
||||
@@ -74,6 +74,7 @@ class ComplianceResult(BaseModel):
|
||||
class ComplianceChatRequest(BaseModel):
|
||||
"""Define the Compliance Chat Request API model."""
|
||||
query: str
|
||||
segment_context: Optional[str] = None
|
||||
|
||||
|
||||
class AnalyzeResponse(BaseModel):
|
||||
|
||||
@@ -10,6 +10,8 @@ class RagChatRequest(BaseModel):
|
||||
"""Define the Rag Chat Request API model."""
|
||||
query: str
|
||||
top_k: int = 5
|
||||
session_id: Optional[str] = None
|
||||
filters: Optional[str] = None
|
||||
|
||||
|
||||
class RetrievedDoc(BaseModel):
|
||||
|
||||
@@ -95,6 +95,57 @@ class DeepSeekClient(BaseLLMClient):
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Stream chat for the Deep Seek Client instance."""
|
||||
try:
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens or self.config.max_tokens,
|
||||
"temperature": temperature or self.config.temperature,
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": True
|
||||
}
|
||||
|
||||
with self._client.stream("POST", "/chat/completions", json=payload) as response:
|
||||
response.raise_for_status()
|
||||
for line in response.iter_lines():
|
||||
if not line or line.startswith(b":"):
|
||||
continue
|
||||
|
||||
line_str = line.decode("utf-8").strip()
|
||||
if not line_str.startswith("data: "):
|
||||
continue
|
||||
|
||||
data_str = line_str[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
import json
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"DeepSeek Stream API错误: {e.response.status_code}")
|
||||
yield ""
|
||||
except Exception as e:
|
||||
logger.error(f"DeepSeek Stream调用失败: {e}")
|
||||
yield ""
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""Return available models for the Deep Seek Client instance."""
|
||||
return self.SUPPORTED_MODELS
|
||||
|
||||
@@ -17,18 +17,31 @@ from app.infrastructure.parser.vector_chunk_builder import AliyunVectorChunkBuil
|
||||
from app.infrastructure.session.in_memory_conversation_store import InMemoryConversationStore
|
||||
from app.infrastructure.storage.json_document_repository import JsonDocumentRepository
|
||||
from app.infrastructure.storage.minio_binary_store import MinioDocumentBinaryStore
|
||||
from app.infrastructure.storage.postgres_document_repository import PostgresDocumentRepository
|
||||
from app.infrastructure.storage.postgres_parse_artifact_store import PostgresParseArtifactStore
|
||||
from app.infrastructure.vectorstore.dense_retriever import DenseRetriever
|
||||
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
|
||||
from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker
|
||||
# Keep shared wiring centralized so dependency construction remains consistent.
|
||||
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_document_repository() -> JsonDocumentRepository:
|
||||
"""Return document repository."""
|
||||
def get_document_repository():
|
||||
"""Return document repository (json or postgres, controlled by settings)."""
|
||||
if settings.document_repository_backend == "postgres":
|
||||
return PostgresDocumentRepository()
|
||||
return JsonDocumentRepository(settings.document_metadata_path)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_parse_artifact_store():
|
||||
"""Return parse artifact store, or None when postgres backend is not enabled."""
|
||||
if settings.document_repository_backend == "postgres":
|
||||
return PostgresParseArtifactStore()
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_binary_store() -> MinioDocumentBinaryStore:
|
||||
"""Return binary store."""
|
||||
@@ -66,6 +79,14 @@ def get_vector_index() -> MilvusVectorIndex:
|
||||
return MilvusVectorIndex()
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_reranker():
|
||||
"""Return reranker if enabled, else None."""
|
||||
if settings.reranker_enabled and settings.reranker_base_url:
|
||||
return OpenAICompatibleReranker()
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_retrieval_service() -> KnowledgeRetrievalService:
|
||||
"""Return retrieval service."""
|
||||
@@ -73,7 +94,11 @@ def get_retrieval_service() -> KnowledgeRetrievalService:
|
||||
embedding_provider=get_embedding_provider(),
|
||||
vector_index=get_vector_index(),
|
||||
)
|
||||
return KnowledgeRetrievalService(retriever=retriever)
|
||||
return KnowledgeRetrievalService(
|
||||
retriever=retriever,
|
||||
reranker=get_reranker(),
|
||||
reranker_top_k=settings.reranker_top_k,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
@@ -86,6 +111,7 @@ def get_document_command_service() -> DocumentCommandService:
|
||||
chunk_builder=get_chunk_builder(),
|
||||
embedding_provider=get_embedding_provider(),
|
||||
vector_index=get_vector_index(),
|
||||
parse_artifact_store=get_parse_artifact_store(),
|
||||
)
|
||||
|
||||
|
||||
@@ -95,6 +121,7 @@ def get_document_query_service() -> DocumentQueryService:
|
||||
return DocumentQueryService(
|
||||
document_repository=get_document_repository(),
|
||||
binary_store=get_binary_store(),
|
||||
vector_index=get_vector_index(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ tenacity>=8.2.0
|
||||
|
||||
pymilvus>=2.4.0
|
||||
minio>=7.1.0
|
||||
psycopg2-binary>=2.9.0
|
||||
|
||||
pymupdf>=1.24.0
|
||||
python-docx>=1.1.0
|
||||
|
||||
118
dev.sh
118
dev.sh
@@ -12,7 +12,8 @@ API_PID_FILE="$LOG_DIR/api.pid"
|
||||
FRONTEND_PID_FILE="$LOG_DIR/frontend.pid"
|
||||
API_LOG_FILE="$LOG_DIR/api.log"
|
||||
FRONTEND_LOG_FILE="$LOG_DIR/frontend.log"
|
||||
DOCKER_CONTAINERS="milvus minio redis postgres"
|
||||
DISPLAY_HOST="localhost"
|
||||
SERVICE_HOST="6.86.80.8"
|
||||
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
@@ -54,6 +55,51 @@ ensure_log_dir() {
|
||||
mkdir -p "$LOG_DIR"
|
||||
}
|
||||
|
||||
check_tcp_connectivity() {
|
||||
local host="$1"
|
||||
local port="$2"
|
||||
|
||||
if command -v nc > /dev/null 2>&1; then
|
||||
nc -z -w 3 "$host" "$port" > /dev/null 2>&1
|
||||
return
|
||||
fi
|
||||
|
||||
require_python_bootstrap
|
||||
"$PYTHON_BOOTSTRAP" - <<PY > /dev/null 2>&1
|
||||
import socket
|
||||
import sys
|
||||
|
||||
try:
|
||||
with socket.create_connection(("${host}", ${port}), timeout=3):
|
||||
pass
|
||||
except Exception:
|
||||
sys.exit(1)
|
||||
|
||||
sys.exit(0)
|
||||
PY
|
||||
}
|
||||
|
||||
check_foundation_services() {
|
||||
local name
|
||||
local port
|
||||
|
||||
for service in \
|
||||
"Milvus:19530" \
|
||||
"MinIO API:9000" \
|
||||
"MinIO Console:9001" \
|
||||
"Redis:6379" \
|
||||
"PostgreSQL:5432"
|
||||
do
|
||||
name="${service%%:*}"
|
||||
port="${service##*:}"
|
||||
if check_tcp_connectivity "$SERVICE_HOST" "$port"; then
|
||||
success "${name}: ${SERVICE_HOST}:${port} 可连通"
|
||||
else
|
||||
warn "${name}: ${SERVICE_HOST}:${port} 不可连通"
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
print_header() {
|
||||
echo ""
|
||||
echo -e "${CYAN}========================================${NC}"
|
||||
@@ -197,21 +243,8 @@ run_setup() {
|
||||
success "前端依赖安装完成"
|
||||
echo ""
|
||||
|
||||
info "[4/4] 检查 Docker 基础服务"
|
||||
if command -v docker > /dev/null 2>&1; then
|
||||
local container
|
||||
for container in $DOCKER_CONTAINERS; do
|
||||
if docker ps --format '{{.Names}}' | grep -q "^${container}$"; then
|
||||
success "${container}: 运行中"
|
||||
elif docker ps -a --format '{{.Names}}' | grep -q "^${container}$"; then
|
||||
warn "${container}: 已创建但未运行"
|
||||
else
|
||||
warn "${container}: 未找到容器"
|
||||
fi
|
||||
done
|
||||
else
|
||||
warn "未检测到 Docker,已跳过容器检查"
|
||||
fi
|
||||
info "[4/4] 检查 6.86.80.8 基础服务连通性"
|
||||
check_foundation_services
|
||||
|
||||
echo ""
|
||||
success "环境初始化完成"
|
||||
@@ -223,7 +256,7 @@ run_setup() {
|
||||
|
||||
api_health_ok() {
|
||||
if command -v curl > /dev/null 2>&1; then
|
||||
curl -fsS "http://localhost:$API_PORT/health" > /dev/null 2>&1
|
||||
curl -fsS "http://${DISPLAY_HOST}:$API_PORT/health" > /dev/null 2>&1
|
||||
return
|
||||
fi
|
||||
|
||||
@@ -233,7 +266,7 @@ import sys
|
||||
from urllib.request import urlopen
|
||||
|
||||
try:
|
||||
with urlopen("http://localhost:${API_PORT}/health", timeout=3) as response:
|
||||
with urlopen("http://${DISPLAY_HOST}:${API_PORT}/health", timeout=3) as response:
|
||||
body = response.read().decode("utf-8", errors="ignore")
|
||||
sys.exit(0 if "healthy" in body.lower() else 1)
|
||||
except Exception:
|
||||
@@ -260,9 +293,9 @@ start_api() {
|
||||
if [ "$mode" = "foreground" ]; then
|
||||
print_header "AI+合规智能中枢 - 启动 API"
|
||||
echo "运行模式: 前台调试(带 --reload)"
|
||||
echo "服务地址: http://localhost:$API_PORT"
|
||||
echo "文档地址: http://localhost:$API_PORT/docs"
|
||||
echo "健康检查: http://localhost:$API_PORT/health"
|
||||
echo "服务地址: http://${DISPLAY_HOST}:$API_PORT"
|
||||
echo "文档地址: http://${DISPLAY_HOST}:$API_PORT/docs"
|
||||
echo "健康检查: http://${DISPLAY_HOST}:$API_PORT/health"
|
||||
echo ""
|
||||
exec "$VENV_PYTHON" -m uvicorn app.main:app --host "$API_HOST" --port "$API_PORT" --reload
|
||||
fi
|
||||
@@ -274,8 +307,8 @@ start_api() {
|
||||
|
||||
if is_pid_running "$pid"; then
|
||||
success "API 启动成功 (PID: $pid)"
|
||||
echo " 地址: http://localhost:$API_PORT"
|
||||
echo " 文档: http://localhost:$API_PORT/docs"
|
||||
echo " 地址: http://${DISPLAY_HOST}:$API_PORT"
|
||||
echo " 文档: http://${DISPLAY_HOST}:$API_PORT/docs"
|
||||
echo " 日志: $API_LOG_FILE"
|
||||
else
|
||||
rm -f "$API_PID_FILE"
|
||||
@@ -316,7 +349,7 @@ start_frontend() {
|
||||
|
||||
if is_pid_running "$pid"; then
|
||||
success "前端启动成功 (PID: $pid)"
|
||||
echo " 地址: http://localhost:$FRONTEND_PORT"
|
||||
echo " 地址: http://${DISPLAY_HOST}:$FRONTEND_PORT"
|
||||
echo " 模式: $mode"
|
||||
echo " 日志: $FRONTEND_LOG_FILE"
|
||||
else
|
||||
@@ -407,6 +440,8 @@ run_status() {
|
||||
local frontend_running=false
|
||||
local pid
|
||||
local port_listener
|
||||
local service_name
|
||||
local service_port
|
||||
|
||||
echo -e "${YELLOW}API 服务:${NC}"
|
||||
pid="$(read_pid "$API_PID_FILE")"
|
||||
@@ -432,8 +467,8 @@ run_status() {
|
||||
warn " 健康检查: 未通过"
|
||||
fi
|
||||
fi
|
||||
echo " 地址: http://localhost:$API_PORT"
|
||||
echo " 文档: http://localhost:${API_PORT}/docs"
|
||||
echo " 地址: http://${DISPLAY_HOST}:$API_PORT"
|
||||
echo " 文档: http://${DISPLAY_HOST}:${API_PORT}/docs"
|
||||
echo ""
|
||||
|
||||
echo -e "${YELLOW}前端服务:${NC}"
|
||||
@@ -453,24 +488,25 @@ run_status() {
|
||||
fi
|
||||
fi
|
||||
echo " 模式: $FRONTEND_MODE"
|
||||
echo " 地址: http://localhost:$FRONTEND_PORT"
|
||||
echo " 地址: http://${DISPLAY_HOST}:$FRONTEND_PORT"
|
||||
echo ""
|
||||
|
||||
echo -e "${YELLOW}Docker 服务:${NC}"
|
||||
if command -v docker > /dev/null 2>&1; then
|
||||
local container
|
||||
for container in $DOCKER_CONTAINERS; do
|
||||
if docker ps --format '{{.Names}}' | grep -q "^${container}$"; then
|
||||
success " ${container}: 运行中"
|
||||
elif docker ps -a --format '{{.Names}}' | grep -q "^${container}$"; then
|
||||
warn " ${container}: 已停止"
|
||||
echo -e "${YELLOW}基础服务连通性:${NC}"
|
||||
for service in \
|
||||
"Milvus:19530" \
|
||||
"MinIO API:9000" \
|
||||
"MinIO Console:9001" \
|
||||
"Redis:6379" \
|
||||
"PostgreSQL:5432"
|
||||
do
|
||||
service_name="${service%%:*}"
|
||||
service_port="${service##*:}"
|
||||
if check_tcp_connectivity "$SERVICE_HOST" "$service_port"; then
|
||||
success " ${service_name}: ${SERVICE_HOST}:${service_port} 可连通"
|
||||
else
|
||||
warn " ${container}: 未创建"
|
||||
warn " ${service_name}: ${SERVICE_HOST}:${service_port} 不可连通"
|
||||
fi
|
||||
done
|
||||
else
|
||||
warn " Docker 未安装,已跳过"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
if [ "$api_running" = true ] && [ "$frontend_running" = true ]; then
|
||||
@@ -526,7 +562,7 @@ AI+合规智能中枢统一脚本
|
||||
setup
|
||||
进行一次性的本地初始化。
|
||||
包含 Python 版本检查、.venv 虚拟环境创建、后端依赖安装、前端 npm install、
|
||||
以及 Docker 基础容器状态检查。
|
||||
以及 6.86.80.8 基础服务端口连通性检查。
|
||||
|
||||
start
|
||||
启动服务。默认行为等同于 ./dev.sh start all。
|
||||
@@ -548,7 +584,7 @@ AI+合规智能中枢统一脚本
|
||||
restart frontend --mode static 可直接切换前端启动模式。
|
||||
|
||||
status
|
||||
查看 API、前端、Docker 基础容器的状态。
|
||||
查看 API、前端、6.86.80.8 基础服务的状态。
|
||||
API 状态包含健康检查;前端状态包含当前模式和访问地址。
|
||||
|
||||
logs
|
||||
|
||||
@@ -58,7 +58,7 @@ services:
|
||||
retries: 5
|
||||
restart: unless-stopped
|
||||
|
||||
# PostgreSQL数据库 (可选)
|
||||
# PostgreSQL数据库 (可选,启用 DOCUMENT_REPOSITORY_BACKEND=postgres 时使用)
|
||||
postgres:
|
||||
image: postgres:15-alpine
|
||||
container_name: postgres
|
||||
@@ -71,7 +71,7 @@ services:
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U compliance"]
|
||||
test: ["CMD-SHELL", "pg_isready -U compliance -d compliance_db"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
# 法规对话模块优化设计文档
|
||||
|
||||
**日期**: 2026-05-20
|
||||
**状态**: 已批准,实施中
|
||||
|
||||
---
|
||||
|
||||
## 背景
|
||||
|
||||
当前法规对话模块存在以下问题:
|
||||
1. `compliance.py` `/compliance/chat/{segment_id}` 返回硬编码 Mock 数据
|
||||
2. `rag.py` `/rag/chat` 返回硬编码 Mock 数据(前端实际调用 `/agent/chat/stream`,此路由可统一)
|
||||
3. 前端快速问题列表硬编码在 `rag.ts`,未调用后端
|
||||
4. 仅有 Dense 向量检索(COSINE),无 BM25 混合检索与精排
|
||||
5. LLM 输出的 `[1][2]` 引用编号未在前端内联高亮
|
||||
6. 会话存储仅为内存(100条上限,30分钟过期)
|
||||
|
||||
---
|
||||
|
||||
## 方案:分层优先(4个阶段)
|
||||
|
||||
### Phase 1 — 接入真实服务(Week 1)
|
||||
|
||||
**目标**:消灭所有 Mock,让系统真正可用。
|
||||
|
||||
#### 后端变更
|
||||
|
||||
**`backend/app/schemas/compliance.py`**
|
||||
- `ComplianceChatRequest` 新增 `segment_context: str | None = None`
|
||||
|
||||
**`backend/app/api/routes/compliance.py`**
|
||||
- 移除 `get_mock_compliance_chat_response` 导入
|
||||
- `/compliance/chat/{segment_id}` 接入 `get_agent_conversation_service().stream_chat()`
|
||||
- 将 `segment_context` 拼接到 query 前缀作为上下文
|
||||
- 将 agent `content` 事件翻译为 `{"type":"chunk","text":"..."}` 格式,保持前端兼容
|
||||
|
||||
**`backend/app/api/routes/rag.py`**
|
||||
- 移除 `get_mock_retrieval`、`get_mock_rag_answer` 导入
|
||||
- `/rag/chat` 接入 `get_agent_conversation_service().stream_chat()`
|
||||
- 翻译 agent 事件为 rag 格式(`retrieved`/`chunk`/`done`)
|
||||
|
||||
**`backend/app/schemas/rag.py`**
|
||||
- `RagChatRequest` 新增 `session_id: str | None = None`,`filters: str | None = None`
|
||||
|
||||
#### 前端变更
|
||||
|
||||
**`frontend/src/api/rag.ts`**
|
||||
- `getQuickQuestions()` 改为真实调用 `GET /api/v1/rag/quick-questions`,失败时降级为本地数组
|
||||
|
||||
**`frontend/src/api/compliance.ts`**
|
||||
- `complianceChat()` 新增第三个参数 `segmentContext: string | undefined`,传入 request body
|
||||
|
||||
**`frontend/src/pages/Compliance/CompliancePage.tsx`**
|
||||
- `sendChatMessage()` 中构建 `segmentContext`(intent + content 摘要 + 法规名)
|
||||
- 传给 `complianceChat()`
|
||||
|
||||
---
|
||||
|
||||
### Phase 2 — 混合检索 + Reranking(Week 2-3)
|
||||
|
||||
**目标**:提升召回质量(BM25 + dense RRF融合)+ 精排。
|
||||
|
||||
#### 2a — Cross-Encoder Reranking(优先,无需 schema 变更)
|
||||
|
||||
- 新增端口 `backend/app/domain/retrieval/ports.py`: `Reranker` ABC
|
||||
- 新增适配器 `backend/app/infrastructure/vectorstore/cross_encoder_reranker.py`
|
||||
- 调用 OpenAI-compatible reranker API(BAAI/bge-reranker-v2-m3)
|
||||
- 修改 `KnowledgeRetrievalService`:先 retrieve top-20,再 rerank 到 top-5
|
||||
- 新增 settings: `reranker_enabled: bool = False`,`reranker_model: str`,`reranker_top_k: int = 5`
|
||||
|
||||
#### 2b — Milvus Sparse BM25(需 schema 迁移)
|
||||
|
||||
- Milvus collection 新增 `sparse_embedding SPARSE_FLOAT_VECTOR` 字段
|
||||
- 新增端口 `SparseEmbeddingProvider`(sparse embed 接口)
|
||||
- 适配器优先使用 BGE-M3 API(同时输出 dense + sparse),不可用时降级为 TF-IDF keyword weights
|
||||
- `MilvusVectorIndex.upsert()` 同时写入 sparse 向量
|
||||
- `MilvusVectorIndex.search()` 改为 hybrid search(`WeightedRanker` 或 `RRFRanker`)
|
||||
- 提供一次性迁移脚本:dump 所有 chunks → recreate collection → re-embed → re-insert
|
||||
|
||||
---
|
||||
|
||||
### Phase 3 — 引用溯源 + 筛选 UI(Week 4)
|
||||
|
||||
**目标**:答案文本中 `[1][2]` 可点击跳转原文片段;法规类型/版本可筛选。
|
||||
|
||||
#### 引用内联解析
|
||||
|
||||
- 新增 React 组件 `CitedAnswer`:接受 `text` + `sources[]`
|
||||
- 用正则 `/\[(\d+)\]/g` 拆分文本,将 `[N]` 渲染为可点击 `<cite>` 元素
|
||||
- 点击高亮/滚动到来源面板对应条目
|
||||
- `RagChatPage.tsx` 将 assistant 消息改用 `CitedAnswer` 组件渲染
|
||||
- 来源面板各条目加 `id="source-N"` 属性
|
||||
|
||||
#### 法规筛选 UI
|
||||
|
||||
- `RagChatPage.tsx` 在输入框上方加筛选栏:`regulation_type` 下拉 + `version` 输入
|
||||
- 筛选值作为 `filters` 参数传到 `/agent/chat/stream`
|
||||
- `MilvusVectorIndex.search()` 解析 `filters` 字符串为 Milvus expr(如 `regulation_type == "GB"`)
|
||||
|
||||
#### 快速问题动态化
|
||||
|
||||
- 后端 `/rag/quick-questions` 改为从 settings 或配置文件加载,不硬编码
|
||||
- 前端已在 Phase 1 中改为调用后端
|
||||
|
||||
---
|
||||
|
||||
### Phase 4 — 会话持久化 + 上下文压缩(Week 5)
|
||||
|
||||
**目标**:长对话不丢失;上下文过长时智能压缩。
|
||||
|
||||
#### PostgreSQL 会话存储
|
||||
|
||||
- 新增 `backend/app/infrastructure/session/postgres_conversation_store.py`
|
||||
- 实现 `ConversationStore` port
|
||||
- 复用现有 PostgreSQL 连接
|
||||
- DDL: `conversation_sessions` + `conversation_messages` 两张表
|
||||
- 新增 settings: `conversation_store_backend: str = "memory"`
|
||||
- `bootstrap.py` 按 settings 切换实现
|
||||
|
||||
#### 上下文压缩
|
||||
|
||||
- 在 `AgentConversationService` 中,当对话轮数 > N 时,将早期消息用 LLM 摘要
|
||||
- 摘要替换为一条 `system` 消息 `"对话摘要:..."`
|
||||
- 新增 `AnswerGenerator.summarize(messages) -> str` 方法
|
||||
|
||||
---
|
||||
|
||||
## 关键设计决策
|
||||
|
||||
| 决策 | 选择 | 原因 |
|
||||
|------|------|------|
|
||||
| BM25 实现 | Milvus sparse vectors | 不引入新服务,Milvus 2.4+ 原生支持 |
|
||||
| Reranking | API-based cross-encoder | 复用现有 OpenAI-compatible 接口 |
|
||||
| 会话持久化 | PostgreSQL | 复用现有 PostgreSQL 基础设施 |
|
||||
| 引用解析 | 前端纯客户端 | LLM 已输出 [N] 编号,无需后端改动 |
|
||||
| segment_context | 前端构建传后端 | 合规结果存在前端状态,无需后端持久化 |
|
||||
|
||||
---
|
||||
|
||||
## 验证标准
|
||||
|
||||
- Phase 1: 合规对话面板返回真实法规检索结果,不再固定响应
|
||||
- Phase 2: 相同 query 在 hybrid 模式下比 dense-only 召回更相关结果
|
||||
- Phase 3: 点击答案中的 `[1]` 能跳转到来源面板第一条
|
||||
- Phase 4: 刷新页面后历史对话仍可恢复
|
||||
@@ -31,11 +31,18 @@ export async function getComplianceResult(
|
||||
export function complianceChat(
|
||||
segmentId: number,
|
||||
query: string,
|
||||
segmentContext: string | undefined,
|
||||
onMessage: (data: SSEMessage) => void,
|
||||
onError?: (error: Error) => void,
|
||||
onComplete?: () => void
|
||||
): void {
|
||||
void streamSSE<SSEMessage>(`/compliance/chat/${segmentId}`, { query }, onMessage, onError, onComplete);
|
||||
void streamSSE<SSEMessage>(
|
||||
`/compliance/chat/${segmentId}`,
|
||||
{ query, segment_context: segmentContext },
|
||||
onMessage,
|
||||
onError,
|
||||
onComplete,
|
||||
);
|
||||
}
|
||||
|
||||
export type { ComplianceResult, SSEMessage };
|
||||
|
||||
@@ -6,7 +6,11 @@ interface BackendDocumentItem {
|
||||
doc_name: string;
|
||||
status: string;
|
||||
chunk_count: number;
|
||||
size_bytes?: number;
|
||||
summary?: string;
|
||||
updated_at?: string;
|
||||
regulation_type?: string;
|
||||
version?: string;
|
||||
}
|
||||
|
||||
interface BackendDocumentListResponse {
|
||||
@@ -44,6 +48,7 @@ export interface RegulationSearchResponse {
|
||||
}
|
||||
|
||||
function mapDoc(item: BackendDocumentItem): DocInfo {
|
||||
const sizeMB = item.size_bytes ? (item.size_bytes / (1024 * 1024)).toFixed(1) + 'MB' : '';
|
||||
return {
|
||||
id: item.doc_id,
|
||||
name: item.doc_name,
|
||||
@@ -51,14 +56,23 @@ function mapDoc(item: BackendDocumentItem): DocInfo {
|
||||
status: item.status,
|
||||
updated_at: item.updated_at,
|
||||
download_url: `${API_BASE_URL}/documents/download/${item.doc_id}`,
|
||||
size_text: sizeMB,
|
||||
summary: item.summary,
|
||||
regulation_type: item.regulation_type,
|
||||
version: item.version,
|
||||
};
|
||||
}
|
||||
|
||||
export async function uploadDocument(file: File): Promise<DocUploadResponse> {
|
||||
export async function uploadDocument(
|
||||
file: File,
|
||||
opts?: { regulationType?: string; version?: string }
|
||||
): Promise<DocUploadResponse> {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
formData.append('doc_name', file.name);
|
||||
formData.append('generate_summary', 'true');
|
||||
if (opts?.regulationType) formData.append('regulation_type', opts.regulationType);
|
||||
if (opts?.version) formData.append('version', opts.version);
|
||||
|
||||
const response = await fetch(`${API_BASE_URL}/documents/upload`, {
|
||||
method: 'POST',
|
||||
@@ -92,6 +106,29 @@ export async function getDocumentList(): Promise<DocListResponse> {
|
||||
};
|
||||
}
|
||||
|
||||
export async function getDocumentStatus(docId: string): Promise<DocUploadResponse> {
|
||||
const response = await fetch(`${API_BASE_URL}/documents/status/${docId}`);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Status check failed: ${response.status}`);
|
||||
}
|
||||
return response.json() as Promise<DocUploadResponse>;
|
||||
}
|
||||
|
||||
export async function deleteDocument(docId: string): Promise<void> {
|
||||
const response = await fetch(`${API_BASE_URL}/documents/${docId}`, { method: 'DELETE' });
|
||||
if (!response.ok) {
|
||||
throw new Error(`Delete failed: ${response.status}`);
|
||||
}
|
||||
}
|
||||
|
||||
export async function retryDocument(docId: string): Promise<DocUploadResponse> {
|
||||
const response = await fetch(`${API_BASE_URL}/documents/${docId}/retry`, { method: 'POST' });
|
||||
if (!response.ok) {
|
||||
throw new Error(`Retry failed: ${response.status}`);
|
||||
}
|
||||
return response.json() as Promise<DocUploadResponse>;
|
||||
}
|
||||
|
||||
export async function searchRegulations(query: string, topK: number = 8): Promise<RegulationSearchResponse> {
|
||||
const response = await fetch(`${API_BASE_URL}/knowledge/retrieval`, {
|
||||
method: 'POST',
|
||||
|
||||
@@ -136,6 +136,9 @@ export interface DocInfo {
|
||||
updated_at?: string;
|
||||
download_url?: string;
|
||||
size_text?: string;
|
||||
summary?: string;
|
||||
regulation_type?: string;
|
||||
version?: string;
|
||||
}
|
||||
|
||||
export interface DocListResponse {
|
||||
@@ -149,6 +152,8 @@ export interface DocUploadResponse {
|
||||
message?: string;
|
||||
num_chunks?: number;
|
||||
summary?: string;
|
||||
regulation_type?: string;
|
||||
version?: string;
|
||||
}
|
||||
|
||||
export interface QuickQuestion {
|
||||
|
||||
@@ -2,15 +2,21 @@ import type { QuickQuestionsResponse, SSEMessage } from './index';
|
||||
|
||||
const AGENT_API_BASE = '/api/v1';
|
||||
|
||||
export async function getQuickQuestions(): Promise<QuickQuestionsResponse> {
|
||||
return {
|
||||
questions: [
|
||||
const _FALLBACK_QUESTIONS = [
|
||||
{ id: '1', question: '请总结最新入库法规对电池安全的核心要求', category: '法规解读' },
|
||||
{ id: '2', question: '我上传的制度文档与新能源法规有哪些潜在冲突?', category: '差距分析' },
|
||||
{ id: '3', question: '请给出法规依据,并按条款列出整改建议', category: '整改建议' },
|
||||
{ id: '4', question: '请解释 UN-ECE 与 GB 标准在网络安全方面的差异', category: '标准对比' },
|
||||
],
|
||||
};
|
||||
];
|
||||
|
||||
export async function getQuickQuestions(): Promise<QuickQuestionsResponse> {
|
||||
try {
|
||||
const response = await fetch(`${AGENT_API_BASE}/rag/quick-questions`);
|
||||
if (!response.ok) throw new Error(`status ${response.status}`);
|
||||
return response.json() as Promise<QuickQuestionsResponse>;
|
||||
} catch {
|
||||
return { questions: _FALLBACK_QUESTIONS };
|
||||
}
|
||||
}
|
||||
|
||||
function parseSSEChunk(raw: string, onMessage: (data: SSEMessage) => void) {
|
||||
@@ -67,7 +73,8 @@ export async function ragChat(
|
||||
topK: number = 5,
|
||||
onMessage: (data: SSEMessage) => void,
|
||||
onError?: (error: Error) => void,
|
||||
onComplete?: () => void
|
||||
onComplete?: () => void,
|
||||
filters?: string
|
||||
): Promise<void> {
|
||||
try {
|
||||
const response = await fetch(`${AGENT_API_BASE}/agent/chat/stream`, {
|
||||
@@ -76,7 +83,7 @@ export async function ragChat(
|
||||
'Content-Type': 'application/json',
|
||||
Accept: 'text/event-stream',
|
||||
},
|
||||
body: JSON.stringify({ query, top_k: topK }),
|
||||
body: JSON.stringify({ query, top_k: topK, ...(filters ? { filters } : {}) }),
|
||||
});
|
||||
|
||||
if (!response.ok || !response.body) {
|
||||
|
||||
@@ -250,11 +250,20 @@ export const CompliancePage: React.FC = () => {
|
||||
setChatInput('');
|
||||
setChatLoading(true);
|
||||
|
||||
const segmentContext = [
|
||||
`意图:${chunk.intent}`,
|
||||
`内容:${chunk.content.slice(0, 300)}`,
|
||||
chunk.regulations.length > 0
|
||||
? `相关法规:${chunk.regulations.slice(0, 3).map(r => `${r.name}${r.clause ? ' ' + r.clause : ''}(相关性 ${Math.round(r.score * 100)}%)`).join(';')}`
|
||||
: '',
|
||||
].filter(Boolean).join('\n');
|
||||
|
||||
let currentResponse = '';
|
||||
|
||||
complianceChat(
|
||||
activeChunkId,
|
||||
chatInput,
|
||||
segmentContext,
|
||||
(data: unknown) => {
|
||||
const sseData = data as { type: string; text?: string };
|
||||
if (sseData.type === 'chunk' && sseData.text) {
|
||||
|
||||
@@ -2,7 +2,7 @@ import React, { useEffect, useRef, useState } from 'react';
|
||||
import { useTheme } from '../../contexts';
|
||||
import { Content } from '../../components/layout/Content';
|
||||
import { TPattern } from '../../components/common/TPattern';
|
||||
import { getDocumentList, searchRegulations, uploadDocument, type RegulationSearchItem } from '../../api/docs';
|
||||
import { getDocumentList, getDocumentStatus, searchRegulations, uploadDocument, deleteDocument, retryDocument, type RegulationSearchItem } from '../../api/docs';
|
||||
import type { Doc } from '../../types';
|
||||
|
||||
type PipelineStatus = 'idle' | 'running' | 'completed' | 'error';
|
||||
@@ -15,13 +15,13 @@ const PIPELINE_STEPS = [
|
||||
{ name: 'STORE' },
|
||||
];
|
||||
|
||||
const REGULATION_TYPES = ['', '国家标准', '行业标准', '地方标准', '企业标准', '法律法规', '监管规定'];
|
||||
|
||||
const STEP_DURATION_MS = 700;
|
||||
const INITIAL_SEARCH_QUERY = '新能源汽车电池安全要求';
|
||||
|
||||
function wait(ms: number) {
|
||||
return new Promise<void>((resolve) => {
|
||||
window.setTimeout(resolve, ms);
|
||||
});
|
||||
return new Promise<void>((resolve) => { window.setTimeout(resolve, ms); });
|
||||
}
|
||||
|
||||
export const DocsPage: React.FC = () => {
|
||||
@@ -41,6 +41,13 @@ export const DocsPage: React.FC = () => {
|
||||
const [searchLoading, setSearchLoading] = useState(false);
|
||||
const [searchError, setSearchError] = useState('');
|
||||
|
||||
// Upload metadata
|
||||
const [regulationType, setRegulationType] = useState('');
|
||||
const [version, setVersion] = useState('');
|
||||
|
||||
// Batch queue: files waiting to be uploaded after the current one finishes
|
||||
const batchQueueRef = useRef<File[]>([]);
|
||||
|
||||
async function loadDocuments() {
|
||||
setLoading(true);
|
||||
try {
|
||||
@@ -54,6 +61,9 @@ export const DocsPage: React.FC = () => {
|
||||
docId: doc.id,
|
||||
downloadUrl: doc.download_url,
|
||||
updatedAt: doc.updated_at,
|
||||
summary: doc.summary,
|
||||
regulationType: doc.regulation_type,
|
||||
version: doc.version,
|
||||
}));
|
||||
setDocs(apiDocs);
|
||||
} catch (error) {
|
||||
@@ -81,62 +91,71 @@ export const DocsPage: React.FC = () => {
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
const timerId = window.setTimeout(() => {
|
||||
void loadDocuments();
|
||||
}, 0);
|
||||
const timerId = window.setTimeout(() => { void loadDocuments(); }, 0);
|
||||
return () => window.clearTimeout(timerId);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const timerId = window.setTimeout(() => {
|
||||
void runSearch(INITIAL_SEARCH_QUERY);
|
||||
}, 0);
|
||||
const timerId = window.setTimeout(() => { void runSearch(INITIAL_SEARCH_QUERY); }, 0);
|
||||
return () => window.clearTimeout(timerId);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
pipelineRunIdRef.current += 1;
|
||||
};
|
||||
return () => { pipelineRunIdRef.current += 1; };
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const parsingDocs = docs.filter(
|
||||
(doc) => doc.status === 'parsing' && doc.docId && !doc.docId.startsWith('pending-')
|
||||
);
|
||||
if (parsingDocs.length === 0) return;
|
||||
|
||||
const timerId = window.setInterval(() => {
|
||||
parsingDocs.forEach((doc) => {
|
||||
void getDocumentStatus(doc.docId!).then((res) => {
|
||||
if (res.status === 'indexed' || res.status === 'failed') {
|
||||
setDocs((prev) =>
|
||||
prev.map((d) =>
|
||||
d.docId === doc.docId
|
||||
? {
|
||||
...d,
|
||||
status: res.status === 'indexed' ? 'indexed' : 'failed',
|
||||
chunks: res.num_chunks ?? d.chunks,
|
||||
summary: res.summary ?? d.summary,
|
||||
regulationType: res.regulation_type ?? d.regulationType,
|
||||
version: res.version ?? d.version,
|
||||
}
|
||||
: d
|
||||
)
|
||||
);
|
||||
}
|
||||
}).catch(() => {});
|
||||
});
|
||||
}, 5000);
|
||||
|
||||
return () => window.clearInterval(timerId);
|
||||
}, [docs]);
|
||||
|
||||
const runPipelineFlow = async (runId: number, uploadPromise: Promise<Awaited<ReturnType<typeof uploadDocument>>>) => {
|
||||
const guardedSetActiveStep = (step: number) => {
|
||||
if (pipelineRunIdRef.current !== runId) return false;
|
||||
setActiveStep(step);
|
||||
return true;
|
||||
};
|
||||
const guard = (fn: () => void) => { if (pipelineRunIdRef.current !== runId) return false; fn(); return true; };
|
||||
|
||||
const guardedCompleteStep = (step: number) => {
|
||||
if (pipelineRunIdRef.current !== runId) return false;
|
||||
setCompletedSteps((prev) => (prev.includes(step) ? prev : [...prev, step]));
|
||||
return true;
|
||||
};
|
||||
|
||||
for (let index = 0; index < PIPELINE_STEPS.length - 1; index += 1) {
|
||||
if (!guardedSetActiveStep(index)) return;
|
||||
for (let i = 0; i < PIPELINE_STEPS.length - 1; i++) {
|
||||
if (!guard(() => setActiveStep(i))) return;
|
||||
await wait(STEP_DURATION_MS);
|
||||
if (!guardedCompleteStep(index)) return;
|
||||
if (!guard(() => setCompletedSteps((p) => p.includes(i) ? p : [...p, i]))) return;
|
||||
}
|
||||
|
||||
if (!guardedSetActiveStep(PIPELINE_STEPS.length - 1)) return;
|
||||
if (!guard(() => setActiveStep(PIPELINE_STEPS.length - 1))) return;
|
||||
await uploadPromise;
|
||||
if (!guardedCompleteStep(PIPELINE_STEPS.length - 1)) return;
|
||||
if (!guard(() => setCompletedSteps((p) => { const last = PIPELINE_STEPS.length - 1; return p.includes(last) ? p : [...p, last]; }))) return;
|
||||
|
||||
await wait(240);
|
||||
if (pipelineRunIdRef.current !== runId) return;
|
||||
|
||||
setActiveStep(-1);
|
||||
setPipelineStatus('completed');
|
||||
};
|
||||
|
||||
const handleFileSelect = async (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const file = event.target.files?.[0];
|
||||
if (!file || uploading) return;
|
||||
|
||||
const runId = pipelineRunIdRef.current + 1;
|
||||
pipelineRunIdRef.current = runId;
|
||||
|
||||
const uploadSingleFile = async (file: File, runId: number) => {
|
||||
setUploading(true);
|
||||
setUploadFileName(file.name);
|
||||
setActiveStep(-1);
|
||||
@@ -152,11 +171,16 @@ export const DocsPage: React.FC = () => {
|
||||
size: `${fileSizeMB}MB`,
|
||||
status: 'parsing',
|
||||
docId: tempDocId,
|
||||
regulationType: regulationType || undefined,
|
||||
version: version || undefined,
|
||||
};
|
||||
|
||||
setDocs((prev) => [newDoc, ...prev]);
|
||||
|
||||
const uploadPromise = uploadDocument(file);
|
||||
const uploadPromise = uploadDocument(file, {
|
||||
regulationType: regulationType || undefined,
|
||||
version: version || undefined,
|
||||
});
|
||||
void runPipelineFlow(runId, uploadPromise);
|
||||
|
||||
try {
|
||||
@@ -166,143 +190,123 @@ export const DocsPage: React.FC = () => {
|
||||
setDocs((prev) =>
|
||||
prev.map((doc) =>
|
||||
doc.id === newDoc.id
|
||||
? {
|
||||
...doc,
|
||||
status: 'indexed',
|
||||
docId: uploadRes.doc_id,
|
||||
chunks: uploadRes.num_chunks || doc.chunks,
|
||||
summary: uploadRes.summary,
|
||||
}
|
||||
? { ...doc, status: 'indexed', docId: uploadRes.doc_id, chunks: uploadRes.num_chunks || doc.chunks, summary: uploadRes.summary }
|
||||
: doc
|
||||
)
|
||||
);
|
||||
|
||||
setUploading(false);
|
||||
setUploadFileName('');
|
||||
void loadDocuments();
|
||||
} catch (error) {
|
||||
console.error('Upload failed:', error);
|
||||
if (pipelineRunIdRef.current !== runId) return;
|
||||
|
||||
setUploading(false);
|
||||
setUploadFileName('');
|
||||
setDocs((prev) => prev.filter((doc) => doc.id !== newDoc.id));
|
||||
setPipelineStatus('error');
|
||||
setActiveStep(-1);
|
||||
setCompletedSteps([]);
|
||||
} finally {
|
||||
if (fileInputRef.current) {
|
||||
fileInputRef.current.value = '';
|
||||
setUploading(false);
|
||||
setUploadFileName('');
|
||||
if (fileInputRef.current) fileInputRef.current.value = '';
|
||||
|
||||
// Process next file in batch queue
|
||||
const next = batchQueueRef.current.shift();
|
||||
if (next) {
|
||||
const nextRunId = pipelineRunIdRef.current + 1;
|
||||
pipelineRunIdRef.current = nextRunId;
|
||||
void uploadSingleFile(next, nextRunId);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const triggerFileUpload = () => {
|
||||
if (uploading) return;
|
||||
fileInputRef.current?.click();
|
||||
const handleFileSelect = async (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const files = Array.from(event.target.files ?? []);
|
||||
if (files.length === 0 || uploading) return;
|
||||
|
||||
const [first, ...rest] = files;
|
||||
batchQueueRef.current = rest;
|
||||
|
||||
const runId = pipelineRunIdRef.current + 1;
|
||||
pipelineRunIdRef.current = runId;
|
||||
await uploadSingleFile(first, runId);
|
||||
};
|
||||
|
||||
const handleDragOver = (event: React.DragEvent) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
const handleDelete = async (docId: string) => {
|
||||
try {
|
||||
await deleteDocument(docId);
|
||||
setDocs((prev) => prev.filter((doc) => doc.docId !== docId));
|
||||
} catch (error) {
|
||||
console.error('Delete failed:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleRetry = async (docId: string) => {
|
||||
setDocs((prev) => prev.map((doc) => doc.docId === docId ? { ...doc, status: 'parsing' } : doc));
|
||||
try {
|
||||
const result = await retryDocument(docId);
|
||||
setDocs((prev) =>
|
||||
prev.map((doc) => doc.docId === docId ? { ...doc, status: 'indexed', chunks: result.num_chunks || doc.chunks } : doc)
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('Retry failed:', error);
|
||||
setDocs((prev) => prev.map((doc) => doc.docId === docId ? { ...doc, status: 'failed' } : doc));
|
||||
}
|
||||
};
|
||||
|
||||
const triggerFileUpload = () => { if (uploading) return; fileInputRef.current?.click(); };
|
||||
|
||||
const handleDragOver = (event: React.DragEvent) => { event.preventDefault(); event.stopPropagation(); };
|
||||
|
||||
const handleDrop = (event: React.DragEvent) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
|
||||
const files = event.dataTransfer.files;
|
||||
const files = Array.from(event.dataTransfer.files);
|
||||
if (files.length === 0 || uploading) return;
|
||||
|
||||
const droppedFile = files[0];
|
||||
if (fileInputRef.current) {
|
||||
const dataTransfer = new DataTransfer();
|
||||
dataTransfer.items.add(droppedFile);
|
||||
fileInputRef.current.files = dataTransfer.files;
|
||||
}
|
||||
|
||||
void handleFileSelect({
|
||||
target: { files: [droppedFile] as unknown as FileList },
|
||||
} as React.ChangeEvent<HTMLInputElement>);
|
||||
const [first, ...rest] = files;
|
||||
batchQueueRef.current = rest;
|
||||
const runId = pipelineRunIdRef.current + 1;
|
||||
pipelineRunIdRef.current = runId;
|
||||
void uploadSingleFile(first, runId);
|
||||
};
|
||||
|
||||
const getStepStyle = (index: number) => {
|
||||
const isActive = activeStep === index;
|
||||
const isCompleted = completedSteps.includes(index);
|
||||
|
||||
if (isActive) {
|
||||
return {
|
||||
background: theme.bgCard,
|
||||
border: `2px solid ${theme.accent}`,
|
||||
boxShadow: `0 0 12px ${theme.accent}40`,
|
||||
};
|
||||
}
|
||||
|
||||
if (isCompleted) {
|
||||
return {
|
||||
background: theme.bgCard,
|
||||
border: `1px solid ${theme.green}`,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
background: theme.bgCard,
|
||||
border: `1px solid ${theme.border}`,
|
||||
};
|
||||
if (activeStep === index) return { background: theme.bgCard, border: `2px solid ${theme.accent}`, boxShadow: `0 0 12px ${theme.accent}40` };
|
||||
if (completedSteps.includes(index)) return { background: theme.bgCard, border: `1px solid ${theme.green}` };
|
||||
return { background: theme.bgCard, border: `1px solid ${theme.border}` };
|
||||
};
|
||||
|
||||
const getCheckStyle = (index: number) => {
|
||||
const isActive = activeStep === index;
|
||||
const isCompleted = completedSteps.includes(index);
|
||||
|
||||
if (isActive) {
|
||||
return {
|
||||
background: theme.gradientAccent,
|
||||
color: '#fff',
|
||||
animation: 'pulse 0.6s infinite',
|
||||
};
|
||||
}
|
||||
|
||||
if (isCompleted) {
|
||||
return {
|
||||
background: theme.green,
|
||||
color: '#fff',
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
background: theme.bgHover,
|
||||
color: theme.text3,
|
||||
};
|
||||
if (activeStep === index) return { background: theme.gradientAccent, color: '#fff', animation: 'pulse 0.6s infinite' };
|
||||
if (completedSteps.includes(index)) return { background: theme.green, color: '#fff' };
|
||||
return { background: theme.bgHover, color: theme.text3 };
|
||||
};
|
||||
|
||||
const getPipelineHint = () => {
|
||||
if (pipelineStatus === 'running') {
|
||||
return activeStep >= 0 ? `${PIPELINE_STEPS[activeStep].name} · ${uploadFileName}` : `LOAD · ${uploadFileName}`;
|
||||
}
|
||||
if (pipelineStatus === 'completed') {
|
||||
return 'PIPELINE COMPLETE';
|
||||
}
|
||||
if (pipelineStatus === 'error') {
|
||||
return 'PIPELINE FAILED';
|
||||
const queueLen = batchQueueRef.current.length;
|
||||
const suffix = queueLen > 0 ? ` (+${queueLen} 待上传)` : '';
|
||||
return `${activeStep >= 0 ? PIPELINE_STEPS[activeStep].name : 'LOAD'} · ${uploadFileName}${suffix}`;
|
||||
}
|
||||
if (pipelineStatus === 'completed') return 'PIPELINE COMPLETE';
|
||||
if (pipelineStatus === 'error') return 'PIPELINE FAILED';
|
||||
return 'WAITING FOR UPLOAD';
|
||||
};
|
||||
|
||||
const inputStyle: React.CSSProperties = {
|
||||
padding: '8px 12px',
|
||||
fontSize: 13,
|
||||
background: theme.bgCard,
|
||||
border: `1px solid ${theme.border}`,
|
||||
borderRadius: 8,
|
||||
color: theme.text,
|
||||
outline: 'none',
|
||||
};
|
||||
|
||||
return (
|
||||
<Content>
|
||||
<TPattern />
|
||||
|
||||
<section style={{ marginBottom: 56 }}>
|
||||
<h2
|
||||
style={{
|
||||
fontSize: 14,
|
||||
fontWeight: 600,
|
||||
color: theme.accent,
|
||||
marginBottom: 20,
|
||||
letterSpacing: '1px',
|
||||
}}
|
||||
>
|
||||
<h2 style={{ fontSize: 14, fontWeight: 600, color: theme.accent, marginBottom: 20, letterSpacing: '1px' }}>
|
||||
UPLOAD
|
||||
</h2>
|
||||
|
||||
@@ -310,10 +314,30 @@ export const DocsPage: React.FC = () => {
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept=".pdf,.docx,.doc"
|
||||
multiple
|
||||
onChange={handleFileSelect}
|
||||
style={{ display: 'none' }}
|
||||
/>
|
||||
|
||||
{/* Metadata row */}
|
||||
<div style={{ display: 'flex', gap: 12, marginBottom: 12 }}>
|
||||
<select
|
||||
value={regulationType}
|
||||
onChange={(e) => setRegulationType(e.target.value)}
|
||||
style={{ ...inputStyle, flex: 1 }}
|
||||
>
|
||||
{REGULATION_TYPES.map((t) => (
|
||||
<option key={t} value={t}>{t || '法规类型(可选)'}</option>
|
||||
))}
|
||||
</select>
|
||||
<input
|
||||
value={version}
|
||||
onChange={(e) => setVersion(e.target.value)}
|
||||
placeholder="版本号(可选,如 2024)"
|
||||
style={{ ...inputStyle, flex: 1 }}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div
|
||||
onClick={triggerFileUpload}
|
||||
onDragOver={handleDragOver}
|
||||
@@ -330,30 +354,11 @@ export const DocsPage: React.FC = () => {
|
||||
opacity: uploading ? 0.78 : 1,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
width: 80,
|
||||
height: 80,
|
||||
borderRadius: 20,
|
||||
background: theme.bgHover,
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
margin: '0 auto 20px',
|
||||
}}
|
||||
>
|
||||
<div style={{ width: 80, height: 80, borderRadius: 20, background: theme.bgHover, display: 'flex', alignItems: 'center', justifyContent: 'center', margin: '0 auto 20px' }}>
|
||||
{uploading ? (
|
||||
<div style={{ animation: 'spin 1s linear infinite' }}>
|
||||
<svg width="36" height="36" viewBox="0 0 24 24" fill="none">
|
||||
<circle
|
||||
cx="12"
|
||||
cy="12"
|
||||
r="10"
|
||||
stroke={theme.accent}
|
||||
strokeWidth="2"
|
||||
strokeDasharray="60"
|
||||
strokeDashoffset="20"
|
||||
/>
|
||||
<circle cx="12" cy="12" r="10" stroke={theme.accent} strokeWidth="2" strokeDasharray="60" strokeDashoffset="20" />
|
||||
</svg>
|
||||
</div>
|
||||
) : (
|
||||
@@ -363,26 +368,17 @@ export const DocsPage: React.FC = () => {
|
||||
</svg>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div style={{ fontSize: 16, fontWeight: 500, marginBottom: 8 }}>
|
||||
{uploading ? '正在上传并启动处理链路...' : '拖拽文件或点击上传'}
|
||||
{uploading ? '正在上传并启动处理链路...' : '拖拽文件或点击上传(支持多选)'}
|
||||
</div>
|
||||
<div className="mono" style={{ fontSize: 12, color: theme.text3 }}>
|
||||
{uploading ? uploadFileName : 'PDF · DOCX · DOC · MAX 100MB'}
|
||||
{uploading ? uploadFileName : 'PDF · DOCX · DOC · MAX 100MB · 支持批量'}
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section style={{ marginBottom: 40 }}>
|
||||
<h2
|
||||
style={{
|
||||
fontSize: 14,
|
||||
fontWeight: 600,
|
||||
color: theme.accent,
|
||||
marginBottom: 20,
|
||||
letterSpacing: '1px',
|
||||
}}
|
||||
>
|
||||
<h2 style={{ fontSize: 14, fontWeight: 600, color: theme.accent, marginBottom: 20, letterSpacing: '1px' }}>
|
||||
PROCESSING PIPELINE
|
||||
</h2>
|
||||
|
||||
@@ -390,12 +386,7 @@ export const DocsPage: React.FC = () => {
|
||||
className="mono"
|
||||
style={{
|
||||
fontSize: 11,
|
||||
color:
|
||||
pipelineStatus === 'error'
|
||||
? '#d64545'
|
||||
: pipelineStatus === 'completed'
|
||||
? theme.green
|
||||
: theme.text3,
|
||||
color: pipelineStatus === 'error' ? '#d64545' : pipelineStatus === 'completed' ? theme.green : theme.text3,
|
||||
letterSpacing: '1px',
|
||||
marginBottom: 12,
|
||||
}}
|
||||
@@ -405,63 +396,24 @@ export const DocsPage: React.FC = () => {
|
||||
|
||||
<div style={{ display: 'flex', gap: 16 }}>
|
||||
{PIPELINE_STEPS.map((step, index) => {
|
||||
const stepStyle = getStepStyle(index);
|
||||
const checkStyle = getCheckStyle(index);
|
||||
const arrowActive = activeStep > index || completedSteps.includes(index);
|
||||
const isCompleted = completedSteps.includes(index);
|
||||
const isActive = activeStep === index;
|
||||
const arrowActive = activeStep > index || isCompleted;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={step.name}
|
||||
style={{
|
||||
flex: 1,
|
||||
padding: 20,
|
||||
textAlign: 'center',
|
||||
borderRadius: 12,
|
||||
position: 'relative',
|
||||
boxShadow: !isDark ? '0 2px 8px rgba(226,0,116,0.04)' : 'none',
|
||||
transition: 'all 0.3s ease',
|
||||
...stepStyle,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
width: 36,
|
||||
height: 36,
|
||||
borderRadius: 8,
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
margin: '0 auto 12px',
|
||||
fontSize: 16,
|
||||
transition: 'all 0.3s ease',
|
||||
...checkStyle,
|
||||
}}
|
||||
style={{ flex: 1, padding: 20, textAlign: 'center', borderRadius: 12, position: 'relative', boxShadow: !isDark ? '0 2px 8px rgba(226,0,116,0.04)' : 'none', transition: 'all 0.3s ease', ...getStepStyle(index) }}
|
||||
>
|
||||
<div style={{ width: 36, height: 36, borderRadius: 8, display: 'flex', alignItems: 'center', justifyContent: 'center', margin: '0 auto 12px', fontSize: 16, transition: 'all 0.3s ease', ...getCheckStyle(index) }}>
|
||||
{isActive ? step.name : isCompleted ? '✓' : step.name}
|
||||
</div>
|
||||
|
||||
<div className="mono" style={{ fontSize: 12, fontWeight: 600 }}>
|
||||
{step.name}
|
||||
</div>
|
||||
<div className="mono" style={{ fontSize: 12, fontWeight: 600 }}>{step.name}</div>
|
||||
<div className="mono" style={{ fontSize: 10, color: theme.text3, marginTop: 8 }}>
|
||||
{isCompleted ? 'DONE' : isActive ? 'RUNNING' : 'PENDING'}
|
||||
</div>
|
||||
|
||||
{index < PIPELINE_STEPS.length - 1 && (
|
||||
<div
|
||||
style={{
|
||||
position: 'absolute',
|
||||
right: -8,
|
||||
top: '50%',
|
||||
transform: 'translateY(-50%)',
|
||||
color: arrowActive ? theme.green : theme.borderLight,
|
||||
fontWeight: arrowActive ? 700 : 400,
|
||||
opacity: arrowActive ? 1 : 0.45,
|
||||
transition: 'all 0.3s ease',
|
||||
}}
|
||||
>
|
||||
<div style={{ position: 'absolute', right: -8, top: '50%', transform: 'translateY(-50%)', color: arrowActive ? theme.green : theme.borderLight, fontWeight: arrowActive ? 700 : 400, opacity: arrowActive ? 1 : 0.45, transition: 'all 0.3s ease' }}>
|
||||
→
|
||||
</div>
|
||||
)}
|
||||
@@ -473,15 +425,7 @@ export const DocsPage: React.FC = () => {
|
||||
|
||||
<section style={{ display: 'grid', gridTemplateColumns: '1fr 1fr', gap: 24, marginBottom: 56 }}>
|
||||
<div>
|
||||
<h2
|
||||
style={{
|
||||
fontSize: 14,
|
||||
fontWeight: 600,
|
||||
color: theme.accent,
|
||||
marginBottom: 20,
|
||||
letterSpacing: '1px',
|
||||
}}
|
||||
>
|
||||
<h2 style={{ fontSize: 14, fontWeight: 600, color: theme.accent, marginBottom: 20, letterSpacing: '1px' }}>
|
||||
文档管理清单 ({loading ? '...' : docs.length})
|
||||
</h2>
|
||||
|
||||
@@ -489,36 +433,12 @@ export const DocsPage: React.FC = () => {
|
||||
{docs.map((doc) => (
|
||||
<div
|
||||
key={doc.id}
|
||||
style={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'space-between',
|
||||
padding: 20,
|
||||
background: theme.bgCard,
|
||||
borderRadius: 12,
|
||||
border: `1px solid ${doc.status === 'parsing' ? theme.accent : theme.border}`,
|
||||
transition: 'all 0.2s ease',
|
||||
boxShadow: !isDark ? '0 2px 8px rgba(226,0,116,0.04)' : 'none',
|
||||
}}
|
||||
>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 16 }}>
|
||||
<div
|
||||
style={{
|
||||
width: 44,
|
||||
height: 44,
|
||||
borderRadius: 10,
|
||||
background: theme.bgHover,
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
}}
|
||||
style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', padding: 20, background: theme.bgCard, borderRadius: 12, border: `1px solid ${doc.status === 'parsing' ? theme.accent : theme.border}`, transition: 'all 0.2s ease', boxShadow: !isDark ? '0 2px 8px rgba(226,0,116,0.04)' : 'none' }}
|
||||
>
|
||||
<div style={{ display: 'flex', alignItems: 'flex-start', gap: 16 }}>
|
||||
<div style={{ width: 44, height: 44, borderRadius: 10, background: theme.bgHover, display: 'flex', alignItems: 'center', justifyContent: 'center', flexShrink: 0 }}>
|
||||
<svg width="20" height="20" viewBox="0 0 24 24" fill="none">
|
||||
<path
|
||||
d="M14 2H6C5 2 4 3 4 4V20C4 21 5 22 6 22H18C19 22 20 21 20 20V8L14 2Z"
|
||||
stroke={theme.accent}
|
||||
strokeWidth="1.5"
|
||||
/>
|
||||
<path d="M14 2H6C5 2 4 3 4 4V20C4 21 5 22 6 22H18C19 22 20 21 20 20V8L14 2Z" stroke={theme.accent} strokeWidth="1.5" />
|
||||
<path d="M14 2V8H20" stroke={theme.accent} strokeWidth="1.5" />
|
||||
</svg>
|
||||
</div>
|
||||
@@ -529,36 +449,48 @@ export const DocsPage: React.FC = () => {
|
||||
{doc.updatedAt ? new Date(doc.updatedAt).toLocaleString() : doc.size}
|
||||
{doc.docId ? ` · ${doc.docId}` : ''}
|
||||
</div>
|
||||
{/* Tags row */}
|
||||
{(doc.regulationType || doc.version) && (
|
||||
<div style={{ display: 'flex', gap: 6, marginTop: 5, flexWrap: 'wrap' }}>
|
||||
{doc.regulationType && (
|
||||
<span style={{ fontSize: 11, padding: '2px 8px', borderRadius: 4, background: `${theme.accent}18`, color: theme.accent, fontWeight: 500 }}>
|
||||
{doc.regulationType}
|
||||
</span>
|
||||
)}
|
||||
{doc.version && (
|
||||
<span style={{ fontSize: 11, padding: '2px 8px', borderRadius: 4, background: theme.bgHover, color: theme.text2 }}>
|
||||
v{doc.version}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
{doc.summary && (
|
||||
<div style={{ fontSize: 12, color: theme.text2, marginTop: 6, lineHeight: 1.5, maxWidth: 320, display: '-webkit-box', WebkitLineClamp: 2, WebkitBoxOrient: 'vertical', overflow: 'hidden' }}>
|
||||
{doc.summary}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 20 }}>
|
||||
{doc.downloadUrl && (
|
||||
<a
|
||||
href={doc.downloadUrl}
|
||||
target="_blank"
|
||||
rel="noreferrer"
|
||||
style={{ fontSize: 12, color: theme.accent, textDecoration: 'none' }}
|
||||
>
|
||||
<div style={{ display: 'flex', alignItems: 'flex-start', gap: 12, flexShrink: 0 }}>
|
||||
{doc.status === 'failed' && doc.docId && !doc.docId.startsWith('pending-') && (
|
||||
<button onClick={() => void handleRetry(doc.docId!)} style={{ fontSize: 12, padding: '6px 12px', background: 'transparent', border: `1px solid ${theme.accent}`, borderRadius: 6, color: theme.accent, cursor: 'pointer' }}>
|
||||
重试
|
||||
</button>
|
||||
)}
|
||||
{doc.downloadUrl && doc.status === 'indexed' && (
|
||||
<a href={doc.downloadUrl} target="_blank" rel="noreferrer" style={{ fontSize: 12, color: theme.accent, textDecoration: 'none' }}>
|
||||
下载
|
||||
</a>
|
||||
)}
|
||||
<div
|
||||
className="mono"
|
||||
style={{
|
||||
fontSize: 12,
|
||||
padding: '6px 12px',
|
||||
background: theme.bgHover,
|
||||
borderRadius: 6,
|
||||
color: doc.status === 'failed' ? '#d64545' : theme.text2,
|
||||
}}
|
||||
>
|
||||
{doc.status === 'parsing'
|
||||
? '处理中...'
|
||||
: doc.status === 'failed'
|
||||
? '处理失败'
|
||||
: `${doc.chunks} chunks`}
|
||||
<div className="mono" style={{ fontSize: 12, padding: '6px 12px', background: theme.bgHover, borderRadius: 6, color: doc.status === 'failed' ? '#d64545' : theme.text2 }}>
|
||||
{doc.status === 'parsing' ? '处理中...' : doc.status === 'failed' ? '处理失败' : `${doc.chunks} chunks`}
|
||||
</div>
|
||||
{doc.docId && !doc.docId.startsWith('pending-') && (
|
||||
<button onClick={() => void handleDelete(doc.docId!)} style={{ fontSize: 12, padding: '6px 10px', background: 'transparent', border: `1px solid ${theme.border}`, borderRadius: 6, color: theme.text3, cursor: 'pointer' }}>
|
||||
删除
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
@@ -566,15 +498,7 @@ export const DocsPage: React.FC = () => {
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<h2
|
||||
style={{
|
||||
fontSize: 14,
|
||||
fontWeight: 600,
|
||||
color: theme.accent,
|
||||
marginBottom: 20,
|
||||
letterSpacing: '1px',
|
||||
}}
|
||||
>
|
||||
<h2 style={{ fontSize: 14, fontWeight: 600, color: theme.accent, marginBottom: 20, letterSpacing: '1px' }}>
|
||||
文档管理内法规检索
|
||||
</h2>
|
||||
|
||||
@@ -582,86 +506,37 @@ export const DocsPage: React.FC = () => {
|
||||
<input
|
||||
value={searchQuery}
|
||||
onChange={(event) => setSearchQuery(event.target.value)}
|
||||
onKeyDown={(event) => {
|
||||
if (event.key === 'Enter') {
|
||||
void runSearch(searchQuery);
|
||||
}
|
||||
}}
|
||||
onKeyDown={(event) => { if (event.key === 'Enter') void runSearch(searchQuery); }}
|
||||
placeholder="输入法规关键词、条款或制度主题"
|
||||
style={{
|
||||
flex: 1,
|
||||
padding: 12,
|
||||
fontSize: 14,
|
||||
background: theme.bgCard,
|
||||
border: `1px solid ${theme.border}`,
|
||||
borderRadius: 8,
|
||||
color: theme.text,
|
||||
outline: 'none',
|
||||
}}
|
||||
style={{ flex: 1, padding: 12, fontSize: 14, background: theme.bgCard, border: `1px solid ${theme.border}`, borderRadius: 8, color: theme.text, outline: 'none' }}
|
||||
/>
|
||||
<button
|
||||
onClick={() => void runSearch(searchQuery)}
|
||||
disabled={searchLoading || !searchQuery.trim()}
|
||||
style={{
|
||||
padding: '12px 18px',
|
||||
fontSize: 13,
|
||||
fontWeight: 600,
|
||||
background: searchLoading || !searchQuery.trim() ? theme.bgHover : theme.gradientAccent,
|
||||
color: searchLoading || !searchQuery.trim() ? theme.text3 : '#fff',
|
||||
border: 'none',
|
||||
borderRadius: 8,
|
||||
cursor: searchLoading || !searchQuery.trim() ? 'not-allowed' : 'pointer',
|
||||
}}
|
||||
style={{ padding: '12px 18px', fontSize: 13, fontWeight: 600, background: searchLoading || !searchQuery.trim() ? theme.bgHover : theme.gradientAccent, color: searchLoading || !searchQuery.trim() ? theme.text3 : '#fff', border: 'none', borderRadius: 8, cursor: searchLoading || !searchQuery.trim() ? 'not-allowed' : 'pointer' }}
|
||||
>
|
||||
检索
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{searchError && (
|
||||
<div style={{ marginBottom: 12, fontSize: 13, color: '#d64545' }}>
|
||||
{searchError}
|
||||
</div>
|
||||
)}
|
||||
{searchError && <div style={{ marginBottom: 12, fontSize: 13, color: '#d64545' }}>{searchError}</div>}
|
||||
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}>
|
||||
{searchResults.map((item) => (
|
||||
<div
|
||||
key={`${item.id}-${item.file}`}
|
||||
style={{
|
||||
padding: 18,
|
||||
background: theme.bgCard,
|
||||
borderRadius: 12,
|
||||
border: `1px solid ${theme.border}`,
|
||||
boxShadow: !isDark ? '0 2px 8px rgba(226,0,116,0.04)' : 'none',
|
||||
}}
|
||||
>
|
||||
<div key={`${item.id}-${item.file}`} style={{ padding: 18, background: theme.bgCard, borderRadius: 12, border: `1px solid ${theme.border}`, boxShadow: !isDark ? '0 2px 8px rgba(226,0,116,0.04)' : 'none' }}>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', gap: 12, marginBottom: 6 }}>
|
||||
<div style={{ fontSize: 14, fontWeight: 600, color: theme.text }}>{item.file}</div>
|
||||
<div className="mono" style={{ fontSize: 11, color: theme.accent }}>
|
||||
{(item.score * 100).toFixed(1)}%
|
||||
<div className="mono" style={{ fontSize: 11, color: theme.accent }}>{(item.score * 100).toFixed(1)}%</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mono" style={{ fontSize: 11, color: theme.text3, marginBottom: 8 }}>
|
||||
{item.clause}
|
||||
{item.tags.length > 0 ? ` · ${item.tags.join(' · ')}` : ''}
|
||||
{item.clause}{item.tags.length > 0 ? ` · ${item.tags.join(' · ')}` : ''}
|
||||
</div>
|
||||
|
||||
<div style={{ fontSize: 12, color: theme.text2, lineHeight: 1.6 }}>{item.content}</div>
|
||||
</div>
|
||||
))}
|
||||
|
||||
{!searchLoading && searchResults.length === 0 && (
|
||||
<div
|
||||
style={{
|
||||
padding: 24,
|
||||
borderRadius: 12,
|
||||
background: theme.bgCard,
|
||||
border: `1px solid ${theme.border}`,
|
||||
textAlign: 'center',
|
||||
color: theme.text3,
|
||||
}}
|
||||
>
|
||||
<div style={{ padding: 24, borderRadius: 12, background: theme.bgCard, border: `1px solid ${theme.border}`, textAlign: 'center', color: theme.text3 }}>
|
||||
暂无检索结果
|
||||
</div>
|
||||
)}
|
||||
|
||||
58
frontend/src/pages/RagChat/CitedAnswer.tsx
Normal file
58
frontend/src/pages/RagChat/CitedAnswer.tsx
Normal file
@@ -0,0 +1,58 @@
|
||||
import React, { useRef } from 'react';
|
||||
import { useTheme } from '../../contexts';
|
||||
import type { RetrievalData } from '../../types';
|
||||
|
||||
interface CitedAnswerProps {
|
||||
text: string;
|
||||
sources: RetrievalData[];
|
||||
onCiteClick: (index: number) => void;
|
||||
}
|
||||
|
||||
export const CitedAnswer: React.FC<CitedAnswerProps> = ({ text, sources, onCiteClick }) => {
|
||||
const { theme } = useTheme();
|
||||
|
||||
if (!text) return null;
|
||||
|
||||
// Split on [N] patterns, preserving delimiters
|
||||
const parts = text.split(/(\[\d+\])/g);
|
||||
|
||||
return (
|
||||
<span style={{ whiteSpace: 'pre-wrap', lineHeight: 1.6 }}>
|
||||
{parts.map((part, i) => {
|
||||
const match = part.match(/^\[(\d+)\]$/);
|
||||
if (!match) return <React.Fragment key={i}>{part}</React.Fragment>;
|
||||
|
||||
const idx = parseInt(match[1], 10);
|
||||
const source = sources[idx - 1];
|
||||
|
||||
return (
|
||||
<sup
|
||||
key={i}
|
||||
title={source ? `${source.file} · ${source.clause}` : `引用 ${idx}`}
|
||||
onClick={() => onCiteClick(idx)}
|
||||
style={{
|
||||
display: 'inline-flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
minWidth: 18,
|
||||
height: 18,
|
||||
padding: '0 4px',
|
||||
marginLeft: 1,
|
||||
fontSize: 10,
|
||||
fontWeight: 700,
|
||||
background: theme.gradientAccent,
|
||||
color: '#fff',
|
||||
borderRadius: 4,
|
||||
cursor: source ? 'pointer' : 'default',
|
||||
verticalAlign: 'super',
|
||||
lineHeight: 1,
|
||||
userSelect: 'none',
|
||||
}}
|
||||
>
|
||||
{idx}
|
||||
</sup>
|
||||
);
|
||||
})}
|
||||
</span>
|
||||
);
|
||||
};
|
||||
@@ -2,6 +2,7 @@ import React, { useEffect, useRef, useState } from 'react';
|
||||
import { useTheme } from '../../contexts';
|
||||
import type { ChatMessage, RetrievalData } from '../../types';
|
||||
import { getQuickQuestions, ragChat } from '../../api/rag';
|
||||
import { CitedAnswer } from './CitedAnswer';
|
||||
|
||||
const ragQuickQuestionsDefault = [
|
||||
'电动自行车上路需要什么条件?',
|
||||
@@ -24,6 +25,8 @@ export const RagChatPage: React.FC = () => {
|
||||
const [showClearConfirm, setShowClearConfirm] = useState<boolean>(false);
|
||||
const [selectedRetrieval, setSelectedRetrieval] = useState<RetrievalData | null>(null);
|
||||
const [quickQuestions, setQuickQuestions] = useState<string[]>(ragQuickQuestionsDefault);
|
||||
const [filterRegulationType, setFilterRegulationType] = useState<string>('');
|
||||
const [highlightedSourceIdx, setHighlightedSourceIdx] = useState<number | null>(null);
|
||||
|
||||
function nextMessageId() {
|
||||
const currentId = nextMessageIdRef.current;
|
||||
@@ -55,8 +58,10 @@ export const RagChatPage: React.FC = () => {
|
||||
setInput('');
|
||||
setLoading(true);
|
||||
setRetrievals([]);
|
||||
setHighlightedSourceIdx(null);
|
||||
|
||||
let currentResponse = '';
|
||||
const activeFilters = filterRegulationType.trim() || undefined;
|
||||
|
||||
void ragChat(
|
||||
text,
|
||||
@@ -112,7 +117,8 @@ export const RagChatPage: React.FC = () => {
|
||||
},
|
||||
() => {
|
||||
setLoading(false);
|
||||
}
|
||||
},
|
||||
activeFilters
|
||||
);
|
||||
};
|
||||
|
||||
@@ -130,8 +136,10 @@ export const RagChatPage: React.FC = () => {
|
||||
setLoading(true);
|
||||
setMessages((prev) => [...prev.slice(0, -1)]);
|
||||
setRetrievals([]);
|
||||
setHighlightedSourceIdx(null);
|
||||
|
||||
let currentResponse = '';
|
||||
const activeFilters = filterRegulationType.trim() || undefined;
|
||||
|
||||
void ragChat(
|
||||
lastUserMsg.content,
|
||||
@@ -185,7 +193,8 @@ export const RagChatPage: React.FC = () => {
|
||||
},
|
||||
() => {
|
||||
setLoading(false);
|
||||
}
|
||||
},
|
||||
activeFilters
|
||||
);
|
||||
};
|
||||
|
||||
@@ -267,7 +276,17 @@ export const RagChatPage: React.FC = () => {
|
||||
whiteSpace: 'pre-wrap',
|
||||
border: msg.role === 'assistant' ? `1px solid ${theme.border}` : 'none',
|
||||
}}>
|
||||
{msg.content}
|
||||
{msg.role === 'assistant' ? (
|
||||
<CitedAnswer
|
||||
text={msg.content}
|
||||
sources={retrievals}
|
||||
onCiteClick={(idx) => {
|
||||
setHighlightedSourceIdx(idx);
|
||||
const el = document.getElementById(`source-${idx}`);
|
||||
if (el) el.scrollIntoView({ behavior: 'smooth', block: 'center' });
|
||||
}}
|
||||
/>
|
||||
) : msg.content}
|
||||
{msg.role === 'assistant' && msg.retrievalIds && msg.retrievalIds.length > 0 && (
|
||||
<div style={{
|
||||
marginTop: 10,
|
||||
@@ -331,6 +350,31 @@ export const RagChatPage: React.FC = () => {
|
||||
background: theme.bg,
|
||||
borderTop: `1px solid ${theme.border}`,
|
||||
}}>
|
||||
<div style={{
|
||||
display: 'flex',
|
||||
gap: 8,
|
||||
marginBottom: 10,
|
||||
alignItems: 'center',
|
||||
}}>
|
||||
<span className="mono" style={{ fontSize: 11, color: theme.text3, whiteSpace: 'nowrap' }}>法规类型</span>
|
||||
<input
|
||||
value={filterRegulationType}
|
||||
onChange={(e) => setFilterRegulationType(e.target.value)}
|
||||
placeholder="如: GB / UN-ECE / IATF(留空不过滤)"
|
||||
style={{
|
||||
flex: 1,
|
||||
maxWidth: 280,
|
||||
padding: '5px 10px',
|
||||
fontSize: 12,
|
||||
background: theme.bgHover,
|
||||
border: `1px solid ${theme.border}`,
|
||||
borderRadius: 6,
|
||||
color: theme.text,
|
||||
outline: 'none',
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div style={{
|
||||
display: 'flex',
|
||||
gap: 8,
|
||||
@@ -468,14 +512,16 @@ export const RagChatPage: React.FC = () => {
|
||||
{retrievals.map((r, i) => (
|
||||
<div
|
||||
key={r.id}
|
||||
id={`source-${i + 1}`}
|
||||
onClick={() => setSelectedRetrieval(r)}
|
||||
style={{
|
||||
padding: 16,
|
||||
background: theme.bgHover,
|
||||
background: highlightedSourceIdx === i + 1 ? theme.bgElevated : theme.bgHover,
|
||||
borderRadius: 10,
|
||||
border: `1px solid ${theme.border}`,
|
||||
border: `1px solid ${highlightedSourceIdx === i + 1 ? theme.accent : theme.border}`,
|
||||
cursor: 'pointer',
|
||||
position: 'relative',
|
||||
transition: 'border-color 0.2s, background 0.2s',
|
||||
}}
|
||||
>
|
||||
<div style={{
|
||||
|
||||
@@ -8,6 +8,8 @@ export interface Doc {
|
||||
downloadUrl?: string;
|
||||
summary?: string;
|
||||
updatedAt?: string;
|
||||
regulationType?: string;
|
||||
version?: string;
|
||||
}
|
||||
|
||||
export interface SearchResult {
|
||||
|
||||
@@ -4,7 +4,7 @@ import react from '@vitejs/plugin-react'
|
||||
// https://vite.dev/config/
|
||||
export default defineConfig(({ mode }) => {
|
||||
const env = loadEnv(mode, process.cwd(), '')
|
||||
const apiHost = env.API_HOST || 'localhost'
|
||||
const apiHost = env.API_HOST || '6.86.80.8'
|
||||
const apiPort = env.API_PORT || '8000'
|
||||
const proxyTarget = env.VITE_API_PROXY_TARGET || `http://${apiHost}:${apiPort}`
|
||||
|
||||
|
||||
Reference in New Issue
Block a user