diff --git a/.env b/.env index d15b01b..01bd0a7 100644 --- a/.env +++ b/.env @@ -7,26 +7,29 @@ APP_VERSION=0.1.0 DEBUG=false # ===== Milvus向量数据库配置(已有)===== -MILVUS_HOST=localhost +MILVUS_HOST=6.86.80.8 MILVUS_PORT=19530 MILVUS_COLLECTION=regulations_dense_1024_v1 MILVUS_DB_NAME=default +MILVUS_INDEX_TYPE=IVF_FLAT +MILVUS_NLIST=128 +MILVUS_NPROBE=16 # ===== MinIO对象存储配置(已有)===== -MINIO_ENDPOINT=localhost:9000 +MINIO_ENDPOINT=6.86.80.8:9000 MINIO_ACCESS_KEY=minioadmin MINIO_SECRET_KEY=minioadmin MINIO_BUCKET=compliance-docs MINIO_SECURE=false # ===== Redis配置(已有)===== -REDIS_HOST=localhost +REDIS_HOST=6.86.80.8 REDIS_PORT=6379 REDIS_PASSWORD=redis@123 REDIS_DB=0 # ===== PostgreSQL配置(已有)===== -POSTGRES_HOST=localhost +POSTGRES_HOST=6.86.80.8 POSTGRES_PORT=5432 POSTGRES_USER=postgresql POSTGRES_PASSWORD=postgresql123456 @@ -43,6 +46,10 @@ EMBEDDING_TIMEOUT_SECONDS=120 CHUNK_SIZE=512 CHUNK_OVERLAP=50 MAX_FILE_SIZE_MB=100 +PARSER_BACKEND=aliyun +CHUNK_BACKEND=aliyun +# 文档元数据存储后端:json(默认)或 postgres +DOCUMENT_REPOSITORY_BACKEND=json # ===== API配置 ===== API_HOST=0.0.0.0 diff --git a/.env.development b/.env.development index 62aa43c..44dd662 100644 --- a/.env.development +++ b/.env.development @@ -6,6 +6,9 @@ MILVUS_HOST=6.86.80.8 MILVUS_PORT=19530 MILVUS_COLLECTION=regulations_dense_1024_v1 MILVUS_DB_NAME=default +MILVUS_INDEX_TYPE=IVF_FLAT +MILVUS_NLIST=128 +MILVUS_NPROBE=16 # ===== MinIO对象存储配置(已有)===== MINIO_ENDPOINT=6.86.80.8:9000 @@ -26,3 +29,7 @@ POSTGRES_PORT=5432 POSTGRES_USER=postgresql POSTGRES_PASSWORD=postgresql123456 POSTGRES_DB=compliance_db + +# ===== 文档元数据后端 ===== +# 改为 postgres 以启用 PG 持久化(structure_nodes + semantic_blocks 入库) +DOCUMENT_REPOSITORY_BACKEND=json diff --git a/.env.example b/.env.example index c581aac..17722ce 100644 --- a/.env.example +++ b/.env.example @@ -7,10 +7,13 @@ APP_VERSION=0.1.0 DEBUG=false # ===== Milvus向量数据库配置 ===== -MILVUS_HOST=localhost +MILVUS_HOST=6.86.80.8 MILVUS_PORT=19530 MILVUS_COLLECTION=regulations_dense_1024_v1 MILVUS_DB_NAME=default +MILVUS_INDEX_TYPE=IVF_FLAT +MILVUS_NLIST=128 +MILVUS_NPROBE=16 # ===== 嵌入模型配置 ===== EMBEDDING_MODEL=text-embedding-v3 @@ -20,20 +23,20 @@ EMBEDDING_BASE_URL=http://6.86.80.4:30080/v1 EMBEDDING_TIMEOUT_SECONDS=120 # ===== MinIO对象存储配置 ===== -MINIO_ENDPOINT=localhost:9000 +MINIO_ENDPOINT=6.86.80.8:9000 MINIO_ACCESS_KEY=minioadmin MINIO_SECRET_KEY=minioadmin123 MINIO_BUCKET=compliance-docs MINIO_SECURE=false # ===== Redis配置 ===== -REDIS_HOST=localhost +REDIS_HOST=6.86.80.8 REDIS_PORT=6379 REDIS_PASSWORD= REDIS_DB=0 # ===== PostgreSQL配置 ===== -POSTGRES_HOST=localhost +POSTGRES_HOST=6.86.80.8 POSTGRES_PORT=5432 POSTGRES_USER=compliance POSTGRES_PASSWORD=compliance123 @@ -46,6 +49,8 @@ MAX_FILE_SIZE_MB=100 DOCUMENT_METADATA_PATH=backend/data/documents.json PARSER_BACKEND=aliyun CHUNK_BACKEND=aliyun +# 文档元数据存储后端:json(默认,无需数据库)或 postgres(启用 PG 持久化) +DOCUMENT_REPOSITORY_BACKEND=json # ===== 阿里云文档解析 ===== ALIBABA_ACCESS_KEY_ID=your_aliyun_access_key_id @@ -88,9 +93,18 @@ DEEPSEEK_MODEL=deepseek-v4-flash # ===== RAG配置 ===== RAG_TOP_K=10 +RAG_RETRIEVAL_TOP_K=20 RAG_MAX_CONTEXT_TOKENS=4000 RAG_SUMMARY_MAX_TOKENS=1024 +# ===== Reranker配置(Cross-Encoder精排,默认关闭)===== +# 设置 RERANKER_ENABLED=true 并配置 RERANKER_BASE_URL 以启用精排 +RERANKER_ENABLED=false +RERANKER_BASE_URL= +RERANKER_MODEL=BAAI/bge-reranker-v2-m3 +RERANKER_API_KEY= +RERANKER_TOP_K=5 + # ===== 会话配置 ===== SESSION_MAX_SESSIONS=100 SESSION_TIMEOUT_MINUTES=30 diff --git a/backend/app/api/models/document.py b/backend/app/api/models/document.py index c3e175f..4fb4a97 100644 --- a/backend/app/api/models/document.py +++ b/backend/app/api/models/document.py @@ -23,6 +23,8 @@ class DocumentUploadResponse(BaseModel): num_chunks: int = Field(default=0, description="分块数量") summary: str = Field(default="", description="LLM生成的文档摘要") summary_latency_ms: int = Field(default=0, description="摘要生成耗时(ms)") + regulation_type: str = Field(default="", description="法规类型") + version: str = Field(default="", description="文档版本号") timestamp: datetime = Field(default_factory=datetime.now, description="时间戳") diff --git a/backend/app/api/routes/compliance.py b/backend/app/api/routes/compliance.py index 62f0acf..813a53a 100644 --- a/backend/app/api/routes/compliance.py +++ b/backend/app/api/routes/compliance.py @@ -14,30 +14,22 @@ from app.schemas.compliance import ( AnalyzeResponse, ComplianceChatRequest, ) -from app.services.mock_data import ( - generate_task_id, - get_mock_compliance_result, - get_mock_compliance_chat_response, -) -# Keep route handlers close to their transport-layer wiring for easier auditing. +from app.services.mock_data import generate_task_id, get_mock_compliance_result +from app.shared.bootstrap import get_agent_conversation_service router = APIRouter(prefix="/compliance", tags=["合规分析"]) -# Keep route handlers close to their transport-layer wiring for easier auditing. tasks_store: dict[str, dict] = {} -# Store uploaded compliance files inside the local backend data directory. RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw" @router.post("/analyze", response_model=AnalyzeResponse) async def analyze_document(file: UploadFile = File(...)): """Handle analyze document.""" - # Keep route handlers close to their transport-layer wiring for easier auditing. task_id = generate_task_id() - # Keep route handlers close to their transport-layer wiring for easier auditing. RAW_DATA_DIR.mkdir(parents=True, exist_ok=True) file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}" @@ -45,7 +37,6 @@ async def analyze_document(file: UploadFile = File(...)): with file_path.open("wb") as f: f.write(content) - # Keep route handlers close to their transport-layer wiring for easier auditing. tasks_store[task_id] = { "task_id": task_id, "file_path": str(file_path), @@ -53,8 +44,6 @@ async def analyze_document(file: UploadFile = File(...)): "result": None, } - # Keep route handlers close to their transport-layer wiring for easier auditing. - # Keep route handlers close to their transport-layer wiring for easier auditing. tasks_store[task_id]["status"] = "completed" tasks_store[task_id]["result"] = get_mock_compliance_result(task_id) @@ -65,47 +54,44 @@ async def analyze_document(file: UploadFile = File(...)): async def get_result(task_id: str): """Return result.""" if task_id not in tasks_store: - # Keep route handlers close to their transport-layer wiring for easier auditing. return get_mock_compliance_result(task_id) task = tasks_store[task_id] - if task["status"] == "processing": return {"status": "processing", "message": "分析进行中"} - return task["result"] @router.post("/chat/{segment_id}") async def compliance_chat(segment_id: int, request: ComplianceChatRequest): - """Handle compliance chat.""" - # Keep route handlers close to their transport-layer wiring for easier auditing. - intent_map = { - 1: "车身结构设计", - 2: "动力系统配置", - 3: "安全配置设计", - } - intent = intent_map.get(segment_id, "车身结构设计") + """Stream compliance Q&A grounded in real vector retrieval.""" + query = request.query + if request.segment_context: + query = f"[段落分析上下文]\n{request.segment_context}\n\n用户问题:{request.query}" + + _, event_stream = get_agent_conversation_service().stream_chat( + query=query, + top_k=5, + prompt_template="compliance_qa", + ) async def generate() -> AsyncGenerator[str, None]: - # Keep route handlers close to their transport-layer wiring for easier auditing. - """Handle generate.""" - response = get_mock_compliance_chat_response(intent, request.query) - - # Keep route handlers close to their transport-layer wiring for easier auditing. - sentences = response.split("\n\n") - for sentence in sentences: - if sentence.strip(): - chunks = sentence.split("\n") - for chunk in chunks: - if chunk.strip(): - await asyncio.sleep(0.05) - yield ( - "event: message\n" - f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n" - ) - - yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" + """Translate agent SSE events to compliance chunk/done format.""" + for event in event_stream: + event_type = event.get("event", "") + if event_type == "content": + text = event.get("data", "") + if text: + yield ( + "event: message\n" + f"data: {json.dumps({'type': 'chunk', 'text': text}, ensure_ascii=False)}\n\n" + ) + elif event_type == "done": + yield ( + "event: message\n" + f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" + ) + await asyncio.sleep(0) return StreamingResponse( generate(), diff --git a/backend/app/api/routes/documents.py b/backend/app/api/routes/documents.py index e211796..e7f8577 100644 --- a/backend/app/api/routes/documents.py +++ b/backend/app/api/routes/documents.py @@ -80,6 +80,8 @@ async def get_document_status(doc_id: str): num_chunks=document.chunk_count, summary=document.summary, summary_latency_ms=document.summary_latency_ms, + regulation_type=document.regulation_type, + version=document.version, ) @@ -123,7 +125,7 @@ async def list_documents(): @router.get("/management-list") async def get_document_management_list(): """Return document management list.""" - documents = get_document_query_service().list_documents(limit=10) + documents = get_document_query_service().list_documents() return { "documents": [ { @@ -131,10 +133,37 @@ async def get_document_management_list(): "doc_name": item.doc_name, "status": item.status.value, "chunk_count": item.chunk_count, + "size_bytes": item.size_bytes, + "summary": item.summary, "updated_at": item.updated_at.isoformat(), + "regulation_type": item.regulation_type, + "version": item.version, } for item in documents ], "total": len(documents), - "limit": 10, } + + +@router.delete("/{doc_id}") +async def delete_document(doc_id: str): + """Delete a document and its associated data.""" + deleted = get_document_command_service().delete(doc_id) + if not deleted: + raise HTTPException(status_code=404, detail="文档不存在") + return {"doc_id": doc_id, "deleted": True} + + +@router.post("/{doc_id}/retry", response_model=DocumentUploadResponse) +async def retry_document(doc_id: str): + """Re-process a failed document.""" + try: + result = get_document_command_service().retry(doc_id) + if result.status == "failed": + raise HTTPException(status_code=500, detail=result.message) + return _document_response(result) + except HTTPException: + raise + except Exception as exc: + logger.exception("文档重试失败") + raise HTTPException(status_code=500, detail=str(exc)) diff --git a/backend/app/api/routes/rag.py b/backend/app/api/routes/rag.py index 58df8cc..bd465dc 100644 --- a/backend/app/api/routes/rag.py +++ b/backend/app/api/routes/rag.py @@ -9,73 +9,74 @@ from typing import AsyncGenerator from fastapi import APIRouter from fastapi.responses import StreamingResponse +from app.config.settings import settings from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion -from app.services.mock_data import ( - get_mock_quick_questions, - get_mock_retrieval, - get_mock_rag_answer, -) -# Keep route handlers close to their transport-layer wiring for easier auditing. +from app.shared.bootstrap import get_agent_conversation_service router = APIRouter(prefix="/rag", tags=["RAG问答"]) +_DEFAULT_QUICK_QUESTIONS = [ + {"id": "1", "question": "请总结最新入库法规对电池安全的核心要求", "category": "法规解读"}, + {"id": "2", "question": "我上传的制度文档与新能源法规有哪些潜在冲突?", "category": "差距分析"}, + {"id": "3", "question": "请给出法规依据,并按条款列出整改建议", "category": "整改建议"}, + {"id": "4", "question": "请解释 UN-ECE 与 GB 标准在网络安全方面的差异", "category": "标准对比"}, + {"id": "5", "question": "IATF 16949 对供应商质量管理有哪些强制要求?", "category": "法规解读"}, + {"id": "6", "question": "ISO 45001 与 AQ 标准在职业健康安全方面的主要差异是什么?", "category": "标准对比"}, +] + @router.post("/chat") async def rag_chat(request: RagChatRequest): - """Handle rag chat.""" + """Stream RAG Q&A using the real agent service.""" + _, event_stream = get_agent_conversation_service().stream_chat( + query=request.query, + session_id=request.session_id, + filters=request.filters, + top_k=request.top_k or settings.rag_top_k, + ) async def generate() -> AsyncGenerator[str, None]: - # Keep route handlers close to their transport-layer wiring for easier auditing. - """Handle generate.""" - yield f"event: message\ndata: {json.dumps({'type': 'retrieving'}, ensure_ascii=False)}\n\n" - - # Keep route handlers close to their transport-layer wiring for easier auditing. - await asyncio.sleep(0.3) - - # Keep route handlers close to their transport-layer wiring for easier auditing. - docs = get_mock_retrieval(request.query, top_k=request.top_k) - - retrieved_data = [ - { - "id": d["id"], - "score": d["score"], - "preview": d["preview"], - "doc_name": d.get("doc_name", ""), - "clause": d.get("clause", ""), - } - for d in docs - ] - yield f"event: message\ndata: {json.dumps({'type': 'retrieved', 'docs': retrieved_data}, ensure_ascii=False)}\n\n" - - # Keep route handlers close to their transport-layer wiring for easier auditing. - yield ( - f"event: message\ndata: " - f"{json.dumps({'type': 'generating', 'text': '正在生成答案...'}, ensure_ascii=False)}\n\n" - ) - - # Keep route handlers close to their transport-layer wiring for easier auditing. - await asyncio.sleep(0.2) - - # Keep route handlers close to their transport-layer wiring for easier auditing. - answer = get_mock_rag_answer(request.query) - - # Keep route handlers close to their transport-layer wiring for easier auditing. - sentences = answer.split("\n\n") - for sentence in sentences: - if sentence.strip(): - # Keep route handlers close to their transport-layer wiring for easier auditing. - chunks = sentence.split("\n") - for chunk in chunks: - if chunk.strip(): - await asyncio.sleep(0.05) # Keep route handlers close to their transport-layer wiring for easier auditing. - yield ( - "event: message\n" - f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n" - ) - - # Keep route handlers close to their transport-layer wiring for easier auditing. - yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" + """Translate agent SSE events to rag format.""" + for event in event_stream: + event_type = event.get("event", "") + data = event.get("data", "") + if event_type == "sources": + docs = [ + { + "id": str(s.get("chunk_id") or s.get("doc_id") or idx + 1), + "score": s.get("score", 0), + "preview": s.get("content", "")[:200], + "doc_name": s.get("doc_name", ""), + "clause": s.get("section_title", "法规片段"), + "doc_id": s.get("doc_id"), + "download_url": ( + f"/api/v1/documents/download/{s['doc_id']}" if s.get("doc_id") else None + ), + } + for idx, s in enumerate(data if isinstance(data, list) else []) + ] + yield ( + "event: message\n" + f"data: {json.dumps({'type': 'retrieved', 'docs': docs}, ensure_ascii=False)}\n\n" + ) + elif event_type == "content": + if data: + yield ( + "event: message\n" + f"data: {json.dumps({'type': 'chunk', 'text': data}, ensure_ascii=False)}\n\n" + ) + elif event_type == "done": + yield ( + "event: message\n" + f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" + ) + elif event_type == "status": + yield ( + "event: message\n" + f"data: {json.dumps({'type': 'status', 'text': data}, ensure_ascii=False)}\n\n" + ) + await asyncio.sleep(0) return StreamingResponse( generate(), @@ -86,9 +87,13 @@ async def rag_chat(request: RagChatRequest): @router.get("/quick-questions", response_model=QuickQuestionsResponse) async def get_quick_questions(): - """Return quick questions.""" - questions = [ - QuickQuestion(id=q["id"], question=q["question"], category=q["category"]) - for q in get_mock_quick_questions() - ] + """Return configurable quick questions from settings or defaults.""" + raw = getattr(settings, "rag_quick_questions", None) + if raw and isinstance(raw, list): + questions = [ + QuickQuestion(id=str(i + 1), question=q if isinstance(q, str) else q.get("question", ""), category=q.get("category", "法规问答") if isinstance(q, dict) else "法规问答") + for i, q in enumerate(raw) + ] + else: + questions = [QuickQuestion(**q) for q in _DEFAULT_QUICK_QUESTIONS] return QuickQuestionsResponse(questions=questions) diff --git a/backend/app/application/documents/services.py b/backend/app/application/documents/services.py index b6bbfd6..2a05870 100644 --- a/backend/app/application/documents/services.py +++ b/backend/app/application/documents/services.py @@ -17,6 +17,7 @@ from app.domain.documents import ( DocumentParser, DocumentRepository, DocumentStatus, + ParseArtifactStore, ParsedDocument, ) from app.domain.retrieval import EmbeddingProvider, VectorIndex @@ -47,6 +48,7 @@ class DocumentCommandService: chunk_builder: ChunkBuilder, embedding_provider: EmbeddingProvider, vector_index: VectorIndex, + parse_artifact_store: ParseArtifactStore | None = None, ) -> None: """Initialize the Document Command Service instance.""" self.document_repository = document_repository @@ -55,6 +57,7 @@ class DocumentCommandService: self.chunk_builder = chunk_builder self.embedding_provider = embedding_provider self.vector_index = vector_index + self.parse_artifact_store = parse_artifact_store def _save_parse_artifacts(self, *, doc_id: str, parsed_document: ParsedDocument) -> dict[str, str]: """Persist parse artifacts so troubleshooting does not depend on provider retention windows.""" @@ -143,6 +146,15 @@ class DocumentCommandService: "processing_stage": "parsed", }, ) + if self.parse_artifact_store: + try: + self.parse_artifact_store.save( + doc_id, + parsed_document.structure_nodes, + parsed_document.semantic_blocks, + ) + except Exception: + logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id) chunks = self.chunk_builder.build( parsed_document=parsed_document, @@ -205,20 +217,120 @@ class DocumentCommandService: logger.warning("临时文件清理失败: {}", temp_path) + def delete(self, doc_id: str) -> bool: + """Delete document record, binary file, and vector chunks.""" + document = self.document_repository.get(doc_id) + if not document: + return False + try: + self.binary_store.delete(document.object_name) + except Exception: + logger.warning("Binary delete failed for doc_id={}", doc_id) + try: + self.vector_index.delete_by_document(doc_id) + except Exception: + logger.warning("Vector delete failed for doc_id={}", doc_id) + if self.parse_artifact_store: + try: + self.parse_artifact_store.delete(doc_id) + except Exception: + logger.warning("ParseArtifactStore delete failed for doc_id={}", doc_id) + self.document_repository.delete(doc_id) + return True + + def retry(self, doc_id: str) -> DocumentProcessResult: + """Re-process a failed document from its stored binary.""" + document = self.document_repository.get(doc_id) + if not document: + return DocumentProcessResult(doc_id=doc_id, doc_name="", status="failed", message="文档不存在") + content = self.binary_store.read(document.object_name) + return self.upload_and_process( + doc_id=doc_id, + file_name=document.file_name, + content=content, + content_type=document.content_type, + doc_name=document.doc_name, + regulation_type=document.regulation_type, + version=document.version, + generate_summary=bool(document.metadata.get("generate_summary", False)), + ) + + class DocumentQueryService: """Provide the Document Query Service service.""" - def __init__(self, *, document_repository: DocumentRepository, binary_store: DocumentBinaryStore) -> None: + def __init__(self, *, document_repository: DocumentRepository, binary_store: DocumentBinaryStore, vector_index: VectorIndex) -> None: """Initialize the Document Query Service instance.""" self.document_repository = document_repository self.binary_store = binary_store + self.vector_index = vector_index def get(self, doc_id: str) -> Document | None: """Handle get for the Document Query Service instance.""" return self.document_repository.get(doc_id) def list_documents(self, limit: int | None = None) -> list[Document]: - """List documents for the Document Query Service instance.""" - return self.document_repository.list(limit=limit) + """Return documents with real-time state from Milvus as the authoritative source. + + Algorithm: + 1. Query Milvus for all doc metadata (doc_id, doc_name, chunk_count, …). + 2. Load JSON/PG metadata records and index them by doc_id. + 3. Merge: Milvus-present docs get status=INDEXED and live chunk_count; + metadata-only docs with status=INDEXED are demoted to FAILED. + 4. Milvus-only docs (no metadata record) are surfaced as synthetic INDEXED + entries so they are never invisible to the management list. + """ + # Fetch live Milvus state first. + try: + milvus_rows = self.vector_index.list_document_metadata() + except Exception: + milvus_rows = [] + + milvus_by_id: dict[str, dict] = {r["doc_id"]: r for r in milvus_rows} + + # Load metadata store records. + meta_docs = self.document_repository.list(limit=limit) + meta_by_id: dict[str, Document] = {d.doc_id: d for d in meta_docs} + + result: list[Document] = [] + + # Reconcile metadata records against Milvus. + for doc in meta_docs: + if doc.doc_id in milvus_by_id: + row = milvus_by_id[doc.doc_id] + doc.chunk_count = row["chunk_count"] + doc.status = DocumentStatus.INDEXED + # Backfill fields that may be missing from older JSON records. + if not doc.doc_name and row.get("doc_name"): + doc.doc_name = row["doc_name"] + if not doc.regulation_type and row.get("regulation_type"): + doc.regulation_type = row["regulation_type"] + if not doc.version and row.get("version"): + doc.version = row["version"] + elif doc.status == DocumentStatus.INDEXED: + # Metadata says indexed but Milvus has no chunks. + doc.status = DocumentStatus.FAILED + doc.error_message = "向量数据库中未找到对应数据" + result.append(doc) + + # Surface Milvus-only docs that have no metadata record at all. + for doc_id, row in milvus_by_id.items(): + if doc_id not in meta_by_id: + synthetic = Document( + doc_id=doc_id, + doc_name=row.get("doc_name", doc_id), + file_name=row.get("doc_name", doc_id), + object_name="", + content_type="", + size_bytes=0, + status=DocumentStatus.INDEXED, + regulation_type=row.get("regulation_type", ""), + version=row.get("version", ""), + chunk_count=row["chunk_count"], + ) + result.append(synthetic) + + result.sort(key=lambda d: d.updated_at, reverse=True) + return result[:limit] if limit is not None else result def download(self, doc_id: str) -> tuple[Document, bytes]: """Handle download for the Document Query Service instance.""" diff --git a/backend/app/application/knowledge/services.py b/backend/app/application/knowledge/services.py index ae98969..b8d3685 100644 --- a/backend/app/application/knowledge/services.py +++ b/backend/app/application/knowledge/services.py @@ -3,17 +3,24 @@ from __future__ import annotations from app.domain.retrieval import RetrievalQuery, Retriever, RetrievedChunk +from app.domain.retrieval.ports import Reranker # Keep orchestration logic centralized so use-case flow stays easy to trace. - class KnowledgeRetrievalService: """Provide the Knowledge Retrieval Service service.""" - def __init__(self, *, retriever: Retriever) -> None: + + def __init__(self, *, retriever: Retriever, reranker: Reranker | None = None, reranker_top_k: int = 5) -> None: """Initialize the Knowledge Retrieval Service instance.""" self.retriever = retriever + self.reranker = reranker + self.reranker_top_k = reranker_top_k def retrieve(self, *, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]: - """Handle retrieve for the Knowledge Retrieval Service instance.""" - retrieval_query = RetrievalQuery(query=query, top_k=top_k, filters=filters) - return self.retriever.retrieve(retrieval_query) + """Retrieve and optionally rerank chunks for a query.""" + candidate_k = top_k if self.reranker is None else max(top_k * 4, 20) + retrieval_query = RetrievalQuery(query=query, top_k=candidate_k, filters=filters) + candidates = self.retriever.retrieve(retrieval_query) + if self.reranker and candidates: + return self.reranker.rerank(query, candidates, top_k=self.reranker_top_k) + return candidates[:top_k] diff --git a/backend/app/config/settings.py b/backend/app/config/settings.py index f43953f..b8b9a3c 100644 --- a/backend/app/config/settings.py +++ b/backend/app/config/settings.py @@ -31,7 +31,7 @@ class Settings(BaseSettings): debug: bool = Field(default=False, description="调试模式") # Keep configuration setup explicit so runtime behavior is easy to reason about. - milvus_host: str = Field(default="localhost", description="Milvus服务地址") + milvus_host: str = Field(default="6.86.80.8", description="Milvus服务地址") milvus_port: int = Field(default=19530, description="Milvus服务端口") milvus_collection: str = Field(default="regulations_dense_1024_v1", description="法规向量集合名称") milvus_db_name: str = Field(default="default", description="Milvus数据库名称") @@ -54,20 +54,20 @@ class Settings(BaseSettings): parser_failure_mode: str = Field(default="fail", description="解析失败策略") # Keep configuration setup explicit so runtime behavior is easy to reason about. - minio_endpoint: str = Field(default="localhost:9000", description="MinIO服务地址") + minio_endpoint: str = Field(default="6.86.80.8:9000", description="MinIO服务地址") minio_access_key: str = Field(default="minioadmin", description="MinIO访问密钥") minio_secret_key: str = Field(default="minioadmin123", description="MinIO秘密密钥") minio_bucket: str = Field(default="upload-files", description="文档存储桶名称") minio_secure: bool = Field(default=False, description="是否使用HTTPS") # Keep configuration setup explicit so runtime behavior is easy to reason about. - redis_host: str = Field(default="localhost", description="Redis服务地址") + redis_host: str = Field(default="6.86.80.8", description="Redis服务地址") redis_port: int = Field(default=6379, description="Redis服务端口") redis_password: str = Field(default="", description="Redis密码") redis_db: int = Field(default=0, description="Redis数据库编号") # Keep configuration setup explicit so runtime behavior is easy to reason about. - postgres_host: str = Field(default="localhost", description="PostgreSQL服务地址") + postgres_host: str = Field(default="6.86.80.8", description="PostgreSQL服务地址") postgres_port: int = Field(default=5432, description="PostgreSQL服务端口") postgres_user: str = Field(default="compliance", description="PostgreSQL用户名") postgres_password: str = Field(default="compliance123", description="PostgreSQL密码") @@ -80,6 +80,7 @@ class Settings(BaseSettings): document_metadata_path: str = Field(default="backend/data/documents.json", description="文档元数据存储路径") parser_backend: str = Field(default="aliyun", description="解析后端(local/aliyun)") chunk_backend: str = Field(default="aliyun", description="分块后端(local/aliyun)") + document_repository_backend: str = Field(default="json", description="文档元数据存储后端 (json/postgres)") # Keep configuration setup explicit so runtime behavior is easy to reason about. api_host: str = Field(default="0.0.0.0", description="API服务地址") @@ -104,9 +105,16 @@ class Settings(BaseSettings): # Keep configuration setup explicit so runtime behavior is easy to reason about. rag_top_k: int = Field(default=5, description="检索召回数量") + rag_retrieval_top_k: int = Field(default=20, description="精排前召回候选数量(reranker 启用时生效)") rag_max_context_tokens: int = Field(default=2000, description="RAG最大上下文token数") rag_summary_max_tokens: int = Field(default=10240, description="文档摘要最大token数") + reranker_enabled: bool = Field(default=False, description="是否启用 Cross-Encoder 精排") + reranker_base_url: str = Field(default="", description="Reranker API 地址") + reranker_model: str = Field(default="BAAI/bge-reranker-v2-m3", description="Reranker 模型名称") + reranker_api_key: str = Field(default="", description="Reranker API 密钥") + reranker_top_k: int = Field(default=5, description="精排后保留的最终结果数量") + # Keep configuration setup explicit so runtime behavior is easy to reason about. milvus_index_type: str = Field(default="IVF_FLAT", description="Milvus索引类型") milvus_nlist: int = Field(default=128, description="Milvus nlist参数") diff --git a/backend/app/core/config.py b/backend/app/core/config.py index e4535ce..730d3a9 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -25,7 +25,7 @@ class Settings(BaseSettings): dashscope_api_key: str = "" # Milvus - milvus_host: str = "localhost" + milvus_host: str = "6.86.80.8" milvus_port: int = 19530 milvus_collection: str = "regulations_dense_1024_v1" diff --git a/backend/app/domain/documents/__init__.py b/backend/app/domain/documents/__init__.py index d4e36ea..2fcaabc 100644 --- a/backend/app/domain/documents/__init__.py +++ b/backend/app/domain/documents/__init__.py @@ -1,7 +1,7 @@ """Initialize the app.domain.documents package.""" from .models import Chunk, Document, DocumentStatus, ParsedDocument -from .ports import ChunkBuilder, DocumentBinaryStore, DocumentParser, DocumentRepository +from .ports import ChunkBuilder, DocumentBinaryStore, DocumentParser, DocumentRepository, ParseArtifactStore # Keep package boundaries explicit so backend imports stay predictable. @@ -14,4 +14,5 @@ __all__ = [ "DocumentBinaryStore", "DocumentParser", "DocumentRepository", + "ParseArtifactStore", ] diff --git a/backend/app/domain/documents/ports.py b/backend/app/domain/documents/ports.py index 6cbb715..6807f86 100644 --- a/backend/app/domain/documents/ports.py +++ b/backend/app/domain/documents/ports.py @@ -31,6 +31,11 @@ class DocumentRepository(ABC): """Handle list for the Document Repository instance.""" pass + @abstractmethod + def delete(self, doc_id: str) -> bool: + """Delete a document record. Returns True if deleted, False if not found.""" + pass + @abstractmethod def update_status( self, @@ -94,3 +99,32 @@ class ChunkBuilder(ABC): ) -> list[Chunk]: """Handle build for the Chunk Builder instance.""" pass + + +class ParseArtifactStore(ABC): + """Persist parse artifacts (structure nodes and semantic blocks) for relational queries.""" + + @abstractmethod + def save( + self, + doc_id: str, + structure_nodes: list[dict], + semantic_blocks: list[dict], + ) -> None: + """Persist structure nodes and semantic blocks for a document.""" + pass + + @abstractmethod + def delete(self, doc_id: str) -> None: + """Remove all parse artifacts for a document.""" + pass + + @abstractmethod + def get_semantic_blocks(self, doc_id: str) -> list[dict]: + """Return all semantic blocks for a document.""" + pass + + @abstractmethod + def get_structure_nodes(self, doc_id: str) -> list[dict]: + """Return all structure nodes for a document.""" + pass diff --git a/backend/app/domain/retrieval/__init__.py b/backend/app/domain/retrieval/__init__.py index 56426f6..e03de34 100644 --- a/backend/app/domain/retrieval/__init__.py +++ b/backend/app/domain/retrieval/__init__.py @@ -1,8 +1,8 @@ """Initialize the app.domain.retrieval package.""" from .models import RetrievalQuery, RetrievedChunk -from .ports import EmbeddingProvider, Retriever, VectorIndex +from .ports import EmbeddingProvider, Reranker, Retriever, VectorIndex # Keep package boundaries explicit so backend imports stay predictable. -__all__ = ["RetrievalQuery", "RetrievedChunk", "EmbeddingProvider", "Retriever", "VectorIndex"] +__all__ = ["RetrievalQuery", "RetrievedChunk", "EmbeddingProvider", "Reranker", "Retriever", "VectorIndex"] diff --git a/backend/app/domain/retrieval/ports.py b/backend/app/domain/retrieval/ports.py index 8d5e2c1..1edb9d7 100644 --- a/backend/app/domain/retrieval/ports.py +++ b/backend/app/domain/retrieval/ports.py @@ -10,7 +10,6 @@ from .models import RetrievalQuery, RetrievedChunk # Keep domain contracts explicit so adapters can swap implementations cleanly. - class EmbeddingProvider(ABC): """Provide the Embedding Provider provider.""" @abstractmethod @@ -41,12 +40,35 @@ class VectorIndex(ABC): """Handle search for the Vector Index instance.""" pass + @abstractmethod + def count_by_document(self) -> dict[str, int]: + """Return a mapping of doc_id -> chunk count from the vector store.""" + pass + + @abstractmethod + def list_document_metadata(self) -> list[dict]: + """Return per-document metadata rows from the vector store. + + Each row contains at minimum: doc_id, doc_name, chunk_count. + Optional fields: regulation_type, version. + """ + pass + @abstractmethod def health(self) -> dict: """Handle health for the Vector Index instance.""" pass +class Reranker(ABC): + """Re-score and re-order a candidate list using a cross-encoder model.""" + + @abstractmethod + def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]: + """Return top_k chunks sorted by cross-encoder score (descending).""" + pass + + class Retriever(ABC): """Provide the Retriever retriever.""" @abstractmethod diff --git a/backend/app/infrastructure/parser/aliyun_layout_normalizer.py b/backend/app/infrastructure/parser/aliyun_layout_normalizer.py index c6a8ae3..5aaaedb 100644 --- a/backend/app/infrastructure/parser/aliyun_layout_normalizer.py +++ b/backend/app/infrastructure/parser/aliyun_layout_normalizer.py @@ -289,7 +289,7 @@ def build_vector_chunks( { "doc_id": doc_id, "doc_title": doc_title, - "chunk_id": f"chunk-{chunk_index}", + "chunk_id": f"{doc_id}-chunk-{chunk_index}", "chunk_index": chunk_index, "semantic_id": block["semantic_id"], "chunk_type": block["block_type"], diff --git a/backend/app/infrastructure/storage/json_document_repository.py b/backend/app/infrastructure/storage/json_document_repository.py index 3fc910f..6eca725 100644 --- a/backend/app/infrastructure/storage/json_document_repository.py +++ b/backend/app/infrastructure/storage/json_document_repository.py @@ -75,6 +75,15 @@ class JsonDocumentRepository(DocumentRepository): documents.sort(key=lambda item: item.updated_at, reverse=True) return documents[:limit] if limit is not None else documents + def delete(self, doc_id: str) -> bool: + """Delete a document record.""" + payload = self._load() + if doc_id not in payload: + return False + del payload[doc_id] + self._save(payload) + return True + def update_status( self, doc_id: str, diff --git a/backend/app/infrastructure/storage/postgres_document_repository.py b/backend/app/infrastructure/storage/postgres_document_repository.py new file mode 100644 index 0000000..2536c28 --- /dev/null +++ b/backend/app/infrastructure/storage/postgres_document_repository.py @@ -0,0 +1,228 @@ +"""Implement infrastructure support for postgres document repository.""" + +from __future__ import annotations + +import json +from contextlib import contextmanager +from datetime import UTC, datetime +from typing import Any + +import psycopg2 +import psycopg2.extras +from psycopg2.pool import ThreadedConnectionPool + +from app.config.settings import settings +from app.domain.documents import Document, DocumentRepository, DocumentStatus + +_CREATE_TABLE = """ +CREATE TABLE IF NOT EXISTS documents ( + doc_id VARCHAR(128) PRIMARY KEY, + doc_name VARCHAR(512) NOT NULL DEFAULT '', + file_name VARCHAR(512) NOT NULL DEFAULT '', + object_name VARCHAR(1024) NOT NULL DEFAULT '', + content_type VARCHAR(128) NOT NULL DEFAULT '', + size_bytes BIGINT NOT NULL DEFAULT 0, + status VARCHAR(32) NOT NULL DEFAULT 'pending', + regulation_type VARCHAR(128) NOT NULL DEFAULT '', + version VARCHAR(64) NOT NULL DEFAULT '', + summary TEXT NOT NULL DEFAULT '', + summary_latency_ms INTEGER NOT NULL DEFAULT 0, + chunk_count INTEGER NOT NULL DEFAULT 0, + parser_name VARCHAR(128) NOT NULL DEFAULT '', + index_name VARCHAR(128) NOT NULL DEFAULT '', + error_message TEXT NOT NULL DEFAULT '', + metadata JSONB NOT NULL DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +""" + +_COLUMNS = ( + "doc_id", "doc_name", "file_name", "object_name", "content_type", + "size_bytes", "status", "regulation_type", "version", "summary", + "summary_latency_ms", "chunk_count", "parser_name", "index_name", + "error_message", "metadata", "created_at", "updated_at", +) + + +class PostgresDocumentRepository(DocumentRepository): + """DocumentRepository implementation backed by PostgreSQL.""" + + def __init__(self) -> None: + self._pool = ThreadedConnectionPool( + minconn=1, + maxconn=5, + host=settings.postgres_host, + port=settings.postgres_port, + user=settings.postgres_user, + password=settings.postgres_password, + dbname=settings.postgres_db, + ) + self._ensure_schema() + + def _ensure_schema(self) -> None: + with self._conn() as conn: + with conn.cursor() as cur: + cur.execute(_CREATE_TABLE) + conn.commit() + + @contextmanager + def _conn(self): + conn = self._pool.getconn() + try: + yield conn + finally: + self._pool.putconn(conn) + + # ------------------------------------------------------------------ + # Serialization helpers + # ------------------------------------------------------------------ + + def _row_to_document(self, row: dict[str, Any]) -> Document: + return Document( + doc_id=row["doc_id"], + doc_name=row["doc_name"], + file_name=row["file_name"], + object_name=row["object_name"], + content_type=row["content_type"], + size_bytes=row["size_bytes"], + status=DocumentStatus(row["status"]), + regulation_type=row["regulation_type"], + version=row["version"], + summary=row["summary"], + summary_latency_ms=row["summary_latency_ms"], + chunk_count=row["chunk_count"], + parser_name=row["parser_name"], + index_name=row["index_name"], + error_message=row["error_message"], + metadata=row["metadata"] if isinstance(row["metadata"], dict) else json.loads(row["metadata"] or "{}"), + created_at=row["created_at"], + updated_at=row["updated_at"], + ) + + # ------------------------------------------------------------------ + # DocumentRepository interface + # ------------------------------------------------------------------ + + def create(self, document: Document) -> Document: + sql = """ + INSERT INTO documents + (doc_id, doc_name, file_name, object_name, content_type, size_bytes, + status, regulation_type, version, summary, summary_latency_ms, + chunk_count, parser_name, index_name, error_message, metadata, + created_at, updated_at) + VALUES + (%(doc_id)s, %(doc_name)s, %(file_name)s, %(object_name)s, %(content_type)s, + %(size_bytes)s, %(status)s, %(regulation_type)s, %(version)s, %(summary)s, + %(summary_latency_ms)s, %(chunk_count)s, %(parser_name)s, %(index_name)s, + %(error_message)s, %(metadata)s, %(created_at)s, %(updated_at)s) + ON CONFLICT (doc_id) DO NOTHING + """ + with self._conn() as conn: + with conn.cursor() as cur: + cur.execute(sql, self._to_params(document)) + conn.commit() + return document + + def update(self, document: Document) -> Document: + document.updated_at = datetime.now(UTC) + sql = """ + UPDATE documents SET + doc_name=%(doc_name)s, file_name=%(file_name)s, object_name=%(object_name)s, + content_type=%(content_type)s, size_bytes=%(size_bytes)s, status=%(status)s, + regulation_type=%(regulation_type)s, version=%(version)s, summary=%(summary)s, + summary_latency_ms=%(summary_latency_ms)s, chunk_count=%(chunk_count)s, + parser_name=%(parser_name)s, index_name=%(index_name)s, + error_message=%(error_message)s, metadata=%(metadata)s, updated_at=%(updated_at)s + WHERE doc_id=%(doc_id)s + """ + with self._conn() as conn: + with conn.cursor() as cur: + cur.execute(sql, self._to_params(document)) + conn.commit() + return document + + def get(self, doc_id: str) -> Document | None: + sql = "SELECT * FROM documents WHERE doc_id = %s" + with self._conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: + cur.execute(sql, (doc_id,)) + row = cur.fetchone() + return self._row_to_document(dict(row)) if row else None + + def list(self, limit: int | None = None) -> list[Document]: + sql = "SELECT * FROM documents ORDER BY updated_at DESC" + if limit is not None: + sql += f" LIMIT {int(limit)}" + with self._conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: + cur.execute(sql) + rows = cur.fetchall() + return [self._row_to_document(dict(r)) for r in rows] + + def delete(self, doc_id: str) -> bool: + sql = "DELETE FROM documents WHERE doc_id = %s" + with self._conn() as conn: + with conn.cursor() as cur: + cur.execute(sql, (doc_id,)) + deleted = cur.rowcount > 0 + conn.commit() + return deleted + + def update_status( + self, + doc_id: str, + status: DocumentStatus, + *, + error_message: str = "", + chunk_count: int | None = None, + summary: str | None = None, + summary_latency_ms: int | None = None, + parser_name: str | None = None, + index_name: str | None = None, + metadata: dict | None = None, + ) -> Document | None: + document = self.get(doc_id) + if not document: + return None + document.status = status + document.error_message = error_message + if chunk_count is not None: + document.chunk_count = chunk_count + if summary is not None: + document.summary = summary + if summary_latency_ms is not None: + document.summary_latency_ms = summary_latency_ms + if parser_name is not None: + document.parser_name = parser_name + if index_name is not None: + document.index_name = index_name + if metadata: + document.metadata.update(metadata) + return self.update(document) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _to_params(self, document: Document) -> dict[str, Any]: + return { + "doc_id": document.doc_id, + "doc_name": document.doc_name, + "file_name": document.file_name, + "object_name": document.object_name, + "content_type": document.content_type, + "size_bytes": document.size_bytes, + "status": document.status.value, + "regulation_type": document.regulation_type, + "version": document.version, + "summary": document.summary, + "summary_latency_ms": document.summary_latency_ms, + "chunk_count": document.chunk_count, + "parser_name": document.parser_name, + "index_name": document.index_name, + "error_message": document.error_message, + "metadata": json.dumps(document.metadata, ensure_ascii=False), + "created_at": document.created_at, + "updated_at": document.updated_at, + } diff --git a/backend/app/infrastructure/storage/postgres_parse_artifact_store.py b/backend/app/infrastructure/storage/postgres_parse_artifact_store.py new file mode 100644 index 0000000..c1b8f46 --- /dev/null +++ b/backend/app/infrastructure/storage/postgres_parse_artifact_store.py @@ -0,0 +1,196 @@ +"""Implement infrastructure support for postgres parse artifact store.""" + +from __future__ import annotations + +import json +from contextlib import contextmanager +from typing import Any + +import psycopg2 +import psycopg2.extras +from psycopg2.pool import ThreadedConnectionPool + +from app.config.settings import settings +from app.domain.documents import ParseArtifactStore + +_CREATE_STRUCTURE_NODES = """ +CREATE TABLE IF NOT EXISTS structure_nodes ( + id SERIAL PRIMARY KEY, + doc_id VARCHAR(128) NOT NULL, + unique_id VARCHAR(128), + page INTEGER NOT NULL DEFAULT 0, + idx INTEGER NOT NULL DEFAULT 0, + level INTEGER NOT NULL DEFAULT 0, + title TEXT NOT NULL DEFAULT '', + type VARCHAR(64), + sub_type VARCHAR(64), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT fk_sn_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS idx_structure_nodes_doc_id ON structure_nodes(doc_id); +""" + +_CREATE_SEMANTIC_BLOCKS = """ +CREATE TABLE IF NOT EXISTS semantic_blocks ( + id SERIAL PRIMARY KEY, + doc_id VARCHAR(128) NOT NULL, + semantic_id VARCHAR(128) NOT NULL, + block_type VARCHAR(64) NOT NULL DEFAULT '', + page_start INTEGER NOT NULL DEFAULT 0, + page_end INTEGER NOT NULL DEFAULT 0, + section_path JSONB NOT NULL DEFAULT '[]', + section_level INTEGER NOT NULL DEFAULT 0, + section_title VARCHAR(512) NOT NULL DEFAULT '', + source_ids JSONB NOT NULL DEFAULT '[]', + text TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + CONSTRAINT fk_sb_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE, + CONSTRAINT uq_semantic_blocks UNIQUE (doc_id, semantic_id) +); +CREATE INDEX IF NOT EXISTS idx_semantic_blocks_doc_id ON semantic_blocks(doc_id); +""" + + +class PostgresParseArtifactStore(ParseArtifactStore): + """ParseArtifactStore implementation backed by PostgreSQL. + + Requires the `documents` table to exist first (created by PostgresDocumentRepository). + """ + + def __init__(self) -> None: + self._pool = ThreadedConnectionPool( + minconn=1, + maxconn=5, + host=settings.postgres_host, + port=settings.postgres_port, + user=settings.postgres_user, + password=settings.postgres_password, + dbname=settings.postgres_db, + ) + self._ensure_schema() + + def _ensure_schema(self) -> None: + with self._conn() as conn: + with conn.cursor() as cur: + cur.execute(_CREATE_STRUCTURE_NODES) + cur.execute(_CREATE_SEMANTIC_BLOCKS) + conn.commit() + + @contextmanager + def _conn(self): + conn = self._pool.getconn() + try: + yield conn + finally: + self._pool.putconn(conn) + + # ------------------------------------------------------------------ + # ParseArtifactStore interface + # ------------------------------------------------------------------ + + def save( + self, + doc_id: str, + structure_nodes: list[dict], + semantic_blocks: list[dict], + ) -> None: + """Persist structure nodes and semantic blocks, replacing any existing records.""" + with self._conn() as conn: + with conn.cursor() as cur: + # Delete existing records first to keep save idempotent. + cur.execute("DELETE FROM structure_nodes WHERE doc_id = %s", (doc_id,)) + cur.execute("DELETE FROM semantic_blocks WHERE doc_id = %s", (doc_id,)) + + if structure_nodes: + psycopg2.extras.execute_values( + cur, + """ + INSERT INTO structure_nodes + (doc_id, unique_id, page, idx, level, title, type, sub_type) + VALUES %s + """, + [ + ( + doc_id, + node.get("unique_id"), + int(node.get("page", 0) or 0), + int(node.get("index", 0) or 0), + int(node.get("level", 0) or 0), + str(node.get("title", "")), + node.get("type"), + node.get("sub_type"), + ) + for node in structure_nodes + ], + ) + + if semantic_blocks: + psycopg2.extras.execute_values( + cur, + """ + INSERT INTO semantic_blocks + (doc_id, semantic_id, block_type, page_start, page_end, + section_path, section_level, section_title, source_ids, text) + VALUES %s + """, + [ + ( + doc_id, + block.get("semantic_id", ""), + block.get("block_type", ""), + int(block.get("page_start", 0) or 0), + int(block.get("page_end", 0) or 0), + json.dumps(block.get("section_path", []), ensure_ascii=False), + int(block.get("section_level", 0) or 0), + str(block.get("section_title", "")), + json.dumps(block.get("source_ids", []), ensure_ascii=False), + str(block.get("text", "")), + ) + for block in semantic_blocks + ], + ) + conn.commit() + + def delete(self, doc_id: str) -> None: + """Remove all parse artifacts for a document (ON DELETE CASCADE handles child rows).""" + with self._conn() as conn: + with conn.cursor() as cur: + cur.execute("DELETE FROM structure_nodes WHERE doc_id = %s", (doc_id,)) + cur.execute("DELETE FROM semantic_blocks WHERE doc_id = %s", (doc_id,)) + conn.commit() + + def get_semantic_blocks(self, doc_id: str) -> list[dict[str, Any]]: + """Return all semantic blocks for a document ordered by id.""" + sql = """ + SELECT semantic_id, block_type, page_start, page_end, + section_path, section_level, section_title, source_ids, text + FROM semantic_blocks + WHERE doc_id = %s + ORDER BY id + """ + with self._conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: + cur.execute(sql, (doc_id,)) + rows = cur.fetchall() + results = [] + for row in rows: + item = dict(row) + for key in ("section_path", "source_ids"): + if isinstance(item[key], str): + item[key] = json.loads(item[key]) + results.append(item) + return results + + def get_structure_nodes(self, doc_id: str) -> list[dict[str, Any]]: + """Return all structure nodes for a document ordered by idx.""" + sql = """ + SELECT unique_id, page, idx, level, title, type, sub_type + FROM structure_nodes + WHERE doc_id = %s + ORDER BY idx + """ + with self._conn() as conn: + with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: + cur.execute(sql, (doc_id,)) + rows = cur.fetchall() + return [dict(r) for r in rows] diff --git a/backend/app/infrastructure/vectorstore/cross_encoder_reranker.py b/backend/app/infrastructure/vectorstore/cross_encoder_reranker.py new file mode 100644 index 0000000..b1a280e --- /dev/null +++ b/backend/app/infrastructure/vectorstore/cross_encoder_reranker.py @@ -0,0 +1,84 @@ +"""Implement cross-encoder reranking via an OpenAI-compatible reranker API.""" + +from __future__ import annotations + +import time + +import requests +from loguru import logger + +from app.config.settings import settings +from app.domain.retrieval import Reranker, RetrievedChunk + + +class OpenAICompatibleReranker(Reranker): + """Call a TEI / Cohere-style reranker endpoint to re-score retrieved chunks.""" + + def __init__( + self, + base_url: str | None = None, + model: str | None = None, + api_key: str | None = None, + timeout: int = 30, + ) -> None: + self._base_url = (base_url or settings.reranker_base_url).rstrip("/") + self._model = model or settings.reranker_model + self._api_key = api_key or settings.reranker_api_key + self._timeout = timeout + + def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]: + """Return up to top_k chunks re-sorted by cross-encoder score.""" + if not chunks: + return [] + + texts = [chunk.content for chunk in chunks] + start = time.time() + try: + scores = self._call_reranker(query, texts) + except Exception as exc: + logger.warning("Reranker call failed ({}), falling back to original order: {}", type(exc).__name__, exc) + return chunks[:top_k] + + elapsed_ms = int((time.time() - start) * 1000) + logger.debug("Reranker scored {} chunks in {}ms", len(chunks), elapsed_ms) + + ranked = sorted( + [(score, chunk) for score, chunk in zip(scores, chunks)], + key=lambda x: x[0], + reverse=True, + ) + result = [] + for score, chunk in ranked[:top_k]: + chunk.score = float(score) + result.append(chunk) + return result + + def _call_reranker(self, query: str, texts: list[str]) -> list[float]: + """Call the reranker API and return a score per text.""" + headers = {"Content-Type": "application/json"} + if self._api_key: + headers["Authorization"] = f"Bearer {self._api_key}" + + # Try TEI format first: POST /rerank + payload = {"query": query, "texts": texts, "raw_scores": False, "return_text": False} + url = f"{self._base_url}/rerank" + resp = requests.post(url, json=payload, headers=headers, timeout=self._timeout) + + if resp.status_code == 404: + # Fall back to Cohere / OpenAI-style: POST /v1/rerank + payload_v1 = {"model": self._model, "query": query, "documents": texts} + url = f"{self._base_url}/v1/rerank" + resp = requests.post(url, json=payload_v1, headers=headers, timeout=self._timeout) + + resp.raise_for_status() + data = resp.json() + + # TEI response: list of {"index": N, "score": F} + if isinstance(data, list): + ordered = sorted(data, key=lambda x: x["index"]) + return [float(item["score"]) for item in ordered] + + # Cohere/OpenAI response: {"results": [{"index": N, "relevance_score": F}]} + results = data.get("results", []) + ordered = sorted(results, key=lambda x: x["index"]) + return [float(item.get("relevance_score", item.get("score", 0))) for item in ordered] diff --git a/backend/app/infrastructure/vectorstore/milvus_vector_index.py b/backend/app/infrastructure/vectorstore/milvus_vector_index.py index 6d7ce55..df7a4fc 100644 --- a/backend/app/infrastructure/vectorstore/milvus_vector_index.py +++ b/backend/app/infrastructure/vectorstore/milvus_vector_index.py @@ -100,14 +100,42 @@ class MilvusVectorIndex(VectorIndex): result = self.collection.delete(f'doc_id == "{doc_id}"') return len(result.primary_keys) + def _parse_filters(self, filters: str | None) -> str | None: + """Parse filter string into Milvus expression.""" + if not filters or not filters.strip(): + return None + + filters = filters.strip() + + # Check if already a Milvus expression (contains operators) + if any(op in filters for op in ["==", "!=", "in", "not in", ">", "<", ">=", "<=", "and", "or"]): + return filters + + # Parse simple regulation_type filter + # Support: "GB" or "GB,UN-ECE" or "GB, UN-ECE" + types = [t.strip() for t in filters.split(",") if t.strip()] + + if not types: + return None + + if len(types) == 1: + # Single value: regulation_type == "GB" + return f'regulation_type == "{types[0]}"' + else: + # Multiple values: regulation_type in ["GB", "UN-ECE"] + quoted_types = [f'"{t}"' for t in types] + return f'regulation_type in [{", ".join(quoted_types)}]' + def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]: """Handle search for the Milvus Vector Index instance.""" + milvus_expr = self._parse_filters(filters) + results = self.collection.search( data=[query_vector], anns_field="embedding", param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}}, limit=top_k, - filter=filters, + expr=milvus_expr, output_fields=[ "doc_id", "doc_name", @@ -145,6 +173,49 @@ class MilvusVectorIndex(VectorIndex): ) return payload + def count_by_document(self) -> dict[str, int]: + """Return doc_id -> chunk count from Milvus.""" + try: + rows = self.collection.query(expr="doc_id != \"\"", output_fields=["doc_id"]) + except Exception: + return {} + counts: dict[str, int] = {} + for row in rows: + doc_id = row.get("doc_id", "") + if doc_id: + counts[doc_id] = counts.get(doc_id, 0) + 1 + return counts + + def list_document_metadata(self) -> list[dict]: + """Return one metadata row per document from Milvus (single query, no embeddings).""" + try: + rows = self.collection.query( + expr="doc_id != \"\"", + output_fields=["doc_id", "doc_name", "regulation_type", "version"], + ) + except Exception: + return [] + + seen: dict[str, dict] = {} + counts: dict[str, int] = {} + for row in rows: + doc_id = row.get("doc_id", "") + if not doc_id: + continue + counts[doc_id] = counts.get(doc_id, 0) + 1 + if doc_id not in seen: + seen[doc_id] = { + "doc_id": doc_id, + "doc_name": row.get("doc_name", ""), + "regulation_type": row.get("regulation_type", ""), + "version": row.get("version", ""), + } + + return [ + {**meta, "chunk_count": counts[meta["doc_id"]]} + for meta in seen.values() + ] + def health(self) -> dict: """Handle health for the Milvus Vector Index instance.""" return { diff --git a/backend/app/schemas/compliance.py b/backend/app/schemas/compliance.py index 9039027..3218a46 100644 --- a/backend/app/schemas/compliance.py +++ b/backend/app/schemas/compliance.py @@ -74,6 +74,7 @@ class ComplianceResult(BaseModel): class ComplianceChatRequest(BaseModel): """Define the Compliance Chat Request API model.""" query: str + segment_context: Optional[str] = None class AnalyzeResponse(BaseModel): diff --git a/backend/app/schemas/rag.py b/backend/app/schemas/rag.py index 5d2e24f..3b72209 100644 --- a/backend/app/schemas/rag.py +++ b/backend/app/schemas/rag.py @@ -10,6 +10,8 @@ class RagChatRequest(BaseModel): """Define the Rag Chat Request API model.""" query: str top_k: int = 5 + session_id: Optional[str] = None + filters: Optional[str] = None class RetrievedDoc(BaseModel): diff --git a/backend/app/services/llm/deepseek_client.py b/backend/app/services/llm/deepseek_client.py index a457062..0e0a7f3 100644 --- a/backend/app/services/llm/deepseek_client.py +++ b/backend/app/services/llm/deepseek_client.py @@ -95,6 +95,57 @@ class DeepSeekClient(BaseLLMClient): error=str(e) ) + def stream_chat( + self, + messages: List[Dict[str, str]], + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + **kwargs + ): + """Stream chat for the Deep Seek Client instance.""" + try: + payload = { + "model": self.config.model, + "messages": messages, + "max_tokens": max_tokens or self.config.max_tokens, + "temperature": temperature or self.config.temperature, + "top_p": kwargs.get("top_p", self.config.top_p), + "stream": True + } + + with self._client.stream("POST", "/chat/completions", json=payload) as response: + response.raise_for_status() + for line in response.iter_lines(): + if not line or line.startswith(b":"): + continue + + line_str = line.decode("utf-8").strip() + if not line_str.startswith("data: "): + continue + + data_str = line_str[6:] + if data_str == "[DONE]": + break + + try: + import json + data = json.loads(data_str) + choices = data.get("choices", []) + if choices: + delta = choices[0].get("delta", {}) + content = delta.get("content", "") + if content: + yield content + except json.JSONDecodeError: + continue + + except httpx.HTTPStatusError as e: + logger.error(f"DeepSeek Stream API错误: {e.response.status_code}") + yield "" + except Exception as e: + logger.error(f"DeepSeek Stream调用失败: {e}") + yield "" + def get_available_models(self) -> List[str]: """Return available models for the Deep Seek Client instance.""" return self.SUPPORTED_MODELS diff --git a/backend/app/shared/bootstrap.py b/backend/app/shared/bootstrap.py index 31d74df..4ab08e0 100644 --- a/backend/app/shared/bootstrap.py +++ b/backend/app/shared/bootstrap.py @@ -17,18 +17,31 @@ from app.infrastructure.parser.vector_chunk_builder import AliyunVectorChunkBuil from app.infrastructure.session.in_memory_conversation_store import InMemoryConversationStore from app.infrastructure.storage.json_document_repository import JsonDocumentRepository from app.infrastructure.storage.minio_binary_store import MinioDocumentBinaryStore +from app.infrastructure.storage.postgres_document_repository import PostgresDocumentRepository +from app.infrastructure.storage.postgres_parse_artifact_store import PostgresParseArtifactStore from app.infrastructure.vectorstore.dense_retriever import DenseRetriever from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex +from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker # Keep shared wiring centralized so dependency construction remains consistent. @lru_cache -def get_document_repository() -> JsonDocumentRepository: - """Return document repository.""" +def get_document_repository(): + """Return document repository (json or postgres, controlled by settings).""" + if settings.document_repository_backend == "postgres": + return PostgresDocumentRepository() return JsonDocumentRepository(settings.document_metadata_path) +@lru_cache +def get_parse_artifact_store(): + """Return parse artifact store, or None when postgres backend is not enabled.""" + if settings.document_repository_backend == "postgres": + return PostgresParseArtifactStore() + return None + + @lru_cache def get_binary_store() -> MinioDocumentBinaryStore: """Return binary store.""" @@ -66,6 +79,14 @@ def get_vector_index() -> MilvusVectorIndex: return MilvusVectorIndex() +@lru_cache +def get_reranker(): + """Return reranker if enabled, else None.""" + if settings.reranker_enabled and settings.reranker_base_url: + return OpenAICompatibleReranker() + return None + + @lru_cache def get_retrieval_service() -> KnowledgeRetrievalService: """Return retrieval service.""" @@ -73,7 +94,11 @@ def get_retrieval_service() -> KnowledgeRetrievalService: embedding_provider=get_embedding_provider(), vector_index=get_vector_index(), ) - return KnowledgeRetrievalService(retriever=retriever) + return KnowledgeRetrievalService( + retriever=retriever, + reranker=get_reranker(), + reranker_top_k=settings.reranker_top_k, + ) @lru_cache @@ -86,6 +111,7 @@ def get_document_command_service() -> DocumentCommandService: chunk_builder=get_chunk_builder(), embedding_provider=get_embedding_provider(), vector_index=get_vector_index(), + parse_artifact_store=get_parse_artifact_store(), ) @@ -95,6 +121,7 @@ def get_document_query_service() -> DocumentQueryService: return DocumentQueryService( document_repository=get_document_repository(), binary_store=get_binary_store(), + vector_index=get_vector_index(), ) diff --git a/backend/requirements.txt b/backend/requirements.txt index 0a5f829..20d07bb 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -13,6 +13,7 @@ tenacity>=8.2.0 pymilvus>=2.4.0 minio>=7.1.0 +psycopg2-binary>=2.9.0 pymupdf>=1.24.0 python-docx>=1.1.0 diff --git a/dev.sh b/dev.sh index dea6c1a..79a9cb0 100644 --- a/dev.sh +++ b/dev.sh @@ -12,7 +12,8 @@ API_PID_FILE="$LOG_DIR/api.pid" FRONTEND_PID_FILE="$LOG_DIR/frontend.pid" API_LOG_FILE="$LOG_DIR/api.log" FRONTEND_LOG_FILE="$LOG_DIR/frontend.log" -DOCKER_CONTAINERS="milvus minio redis postgres" +DISPLAY_HOST="localhost" +SERVICE_HOST="6.86.80.8" RED='\033[0;31m' GREEN='\033[0;32m' @@ -54,6 +55,51 @@ ensure_log_dir() { mkdir -p "$LOG_DIR" } +check_tcp_connectivity() { + local host="$1" + local port="$2" + + if command -v nc > /dev/null 2>&1; then + nc -z -w 3 "$host" "$port" > /dev/null 2>&1 + return + fi + + require_python_bootstrap + "$PYTHON_BOOTSTRAP" - < /dev/null 2>&1 +import socket +import sys + +try: + with socket.create_connection(("${host}", ${port}), timeout=3): + pass +except Exception: + sys.exit(1) + +sys.exit(0) +PY +} + +check_foundation_services() { + local name + local port + + for service in \ + "Milvus:19530" \ + "MinIO API:9000" \ + "MinIO Console:9001" \ + "Redis:6379" \ + "PostgreSQL:5432" + do + name="${service%%:*}" + port="${service##*:}" + if check_tcp_connectivity "$SERVICE_HOST" "$port"; then + success "${name}: ${SERVICE_HOST}:${port} 可连通" + else + warn "${name}: ${SERVICE_HOST}:${port} 不可连通" + fi + done +} + print_header() { echo "" echo -e "${CYAN}========================================${NC}" @@ -197,21 +243,8 @@ run_setup() { success "前端依赖安装完成" echo "" - info "[4/4] 检查 Docker 基础服务" - if command -v docker > /dev/null 2>&1; then - local container - for container in $DOCKER_CONTAINERS; do - if docker ps --format '{{.Names}}' | grep -q "^${container}$"; then - success "${container}: 运行中" - elif docker ps -a --format '{{.Names}}' | grep -q "^${container}$"; then - warn "${container}: 已创建但未运行" - else - warn "${container}: 未找到容器" - fi - done - else - warn "未检测到 Docker,已跳过容器检查" - fi + info "[4/4] 检查 6.86.80.8 基础服务连通性" + check_foundation_services echo "" success "环境初始化完成" @@ -223,7 +256,7 @@ run_setup() { api_health_ok() { if command -v curl > /dev/null 2>&1; then - curl -fsS "http://localhost:$API_PORT/health" > /dev/null 2>&1 + curl -fsS "http://${DISPLAY_HOST}:$API_PORT/health" > /dev/null 2>&1 return fi @@ -233,7 +266,7 @@ import sys from urllib.request import urlopen try: - with urlopen("http://localhost:${API_PORT}/health", timeout=3) as response: + with urlopen("http://${DISPLAY_HOST}:${API_PORT}/health", timeout=3) as response: body = response.read().decode("utf-8", errors="ignore") sys.exit(0 if "healthy" in body.lower() else 1) except Exception: @@ -260,9 +293,9 @@ start_api() { if [ "$mode" = "foreground" ]; then print_header "AI+合规智能中枢 - 启动 API" echo "运行模式: 前台调试(带 --reload)" - echo "服务地址: http://localhost:$API_PORT" - echo "文档地址: http://localhost:$API_PORT/docs" - echo "健康检查: http://localhost:$API_PORT/health" + echo "服务地址: http://${DISPLAY_HOST}:$API_PORT" + echo "文档地址: http://${DISPLAY_HOST}:$API_PORT/docs" + echo "健康检查: http://${DISPLAY_HOST}:$API_PORT/health" echo "" exec "$VENV_PYTHON" -m uvicorn app.main:app --host "$API_HOST" --port "$API_PORT" --reload fi @@ -274,8 +307,8 @@ start_api() { if is_pid_running "$pid"; then success "API 启动成功 (PID: $pid)" - echo " 地址: http://localhost:$API_PORT" - echo " 文档: http://localhost:$API_PORT/docs" + echo " 地址: http://${DISPLAY_HOST}:$API_PORT" + echo " 文档: http://${DISPLAY_HOST}:$API_PORT/docs" echo " 日志: $API_LOG_FILE" else rm -f "$API_PID_FILE" @@ -316,7 +349,7 @@ start_frontend() { if is_pid_running "$pid"; then success "前端启动成功 (PID: $pid)" - echo " 地址: http://localhost:$FRONTEND_PORT" + echo " 地址: http://${DISPLAY_HOST}:$FRONTEND_PORT" echo " 模式: $mode" echo " 日志: $FRONTEND_LOG_FILE" else @@ -407,6 +440,8 @@ run_status() { local frontend_running=false local pid local port_listener + local service_name + local service_port echo -e "${YELLOW}API 服务:${NC}" pid="$(read_pid "$API_PID_FILE")" @@ -432,8 +467,8 @@ run_status() { warn " 健康检查: 未通过" fi fi - echo " 地址: http://localhost:$API_PORT" - echo " 文档: http://localhost:${API_PORT}/docs" + echo " 地址: http://${DISPLAY_HOST}:$API_PORT" + echo " 文档: http://${DISPLAY_HOST}:${API_PORT}/docs" echo "" echo -e "${YELLOW}前端服务:${NC}" @@ -453,24 +488,25 @@ run_status() { fi fi echo " 模式: $FRONTEND_MODE" - echo " 地址: http://localhost:$FRONTEND_PORT" + echo " 地址: http://${DISPLAY_HOST}:$FRONTEND_PORT" echo "" - echo -e "${YELLOW}Docker 服务:${NC}" - if command -v docker > /dev/null 2>&1; then - local container - for container in $DOCKER_CONTAINERS; do - if docker ps --format '{{.Names}}' | grep -q "^${container}$"; then - success " ${container}: 运行中" - elif docker ps -a --format '{{.Names}}' | grep -q "^${container}$"; then - warn " ${container}: 已停止" - else - warn " ${container}: 未创建" - fi - done - else - warn " Docker 未安装,已跳过" - fi + echo -e "${YELLOW}基础服务连通性:${NC}" + for service in \ + "Milvus:19530" \ + "MinIO API:9000" \ + "MinIO Console:9001" \ + "Redis:6379" \ + "PostgreSQL:5432" + do + service_name="${service%%:*}" + service_port="${service##*:}" + if check_tcp_connectivity "$SERVICE_HOST" "$service_port"; then + success " ${service_name}: ${SERVICE_HOST}:${service_port} 可连通" + else + warn " ${service_name}: ${SERVICE_HOST}:${service_port} 不可连通" + fi + done echo "" if [ "$api_running" = true ] && [ "$frontend_running" = true ]; then @@ -526,7 +562,7 @@ AI+合规智能中枢统一脚本 setup 进行一次性的本地初始化。 包含 Python 版本检查、.venv 虚拟环境创建、后端依赖安装、前端 npm install、 - 以及 Docker 基础容器状态检查。 + 以及 6.86.80.8 基础服务端口连通性检查。 start 启动服务。默认行为等同于 ./dev.sh start all。 @@ -548,7 +584,7 @@ AI+合规智能中枢统一脚本 restart frontend --mode static 可直接切换前端启动模式。 status - 查看 API、前端、Docker 基础容器的状态。 + 查看 API、前端、6.86.80.8 基础服务的状态。 API 状态包含健康检查;前端状态包含当前模式和访问地址。 logs diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 1c6132b..6427b32 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -58,7 +58,7 @@ services: retries: 5 restart: unless-stopped - # PostgreSQL数据库 (可选) + # PostgreSQL数据库 (可选,启用 DOCUMENT_REPOSITORY_BACKEND=postgres 时使用) postgres: image: postgres:15-alpine container_name: postgres @@ -71,7 +71,7 @@ services: volumes: - postgres_data:/var/lib/postgresql/data healthcheck: - test: ["CMD-SHELL", "pg_isready -U compliance"] + test: ["CMD-SHELL", "pg_isready -U compliance -d compliance_db"] interval: 30s timeout: 10s retries: 5 diff --git a/docs/superpowers/specs/2026-05-20-rag-dialogue-optimization-design.md b/docs/superpowers/specs/2026-05-20-rag-dialogue-optimization-design.md new file mode 100644 index 0000000..b77705d --- /dev/null +++ b/docs/superpowers/specs/2026-05-20-rag-dialogue-optimization-design.md @@ -0,0 +1,145 @@ +# 法规对话模块优化设计文档 + +**日期**: 2026-05-20 +**状态**: 已批准,实施中 + +--- + +## 背景 + +当前法规对话模块存在以下问题: +1. `compliance.py` `/compliance/chat/{segment_id}` 返回硬编码 Mock 数据 +2. `rag.py` `/rag/chat` 返回硬编码 Mock 数据(前端实际调用 `/agent/chat/stream`,此路由可统一) +3. 前端快速问题列表硬编码在 `rag.ts`,未调用后端 +4. 仅有 Dense 向量检索(COSINE),无 BM25 混合检索与精排 +5. LLM 输出的 `[1][2]` 引用编号未在前端内联高亮 +6. 会话存储仅为内存(100条上限,30分钟过期) + +--- + +## 方案:分层优先(4个阶段) + +### Phase 1 — 接入真实服务(Week 1) + +**目标**:消灭所有 Mock,让系统真正可用。 + +#### 后端变更 + +**`backend/app/schemas/compliance.py`** +- `ComplianceChatRequest` 新增 `segment_context: str | None = None` + +**`backend/app/api/routes/compliance.py`** +- 移除 `get_mock_compliance_chat_response` 导入 +- `/compliance/chat/{segment_id}` 接入 `get_agent_conversation_service().stream_chat()` +- 将 `segment_context` 拼接到 query 前缀作为上下文 +- 将 agent `content` 事件翻译为 `{"type":"chunk","text":"..."}` 格式,保持前端兼容 + +**`backend/app/api/routes/rag.py`** +- 移除 `get_mock_retrieval`、`get_mock_rag_answer` 导入 +- `/rag/chat` 接入 `get_agent_conversation_service().stream_chat()` +- 翻译 agent 事件为 rag 格式(`retrieved`/`chunk`/`done`) + +**`backend/app/schemas/rag.py`** +- `RagChatRequest` 新增 `session_id: str | None = None`,`filters: str | None = None` + +#### 前端变更 + +**`frontend/src/api/rag.ts`** +- `getQuickQuestions()` 改为真实调用 `GET /api/v1/rag/quick-questions`,失败时降级为本地数组 + +**`frontend/src/api/compliance.ts`** +- `complianceChat()` 新增第三个参数 `segmentContext: string | undefined`,传入 request body + +**`frontend/src/pages/Compliance/CompliancePage.tsx`** +- `sendChatMessage()` 中构建 `segmentContext`(intent + content 摘要 + 法规名) +- 传给 `complianceChat()` + +--- + +### Phase 2 — 混合检索 + Reranking(Week 2-3) + +**目标**:提升召回质量(BM25 + dense RRF融合)+ 精排。 + +#### 2a — Cross-Encoder Reranking(优先,无需 schema 变更) + +- 新增端口 `backend/app/domain/retrieval/ports.py`: `Reranker` ABC +- 新增适配器 `backend/app/infrastructure/vectorstore/cross_encoder_reranker.py` + - 调用 OpenAI-compatible reranker API(BAAI/bge-reranker-v2-m3) +- 修改 `KnowledgeRetrievalService`:先 retrieve top-20,再 rerank 到 top-5 +- 新增 settings: `reranker_enabled: bool = False`,`reranker_model: str`,`reranker_top_k: int = 5` + +#### 2b — Milvus Sparse BM25(需 schema 迁移) + +- Milvus collection 新增 `sparse_embedding SPARSE_FLOAT_VECTOR` 字段 +- 新增端口 `SparseEmbeddingProvider`(sparse embed 接口) +- 适配器优先使用 BGE-M3 API(同时输出 dense + sparse),不可用时降级为 TF-IDF keyword weights +- `MilvusVectorIndex.upsert()` 同时写入 sparse 向量 +- `MilvusVectorIndex.search()` 改为 hybrid search(`WeightedRanker` 或 `RRFRanker`) +- 提供一次性迁移脚本:dump 所有 chunks → recreate collection → re-embed → re-insert + +--- + +### Phase 3 — 引用溯源 + 筛选 UI(Week 4) + +**目标**:答案文本中 `[1][2]` 可点击跳转原文片段;法规类型/版本可筛选。 + +#### 引用内联解析 + +- 新增 React 组件 `CitedAnswer`:接受 `text` + `sources[]` + - 用正则 `/\[(\d+)\]/g` 拆分文本,将 `[N]` 渲染为可点击 `` 元素 + - 点击高亮/滚动到来源面板对应条目 +- `RagChatPage.tsx` 将 assistant 消息改用 `CitedAnswer` 组件渲染 +- 来源面板各条目加 `id="source-N"` 属性 + +#### 法规筛选 UI + +- `RagChatPage.tsx` 在输入框上方加筛选栏:`regulation_type` 下拉 + `version` 输入 +- 筛选值作为 `filters` 参数传到 `/agent/chat/stream` +- `MilvusVectorIndex.search()` 解析 `filters` 字符串为 Milvus expr(如 `regulation_type == "GB"`) + +#### 快速问题动态化 + +- 后端 `/rag/quick-questions` 改为从 settings 或配置文件加载,不硬编码 +- 前端已在 Phase 1 中改为调用后端 + +--- + +### Phase 4 — 会话持久化 + 上下文压缩(Week 5) + +**目标**:长对话不丢失;上下文过长时智能压缩。 + +#### PostgreSQL 会话存储 + +- 新增 `backend/app/infrastructure/session/postgres_conversation_store.py` + - 实现 `ConversationStore` port + - 复用现有 PostgreSQL 连接 + - DDL: `conversation_sessions` + `conversation_messages` 两张表 +- 新增 settings: `conversation_store_backend: str = "memory"` +- `bootstrap.py` 按 settings 切换实现 + +#### 上下文压缩 + +- 在 `AgentConversationService` 中,当对话轮数 > N 时,将早期消息用 LLM 摘要 + - 摘要替换为一条 `system` 消息 `"对话摘要:..."` +- 新增 `AnswerGenerator.summarize(messages) -> str` 方法 + +--- + +## 关键设计决策 + +| 决策 | 选择 | 原因 | +|------|------|------| +| BM25 实现 | Milvus sparse vectors | 不引入新服务,Milvus 2.4+ 原生支持 | +| Reranking | API-based cross-encoder | 复用现有 OpenAI-compatible 接口 | +| 会话持久化 | PostgreSQL | 复用现有 PostgreSQL 基础设施 | +| 引用解析 | 前端纯客户端 | LLM 已输出 [N] 编号,无需后端改动 | +| segment_context | 前端构建传后端 | 合规结果存在前端状态,无需后端持久化 | + +--- + +## 验证标准 + +- Phase 1: 合规对话面板返回真实法规检索结果,不再固定响应 +- Phase 2: 相同 query 在 hybrid 模式下比 dense-only 召回更相关结果 +- Phase 3: 点击答案中的 `[1]` 能跳转到来源面板第一条 +- Phase 4: 刷新页面后历史对话仍可恢复 diff --git a/frontend/src/api/compliance.ts b/frontend/src/api/compliance.ts index 8a14abe..7ca4c56 100644 --- a/frontend/src/api/compliance.ts +++ b/frontend/src/api/compliance.ts @@ -31,11 +31,18 @@ export async function getComplianceResult( export function complianceChat( segmentId: number, query: string, + segmentContext: string | undefined, onMessage: (data: SSEMessage) => void, onError?: (error: Error) => void, onComplete?: () => void ): void { - void streamSSE(`/compliance/chat/${segmentId}`, { query }, onMessage, onError, onComplete); + void streamSSE( + `/compliance/chat/${segmentId}`, + { query, segment_context: segmentContext }, + onMessage, + onError, + onComplete, + ); } export type { ComplianceResult, SSEMessage }; diff --git a/frontend/src/api/docs.ts b/frontend/src/api/docs.ts index a816acc..85a3219 100644 --- a/frontend/src/api/docs.ts +++ b/frontend/src/api/docs.ts @@ -6,7 +6,11 @@ interface BackendDocumentItem { doc_name: string; status: string; chunk_count: number; + size_bytes?: number; + summary?: string; updated_at?: string; + regulation_type?: string; + version?: string; } interface BackendDocumentListResponse { @@ -44,6 +48,7 @@ export interface RegulationSearchResponse { } function mapDoc(item: BackendDocumentItem): DocInfo { + const sizeMB = item.size_bytes ? (item.size_bytes / (1024 * 1024)).toFixed(1) + 'MB' : ''; return { id: item.doc_id, name: item.doc_name, @@ -51,14 +56,23 @@ function mapDoc(item: BackendDocumentItem): DocInfo { status: item.status, updated_at: item.updated_at, download_url: `${API_BASE_URL}/documents/download/${item.doc_id}`, + size_text: sizeMB, + summary: item.summary, + regulation_type: item.regulation_type, + version: item.version, }; } -export async function uploadDocument(file: File): Promise { +export async function uploadDocument( + file: File, + opts?: { regulationType?: string; version?: string } +): Promise { const formData = new FormData(); formData.append('file', file); formData.append('doc_name', file.name); formData.append('generate_summary', 'true'); + if (opts?.regulationType) formData.append('regulation_type', opts.regulationType); + if (opts?.version) formData.append('version', opts.version); const response = await fetch(`${API_BASE_URL}/documents/upload`, { method: 'POST', @@ -92,6 +106,29 @@ export async function getDocumentList(): Promise { }; } +export async function getDocumentStatus(docId: string): Promise { + const response = await fetch(`${API_BASE_URL}/documents/status/${docId}`); + if (!response.ok) { + throw new Error(`Status check failed: ${response.status}`); + } + return response.json() as Promise; +} + +export async function deleteDocument(docId: string): Promise { + const response = await fetch(`${API_BASE_URL}/documents/${docId}`, { method: 'DELETE' }); + if (!response.ok) { + throw new Error(`Delete failed: ${response.status}`); + } +} + +export async function retryDocument(docId: string): Promise { + const response = await fetch(`${API_BASE_URL}/documents/${docId}/retry`, { method: 'POST' }); + if (!response.ok) { + throw new Error(`Retry failed: ${response.status}`); + } + return response.json() as Promise; +} + export async function searchRegulations(query: string, topK: number = 8): Promise { const response = await fetch(`${API_BASE_URL}/knowledge/retrieval`, { method: 'POST', diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index 793bf39..d64ac34 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -136,6 +136,9 @@ export interface DocInfo { updated_at?: string; download_url?: string; size_text?: string; + summary?: string; + regulation_type?: string; + version?: string; } export interface DocListResponse { @@ -149,6 +152,8 @@ export interface DocUploadResponse { message?: string; num_chunks?: number; summary?: string; + regulation_type?: string; + version?: string; } export interface QuickQuestion { diff --git a/frontend/src/api/rag.ts b/frontend/src/api/rag.ts index 09f2239..2e50ab2 100644 --- a/frontend/src/api/rag.ts +++ b/frontend/src/api/rag.ts @@ -2,15 +2,21 @@ import type { QuickQuestionsResponse, SSEMessage } from './index'; const AGENT_API_BASE = '/api/v1'; +const _FALLBACK_QUESTIONS = [ + { id: '1', question: '请总结最新入库法规对电池安全的核心要求', category: '法规解读' }, + { id: '2', question: '我上传的制度文档与新能源法规有哪些潜在冲突?', category: '差距分析' }, + { id: '3', question: '请给出法规依据,并按条款列出整改建议', category: '整改建议' }, + { id: '4', question: '请解释 UN-ECE 与 GB 标准在网络安全方面的差异', category: '标准对比' }, +]; + export async function getQuickQuestions(): Promise { - return { - questions: [ - { id: '1', question: '请总结最新入库法规对电池安全的核心要求', category: '法规解读' }, - { id: '2', question: '我上传的制度文档与新能源法规有哪些潜在冲突?', category: '差距分析' }, - { id: '3', question: '请给出法规依据,并按条款列出整改建议', category: '整改建议' }, - { id: '4', question: '请解释 UN-ECE 与 GB 标准在网络安全方面的差异', category: '标准对比' }, - ], - }; + try { + const response = await fetch(`${AGENT_API_BASE}/rag/quick-questions`); + if (!response.ok) throw new Error(`status ${response.status}`); + return response.json() as Promise; + } catch { + return { questions: _FALLBACK_QUESTIONS }; + } } function parseSSEChunk(raw: string, onMessage: (data: SSEMessage) => void) { @@ -67,7 +73,8 @@ export async function ragChat( topK: number = 5, onMessage: (data: SSEMessage) => void, onError?: (error: Error) => void, - onComplete?: () => void + onComplete?: () => void, + filters?: string ): Promise { try { const response = await fetch(`${AGENT_API_BASE}/agent/chat/stream`, { @@ -76,7 +83,7 @@ export async function ragChat( 'Content-Type': 'application/json', Accept: 'text/event-stream', }, - body: JSON.stringify({ query, top_k: topK }), + body: JSON.stringify({ query, top_k: topK, ...(filters ? { filters } : {}) }), }); if (!response.ok || !response.body) { diff --git a/frontend/src/pages/Compliance/CompliancePage.tsx b/frontend/src/pages/Compliance/CompliancePage.tsx index 35ef3b9..3d61cec 100644 --- a/frontend/src/pages/Compliance/CompliancePage.tsx +++ b/frontend/src/pages/Compliance/CompliancePage.tsx @@ -250,11 +250,20 @@ export const CompliancePage: React.FC = () => { setChatInput(''); setChatLoading(true); + const segmentContext = [ + `意图:${chunk.intent}`, + `内容:${chunk.content.slice(0, 300)}`, + chunk.regulations.length > 0 + ? `相关法规:${chunk.regulations.slice(0, 3).map(r => `${r.name}${r.clause ? ' ' + r.clause : ''}(相关性 ${Math.round(r.score * 100)}%)`).join(';')}` + : '', + ].filter(Boolean).join('\n'); + let currentResponse = ''; complianceChat( activeChunkId, chatInput, + segmentContext, (data: unknown) => { const sseData = data as { type: string; text?: string }; if (sseData.type === 'chunk' && sseData.text) { diff --git a/frontend/src/pages/Docs/DocsPage.tsx b/frontend/src/pages/Docs/DocsPage.tsx index 32da143..6ce1414 100644 --- a/frontend/src/pages/Docs/DocsPage.tsx +++ b/frontend/src/pages/Docs/DocsPage.tsx @@ -2,7 +2,7 @@ import React, { useEffect, useRef, useState } from 'react'; import { useTheme } from '../../contexts'; import { Content } from '../../components/layout/Content'; import { TPattern } from '../../components/common/TPattern'; -import { getDocumentList, searchRegulations, uploadDocument, type RegulationSearchItem } from '../../api/docs'; +import { getDocumentList, getDocumentStatus, searchRegulations, uploadDocument, deleteDocument, retryDocument, type RegulationSearchItem } from '../../api/docs'; import type { Doc } from '../../types'; type PipelineStatus = 'idle' | 'running' | 'completed' | 'error'; @@ -15,13 +15,13 @@ const PIPELINE_STEPS = [ { name: 'STORE' }, ]; +const REGULATION_TYPES = ['', '国家标准', '行业标准', '地方标准', '企业标准', '法律法规', '监管规定']; + const STEP_DURATION_MS = 700; const INITIAL_SEARCH_QUERY = '新能源汽车电池安全要求'; function wait(ms: number) { - return new Promise((resolve) => { - window.setTimeout(resolve, ms); - }); + return new Promise((resolve) => { window.setTimeout(resolve, ms); }); } export const DocsPage: React.FC = () => { @@ -41,6 +41,13 @@ export const DocsPage: React.FC = () => { const [searchLoading, setSearchLoading] = useState(false); const [searchError, setSearchError] = useState(''); + // Upload metadata + const [regulationType, setRegulationType] = useState(''); + const [version, setVersion] = useState(''); + + // Batch queue: files waiting to be uploaded after the current one finishes + const batchQueueRef = useRef([]); + async function loadDocuments() { setLoading(true); try { @@ -54,6 +61,9 @@ export const DocsPage: React.FC = () => { docId: doc.id, downloadUrl: doc.download_url, updatedAt: doc.updated_at, + summary: doc.summary, + regulationType: doc.regulation_type, + version: doc.version, })); setDocs(apiDocs); } catch (error) { @@ -81,62 +91,71 @@ export const DocsPage: React.FC = () => { } useEffect(() => { - const timerId = window.setTimeout(() => { - void loadDocuments(); - }, 0); + const timerId = window.setTimeout(() => { void loadDocuments(); }, 0); return () => window.clearTimeout(timerId); }, []); useEffect(() => { - const timerId = window.setTimeout(() => { - void runSearch(INITIAL_SEARCH_QUERY); - }, 0); + const timerId = window.setTimeout(() => { void runSearch(INITIAL_SEARCH_QUERY); }, 0); return () => window.clearTimeout(timerId); }, []); useEffect(() => { - return () => { - pipelineRunIdRef.current += 1; - }; + return () => { pipelineRunIdRef.current += 1; }; }, []); + useEffect(() => { + const parsingDocs = docs.filter( + (doc) => doc.status === 'parsing' && doc.docId && !doc.docId.startsWith('pending-') + ); + if (parsingDocs.length === 0) return; + + const timerId = window.setInterval(() => { + parsingDocs.forEach((doc) => { + void getDocumentStatus(doc.docId!).then((res) => { + if (res.status === 'indexed' || res.status === 'failed') { + setDocs((prev) => + prev.map((d) => + d.docId === doc.docId + ? { + ...d, + status: res.status === 'indexed' ? 'indexed' : 'failed', + chunks: res.num_chunks ?? d.chunks, + summary: res.summary ?? d.summary, + regulationType: res.regulation_type ?? d.regulationType, + version: res.version ?? d.version, + } + : d + ) + ); + } + }).catch(() => {}); + }); + }, 5000); + + return () => window.clearInterval(timerId); + }, [docs]); + const runPipelineFlow = async (runId: number, uploadPromise: Promise>>) => { - const guardedSetActiveStep = (step: number) => { - if (pipelineRunIdRef.current !== runId) return false; - setActiveStep(step); - return true; - }; + const guard = (fn: () => void) => { if (pipelineRunIdRef.current !== runId) return false; fn(); return true; }; - const guardedCompleteStep = (step: number) => { - if (pipelineRunIdRef.current !== runId) return false; - setCompletedSteps((prev) => (prev.includes(step) ? prev : [...prev, step])); - return true; - }; - - for (let index = 0; index < PIPELINE_STEPS.length - 1; index += 1) { - if (!guardedSetActiveStep(index)) return; + for (let i = 0; i < PIPELINE_STEPS.length - 1; i++) { + if (!guard(() => setActiveStep(i))) return; await wait(STEP_DURATION_MS); - if (!guardedCompleteStep(index)) return; + if (!guard(() => setCompletedSteps((p) => p.includes(i) ? p : [...p, i]))) return; } - if (!guardedSetActiveStep(PIPELINE_STEPS.length - 1)) return; + if (!guard(() => setActiveStep(PIPELINE_STEPS.length - 1))) return; await uploadPromise; - if (!guardedCompleteStep(PIPELINE_STEPS.length - 1)) return; + if (!guard(() => setCompletedSteps((p) => { const last = PIPELINE_STEPS.length - 1; return p.includes(last) ? p : [...p, last]; }))) return; await wait(240); if (pipelineRunIdRef.current !== runId) return; - setActiveStep(-1); setPipelineStatus('completed'); }; - const handleFileSelect = async (event: React.ChangeEvent) => { - const file = event.target.files?.[0]; - if (!file || uploading) return; - - const runId = pipelineRunIdRef.current + 1; - pipelineRunIdRef.current = runId; - + const uploadSingleFile = async (file: File, runId: number) => { setUploading(true); setUploadFileName(file.name); setActiveStep(-1); @@ -152,11 +171,16 @@ export const DocsPage: React.FC = () => { size: `${fileSizeMB}MB`, status: 'parsing', docId: tempDocId, + regulationType: regulationType || undefined, + version: version || undefined, }; setDocs((prev) => [newDoc, ...prev]); - const uploadPromise = uploadDocument(file); + const uploadPromise = uploadDocument(file, { + regulationType: regulationType || undefined, + version: version || undefined, + }); void runPipelineFlow(runId, uploadPromise); try { @@ -166,143 +190,123 @@ export const DocsPage: React.FC = () => { setDocs((prev) => prev.map((doc) => doc.id === newDoc.id - ? { - ...doc, - status: 'indexed', - docId: uploadRes.doc_id, - chunks: uploadRes.num_chunks || doc.chunks, - summary: uploadRes.summary, - } + ? { ...doc, status: 'indexed', docId: uploadRes.doc_id, chunks: uploadRes.num_chunks || doc.chunks, summary: uploadRes.summary } : doc ) ); - - setUploading(false); - setUploadFileName(''); void loadDocuments(); } catch (error) { console.error('Upload failed:', error); if (pipelineRunIdRef.current !== runId) return; - - setUploading(false); - setUploadFileName(''); setDocs((prev) => prev.filter((doc) => doc.id !== newDoc.id)); setPipelineStatus('error'); setActiveStep(-1); setCompletedSteps([]); } finally { - if (fileInputRef.current) { - fileInputRef.current.value = ''; + setUploading(false); + setUploadFileName(''); + if (fileInputRef.current) fileInputRef.current.value = ''; + + // Process next file in batch queue + const next = batchQueueRef.current.shift(); + if (next) { + const nextRunId = pipelineRunIdRef.current + 1; + pipelineRunIdRef.current = nextRunId; + void uploadSingleFile(next, nextRunId); } } }; - const triggerFileUpload = () => { - if (uploading) return; - fileInputRef.current?.click(); + const handleFileSelect = async (event: React.ChangeEvent) => { + const files = Array.from(event.target.files ?? []); + if (files.length === 0 || uploading) return; + + const [first, ...rest] = files; + batchQueueRef.current = rest; + + const runId = pipelineRunIdRef.current + 1; + pipelineRunIdRef.current = runId; + await uploadSingleFile(first, runId); }; - const handleDragOver = (event: React.DragEvent) => { - event.preventDefault(); - event.stopPropagation(); + const handleDelete = async (docId: string) => { + try { + await deleteDocument(docId); + setDocs((prev) => prev.filter((doc) => doc.docId !== docId)); + } catch (error) { + console.error('Delete failed:', error); + } }; + const handleRetry = async (docId: string) => { + setDocs((prev) => prev.map((doc) => doc.docId === docId ? { ...doc, status: 'parsing' } : doc)); + try { + const result = await retryDocument(docId); + setDocs((prev) => + prev.map((doc) => doc.docId === docId ? { ...doc, status: 'indexed', chunks: result.num_chunks || doc.chunks } : doc) + ); + } catch (error) { + console.error('Retry failed:', error); + setDocs((prev) => prev.map((doc) => doc.docId === docId ? { ...doc, status: 'failed' } : doc)); + } + }; + + const triggerFileUpload = () => { if (uploading) return; fileInputRef.current?.click(); }; + + const handleDragOver = (event: React.DragEvent) => { event.preventDefault(); event.stopPropagation(); }; + const handleDrop = (event: React.DragEvent) => { event.preventDefault(); event.stopPropagation(); - - const files = event.dataTransfer.files; + const files = Array.from(event.dataTransfer.files); if (files.length === 0 || uploading) return; - const droppedFile = files[0]; - if (fileInputRef.current) { - const dataTransfer = new DataTransfer(); - dataTransfer.items.add(droppedFile); - fileInputRef.current.files = dataTransfer.files; - } - - void handleFileSelect({ - target: { files: [droppedFile] as unknown as FileList }, - } as React.ChangeEvent); + const [first, ...rest] = files; + batchQueueRef.current = rest; + const runId = pipelineRunIdRef.current + 1; + pipelineRunIdRef.current = runId; + void uploadSingleFile(first, runId); }; const getStepStyle = (index: number) => { - const isActive = activeStep === index; - const isCompleted = completedSteps.includes(index); - - if (isActive) { - return { - background: theme.bgCard, - border: `2px solid ${theme.accent}`, - boxShadow: `0 0 12px ${theme.accent}40`, - }; - } - - if (isCompleted) { - return { - background: theme.bgCard, - border: `1px solid ${theme.green}`, - }; - } - - return { - background: theme.bgCard, - border: `1px solid ${theme.border}`, - }; + if (activeStep === index) return { background: theme.bgCard, border: `2px solid ${theme.accent}`, boxShadow: `0 0 12px ${theme.accent}40` }; + if (completedSteps.includes(index)) return { background: theme.bgCard, border: `1px solid ${theme.green}` }; + return { background: theme.bgCard, border: `1px solid ${theme.border}` }; }; const getCheckStyle = (index: number) => { - const isActive = activeStep === index; - const isCompleted = completedSteps.includes(index); - - if (isActive) { - return { - background: theme.gradientAccent, - color: '#fff', - animation: 'pulse 0.6s infinite', - }; - } - - if (isCompleted) { - return { - background: theme.green, - color: '#fff', - }; - } - - return { - background: theme.bgHover, - color: theme.text3, - }; + if (activeStep === index) return { background: theme.gradientAccent, color: '#fff', animation: 'pulse 0.6s infinite' }; + if (completedSteps.includes(index)) return { background: theme.green, color: '#fff' }; + return { background: theme.bgHover, color: theme.text3 }; }; const getPipelineHint = () => { if (pipelineStatus === 'running') { - return activeStep >= 0 ? `${PIPELINE_STEPS[activeStep].name} · ${uploadFileName}` : `LOAD · ${uploadFileName}`; - } - if (pipelineStatus === 'completed') { - return 'PIPELINE COMPLETE'; - } - if (pipelineStatus === 'error') { - return 'PIPELINE FAILED'; + const queueLen = batchQueueRef.current.length; + const suffix = queueLen > 0 ? ` (+${queueLen} 待上传)` : ''; + return `${activeStep >= 0 ? PIPELINE_STEPS[activeStep].name : 'LOAD'} · ${uploadFileName}${suffix}`; } + if (pipelineStatus === 'completed') return 'PIPELINE COMPLETE'; + if (pipelineStatus === 'error') return 'PIPELINE FAILED'; return 'WAITING FOR UPLOAD'; }; + const inputStyle: React.CSSProperties = { + padding: '8px 12px', + fontSize: 13, + background: theme.bgCard, + border: `1px solid ${theme.border}`, + borderRadius: 8, + color: theme.text, + outline: 'none', + }; + return (
-

+

UPLOAD

@@ -310,10 +314,30 @@ export const DocsPage: React.FC = () => { ref={fileInputRef} type="file" accept=".pdf,.docx,.doc" + multiple onChange={handleFileSelect} style={{ display: 'none' }} /> + {/* Metadata row */} +
+ + setVersion(e.target.value)} + placeholder="版本号(可选,如 2024)" + style={{ ...inputStyle, flex: 1 }} + /> +
+
{ opacity: uploading ? 0.78 : 1, }} > -
+
{uploading ? (
- +
) : ( @@ -363,26 +368,17 @@ export const DocsPage: React.FC = () => { )}
-
- {uploading ? '正在上传并启动处理链路...' : '拖拽文件或点击上传'} + {uploading ? '正在上传并启动处理链路...' : '拖拽文件或点击上传(支持多选)'}
- {uploading ? uploadFileName : 'PDF · DOCX · DOC · MAX 100MB'} + {uploading ? uploadFileName : 'PDF · DOCX · DOC · MAX 100MB · 支持批量'}
-

+

PROCESSING PIPELINE

@@ -390,12 +386,7 @@ export const DocsPage: React.FC = () => { className="mono" style={{ fontSize: 11, - color: - pipelineStatus === 'error' - ? '#d64545' - : pipelineStatus === 'completed' - ? theme.green - : theme.text3, + color: pipelineStatus === 'error' ? '#d64545' : pipelineStatus === 'completed' ? theme.green : theme.text3, letterSpacing: '1px', marginBottom: 12, }} @@ -405,63 +396,24 @@ export const DocsPage: React.FC = () => {
{PIPELINE_STEPS.map((step, index) => { - const stepStyle = getStepStyle(index); - const checkStyle = getCheckStyle(index); - const arrowActive = activeStep > index || completedSteps.includes(index); const isCompleted = completedSteps.includes(index); const isActive = activeStep === index; + const arrowActive = activeStep > index || isCompleted; return (
-
+
{isActive ? step.name : isCompleted ? '✓' : step.name}
- -
- {step.name} -
+
{step.name}
{isCompleted ? 'DONE' : isActive ? 'RUNNING' : 'PENDING'}
- {index < PIPELINE_STEPS.length - 1 && ( -
+
)} @@ -473,15 +425,7 @@ export const DocsPage: React.FC = () => {
-

+

文档管理清单 ({loading ? '...' : docs.length})

@@ -489,36 +433,12 @@ export const DocsPage: React.FC = () => { {docs.map((doc) => (
-
-
+
+
- +
@@ -529,36 +449,48 @@ export const DocsPage: React.FC = () => { {doc.updatedAt ? new Date(doc.updatedAt).toLocaleString() : doc.size} {doc.docId ? ` · ${doc.docId}` : ''}
+ {/* Tags row */} + {(doc.regulationType || doc.version) && ( +
+ {doc.regulationType && ( + + {doc.regulationType} + + )} + {doc.version && ( + + v{doc.version} + + )} +
+ )} + {doc.summary && ( +
+ {doc.summary} +
+ )}
-
- {doc.downloadUrl && ( - +
+ {doc.status === 'failed' && doc.docId && !doc.docId.startsWith('pending-') && ( + + )} + {doc.downloadUrl && doc.status === 'indexed' && ( + 下载 )} -
- {doc.status === 'parsing' - ? '处理中...' - : doc.status === 'failed' - ? '处理失败' - : `${doc.chunks} chunks`} +
+ {doc.status === 'parsing' ? '处理中...' : doc.status === 'failed' ? '处理失败' : `${doc.chunks} chunks`}
+ {doc.docId && !doc.docId.startsWith('pending-') && ( + + )}
))} @@ -566,15 +498,7 @@ export const DocsPage: React.FC = () => {
-

+

文档管理内法规检索

@@ -582,86 +506,37 @@ export const DocsPage: React.FC = () => { setSearchQuery(event.target.value)} - onKeyDown={(event) => { - if (event.key === 'Enter') { - void runSearch(searchQuery); - } - }} + onKeyDown={(event) => { if (event.key === 'Enter') void runSearch(searchQuery); }} placeholder="输入法规关键词、条款或制度主题" - style={{ - flex: 1, - padding: 12, - fontSize: 14, - background: theme.bgCard, - border: `1px solid ${theme.border}`, - borderRadius: 8, - color: theme.text, - outline: 'none', - }} + style={{ flex: 1, padding: 12, fontSize: 14, background: theme.bgCard, border: `1px solid ${theme.border}`, borderRadius: 8, color: theme.text, outline: 'none' }} />
- {searchError && ( -
- {searchError} -
- )} + {searchError &&
{searchError}
}
{searchResults.map((item) => ( -
+
{item.file}
-
- {(item.score * 100).toFixed(1)}% -
+
{(item.score * 100).toFixed(1)}%
-
- {item.clause} - {item.tags.length > 0 ? ` · ${item.tags.join(' · ')}` : ''} + {item.clause}{item.tags.length > 0 ? ` · ${item.tags.join(' · ')}` : ''}
-
{item.content}
))} {!searchLoading && searchResults.length === 0 && ( -
+
暂无检索结果
)} diff --git a/frontend/src/pages/RagChat/CitedAnswer.tsx b/frontend/src/pages/RagChat/CitedAnswer.tsx new file mode 100644 index 0000000..42ccfef --- /dev/null +++ b/frontend/src/pages/RagChat/CitedAnswer.tsx @@ -0,0 +1,58 @@ +import React, { useRef } from 'react'; +import { useTheme } from '../../contexts'; +import type { RetrievalData } from '../../types'; + +interface CitedAnswerProps { + text: string; + sources: RetrievalData[]; + onCiteClick: (index: number) => void; +} + +export const CitedAnswer: React.FC = ({ text, sources, onCiteClick }) => { + const { theme } = useTheme(); + + if (!text) return null; + + // Split on [N] patterns, preserving delimiters + const parts = text.split(/(\[\d+\])/g); + + return ( + + {parts.map((part, i) => { + const match = part.match(/^\[(\d+)\]$/); + if (!match) return {part}; + + const idx = parseInt(match[1], 10); + const source = sources[idx - 1]; + + return ( + onCiteClick(idx)} + style={{ + display: 'inline-flex', + alignItems: 'center', + justifyContent: 'center', + minWidth: 18, + height: 18, + padding: '0 4px', + marginLeft: 1, + fontSize: 10, + fontWeight: 700, + background: theme.gradientAccent, + color: '#fff', + borderRadius: 4, + cursor: source ? 'pointer' : 'default', + verticalAlign: 'super', + lineHeight: 1, + userSelect: 'none', + }} + > + {idx} + + ); + })} + + ); +}; diff --git a/frontend/src/pages/RagChat/RagChatPage.tsx b/frontend/src/pages/RagChat/RagChatPage.tsx index 33a804f..e731d7d 100644 --- a/frontend/src/pages/RagChat/RagChatPage.tsx +++ b/frontend/src/pages/RagChat/RagChatPage.tsx @@ -2,6 +2,7 @@ import React, { useEffect, useRef, useState } from 'react'; import { useTheme } from '../../contexts'; import type { ChatMessage, RetrievalData } from '../../types'; import { getQuickQuestions, ragChat } from '../../api/rag'; +import { CitedAnswer } from './CitedAnswer'; const ragQuickQuestionsDefault = [ '电动自行车上路需要什么条件?', @@ -24,6 +25,8 @@ export const RagChatPage: React.FC = () => { const [showClearConfirm, setShowClearConfirm] = useState(false); const [selectedRetrieval, setSelectedRetrieval] = useState(null); const [quickQuestions, setQuickQuestions] = useState(ragQuickQuestionsDefault); + const [filterRegulationType, setFilterRegulationType] = useState(''); + const [highlightedSourceIdx, setHighlightedSourceIdx] = useState(null); function nextMessageId() { const currentId = nextMessageIdRef.current; @@ -55,8 +58,10 @@ export const RagChatPage: React.FC = () => { setInput(''); setLoading(true); setRetrievals([]); + setHighlightedSourceIdx(null); let currentResponse = ''; + const activeFilters = filterRegulationType.trim() || undefined; void ragChat( text, @@ -112,7 +117,8 @@ export const RagChatPage: React.FC = () => { }, () => { setLoading(false); - } + }, + activeFilters ); }; @@ -130,8 +136,10 @@ export const RagChatPage: React.FC = () => { setLoading(true); setMessages((prev) => [...prev.slice(0, -1)]); setRetrievals([]); + setHighlightedSourceIdx(null); let currentResponse = ''; + const activeFilters = filterRegulationType.trim() || undefined; void ragChat( lastUserMsg.content, @@ -185,7 +193,8 @@ export const RagChatPage: React.FC = () => { }, () => { setLoading(false); - } + }, + activeFilters ); }; @@ -267,7 +276,17 @@ export const RagChatPage: React.FC = () => { whiteSpace: 'pre-wrap', border: msg.role === 'assistant' ? `1px solid ${theme.border}` : 'none', }}> - {msg.content} + {msg.role === 'assistant' ? ( + { + setHighlightedSourceIdx(idx); + const el = document.getElementById(`source-${idx}`); + if (el) el.scrollIntoView({ behavior: 'smooth', block: 'center' }); + }} + /> + ) : msg.content} {msg.role === 'assistant' && msg.retrievalIds && msg.retrievalIds.length > 0 && (
{ background: theme.bg, borderTop: `1px solid ${theme.border}`, }}> +
+ 法规类型 + setFilterRegulationType(e.target.value)} + placeholder="如: GB / UN-ECE / IATF(留空不过滤)" + style={{ + flex: 1, + maxWidth: 280, + padding: '5px 10px', + fontSize: 12, + background: theme.bgHover, + border: `1px solid ${theme.border}`, + borderRadius: 6, + color: theme.text, + outline: 'none', + }} + /> +
+
{ {retrievals.map((r, i) => (
setSelectedRetrieval(r)} style={{ padding: 16, - background: theme.bgHover, + background: highlightedSourceIdx === i + 1 ? theme.bgElevated : theme.bgHover, borderRadius: 10, - border: `1px solid ${theme.border}`, + border: `1px solid ${highlightedSourceIdx === i + 1 ? theme.accent : theme.border}`, cursor: 'pointer', position: 'relative', + transition: 'border-color 0.2s, background 0.2s', }} >
{ const env = loadEnv(mode, process.cwd(), '') - const apiHost = env.API_HOST || 'localhost' + const apiHost = env.API_HOST || '6.86.80.8' const apiPort = env.API_PORT || '8000' const proxyTarget = env.VITE_API_PROXY_TARGET || `http://${apiHost}:${apiPort}`