fix 文档管理模块 & 法规对话模块

This commit is contained in:
2026-05-20 23:34:08 +08:00
parent c22b03dc07
commit b065d55c86
39 changed files with 1671 additions and 540 deletions

15
.env
View File

@@ -7,26 +7,29 @@ APP_VERSION=0.1.0
DEBUG=false DEBUG=false
# ===== Milvus向量数据库配置已有===== # ===== Milvus向量数据库配置已有=====
MILVUS_HOST=localhost MILVUS_HOST=6.86.80.8
MILVUS_PORT=19530 MILVUS_PORT=19530
MILVUS_COLLECTION=regulations_dense_1024_v1 MILVUS_COLLECTION=regulations_dense_1024_v1
MILVUS_DB_NAME=default MILVUS_DB_NAME=default
MILVUS_INDEX_TYPE=IVF_FLAT
MILVUS_NLIST=128
MILVUS_NPROBE=16
# ===== MinIO对象存储配置已有===== # ===== MinIO对象存储配置已有=====
MINIO_ENDPOINT=localhost:9000 MINIO_ENDPOINT=6.86.80.8:9000
MINIO_ACCESS_KEY=minioadmin MINIO_ACCESS_KEY=minioadmin
MINIO_SECRET_KEY=minioadmin MINIO_SECRET_KEY=minioadmin
MINIO_BUCKET=compliance-docs MINIO_BUCKET=compliance-docs
MINIO_SECURE=false MINIO_SECURE=false
# ===== Redis配置已有===== # ===== Redis配置已有=====
REDIS_HOST=localhost REDIS_HOST=6.86.80.8
REDIS_PORT=6379 REDIS_PORT=6379
REDIS_PASSWORD=redis@123 REDIS_PASSWORD=redis@123
REDIS_DB=0 REDIS_DB=0
# ===== PostgreSQL配置已有===== # ===== PostgreSQL配置已有=====
POSTGRES_HOST=localhost POSTGRES_HOST=6.86.80.8
POSTGRES_PORT=5432 POSTGRES_PORT=5432
POSTGRES_USER=postgresql POSTGRES_USER=postgresql
POSTGRES_PASSWORD=postgresql123456 POSTGRES_PASSWORD=postgresql123456
@@ -43,6 +46,10 @@ EMBEDDING_TIMEOUT_SECONDS=120
CHUNK_SIZE=512 CHUNK_SIZE=512
CHUNK_OVERLAP=50 CHUNK_OVERLAP=50
MAX_FILE_SIZE_MB=100 MAX_FILE_SIZE_MB=100
PARSER_BACKEND=aliyun
CHUNK_BACKEND=aliyun
# 文档元数据存储后端json默认或 postgres
DOCUMENT_REPOSITORY_BACKEND=json
# ===== API配置 ===== # ===== API配置 =====
API_HOST=0.0.0.0 API_HOST=0.0.0.0

View File

@@ -6,6 +6,9 @@ MILVUS_HOST=6.86.80.8
MILVUS_PORT=19530 MILVUS_PORT=19530
MILVUS_COLLECTION=regulations_dense_1024_v1 MILVUS_COLLECTION=regulations_dense_1024_v1
MILVUS_DB_NAME=default MILVUS_DB_NAME=default
MILVUS_INDEX_TYPE=IVF_FLAT
MILVUS_NLIST=128
MILVUS_NPROBE=16
# ===== MinIO对象存储配置已有===== # ===== MinIO对象存储配置已有=====
MINIO_ENDPOINT=6.86.80.8:9000 MINIO_ENDPOINT=6.86.80.8:9000
@@ -26,3 +29,7 @@ POSTGRES_PORT=5432
POSTGRES_USER=postgresql POSTGRES_USER=postgresql
POSTGRES_PASSWORD=postgresql123456 POSTGRES_PASSWORD=postgresql123456
POSTGRES_DB=compliance_db POSTGRES_DB=compliance_db
# ===== 文档元数据后端 =====
# 改为 postgres 以启用 PG 持久化structure_nodes + semantic_blocks 入库)
DOCUMENT_REPOSITORY_BACKEND=json

View File

@@ -7,10 +7,13 @@ APP_VERSION=0.1.0
DEBUG=false DEBUG=false
# ===== Milvus向量数据库配置 ===== # ===== Milvus向量数据库配置 =====
MILVUS_HOST=localhost MILVUS_HOST=6.86.80.8
MILVUS_PORT=19530 MILVUS_PORT=19530
MILVUS_COLLECTION=regulations_dense_1024_v1 MILVUS_COLLECTION=regulations_dense_1024_v1
MILVUS_DB_NAME=default MILVUS_DB_NAME=default
MILVUS_INDEX_TYPE=IVF_FLAT
MILVUS_NLIST=128
MILVUS_NPROBE=16
# ===== 嵌入模型配置 ===== # ===== 嵌入模型配置 =====
EMBEDDING_MODEL=text-embedding-v3 EMBEDDING_MODEL=text-embedding-v3
@@ -20,20 +23,20 @@ EMBEDDING_BASE_URL=http://6.86.80.4:30080/v1
EMBEDDING_TIMEOUT_SECONDS=120 EMBEDDING_TIMEOUT_SECONDS=120
# ===== MinIO对象存储配置 ===== # ===== MinIO对象存储配置 =====
MINIO_ENDPOINT=localhost:9000 MINIO_ENDPOINT=6.86.80.8:9000
MINIO_ACCESS_KEY=minioadmin MINIO_ACCESS_KEY=minioadmin
MINIO_SECRET_KEY=minioadmin123 MINIO_SECRET_KEY=minioadmin123
MINIO_BUCKET=compliance-docs MINIO_BUCKET=compliance-docs
MINIO_SECURE=false MINIO_SECURE=false
# ===== Redis配置 ===== # ===== Redis配置 =====
REDIS_HOST=localhost REDIS_HOST=6.86.80.8
REDIS_PORT=6379 REDIS_PORT=6379
REDIS_PASSWORD= REDIS_PASSWORD=
REDIS_DB=0 REDIS_DB=0
# ===== PostgreSQL配置 ===== # ===== PostgreSQL配置 =====
POSTGRES_HOST=localhost POSTGRES_HOST=6.86.80.8
POSTGRES_PORT=5432 POSTGRES_PORT=5432
POSTGRES_USER=compliance POSTGRES_USER=compliance
POSTGRES_PASSWORD=compliance123 POSTGRES_PASSWORD=compliance123
@@ -46,6 +49,8 @@ MAX_FILE_SIZE_MB=100
DOCUMENT_METADATA_PATH=backend/data/documents.json DOCUMENT_METADATA_PATH=backend/data/documents.json
PARSER_BACKEND=aliyun PARSER_BACKEND=aliyun
CHUNK_BACKEND=aliyun CHUNK_BACKEND=aliyun
# 文档元数据存储后端json默认无需数据库或 postgres启用 PG 持久化)
DOCUMENT_REPOSITORY_BACKEND=json
# ===== 阿里云文档解析 ===== # ===== 阿里云文档解析 =====
ALIBABA_ACCESS_KEY_ID=your_aliyun_access_key_id ALIBABA_ACCESS_KEY_ID=your_aliyun_access_key_id
@@ -88,9 +93,18 @@ DEEPSEEK_MODEL=deepseek-v4-flash
# ===== RAG配置 ===== # ===== RAG配置 =====
RAG_TOP_K=10 RAG_TOP_K=10
RAG_RETRIEVAL_TOP_K=20
RAG_MAX_CONTEXT_TOKENS=4000 RAG_MAX_CONTEXT_TOKENS=4000
RAG_SUMMARY_MAX_TOKENS=1024 RAG_SUMMARY_MAX_TOKENS=1024
# ===== Reranker配置Cross-Encoder精排默认关闭=====
# 设置 RERANKER_ENABLED=true 并配置 RERANKER_BASE_URL 以启用精排
RERANKER_ENABLED=false
RERANKER_BASE_URL=
RERANKER_MODEL=BAAI/bge-reranker-v2-m3
RERANKER_API_KEY=
RERANKER_TOP_K=5
# ===== 会话配置 ===== # ===== 会话配置 =====
SESSION_MAX_SESSIONS=100 SESSION_MAX_SESSIONS=100
SESSION_TIMEOUT_MINUTES=30 SESSION_TIMEOUT_MINUTES=30

View File

@@ -23,6 +23,8 @@ class DocumentUploadResponse(BaseModel):
num_chunks: int = Field(default=0, description="分块数量") num_chunks: int = Field(default=0, description="分块数量")
summary: str = Field(default="", description="LLM生成的文档摘要") summary: str = Field(default="", description="LLM生成的文档摘要")
summary_latency_ms: int = Field(default=0, description="摘要生成耗时(ms)") summary_latency_ms: int = Field(default=0, description="摘要生成耗时(ms)")
regulation_type: str = Field(default="", description="法规类型")
version: str = Field(default="", description="文档版本号")
timestamp: datetime = Field(default_factory=datetime.now, description="时间戳") timestamp: datetime = Field(default_factory=datetime.now, description="时间戳")

View File

@@ -14,30 +14,22 @@ from app.schemas.compliance import (
AnalyzeResponse, AnalyzeResponse,
ComplianceChatRequest, ComplianceChatRequest,
) )
from app.services.mock_data import ( from app.services.mock_data import generate_task_id, get_mock_compliance_result
generate_task_id, from app.shared.bootstrap import get_agent_conversation_service
get_mock_compliance_result,
get_mock_compliance_chat_response,
)
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/compliance", tags=["合规分析"]) router = APIRouter(prefix="/compliance", tags=["合规分析"])
# Keep route handlers close to their transport-layer wiring for easier auditing.
tasks_store: dict[str, dict] = {} tasks_store: dict[str, dict] = {}
# Store uploaded compliance files inside the local backend data directory.
RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw" RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw"
@router.post("/analyze", response_model=AnalyzeResponse) @router.post("/analyze", response_model=AnalyzeResponse)
async def analyze_document(file: UploadFile = File(...)): async def analyze_document(file: UploadFile = File(...)):
"""Handle analyze document.""" """Handle analyze document."""
# Keep route handlers close to their transport-layer wiring for easier auditing.
task_id = generate_task_id() task_id = generate_task_id()
# Keep route handlers close to their transport-layer wiring for easier auditing.
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True) RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}" file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}"
@@ -45,7 +37,6 @@ async def analyze_document(file: UploadFile = File(...)):
with file_path.open("wb") as f: with file_path.open("wb") as f:
f.write(content) f.write(content)
# Keep route handlers close to their transport-layer wiring for easier auditing.
tasks_store[task_id] = { tasks_store[task_id] = {
"task_id": task_id, "task_id": task_id,
"file_path": str(file_path), "file_path": str(file_path),
@@ -53,8 +44,6 @@ async def analyze_document(file: UploadFile = File(...)):
"result": None, "result": None,
} }
# Keep route handlers close to their transport-layer wiring for easier auditing.
# Keep route handlers close to their transport-layer wiring for easier auditing.
tasks_store[task_id]["status"] = "completed" tasks_store[task_id]["status"] = "completed"
tasks_store[task_id]["result"] = get_mock_compliance_result(task_id) tasks_store[task_id]["result"] = get_mock_compliance_result(task_id)
@@ -65,47 +54,44 @@ async def analyze_document(file: UploadFile = File(...)):
async def get_result(task_id: str): async def get_result(task_id: str):
"""Return result.""" """Return result."""
if task_id not in tasks_store: if task_id not in tasks_store:
# Keep route handlers close to their transport-layer wiring for easier auditing.
return get_mock_compliance_result(task_id) return get_mock_compliance_result(task_id)
task = tasks_store[task_id] task = tasks_store[task_id]
if task["status"] == "processing": if task["status"] == "processing":
return {"status": "processing", "message": "分析进行中"} return {"status": "processing", "message": "分析进行中"}
return task["result"] return task["result"]
@router.post("/chat/{segment_id}") @router.post("/chat/{segment_id}")
async def compliance_chat(segment_id: int, request: ComplianceChatRequest): async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
"""Handle compliance chat.""" """Stream compliance Q&A grounded in real vector retrieval."""
# Keep route handlers close to their transport-layer wiring for easier auditing. query = request.query
intent_map = { if request.segment_context:
1: "车身结构设计", query = f"[段落分析上下文]\n{request.segment_context}\n\n用户问题:{request.query}"
2: "动力系统配置",
3: "安全配置设计",
}
intent = intent_map.get(segment_id, "车身结构设计")
async def generate() -> AsyncGenerator[str, None]: _, event_stream = get_agent_conversation_service().stream_chat(
# Keep route handlers close to their transport-layer wiring for easier auditing. query=query,
"""Handle generate.""" top_k=5,
response = get_mock_compliance_chat_response(intent, request.query) prompt_template="compliance_qa",
# Keep route handlers close to their transport-layer wiring for easier auditing.
sentences = response.split("\n\n")
for sentence in sentences:
if sentence.strip():
chunks = sentence.split("\n")
for chunk in chunks:
if chunk.strip():
await asyncio.sleep(0.05)
yield (
"event: message\n"
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
) )
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" async def generate() -> AsyncGenerator[str, None]:
"""Translate agent SSE events to compliance chunk/done format."""
for event in event_stream:
event_type = event.get("event", "")
if event_type == "content":
text = event.get("data", "")
if text:
yield (
"event: message\n"
f"data: {json.dumps({'type': 'chunk', 'text': text}, ensure_ascii=False)}\n\n"
)
elif event_type == "done":
yield (
"event: message\n"
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
)
await asyncio.sleep(0)
return StreamingResponse( return StreamingResponse(
generate(), generate(),

View File

@@ -80,6 +80,8 @@ async def get_document_status(doc_id: str):
num_chunks=document.chunk_count, num_chunks=document.chunk_count,
summary=document.summary, summary=document.summary,
summary_latency_ms=document.summary_latency_ms, summary_latency_ms=document.summary_latency_ms,
regulation_type=document.regulation_type,
version=document.version,
) )
@@ -123,7 +125,7 @@ async def list_documents():
@router.get("/management-list") @router.get("/management-list")
async def get_document_management_list(): async def get_document_management_list():
"""Return document management list.""" """Return document management list."""
documents = get_document_query_service().list_documents(limit=10) documents = get_document_query_service().list_documents()
return { return {
"documents": [ "documents": [
{ {
@@ -131,10 +133,37 @@ async def get_document_management_list():
"doc_name": item.doc_name, "doc_name": item.doc_name,
"status": item.status.value, "status": item.status.value,
"chunk_count": item.chunk_count, "chunk_count": item.chunk_count,
"size_bytes": item.size_bytes,
"summary": item.summary,
"updated_at": item.updated_at.isoformat(), "updated_at": item.updated_at.isoformat(),
"regulation_type": item.regulation_type,
"version": item.version,
} }
for item in documents for item in documents
], ],
"total": len(documents), "total": len(documents),
"limit": 10,
} }
@router.delete("/{doc_id}")
async def delete_document(doc_id: str):
"""Delete a document and its associated data."""
deleted = get_document_command_service().delete(doc_id)
if not deleted:
raise HTTPException(status_code=404, detail="文档不存在")
return {"doc_id": doc_id, "deleted": True}
@router.post("/{doc_id}/retry", response_model=DocumentUploadResponse)
async def retry_document(doc_id: str):
"""Re-process a failed document."""
try:
result = get_document_command_service().retry(doc_id)
if result.status == "failed":
raise HTTPException(status_code=500, detail=result.message)
return _document_response(result)
except HTTPException:
raise
except Exception as exc:
logger.exception("文档重试失败")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -9,73 +9,74 @@ from typing import AsyncGenerator
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from app.config.settings import settings
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
from app.services.mock_data import ( from app.shared.bootstrap import get_agent_conversation_service
get_mock_quick_questions,
get_mock_retrieval,
get_mock_rag_answer,
)
# Keep route handlers close to their transport-layer wiring for easier auditing.
router = APIRouter(prefix="/rag", tags=["RAG问答"]) router = APIRouter(prefix="/rag", tags=["RAG问答"])
_DEFAULT_QUICK_QUESTIONS = [
{"id": "1", "question": "请总结最新入库法规对电池安全的核心要求", "category": "法规解读"},
{"id": "2", "question": "我上传的制度文档与新能源法规有哪些潜在冲突?", "category": "差距分析"},
{"id": "3", "question": "请给出法规依据,并按条款列出整改建议", "category": "整改建议"},
{"id": "4", "question": "请解释 UN-ECE 与 GB 标准在网络安全方面的差异", "category": "标准对比"},
{"id": "5", "question": "IATF 16949 对供应商质量管理有哪些强制要求?", "category": "法规解读"},
{"id": "6", "question": "ISO 45001 与 AQ 标准在职业健康安全方面的主要差异是什么?", "category": "标准对比"},
]
@router.post("/chat") @router.post("/chat")
async def rag_chat(request: RagChatRequest): async def rag_chat(request: RagChatRequest):
"""Handle rag chat.""" """Stream RAG Q&A using the real agent service."""
_, event_stream = get_agent_conversation_service().stream_chat(
query=request.query,
session_id=request.session_id,
filters=request.filters,
top_k=request.top_k or settings.rag_top_k,
)
async def generate() -> AsyncGenerator[str, None]: async def generate() -> AsyncGenerator[str, None]:
# Keep route handlers close to their transport-layer wiring for easier auditing. """Translate agent SSE events to rag format."""
"""Handle generate.""" for event in event_stream:
yield f"event: message\ndata: {json.dumps({'type': 'retrieving'}, ensure_ascii=False)}\n\n" event_type = event.get("event", "")
data = event.get("data", "")
# Keep route handlers close to their transport-layer wiring for easier auditing. if event_type == "sources":
await asyncio.sleep(0.3) docs = [
# Keep route handlers close to their transport-layer wiring for easier auditing.
docs = get_mock_retrieval(request.query, top_k=request.top_k)
retrieved_data = [
{ {
"id": d["id"], "id": str(s.get("chunk_id") or s.get("doc_id") or idx + 1),
"score": d["score"], "score": s.get("score", 0),
"preview": d["preview"], "preview": s.get("content", "")[:200],
"doc_name": d.get("doc_name", ""), "doc_name": s.get("doc_name", ""),
"clause": d.get("clause", ""), "clause": s.get("section_title", "法规片段"),
"doc_id": s.get("doc_id"),
"download_url": (
f"/api/v1/documents/download/{s['doc_id']}" if s.get("doc_id") else None
),
} }
for d in docs for idx, s in enumerate(data if isinstance(data, list) else [])
] ]
yield f"event: message\ndata: {json.dumps({'type': 'retrieved', 'docs': retrieved_data}, ensure_ascii=False)}\n\n"
# Keep route handlers close to their transport-layer wiring for easier auditing.
yield (
f"event: message\ndata: "
f"{json.dumps({'type': 'generating', 'text': '正在生成答案...'}, ensure_ascii=False)}\n\n"
)
# Keep route handlers close to their transport-layer wiring for easier auditing.
await asyncio.sleep(0.2)
# Keep route handlers close to their transport-layer wiring for easier auditing.
answer = get_mock_rag_answer(request.query)
# Keep route handlers close to their transport-layer wiring for easier auditing.
sentences = answer.split("\n\n")
for sentence in sentences:
if sentence.strip():
# Keep route handlers close to their transport-layer wiring for easier auditing.
chunks = sentence.split("\n")
for chunk in chunks:
if chunk.strip():
await asyncio.sleep(0.05) # Keep route handlers close to their transport-layer wiring for easier auditing.
yield ( yield (
"event: message\n" "event: message\n"
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n" f"data: {json.dumps({'type': 'retrieved', 'docs': docs}, ensure_ascii=False)}\n\n"
) )
elif event_type == "content":
# Keep route handlers close to their transport-layer wiring for easier auditing. if data:
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" yield (
"event: message\n"
f"data: {json.dumps({'type': 'chunk', 'text': data}, ensure_ascii=False)}\n\n"
)
elif event_type == "done":
yield (
"event: message\n"
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
)
elif event_type == "status":
yield (
"event: message\n"
f"data: {json.dumps({'type': 'status', 'text': data}, ensure_ascii=False)}\n\n"
)
await asyncio.sleep(0)
return StreamingResponse( return StreamingResponse(
generate(), generate(),
@@ -86,9 +87,13 @@ async def rag_chat(request: RagChatRequest):
@router.get("/quick-questions", response_model=QuickQuestionsResponse) @router.get("/quick-questions", response_model=QuickQuestionsResponse)
async def get_quick_questions(): async def get_quick_questions():
"""Return quick questions.""" """Return configurable quick questions from settings or defaults."""
raw = getattr(settings, "rag_quick_questions", None)
if raw and isinstance(raw, list):
questions = [ questions = [
QuickQuestion(id=q["id"], question=q["question"], category=q["category"]) QuickQuestion(id=str(i + 1), question=q if isinstance(q, str) else q.get("question", ""), category=q.get("category", "法规问答") if isinstance(q, dict) else "法规问答")
for q in get_mock_quick_questions() for i, q in enumerate(raw)
] ]
else:
questions = [QuickQuestion(**q) for q in _DEFAULT_QUICK_QUESTIONS]
return QuickQuestionsResponse(questions=questions) return QuickQuestionsResponse(questions=questions)

View File

@@ -17,6 +17,7 @@ from app.domain.documents import (
DocumentParser, DocumentParser,
DocumentRepository, DocumentRepository,
DocumentStatus, DocumentStatus,
ParseArtifactStore,
ParsedDocument, ParsedDocument,
) )
from app.domain.retrieval import EmbeddingProvider, VectorIndex from app.domain.retrieval import EmbeddingProvider, VectorIndex
@@ -47,6 +48,7 @@ class DocumentCommandService:
chunk_builder: ChunkBuilder, chunk_builder: ChunkBuilder,
embedding_provider: EmbeddingProvider, embedding_provider: EmbeddingProvider,
vector_index: VectorIndex, vector_index: VectorIndex,
parse_artifact_store: ParseArtifactStore | None = None,
) -> None: ) -> None:
"""Initialize the Document Command Service instance.""" """Initialize the Document Command Service instance."""
self.document_repository = document_repository self.document_repository = document_repository
@@ -55,6 +57,7 @@ class DocumentCommandService:
self.chunk_builder = chunk_builder self.chunk_builder = chunk_builder
self.embedding_provider = embedding_provider self.embedding_provider = embedding_provider
self.vector_index = vector_index self.vector_index = vector_index
self.parse_artifact_store = parse_artifact_store
def _save_parse_artifacts(self, *, doc_id: str, parsed_document: ParsedDocument) -> dict[str, str]: def _save_parse_artifacts(self, *, doc_id: str, parsed_document: ParsedDocument) -> dict[str, str]:
"""Persist parse artifacts so troubleshooting does not depend on provider retention windows.""" """Persist parse artifacts so troubleshooting does not depend on provider retention windows."""
@@ -143,6 +146,15 @@ class DocumentCommandService:
"processing_stage": "parsed", "processing_stage": "parsed",
}, },
) )
if self.parse_artifact_store:
try:
self.parse_artifact_store.save(
doc_id,
parsed_document.structure_nodes,
parsed_document.semantic_blocks,
)
except Exception:
logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id)
chunks = self.chunk_builder.build( chunks = self.chunk_builder.build(
parsed_document=parsed_document, parsed_document=parsed_document,
@@ -205,20 +217,120 @@ class DocumentCommandService:
logger.warning("临时文件清理失败: {}", temp_path) logger.warning("临时文件清理失败: {}", temp_path)
def delete(self, doc_id: str) -> bool:
"""Delete document record, binary file, and vector chunks."""
document = self.document_repository.get(doc_id)
if not document:
return False
try:
self.binary_store.delete(document.object_name)
except Exception:
logger.warning("Binary delete failed for doc_id={}", doc_id)
try:
self.vector_index.delete_by_document(doc_id)
except Exception:
logger.warning("Vector delete failed for doc_id={}", doc_id)
if self.parse_artifact_store:
try:
self.parse_artifact_store.delete(doc_id)
except Exception:
logger.warning("ParseArtifactStore delete failed for doc_id={}", doc_id)
self.document_repository.delete(doc_id)
return True
def retry(self, doc_id: str) -> DocumentProcessResult:
"""Re-process a failed document from its stored binary."""
document = self.document_repository.get(doc_id)
if not document:
return DocumentProcessResult(doc_id=doc_id, doc_name="", status="failed", message="文档不存在")
content = self.binary_store.read(document.object_name)
return self.upload_and_process(
doc_id=doc_id,
file_name=document.file_name,
content=content,
content_type=document.content_type,
doc_name=document.doc_name,
regulation_type=document.regulation_type,
version=document.version,
generate_summary=bool(document.metadata.get("generate_summary", False)),
)
class DocumentQueryService: class DocumentQueryService:
"""Provide the Document Query Service service.""" """Provide the Document Query Service service."""
def __init__(self, *, document_repository: DocumentRepository, binary_store: DocumentBinaryStore) -> None: def __init__(self, *, document_repository: DocumentRepository, binary_store: DocumentBinaryStore, vector_index: VectorIndex) -> None:
"""Initialize the Document Query Service instance.""" """Initialize the Document Query Service instance."""
self.document_repository = document_repository self.document_repository = document_repository
self.binary_store = binary_store self.binary_store = binary_store
self.vector_index = vector_index
def get(self, doc_id: str) -> Document | None: def get(self, doc_id: str) -> Document | None:
"""Handle get for the Document Query Service instance.""" """Handle get for the Document Query Service instance."""
return self.document_repository.get(doc_id) return self.document_repository.get(doc_id)
def list_documents(self, limit: int | None = None) -> list[Document]: def list_documents(self, limit: int | None = None) -> list[Document]:
"""List documents for the Document Query Service instance.""" """Return documents with real-time state from Milvus as the authoritative source.
return self.document_repository.list(limit=limit)
Algorithm:
1. Query Milvus for all doc metadata (doc_id, doc_name, chunk_count, …).
2. Load JSON/PG metadata records and index them by doc_id.
3. Merge: Milvus-present docs get status=INDEXED and live chunk_count;
metadata-only docs with status=INDEXED are demoted to FAILED.
4. Milvus-only docs (no metadata record) are surfaced as synthetic INDEXED
entries so they are never invisible to the management list.
"""
# Fetch live Milvus state first.
try:
milvus_rows = self.vector_index.list_document_metadata()
except Exception:
milvus_rows = []
milvus_by_id: dict[str, dict] = {r["doc_id"]: r for r in milvus_rows}
# Load metadata store records.
meta_docs = self.document_repository.list(limit=limit)
meta_by_id: dict[str, Document] = {d.doc_id: d for d in meta_docs}
result: list[Document] = []
# Reconcile metadata records against Milvus.
for doc in meta_docs:
if doc.doc_id in milvus_by_id:
row = milvus_by_id[doc.doc_id]
doc.chunk_count = row["chunk_count"]
doc.status = DocumentStatus.INDEXED
# Backfill fields that may be missing from older JSON records.
if not doc.doc_name and row.get("doc_name"):
doc.doc_name = row["doc_name"]
if not doc.regulation_type and row.get("regulation_type"):
doc.regulation_type = row["regulation_type"]
if not doc.version and row.get("version"):
doc.version = row["version"]
elif doc.status == DocumentStatus.INDEXED:
# Metadata says indexed but Milvus has no chunks.
doc.status = DocumentStatus.FAILED
doc.error_message = "向量数据库中未找到对应数据"
result.append(doc)
# Surface Milvus-only docs that have no metadata record at all.
for doc_id, row in milvus_by_id.items():
if doc_id not in meta_by_id:
synthetic = Document(
doc_id=doc_id,
doc_name=row.get("doc_name", doc_id),
file_name=row.get("doc_name", doc_id),
object_name="",
content_type="",
size_bytes=0,
status=DocumentStatus.INDEXED,
regulation_type=row.get("regulation_type", ""),
version=row.get("version", ""),
chunk_count=row["chunk_count"],
)
result.append(synthetic)
result.sort(key=lambda d: d.updated_at, reverse=True)
return result[:limit] if limit is not None else result
def download(self, doc_id: str) -> tuple[Document, bytes]: def download(self, doc_id: str) -> tuple[Document, bytes]:
"""Handle download for the Document Query Service instance.""" """Handle download for the Document Query Service instance."""

View File

@@ -3,17 +3,24 @@
from __future__ import annotations from __future__ import annotations
from app.domain.retrieval import RetrievalQuery, Retriever, RetrievedChunk from app.domain.retrieval import RetrievalQuery, Retriever, RetrievedChunk
from app.domain.retrieval.ports import Reranker
# Keep orchestration logic centralized so use-case flow stays easy to trace. # Keep orchestration logic centralized so use-case flow stays easy to trace.
class KnowledgeRetrievalService: class KnowledgeRetrievalService:
"""Provide the Knowledge Retrieval Service service.""" """Provide the Knowledge Retrieval Service service."""
def __init__(self, *, retriever: Retriever) -> None:
def __init__(self, *, retriever: Retriever, reranker: Reranker | None = None, reranker_top_k: int = 5) -> None:
"""Initialize the Knowledge Retrieval Service instance.""" """Initialize the Knowledge Retrieval Service instance."""
self.retriever = retriever self.retriever = retriever
self.reranker = reranker
self.reranker_top_k = reranker_top_k
def retrieve(self, *, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]: def retrieve(self, *, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
"""Handle retrieve for the Knowledge Retrieval Service instance.""" """Retrieve and optionally rerank chunks for a query."""
retrieval_query = RetrievalQuery(query=query, top_k=top_k, filters=filters) candidate_k = top_k if self.reranker is None else max(top_k * 4, 20)
return self.retriever.retrieve(retrieval_query) retrieval_query = RetrievalQuery(query=query, top_k=candidate_k, filters=filters)
candidates = self.retriever.retrieve(retrieval_query)
if self.reranker and candidates:
return self.reranker.rerank(query, candidates, top_k=self.reranker_top_k)
return candidates[:top_k]

View File

@@ -31,7 +31,7 @@ class Settings(BaseSettings):
debug: bool = Field(default=False, description="调试模式") debug: bool = Field(default=False, description="调试模式")
# Keep configuration setup explicit so runtime behavior is easy to reason about. # Keep configuration setup explicit so runtime behavior is easy to reason about.
milvus_host: str = Field(default="localhost", description="Milvus服务地址") milvus_host: str = Field(default="6.86.80.8", description="Milvus服务地址")
milvus_port: int = Field(default=19530, description="Milvus服务端口") milvus_port: int = Field(default=19530, description="Milvus服务端口")
milvus_collection: str = Field(default="regulations_dense_1024_v1", description="法规向量集合名称") milvus_collection: str = Field(default="regulations_dense_1024_v1", description="法规向量集合名称")
milvus_db_name: str = Field(default="default", description="Milvus数据库名称") milvus_db_name: str = Field(default="default", description="Milvus数据库名称")
@@ -54,20 +54,20 @@ class Settings(BaseSettings):
parser_failure_mode: str = Field(default="fail", description="解析失败策略") parser_failure_mode: str = Field(default="fail", description="解析失败策略")
# Keep configuration setup explicit so runtime behavior is easy to reason about. # Keep configuration setup explicit so runtime behavior is easy to reason about.
minio_endpoint: str = Field(default="localhost:9000", description="MinIO服务地址") minio_endpoint: str = Field(default="6.86.80.8:9000", description="MinIO服务地址")
minio_access_key: str = Field(default="minioadmin", description="MinIO访问密钥") minio_access_key: str = Field(default="minioadmin", description="MinIO访问密钥")
minio_secret_key: str = Field(default="minioadmin123", description="MinIO秘密密钥") minio_secret_key: str = Field(default="minioadmin123", description="MinIO秘密密钥")
minio_bucket: str = Field(default="upload-files", description="文档存储桶名称") minio_bucket: str = Field(default="upload-files", description="文档存储桶名称")
minio_secure: bool = Field(default=False, description="是否使用HTTPS") minio_secure: bool = Field(default=False, description="是否使用HTTPS")
# Keep configuration setup explicit so runtime behavior is easy to reason about. # Keep configuration setup explicit so runtime behavior is easy to reason about.
redis_host: str = Field(default="localhost", description="Redis服务地址") redis_host: str = Field(default="6.86.80.8", description="Redis服务地址")
redis_port: int = Field(default=6379, description="Redis服务端口") redis_port: int = Field(default=6379, description="Redis服务端口")
redis_password: str = Field(default="", description="Redis密码") redis_password: str = Field(default="", description="Redis密码")
redis_db: int = Field(default=0, description="Redis数据库编号") redis_db: int = Field(default=0, description="Redis数据库编号")
# Keep configuration setup explicit so runtime behavior is easy to reason about. # Keep configuration setup explicit so runtime behavior is easy to reason about.
postgres_host: str = Field(default="localhost", description="PostgreSQL服务地址") postgres_host: str = Field(default="6.86.80.8", description="PostgreSQL服务地址")
postgres_port: int = Field(default=5432, description="PostgreSQL服务端口") postgres_port: int = Field(default=5432, description="PostgreSQL服务端口")
postgres_user: str = Field(default="compliance", description="PostgreSQL用户名") postgres_user: str = Field(default="compliance", description="PostgreSQL用户名")
postgres_password: str = Field(default="compliance123", description="PostgreSQL密码") postgres_password: str = Field(default="compliance123", description="PostgreSQL密码")
@@ -80,6 +80,7 @@ class Settings(BaseSettings):
document_metadata_path: str = Field(default="backend/data/documents.json", description="文档元数据存储路径") document_metadata_path: str = Field(default="backend/data/documents.json", description="文档元数据存储路径")
parser_backend: str = Field(default="aliyun", description="解析后端(local/aliyun)") parser_backend: str = Field(default="aliyun", description="解析后端(local/aliyun)")
chunk_backend: str = Field(default="aliyun", description="分块后端(local/aliyun)") chunk_backend: str = Field(default="aliyun", description="分块后端(local/aliyun)")
document_repository_backend: str = Field(default="json", description="文档元数据存储后端 (json/postgres)")
# Keep configuration setup explicit so runtime behavior is easy to reason about. # Keep configuration setup explicit so runtime behavior is easy to reason about.
api_host: str = Field(default="0.0.0.0", description="API服务地址") api_host: str = Field(default="0.0.0.0", description="API服务地址")
@@ -104,9 +105,16 @@ class Settings(BaseSettings):
# Keep configuration setup explicit so runtime behavior is easy to reason about. # Keep configuration setup explicit so runtime behavior is easy to reason about.
rag_top_k: int = Field(default=5, description="检索召回数量") rag_top_k: int = Field(default=5, description="检索召回数量")
rag_retrieval_top_k: int = Field(default=20, description="精排前召回候选数量reranker 启用时生效)")
rag_max_context_tokens: int = Field(default=2000, description="RAG最大上下文token数") rag_max_context_tokens: int = Field(default=2000, description="RAG最大上下文token数")
rag_summary_max_tokens: int = Field(default=10240, description="文档摘要最大token数") rag_summary_max_tokens: int = Field(default=10240, description="文档摘要最大token数")
reranker_enabled: bool = Field(default=False, description="是否启用 Cross-Encoder 精排")
reranker_base_url: str = Field(default="", description="Reranker API 地址")
reranker_model: str = Field(default="BAAI/bge-reranker-v2-m3", description="Reranker 模型名称")
reranker_api_key: str = Field(default="", description="Reranker API 密钥")
reranker_top_k: int = Field(default=5, description="精排后保留的最终结果数量")
# Keep configuration setup explicit so runtime behavior is easy to reason about. # Keep configuration setup explicit so runtime behavior is easy to reason about.
milvus_index_type: str = Field(default="IVF_FLAT", description="Milvus索引类型") milvus_index_type: str = Field(default="IVF_FLAT", description="Milvus索引类型")
milvus_nlist: int = Field(default=128, description="Milvus nlist参数") milvus_nlist: int = Field(default=128, description="Milvus nlist参数")

View File

@@ -25,7 +25,7 @@ class Settings(BaseSettings):
dashscope_api_key: str = "" dashscope_api_key: str = ""
# Milvus # Milvus
milvus_host: str = "localhost" milvus_host: str = "6.86.80.8"
milvus_port: int = 19530 milvus_port: int = 19530
milvus_collection: str = "regulations_dense_1024_v1" milvus_collection: str = "regulations_dense_1024_v1"

View File

@@ -1,7 +1,7 @@
"""Initialize the app.domain.documents package.""" """Initialize the app.domain.documents package."""
from .models import Chunk, Document, DocumentStatus, ParsedDocument from .models import Chunk, Document, DocumentStatus, ParsedDocument
from .ports import ChunkBuilder, DocumentBinaryStore, DocumentParser, DocumentRepository from .ports import ChunkBuilder, DocumentBinaryStore, DocumentParser, DocumentRepository, ParseArtifactStore
# Keep package boundaries explicit so backend imports stay predictable. # Keep package boundaries explicit so backend imports stay predictable.
@@ -14,4 +14,5 @@ __all__ = [
"DocumentBinaryStore", "DocumentBinaryStore",
"DocumentParser", "DocumentParser",
"DocumentRepository", "DocumentRepository",
"ParseArtifactStore",
] ]

View File

@@ -31,6 +31,11 @@ class DocumentRepository(ABC):
"""Handle list for the Document Repository instance.""" """Handle list for the Document Repository instance."""
pass pass
@abstractmethod
def delete(self, doc_id: str) -> bool:
"""Delete a document record. Returns True if deleted, False if not found."""
pass
@abstractmethod @abstractmethod
def update_status( def update_status(
self, self,
@@ -94,3 +99,32 @@ class ChunkBuilder(ABC):
) -> list[Chunk]: ) -> list[Chunk]:
"""Handle build for the Chunk Builder instance.""" """Handle build for the Chunk Builder instance."""
pass pass
class ParseArtifactStore(ABC):
"""Persist parse artifacts (structure nodes and semantic blocks) for relational queries."""
@abstractmethod
def save(
self,
doc_id: str,
structure_nodes: list[dict],
semantic_blocks: list[dict],
) -> None:
"""Persist structure nodes and semantic blocks for a document."""
pass
@abstractmethod
def delete(self, doc_id: str) -> None:
"""Remove all parse artifacts for a document."""
pass
@abstractmethod
def get_semantic_blocks(self, doc_id: str) -> list[dict]:
"""Return all semantic blocks for a document."""
pass
@abstractmethod
def get_structure_nodes(self, doc_id: str) -> list[dict]:
"""Return all structure nodes for a document."""
pass

View File

@@ -1,8 +1,8 @@
"""Initialize the app.domain.retrieval package.""" """Initialize the app.domain.retrieval package."""
from .models import RetrievalQuery, RetrievedChunk from .models import RetrievalQuery, RetrievedChunk
from .ports import EmbeddingProvider, Retriever, VectorIndex from .ports import EmbeddingProvider, Reranker, Retriever, VectorIndex
# Keep package boundaries explicit so backend imports stay predictable. # Keep package boundaries explicit so backend imports stay predictable.
__all__ = ["RetrievalQuery", "RetrievedChunk", "EmbeddingProvider", "Retriever", "VectorIndex"] __all__ = ["RetrievalQuery", "RetrievedChunk", "EmbeddingProvider", "Reranker", "Retriever", "VectorIndex"]

View File

@@ -10,7 +10,6 @@ from .models import RetrievalQuery, RetrievedChunk
# Keep domain contracts explicit so adapters can swap implementations cleanly. # Keep domain contracts explicit so adapters can swap implementations cleanly.
class EmbeddingProvider(ABC): class EmbeddingProvider(ABC):
"""Provide the Embedding Provider provider.""" """Provide the Embedding Provider provider."""
@abstractmethod @abstractmethod
@@ -41,12 +40,35 @@ class VectorIndex(ABC):
"""Handle search for the Vector Index instance.""" """Handle search for the Vector Index instance."""
pass pass
@abstractmethod
def count_by_document(self) -> dict[str, int]:
"""Return a mapping of doc_id -> chunk count from the vector store."""
pass
@abstractmethod
def list_document_metadata(self) -> list[dict]:
"""Return per-document metadata rows from the vector store.
Each row contains at minimum: doc_id, doc_name, chunk_count.
Optional fields: regulation_type, version.
"""
pass
@abstractmethod @abstractmethod
def health(self) -> dict: def health(self) -> dict:
"""Handle health for the Vector Index instance.""" """Handle health for the Vector Index instance."""
pass pass
class Reranker(ABC):
"""Re-score and re-order a candidate list using a cross-encoder model."""
@abstractmethod
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
"""Return top_k chunks sorted by cross-encoder score (descending)."""
pass
class Retriever(ABC): class Retriever(ABC):
"""Provide the Retriever retriever.""" """Provide the Retriever retriever."""
@abstractmethod @abstractmethod

View File

@@ -289,7 +289,7 @@ def build_vector_chunks(
{ {
"doc_id": doc_id, "doc_id": doc_id,
"doc_title": doc_title, "doc_title": doc_title,
"chunk_id": f"chunk-{chunk_index}", "chunk_id": f"{doc_id}-chunk-{chunk_index}",
"chunk_index": chunk_index, "chunk_index": chunk_index,
"semantic_id": block["semantic_id"], "semantic_id": block["semantic_id"],
"chunk_type": block["block_type"], "chunk_type": block["block_type"],

View File

@@ -75,6 +75,15 @@ class JsonDocumentRepository(DocumentRepository):
documents.sort(key=lambda item: item.updated_at, reverse=True) documents.sort(key=lambda item: item.updated_at, reverse=True)
return documents[:limit] if limit is not None else documents return documents[:limit] if limit is not None else documents
def delete(self, doc_id: str) -> bool:
"""Delete a document record."""
payload = self._load()
if doc_id not in payload:
return False
del payload[doc_id]
self._save(payload)
return True
def update_status( def update_status(
self, self,
doc_id: str, doc_id: str,

View File

@@ -0,0 +1,228 @@
"""Implement infrastructure support for postgres document repository."""
from __future__ import annotations
import json
from contextlib import contextmanager
from datetime import UTC, datetime
from typing import Any
import psycopg2
import psycopg2.extras
from psycopg2.pool import ThreadedConnectionPool
from app.config.settings import settings
from app.domain.documents import Document, DocumentRepository, DocumentStatus
_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS documents (
doc_id VARCHAR(128) PRIMARY KEY,
doc_name VARCHAR(512) NOT NULL DEFAULT '',
file_name VARCHAR(512) NOT NULL DEFAULT '',
object_name VARCHAR(1024) NOT NULL DEFAULT '',
content_type VARCHAR(128) NOT NULL DEFAULT '',
size_bytes BIGINT NOT NULL DEFAULT 0,
status VARCHAR(32) NOT NULL DEFAULT 'pending',
regulation_type VARCHAR(128) NOT NULL DEFAULT '',
version VARCHAR(64) NOT NULL DEFAULT '',
summary TEXT NOT NULL DEFAULT '',
summary_latency_ms INTEGER NOT NULL DEFAULT 0,
chunk_count INTEGER NOT NULL DEFAULT 0,
parser_name VARCHAR(128) NOT NULL DEFAULT '',
index_name VARCHAR(128) NOT NULL DEFAULT '',
error_message TEXT NOT NULL DEFAULT '',
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
_COLUMNS = (
"doc_id", "doc_name", "file_name", "object_name", "content_type",
"size_bytes", "status", "regulation_type", "version", "summary",
"summary_latency_ms", "chunk_count", "parser_name", "index_name",
"error_message", "metadata", "created_at", "updated_at",
)
class PostgresDocumentRepository(DocumentRepository):
"""DocumentRepository implementation backed by PostgreSQL."""
def __init__(self) -> None:
self._pool = ThreadedConnectionPool(
minconn=1,
maxconn=5,
host=settings.postgres_host,
port=settings.postgres_port,
user=settings.postgres_user,
password=settings.postgres_password,
dbname=settings.postgres_db,
)
self._ensure_schema()
def _ensure_schema(self) -> None:
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute(_CREATE_TABLE)
conn.commit()
@contextmanager
def _conn(self):
conn = self._pool.getconn()
try:
yield conn
finally:
self._pool.putconn(conn)
# ------------------------------------------------------------------
# Serialization helpers
# ------------------------------------------------------------------
def _row_to_document(self, row: dict[str, Any]) -> Document:
return Document(
doc_id=row["doc_id"],
doc_name=row["doc_name"],
file_name=row["file_name"],
object_name=row["object_name"],
content_type=row["content_type"],
size_bytes=row["size_bytes"],
status=DocumentStatus(row["status"]),
regulation_type=row["regulation_type"],
version=row["version"],
summary=row["summary"],
summary_latency_ms=row["summary_latency_ms"],
chunk_count=row["chunk_count"],
parser_name=row["parser_name"],
index_name=row["index_name"],
error_message=row["error_message"],
metadata=row["metadata"] if isinstance(row["metadata"], dict) else json.loads(row["metadata"] or "{}"),
created_at=row["created_at"],
updated_at=row["updated_at"],
)
# ------------------------------------------------------------------
# DocumentRepository interface
# ------------------------------------------------------------------
def create(self, document: Document) -> Document:
sql = """
INSERT INTO documents
(doc_id, doc_name, file_name, object_name, content_type, size_bytes,
status, regulation_type, version, summary, summary_latency_ms,
chunk_count, parser_name, index_name, error_message, metadata,
created_at, updated_at)
VALUES
(%(doc_id)s, %(doc_name)s, %(file_name)s, %(object_name)s, %(content_type)s,
%(size_bytes)s, %(status)s, %(regulation_type)s, %(version)s, %(summary)s,
%(summary_latency_ms)s, %(chunk_count)s, %(parser_name)s, %(index_name)s,
%(error_message)s, %(metadata)s, %(created_at)s, %(updated_at)s)
ON CONFLICT (doc_id) DO NOTHING
"""
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, self._to_params(document))
conn.commit()
return document
def update(self, document: Document) -> Document:
document.updated_at = datetime.now(UTC)
sql = """
UPDATE documents SET
doc_name=%(doc_name)s, file_name=%(file_name)s, object_name=%(object_name)s,
content_type=%(content_type)s, size_bytes=%(size_bytes)s, status=%(status)s,
regulation_type=%(regulation_type)s, version=%(version)s, summary=%(summary)s,
summary_latency_ms=%(summary_latency_ms)s, chunk_count=%(chunk_count)s,
parser_name=%(parser_name)s, index_name=%(index_name)s,
error_message=%(error_message)s, metadata=%(metadata)s, updated_at=%(updated_at)s
WHERE doc_id=%(doc_id)s
"""
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, self._to_params(document))
conn.commit()
return document
def get(self, doc_id: str) -> Document | None:
sql = "SELECT * FROM documents WHERE doc_id = %s"
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, (doc_id,))
row = cur.fetchone()
return self._row_to_document(dict(row)) if row else None
def list(self, limit: int | None = None) -> list[Document]:
sql = "SELECT * FROM documents ORDER BY updated_at DESC"
if limit is not None:
sql += f" LIMIT {int(limit)}"
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql)
rows = cur.fetchall()
return [self._row_to_document(dict(r)) for r in rows]
def delete(self, doc_id: str) -> bool:
sql = "DELETE FROM documents WHERE doc_id = %s"
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, (doc_id,))
deleted = cur.rowcount > 0
conn.commit()
return deleted
def update_status(
self,
doc_id: str,
status: DocumentStatus,
*,
error_message: str = "",
chunk_count: int | None = None,
summary: str | None = None,
summary_latency_ms: int | None = None,
parser_name: str | None = None,
index_name: str | None = None,
metadata: dict | None = None,
) -> Document | None:
document = self.get(doc_id)
if not document:
return None
document.status = status
document.error_message = error_message
if chunk_count is not None:
document.chunk_count = chunk_count
if summary is not None:
document.summary = summary
if summary_latency_ms is not None:
document.summary_latency_ms = summary_latency_ms
if parser_name is not None:
document.parser_name = parser_name
if index_name is not None:
document.index_name = index_name
if metadata:
document.metadata.update(metadata)
return self.update(document)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _to_params(self, document: Document) -> dict[str, Any]:
return {
"doc_id": document.doc_id,
"doc_name": document.doc_name,
"file_name": document.file_name,
"object_name": document.object_name,
"content_type": document.content_type,
"size_bytes": document.size_bytes,
"status": document.status.value,
"regulation_type": document.regulation_type,
"version": document.version,
"summary": document.summary,
"summary_latency_ms": document.summary_latency_ms,
"chunk_count": document.chunk_count,
"parser_name": document.parser_name,
"index_name": document.index_name,
"error_message": document.error_message,
"metadata": json.dumps(document.metadata, ensure_ascii=False),
"created_at": document.created_at,
"updated_at": document.updated_at,
}

View File

@@ -0,0 +1,196 @@
"""Implement infrastructure support for postgres parse artifact store."""
from __future__ import annotations
import json
from contextlib import contextmanager
from typing import Any
import psycopg2
import psycopg2.extras
from psycopg2.pool import ThreadedConnectionPool
from app.config.settings import settings
from app.domain.documents import ParseArtifactStore
_CREATE_STRUCTURE_NODES = """
CREATE TABLE IF NOT EXISTS structure_nodes (
id SERIAL PRIMARY KEY,
doc_id VARCHAR(128) NOT NULL,
unique_id VARCHAR(128),
page INTEGER NOT NULL DEFAULT 0,
idx INTEGER NOT NULL DEFAULT 0,
level INTEGER NOT NULL DEFAULT 0,
title TEXT NOT NULL DEFAULT '',
type VARCHAR(64),
sub_type VARCHAR(64),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT fk_sn_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_structure_nodes_doc_id ON structure_nodes(doc_id);
"""
_CREATE_SEMANTIC_BLOCKS = """
CREATE TABLE IF NOT EXISTS semantic_blocks (
id SERIAL PRIMARY KEY,
doc_id VARCHAR(128) NOT NULL,
semantic_id VARCHAR(128) NOT NULL,
block_type VARCHAR(64) NOT NULL DEFAULT '',
page_start INTEGER NOT NULL DEFAULT 0,
page_end INTEGER NOT NULL DEFAULT 0,
section_path JSONB NOT NULL DEFAULT '[]',
section_level INTEGER NOT NULL DEFAULT 0,
section_title VARCHAR(512) NOT NULL DEFAULT '',
source_ids JSONB NOT NULL DEFAULT '[]',
text TEXT NOT NULL DEFAULT '',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT fk_sb_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE,
CONSTRAINT uq_semantic_blocks UNIQUE (doc_id, semantic_id)
);
CREATE INDEX IF NOT EXISTS idx_semantic_blocks_doc_id ON semantic_blocks(doc_id);
"""
class PostgresParseArtifactStore(ParseArtifactStore):
"""ParseArtifactStore implementation backed by PostgreSQL.
Requires the `documents` table to exist first (created by PostgresDocumentRepository).
"""
def __init__(self) -> None:
self._pool = ThreadedConnectionPool(
minconn=1,
maxconn=5,
host=settings.postgres_host,
port=settings.postgres_port,
user=settings.postgres_user,
password=settings.postgres_password,
dbname=settings.postgres_db,
)
self._ensure_schema()
def _ensure_schema(self) -> None:
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute(_CREATE_STRUCTURE_NODES)
cur.execute(_CREATE_SEMANTIC_BLOCKS)
conn.commit()
@contextmanager
def _conn(self):
conn = self._pool.getconn()
try:
yield conn
finally:
self._pool.putconn(conn)
# ------------------------------------------------------------------
# ParseArtifactStore interface
# ------------------------------------------------------------------
def save(
self,
doc_id: str,
structure_nodes: list[dict],
semantic_blocks: list[dict],
) -> None:
"""Persist structure nodes and semantic blocks, replacing any existing records."""
with self._conn() as conn:
with conn.cursor() as cur:
# Delete existing records first to keep save idempotent.
cur.execute("DELETE FROM structure_nodes WHERE doc_id = %s", (doc_id,))
cur.execute("DELETE FROM semantic_blocks WHERE doc_id = %s", (doc_id,))
if structure_nodes:
psycopg2.extras.execute_values(
cur,
"""
INSERT INTO structure_nodes
(doc_id, unique_id, page, idx, level, title, type, sub_type)
VALUES %s
""",
[
(
doc_id,
node.get("unique_id"),
int(node.get("page", 0) or 0),
int(node.get("index", 0) or 0),
int(node.get("level", 0) or 0),
str(node.get("title", "")),
node.get("type"),
node.get("sub_type"),
)
for node in structure_nodes
],
)
if semantic_blocks:
psycopg2.extras.execute_values(
cur,
"""
INSERT INTO semantic_blocks
(doc_id, semantic_id, block_type, page_start, page_end,
section_path, section_level, section_title, source_ids, text)
VALUES %s
""",
[
(
doc_id,
block.get("semantic_id", ""),
block.get("block_type", ""),
int(block.get("page_start", 0) or 0),
int(block.get("page_end", 0) or 0),
json.dumps(block.get("section_path", []), ensure_ascii=False),
int(block.get("section_level", 0) or 0),
str(block.get("section_title", "")),
json.dumps(block.get("source_ids", []), ensure_ascii=False),
str(block.get("text", "")),
)
for block in semantic_blocks
],
)
conn.commit()
def delete(self, doc_id: str) -> None:
"""Remove all parse artifacts for a document (ON DELETE CASCADE handles child rows)."""
with self._conn() as conn:
with conn.cursor() as cur:
cur.execute("DELETE FROM structure_nodes WHERE doc_id = %s", (doc_id,))
cur.execute("DELETE FROM semantic_blocks WHERE doc_id = %s", (doc_id,))
conn.commit()
def get_semantic_blocks(self, doc_id: str) -> list[dict[str, Any]]:
"""Return all semantic blocks for a document ordered by id."""
sql = """
SELECT semantic_id, block_type, page_start, page_end,
section_path, section_level, section_title, source_ids, text
FROM semantic_blocks
WHERE doc_id = %s
ORDER BY id
"""
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, (doc_id,))
rows = cur.fetchall()
results = []
for row in rows:
item = dict(row)
for key in ("section_path", "source_ids"):
if isinstance(item[key], str):
item[key] = json.loads(item[key])
results.append(item)
return results
def get_structure_nodes(self, doc_id: str) -> list[dict[str, Any]]:
"""Return all structure nodes for a document ordered by idx."""
sql = """
SELECT unique_id, page, idx, level, title, type, sub_type
FROM structure_nodes
WHERE doc_id = %s
ORDER BY idx
"""
with self._conn() as conn:
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(sql, (doc_id,))
rows = cur.fetchall()
return [dict(r) for r in rows]

View File

@@ -0,0 +1,84 @@
"""Implement cross-encoder reranking via an OpenAI-compatible reranker API."""
from __future__ import annotations
import time
import requests
from loguru import logger
from app.config.settings import settings
from app.domain.retrieval import Reranker, RetrievedChunk
class OpenAICompatibleReranker(Reranker):
"""Call a TEI / Cohere-style reranker endpoint to re-score retrieved chunks."""
def __init__(
self,
base_url: str | None = None,
model: str | None = None,
api_key: str | None = None,
timeout: int = 30,
) -> None:
self._base_url = (base_url or settings.reranker_base_url).rstrip("/")
self._model = model or settings.reranker_model
self._api_key = api_key or settings.reranker_api_key
self._timeout = timeout
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
"""Return up to top_k chunks re-sorted by cross-encoder score."""
if not chunks:
return []
texts = [chunk.content for chunk in chunks]
start = time.time()
try:
scores = self._call_reranker(query, texts)
except Exception as exc:
logger.warning("Reranker call failed ({}), falling back to original order: {}", type(exc).__name__, exc)
return chunks[:top_k]
elapsed_ms = int((time.time() - start) * 1000)
logger.debug("Reranker scored {} chunks in {}ms", len(chunks), elapsed_ms)
ranked = sorted(
[(score, chunk) for score, chunk in zip(scores, chunks)],
key=lambda x: x[0],
reverse=True,
)
result = []
for score, chunk in ranked[:top_k]:
chunk.score = float(score)
result.append(chunk)
return result
def _call_reranker(self, query: str, texts: list[str]) -> list[float]:
"""Call the reranker API and return a score per text."""
headers = {"Content-Type": "application/json"}
if self._api_key:
headers["Authorization"] = f"Bearer {self._api_key}"
# Try TEI format first: POST /rerank
payload = {"query": query, "texts": texts, "raw_scores": False, "return_text": False}
url = f"{self._base_url}/rerank"
resp = requests.post(url, json=payload, headers=headers, timeout=self._timeout)
if resp.status_code == 404:
# Fall back to Cohere / OpenAI-style: POST /v1/rerank
payload_v1 = {"model": self._model, "query": query, "documents": texts}
url = f"{self._base_url}/v1/rerank"
resp = requests.post(url, json=payload_v1, headers=headers, timeout=self._timeout)
resp.raise_for_status()
data = resp.json()
# TEI response: list of {"index": N, "score": F}
if isinstance(data, list):
ordered = sorted(data, key=lambda x: x["index"])
return [float(item["score"]) for item in ordered]
# Cohere/OpenAI response: {"results": [{"index": N, "relevance_score": F}]}
results = data.get("results", [])
ordered = sorted(results, key=lambda x: x["index"])
return [float(item.get("relevance_score", item.get("score", 0))) for item in ordered]

View File

@@ -100,14 +100,42 @@ class MilvusVectorIndex(VectorIndex):
result = self.collection.delete(f'doc_id == "{doc_id}"') result = self.collection.delete(f'doc_id == "{doc_id}"')
return len(result.primary_keys) return len(result.primary_keys)
def _parse_filters(self, filters: str | None) -> str | None:
"""Parse filter string into Milvus expression."""
if not filters or not filters.strip():
return None
filters = filters.strip()
# Check if already a Milvus expression (contains operators)
if any(op in filters for op in ["==", "!=", "in", "not in", ">", "<", ">=", "<=", "and", "or"]):
return filters
# Parse simple regulation_type filter
# Support: "GB" or "GB,UN-ECE" or "GB, UN-ECE"
types = [t.strip() for t in filters.split(",") if t.strip()]
if not types:
return None
if len(types) == 1:
# Single value: regulation_type == "GB"
return f'regulation_type == "{types[0]}"'
else:
# Multiple values: regulation_type in ["GB", "UN-ECE"]
quoted_types = [f'"{t}"' for t in types]
return f'regulation_type in [{", ".join(quoted_types)}]'
def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]: def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
"""Handle search for the Milvus Vector Index instance.""" """Handle search for the Milvus Vector Index instance."""
milvus_expr = self._parse_filters(filters)
results = self.collection.search( results = self.collection.search(
data=[query_vector], data=[query_vector],
anns_field="embedding", anns_field="embedding",
param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}}, param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}},
limit=top_k, limit=top_k,
filter=filters, expr=milvus_expr,
output_fields=[ output_fields=[
"doc_id", "doc_id",
"doc_name", "doc_name",
@@ -145,6 +173,49 @@ class MilvusVectorIndex(VectorIndex):
) )
return payload return payload
def count_by_document(self) -> dict[str, int]:
"""Return doc_id -> chunk count from Milvus."""
try:
rows = self.collection.query(expr="doc_id != \"\"", output_fields=["doc_id"])
except Exception:
return {}
counts: dict[str, int] = {}
for row in rows:
doc_id = row.get("doc_id", "")
if doc_id:
counts[doc_id] = counts.get(doc_id, 0) + 1
return counts
def list_document_metadata(self) -> list[dict]:
"""Return one metadata row per document from Milvus (single query, no embeddings)."""
try:
rows = self.collection.query(
expr="doc_id != \"\"",
output_fields=["doc_id", "doc_name", "regulation_type", "version"],
)
except Exception:
return []
seen: dict[str, dict] = {}
counts: dict[str, int] = {}
for row in rows:
doc_id = row.get("doc_id", "")
if not doc_id:
continue
counts[doc_id] = counts.get(doc_id, 0) + 1
if doc_id not in seen:
seen[doc_id] = {
"doc_id": doc_id,
"doc_name": row.get("doc_name", ""),
"regulation_type": row.get("regulation_type", ""),
"version": row.get("version", ""),
}
return [
{**meta, "chunk_count": counts[meta["doc_id"]]}
for meta in seen.values()
]
def health(self) -> dict: def health(self) -> dict:
"""Handle health for the Milvus Vector Index instance.""" """Handle health for the Milvus Vector Index instance."""
return { return {

View File

@@ -74,6 +74,7 @@ class ComplianceResult(BaseModel):
class ComplianceChatRequest(BaseModel): class ComplianceChatRequest(BaseModel):
"""Define the Compliance Chat Request API model.""" """Define the Compliance Chat Request API model."""
query: str query: str
segment_context: Optional[str] = None
class AnalyzeResponse(BaseModel): class AnalyzeResponse(BaseModel):

View File

@@ -10,6 +10,8 @@ class RagChatRequest(BaseModel):
"""Define the Rag Chat Request API model.""" """Define the Rag Chat Request API model."""
query: str query: str
top_k: int = 5 top_k: int = 5
session_id: Optional[str] = None
filters: Optional[str] = None
class RetrievedDoc(BaseModel): class RetrievedDoc(BaseModel):

View File

@@ -95,6 +95,57 @@ class DeepSeekClient(BaseLLMClient):
error=str(e) error=str(e)
) )
def stream_chat(
self,
messages: List[Dict[str, str]],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
):
"""Stream chat for the Deep Seek Client instance."""
try:
payload = {
"model": self.config.model,
"messages": messages,
"max_tokens": max_tokens or self.config.max_tokens,
"temperature": temperature or self.config.temperature,
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": True
}
with self._client.stream("POST", "/chat/completions", json=payload) as response:
response.raise_for_status()
for line in response.iter_lines():
if not line or line.startswith(b":"):
continue
line_str = line.decode("utf-8").strip()
if not line_str.startswith("data: "):
continue
data_str = line_str[6:]
if data_str == "[DONE]":
break
try:
import json
data = json.loads(data_str)
choices = data.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except httpx.HTTPStatusError as e:
logger.error(f"DeepSeek Stream API错误: {e.response.status_code}")
yield ""
except Exception as e:
logger.error(f"DeepSeek Stream调用失败: {e}")
yield ""
def get_available_models(self) -> List[str]: def get_available_models(self) -> List[str]:
"""Return available models for the Deep Seek Client instance.""" """Return available models for the Deep Seek Client instance."""
return self.SUPPORTED_MODELS return self.SUPPORTED_MODELS

View File

@@ -17,18 +17,31 @@ from app.infrastructure.parser.vector_chunk_builder import AliyunVectorChunkBuil
from app.infrastructure.session.in_memory_conversation_store import InMemoryConversationStore from app.infrastructure.session.in_memory_conversation_store import InMemoryConversationStore
from app.infrastructure.storage.json_document_repository import JsonDocumentRepository from app.infrastructure.storage.json_document_repository import JsonDocumentRepository
from app.infrastructure.storage.minio_binary_store import MinioDocumentBinaryStore from app.infrastructure.storage.minio_binary_store import MinioDocumentBinaryStore
from app.infrastructure.storage.postgres_document_repository import PostgresDocumentRepository
from app.infrastructure.storage.postgres_parse_artifact_store import PostgresParseArtifactStore
from app.infrastructure.vectorstore.dense_retriever import DenseRetriever from app.infrastructure.vectorstore.dense_retriever import DenseRetriever
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker
# Keep shared wiring centralized so dependency construction remains consistent. # Keep shared wiring centralized so dependency construction remains consistent.
@lru_cache @lru_cache
def get_document_repository() -> JsonDocumentRepository: def get_document_repository():
"""Return document repository.""" """Return document repository (json or postgres, controlled by settings)."""
if settings.document_repository_backend == "postgres":
return PostgresDocumentRepository()
return JsonDocumentRepository(settings.document_metadata_path) return JsonDocumentRepository(settings.document_metadata_path)
@lru_cache
def get_parse_artifact_store():
"""Return parse artifact store, or None when postgres backend is not enabled."""
if settings.document_repository_backend == "postgres":
return PostgresParseArtifactStore()
return None
@lru_cache @lru_cache
def get_binary_store() -> MinioDocumentBinaryStore: def get_binary_store() -> MinioDocumentBinaryStore:
"""Return binary store.""" """Return binary store."""
@@ -66,6 +79,14 @@ def get_vector_index() -> MilvusVectorIndex:
return MilvusVectorIndex() return MilvusVectorIndex()
@lru_cache
def get_reranker():
"""Return reranker if enabled, else None."""
if settings.reranker_enabled and settings.reranker_base_url:
return OpenAICompatibleReranker()
return None
@lru_cache @lru_cache
def get_retrieval_service() -> KnowledgeRetrievalService: def get_retrieval_service() -> KnowledgeRetrievalService:
"""Return retrieval service.""" """Return retrieval service."""
@@ -73,7 +94,11 @@ def get_retrieval_service() -> KnowledgeRetrievalService:
embedding_provider=get_embedding_provider(), embedding_provider=get_embedding_provider(),
vector_index=get_vector_index(), vector_index=get_vector_index(),
) )
return KnowledgeRetrievalService(retriever=retriever) return KnowledgeRetrievalService(
retriever=retriever,
reranker=get_reranker(),
reranker_top_k=settings.reranker_top_k,
)
@lru_cache @lru_cache
@@ -86,6 +111,7 @@ def get_document_command_service() -> DocumentCommandService:
chunk_builder=get_chunk_builder(), chunk_builder=get_chunk_builder(),
embedding_provider=get_embedding_provider(), embedding_provider=get_embedding_provider(),
vector_index=get_vector_index(), vector_index=get_vector_index(),
parse_artifact_store=get_parse_artifact_store(),
) )
@@ -95,6 +121,7 @@ def get_document_query_service() -> DocumentQueryService:
return DocumentQueryService( return DocumentQueryService(
document_repository=get_document_repository(), document_repository=get_document_repository(),
binary_store=get_binary_store(), binary_store=get_binary_store(),
vector_index=get_vector_index(),
) )

View File

@@ -13,6 +13,7 @@ tenacity>=8.2.0
pymilvus>=2.4.0 pymilvus>=2.4.0
minio>=7.1.0 minio>=7.1.0
psycopg2-binary>=2.9.0
pymupdf>=1.24.0 pymupdf>=1.24.0
python-docx>=1.1.0 python-docx>=1.1.0

118
dev.sh
View File

@@ -12,7 +12,8 @@ API_PID_FILE="$LOG_DIR/api.pid"
FRONTEND_PID_FILE="$LOG_DIR/frontend.pid" FRONTEND_PID_FILE="$LOG_DIR/frontend.pid"
API_LOG_FILE="$LOG_DIR/api.log" API_LOG_FILE="$LOG_DIR/api.log"
FRONTEND_LOG_FILE="$LOG_DIR/frontend.log" FRONTEND_LOG_FILE="$LOG_DIR/frontend.log"
DOCKER_CONTAINERS="milvus minio redis postgres" DISPLAY_HOST="localhost"
SERVICE_HOST="6.86.80.8"
RED='\033[0;31m' RED='\033[0;31m'
GREEN='\033[0;32m' GREEN='\033[0;32m'
@@ -54,6 +55,51 @@ ensure_log_dir() {
mkdir -p "$LOG_DIR" mkdir -p "$LOG_DIR"
} }
check_tcp_connectivity() {
local host="$1"
local port="$2"
if command -v nc > /dev/null 2>&1; then
nc -z -w 3 "$host" "$port" > /dev/null 2>&1
return
fi
require_python_bootstrap
"$PYTHON_BOOTSTRAP" - <<PY > /dev/null 2>&1
import socket
import sys
try:
with socket.create_connection(("${host}", ${port}), timeout=3):
pass
except Exception:
sys.exit(1)
sys.exit(0)
PY
}
check_foundation_services() {
local name
local port
for service in \
"Milvus:19530" \
"MinIO API:9000" \
"MinIO Console:9001" \
"Redis:6379" \
"PostgreSQL:5432"
do
name="${service%%:*}"
port="${service##*:}"
if check_tcp_connectivity "$SERVICE_HOST" "$port"; then
success "${name}: ${SERVICE_HOST}:${port} 可连通"
else
warn "${name}: ${SERVICE_HOST}:${port} 不可连通"
fi
done
}
print_header() { print_header() {
echo "" echo ""
echo -e "${CYAN}========================================${NC}" echo -e "${CYAN}========================================${NC}"
@@ -197,21 +243,8 @@ run_setup() {
success "前端依赖安装完成" success "前端依赖安装完成"
echo "" echo ""
info "[4/4] 检查 Docker 基础服务" info "[4/4] 检查 6.86.80.8 基础服务连通性"
if command -v docker > /dev/null 2>&1; then check_foundation_services
local container
for container in $DOCKER_CONTAINERS; do
if docker ps --format '{{.Names}}' | grep -q "^${container}$"; then
success "${container}: 运行中"
elif docker ps -a --format '{{.Names}}' | grep -q "^${container}$"; then
warn "${container}: 已创建但未运行"
else
warn "${container}: 未找到容器"
fi
done
else
warn "未检测到 Docker已跳过容器检查"
fi
echo "" echo ""
success "环境初始化完成" success "环境初始化完成"
@@ -223,7 +256,7 @@ run_setup() {
api_health_ok() { api_health_ok() {
if command -v curl > /dev/null 2>&1; then if command -v curl > /dev/null 2>&1; then
curl -fsS "http://localhost:$API_PORT/health" > /dev/null 2>&1 curl -fsS "http://${DISPLAY_HOST}:$API_PORT/health" > /dev/null 2>&1
return return
fi fi
@@ -233,7 +266,7 @@ import sys
from urllib.request import urlopen from urllib.request import urlopen
try: try:
with urlopen("http://localhost:${API_PORT}/health", timeout=3) as response: with urlopen("http://${DISPLAY_HOST}:${API_PORT}/health", timeout=3) as response:
body = response.read().decode("utf-8", errors="ignore") body = response.read().decode("utf-8", errors="ignore")
sys.exit(0 if "healthy" in body.lower() else 1) sys.exit(0 if "healthy" in body.lower() else 1)
except Exception: except Exception:
@@ -260,9 +293,9 @@ start_api() {
if [ "$mode" = "foreground" ]; then if [ "$mode" = "foreground" ]; then
print_header "AI+合规智能中枢 - 启动 API" print_header "AI+合规智能中枢 - 启动 API"
echo "运行模式: 前台调试(带 --reload" echo "运行模式: 前台调试(带 --reload"
echo "服务地址: http://localhost:$API_PORT" echo "服务地址: http://${DISPLAY_HOST}:$API_PORT"
echo "文档地址: http://localhost:$API_PORT/docs" echo "文档地址: http://${DISPLAY_HOST}:$API_PORT/docs"
echo "健康检查: http://localhost:$API_PORT/health" echo "健康检查: http://${DISPLAY_HOST}:$API_PORT/health"
echo "" echo ""
exec "$VENV_PYTHON" -m uvicorn app.main:app --host "$API_HOST" --port "$API_PORT" --reload exec "$VENV_PYTHON" -m uvicorn app.main:app --host "$API_HOST" --port "$API_PORT" --reload
fi fi
@@ -274,8 +307,8 @@ start_api() {
if is_pid_running "$pid"; then if is_pid_running "$pid"; then
success "API 启动成功 (PID: $pid)" success "API 启动成功 (PID: $pid)"
echo " 地址: http://localhost:$API_PORT" echo " 地址: http://${DISPLAY_HOST}:$API_PORT"
echo " 文档: http://localhost:$API_PORT/docs" echo " 文档: http://${DISPLAY_HOST}:$API_PORT/docs"
echo " 日志: $API_LOG_FILE" echo " 日志: $API_LOG_FILE"
else else
rm -f "$API_PID_FILE" rm -f "$API_PID_FILE"
@@ -316,7 +349,7 @@ start_frontend() {
if is_pid_running "$pid"; then if is_pid_running "$pid"; then
success "前端启动成功 (PID: $pid)" success "前端启动成功 (PID: $pid)"
echo " 地址: http://localhost:$FRONTEND_PORT" echo " 地址: http://${DISPLAY_HOST}:$FRONTEND_PORT"
echo " 模式: $mode" echo " 模式: $mode"
echo " 日志: $FRONTEND_LOG_FILE" echo " 日志: $FRONTEND_LOG_FILE"
else else
@@ -407,6 +440,8 @@ run_status() {
local frontend_running=false local frontend_running=false
local pid local pid
local port_listener local port_listener
local service_name
local service_port
echo -e "${YELLOW}API 服务:${NC}" echo -e "${YELLOW}API 服务:${NC}"
pid="$(read_pid "$API_PID_FILE")" pid="$(read_pid "$API_PID_FILE")"
@@ -432,8 +467,8 @@ run_status() {
warn " 健康检查: 未通过" warn " 健康检查: 未通过"
fi fi
fi fi
echo " 地址: http://localhost:$API_PORT" echo " 地址: http://${DISPLAY_HOST}:$API_PORT"
echo " 文档: http://localhost:${API_PORT}/docs" echo " 文档: http://${DISPLAY_HOST}:${API_PORT}/docs"
echo "" echo ""
echo -e "${YELLOW}前端服务:${NC}" echo -e "${YELLOW}前端服务:${NC}"
@@ -453,24 +488,25 @@ run_status() {
fi fi
fi fi
echo " 模式: $FRONTEND_MODE" echo " 模式: $FRONTEND_MODE"
echo " 地址: http://localhost:$FRONTEND_PORT" echo " 地址: http://${DISPLAY_HOST}:$FRONTEND_PORT"
echo "" echo ""
echo -e "${YELLOW}Docker 服务:${NC}" echo -e "${YELLOW}基础服务连通性:${NC}"
if command -v docker > /dev/null 2>&1; then for service in \
local container "Milvus:19530" \
for container in $DOCKER_CONTAINERS; do "MinIO API:9000" \
if docker ps --format '{{.Names}}' | grep -q "^${container}$"; then "MinIO Console:9001" \
success " ${container}: 运行中" "Redis:6379" \
elif docker ps -a --format '{{.Names}}' | grep -q "^${container}$"; then "PostgreSQL:5432"
warn " ${container}: 已停止" do
service_name="${service%%:*}"
service_port="${service##*:}"
if check_tcp_connectivity "$SERVICE_HOST" "$service_port"; then
success " ${service_name}: ${SERVICE_HOST}:${service_port} 可连通"
else else
warn " ${container}: 未创建" warn " ${service_name}: ${SERVICE_HOST}:${service_port} 不可连通"
fi fi
done done
else
warn " Docker 未安装,已跳过"
fi
echo "" echo ""
if [ "$api_running" = true ] && [ "$frontend_running" = true ]; then if [ "$api_running" = true ] && [ "$frontend_running" = true ]; then
@@ -526,7 +562,7 @@ AI+合规智能中枢统一脚本
setup setup
进行一次性的本地初始化。 进行一次性的本地初始化。
包含 Python 版本检查、.venv 虚拟环境创建、后端依赖安装、前端 npm install、 包含 Python 版本检查、.venv 虚拟环境创建、后端依赖安装、前端 npm install、
以及 Docker 基础容器状态检查。 以及 6.86.80.8 基础服务端口连通性检查。
start start
启动服务。默认行为等同于 ./dev.sh start all。 启动服务。默认行为等同于 ./dev.sh start all。
@@ -548,7 +584,7 @@ AI+合规智能中枢统一脚本
restart frontend --mode static 可直接切换前端启动模式。 restart frontend --mode static 可直接切换前端启动模式。
status status
查看 API、前端、Docker 基础容器的状态。 查看 API、前端、6.86.80.8 基础服务的状态。
API 状态包含健康检查;前端状态包含当前模式和访问地址。 API 状态包含健康检查;前端状态包含当前模式和访问地址。
logs logs

View File

@@ -58,7 +58,7 @@ services:
retries: 5 retries: 5
restart: unless-stopped restart: unless-stopped
# PostgreSQL数据库 (可选) # PostgreSQL数据库 (可选,启用 DOCUMENT_REPOSITORY_BACKEND=postgres 时使用)
postgres: postgres:
image: postgres:15-alpine image: postgres:15-alpine
container_name: postgres container_name: postgres
@@ -71,7 +71,7 @@ services:
volumes: volumes:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
healthcheck: healthcheck:
test: ["CMD-SHELL", "pg_isready -U compliance"] test: ["CMD-SHELL", "pg_isready -U compliance -d compliance_db"]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 5 retries: 5

View File

@@ -0,0 +1,145 @@
# 法规对话模块优化设计文档
**日期**: 2026-05-20
**状态**: 已批准,实施中
---
## 背景
当前法规对话模块存在以下问题:
1. `compliance.py` `/compliance/chat/{segment_id}` 返回硬编码 Mock 数据
2. `rag.py` `/rag/chat` 返回硬编码 Mock 数据(前端实际调用 `/agent/chat/stream`,此路由可统一)
3. 前端快速问题列表硬编码在 `rag.ts`,未调用后端
4. 仅有 Dense 向量检索COSINE无 BM25 混合检索与精排
5. LLM 输出的 `[1][2]` 引用编号未在前端内联高亮
6. 会话存储仅为内存100条上限30分钟过期
---
## 方案分层优先4个阶段
### Phase 1 — 接入真实服务Week 1
**目标**:消灭所有 Mock让系统真正可用。
#### 后端变更
**`backend/app/schemas/compliance.py`**
- `ComplianceChatRequest` 新增 `segment_context: str | None = None`
**`backend/app/api/routes/compliance.py`**
- 移除 `get_mock_compliance_chat_response` 导入
- `/compliance/chat/{segment_id}` 接入 `get_agent_conversation_service().stream_chat()`
-`segment_context` 拼接到 query 前缀作为上下文
- 将 agent `content` 事件翻译为 `{"type":"chunk","text":"..."}` 格式,保持前端兼容
**`backend/app/api/routes/rag.py`**
- 移除 `get_mock_retrieval``get_mock_rag_answer` 导入
- `/rag/chat` 接入 `get_agent_conversation_service().stream_chat()`
- 翻译 agent 事件为 rag 格式(`retrieved`/`chunk`/`done`
**`backend/app/schemas/rag.py`**
- `RagChatRequest` 新增 `session_id: str | None = None``filters: str | None = None`
#### 前端变更
**`frontend/src/api/rag.ts`**
- `getQuickQuestions()` 改为真实调用 `GET /api/v1/rag/quick-questions`,失败时降级为本地数组
**`frontend/src/api/compliance.ts`**
- `complianceChat()` 新增第三个参数 `segmentContext: string | undefined`,传入 request body
**`frontend/src/pages/Compliance/CompliancePage.tsx`**
- `sendChatMessage()` 中构建 `segmentContext`intent + content 摘要 + 法规名)
- 传给 `complianceChat()`
---
### Phase 2 — 混合检索 + RerankingWeek 2-3
**目标**提升召回质量BM25 + dense RRF融合+ 精排。
#### 2a — Cross-Encoder Reranking优先无需 schema 变更)
- 新增端口 `backend/app/domain/retrieval/ports.py`: `Reranker` ABC
- 新增适配器 `backend/app/infrastructure/vectorstore/cross_encoder_reranker.py`
- 调用 OpenAI-compatible reranker APIBAAI/bge-reranker-v2-m3
- 修改 `KnowledgeRetrievalService`:先 retrieve top-20再 rerank 到 top-5
- 新增 settings: `reranker_enabled: bool = False``reranker_model: str``reranker_top_k: int = 5`
#### 2b — Milvus Sparse BM25需 schema 迁移)
- Milvus collection 新增 `sparse_embedding SPARSE_FLOAT_VECTOR` 字段
- 新增端口 `SparseEmbeddingProvider`sparse embed 接口)
- 适配器优先使用 BGE-M3 API同时输出 dense + sparse不可用时降级为 TF-IDF keyword weights
- `MilvusVectorIndex.upsert()` 同时写入 sparse 向量
- `MilvusVectorIndex.search()` 改为 hybrid search`WeightedRanker``RRFRanker`
- 提供一次性迁移脚本dump 所有 chunks → recreate collection → re-embed → re-insert
---
### Phase 3 — 引用溯源 + 筛选 UIWeek 4
**目标**:答案文本中 `[1][2]` 可点击跳转原文片段;法规类型/版本可筛选。
#### 引用内联解析
- 新增 React 组件 `CitedAnswer`:接受 `text` + `sources[]`
- 用正则 `/\[(\d+)\]/g` 拆分文本,将 `[N]` 渲染为可点击 `<cite>` 元素
- 点击高亮/滚动到来源面板对应条目
- `RagChatPage.tsx` 将 assistant 消息改用 `CitedAnswer` 组件渲染
- 来源面板各条目加 `id="source-N"` 属性
#### 法规筛选 UI
- `RagChatPage.tsx` 在输入框上方加筛选栏:`regulation_type` 下拉 + `version` 输入
- 筛选值作为 `filters` 参数传到 `/agent/chat/stream`
- `MilvusVectorIndex.search()` 解析 `filters` 字符串为 Milvus expr`regulation_type == "GB"`
#### 快速问题动态化
- 后端 `/rag/quick-questions` 改为从 settings 或配置文件加载,不硬编码
- 前端已在 Phase 1 中改为调用后端
---
### Phase 4 — 会话持久化 + 上下文压缩Week 5
**目标**:长对话不丢失;上下文过长时智能压缩。
#### PostgreSQL 会话存储
- 新增 `backend/app/infrastructure/session/postgres_conversation_store.py`
- 实现 `ConversationStore` port
- 复用现有 PostgreSQL 连接
- DDL: `conversation_sessions` + `conversation_messages` 两张表
- 新增 settings: `conversation_store_backend: str = "memory"`
- `bootstrap.py` 按 settings 切换实现
#### 上下文压缩
-`AgentConversationService` 中,当对话轮数 > N 时,将早期消息用 LLM 摘要
- 摘要替换为一条 `system` 消息 `"对话摘要:..."`
- 新增 `AnswerGenerator.summarize(messages) -> str` 方法
---
## 关键设计决策
| 决策 | 选择 | 原因 |
|------|------|------|
| BM25 实现 | Milvus sparse vectors | 不引入新服务Milvus 2.4+ 原生支持 |
| Reranking | API-based cross-encoder | 复用现有 OpenAI-compatible 接口 |
| 会话持久化 | PostgreSQL | 复用现有 PostgreSQL 基础设施 |
| 引用解析 | 前端纯客户端 | LLM 已输出 [N] 编号,无需后端改动 |
| segment_context | 前端构建传后端 | 合规结果存在前端状态,无需后端持久化 |
---
## 验证标准
- Phase 1: 合规对话面板返回真实法规检索结果,不再固定响应
- Phase 2: 相同 query 在 hybrid 模式下比 dense-only 召回更相关结果
- Phase 3: 点击答案中的 `[1]` 能跳转到来源面板第一条
- Phase 4: 刷新页面后历史对话仍可恢复

View File

@@ -31,11 +31,18 @@ export async function getComplianceResult(
export function complianceChat( export function complianceChat(
segmentId: number, segmentId: number,
query: string, query: string,
segmentContext: string | undefined,
onMessage: (data: SSEMessage) => void, onMessage: (data: SSEMessage) => void,
onError?: (error: Error) => void, onError?: (error: Error) => void,
onComplete?: () => void onComplete?: () => void
): void { ): void {
void streamSSE<SSEMessage>(`/compliance/chat/${segmentId}`, { query }, onMessage, onError, onComplete); void streamSSE<SSEMessage>(
`/compliance/chat/${segmentId}`,
{ query, segment_context: segmentContext },
onMessage,
onError,
onComplete,
);
} }
export type { ComplianceResult, SSEMessage }; export type { ComplianceResult, SSEMessage };

View File

@@ -6,7 +6,11 @@ interface BackendDocumentItem {
doc_name: string; doc_name: string;
status: string; status: string;
chunk_count: number; chunk_count: number;
size_bytes?: number;
summary?: string;
updated_at?: string; updated_at?: string;
regulation_type?: string;
version?: string;
} }
interface BackendDocumentListResponse { interface BackendDocumentListResponse {
@@ -44,6 +48,7 @@ export interface RegulationSearchResponse {
} }
function mapDoc(item: BackendDocumentItem): DocInfo { function mapDoc(item: BackendDocumentItem): DocInfo {
const sizeMB = item.size_bytes ? (item.size_bytes / (1024 * 1024)).toFixed(1) + 'MB' : '';
return { return {
id: item.doc_id, id: item.doc_id,
name: item.doc_name, name: item.doc_name,
@@ -51,14 +56,23 @@ function mapDoc(item: BackendDocumentItem): DocInfo {
status: item.status, status: item.status,
updated_at: item.updated_at, updated_at: item.updated_at,
download_url: `${API_BASE_URL}/documents/download/${item.doc_id}`, download_url: `${API_BASE_URL}/documents/download/${item.doc_id}`,
size_text: sizeMB,
summary: item.summary,
regulation_type: item.regulation_type,
version: item.version,
}; };
} }
export async function uploadDocument(file: File): Promise<DocUploadResponse> { export async function uploadDocument(
file: File,
opts?: { regulationType?: string; version?: string }
): Promise<DocUploadResponse> {
const formData = new FormData(); const formData = new FormData();
formData.append('file', file); formData.append('file', file);
formData.append('doc_name', file.name); formData.append('doc_name', file.name);
formData.append('generate_summary', 'true'); formData.append('generate_summary', 'true');
if (opts?.regulationType) formData.append('regulation_type', opts.regulationType);
if (opts?.version) formData.append('version', opts.version);
const response = await fetch(`${API_BASE_URL}/documents/upload`, { const response = await fetch(`${API_BASE_URL}/documents/upload`, {
method: 'POST', method: 'POST',
@@ -92,6 +106,29 @@ export async function getDocumentList(): Promise<DocListResponse> {
}; };
} }
export async function getDocumentStatus(docId: string): Promise<DocUploadResponse> {
const response = await fetch(`${API_BASE_URL}/documents/status/${docId}`);
if (!response.ok) {
throw new Error(`Status check failed: ${response.status}`);
}
return response.json() as Promise<DocUploadResponse>;
}
export async function deleteDocument(docId: string): Promise<void> {
const response = await fetch(`${API_BASE_URL}/documents/${docId}`, { method: 'DELETE' });
if (!response.ok) {
throw new Error(`Delete failed: ${response.status}`);
}
}
export async function retryDocument(docId: string): Promise<DocUploadResponse> {
const response = await fetch(`${API_BASE_URL}/documents/${docId}/retry`, { method: 'POST' });
if (!response.ok) {
throw new Error(`Retry failed: ${response.status}`);
}
return response.json() as Promise<DocUploadResponse>;
}
export async function searchRegulations(query: string, topK: number = 8): Promise<RegulationSearchResponse> { export async function searchRegulations(query: string, topK: number = 8): Promise<RegulationSearchResponse> {
const response = await fetch(`${API_BASE_URL}/knowledge/retrieval`, { const response = await fetch(`${API_BASE_URL}/knowledge/retrieval`, {
method: 'POST', method: 'POST',

View File

@@ -136,6 +136,9 @@ export interface DocInfo {
updated_at?: string; updated_at?: string;
download_url?: string; download_url?: string;
size_text?: string; size_text?: string;
summary?: string;
regulation_type?: string;
version?: string;
} }
export interface DocListResponse { export interface DocListResponse {
@@ -149,6 +152,8 @@ export interface DocUploadResponse {
message?: string; message?: string;
num_chunks?: number; num_chunks?: number;
summary?: string; summary?: string;
regulation_type?: string;
version?: string;
} }
export interface QuickQuestion { export interface QuickQuestion {

View File

@@ -2,15 +2,21 @@ import type { QuickQuestionsResponse, SSEMessage } from './index';
const AGENT_API_BASE = '/api/v1'; const AGENT_API_BASE = '/api/v1';
export async function getQuickQuestions(): Promise<QuickQuestionsResponse> { const _FALLBACK_QUESTIONS = [
return {
questions: [
{ id: '1', question: '请总结最新入库法规对电池安全的核心要求', category: '法规解读' }, { id: '1', question: '请总结最新入库法规对电池安全的核心要求', category: '法规解读' },
{ id: '2', question: '我上传的制度文档与新能源法规有哪些潜在冲突?', category: '差距分析' }, { id: '2', question: '我上传的制度文档与新能源法规有哪些潜在冲突?', category: '差距分析' },
{ id: '3', question: '请给出法规依据,并按条款列出整改建议', category: '整改建议' }, { id: '3', question: '请给出法规依据,并按条款列出整改建议', category: '整改建议' },
{ id: '4', question: '请解释 UN-ECE 与 GB 标准在网络安全方面的差异', category: '标准对比' }, { id: '4', question: '请解释 UN-ECE 与 GB 标准在网络安全方面的差异', category: '标准对比' },
], ];
};
export async function getQuickQuestions(): Promise<QuickQuestionsResponse> {
try {
const response = await fetch(`${AGENT_API_BASE}/rag/quick-questions`);
if (!response.ok) throw new Error(`status ${response.status}`);
return response.json() as Promise<QuickQuestionsResponse>;
} catch {
return { questions: _FALLBACK_QUESTIONS };
}
} }
function parseSSEChunk(raw: string, onMessage: (data: SSEMessage) => void) { function parseSSEChunk(raw: string, onMessage: (data: SSEMessage) => void) {
@@ -67,7 +73,8 @@ export async function ragChat(
topK: number = 5, topK: number = 5,
onMessage: (data: SSEMessage) => void, onMessage: (data: SSEMessage) => void,
onError?: (error: Error) => void, onError?: (error: Error) => void,
onComplete?: () => void onComplete?: () => void,
filters?: string
): Promise<void> { ): Promise<void> {
try { try {
const response = await fetch(`${AGENT_API_BASE}/agent/chat/stream`, { const response = await fetch(`${AGENT_API_BASE}/agent/chat/stream`, {
@@ -76,7 +83,7 @@ export async function ragChat(
'Content-Type': 'application/json', 'Content-Type': 'application/json',
Accept: 'text/event-stream', Accept: 'text/event-stream',
}, },
body: JSON.stringify({ query, top_k: topK }), body: JSON.stringify({ query, top_k: topK, ...(filters ? { filters } : {}) }),
}); });
if (!response.ok || !response.body) { if (!response.ok || !response.body) {

View File

@@ -250,11 +250,20 @@ export const CompliancePage: React.FC = () => {
setChatInput(''); setChatInput('');
setChatLoading(true); setChatLoading(true);
const segmentContext = [
`意图:${chunk.intent}`,
`内容:${chunk.content.slice(0, 300)}`,
chunk.regulations.length > 0
? `相关法规:${chunk.regulations.slice(0, 3).map(r => `${r.name}${r.clause ? ' ' + r.clause : ''}(相关性 ${Math.round(r.score * 100)}%`).join('')}`
: '',
].filter(Boolean).join('\n');
let currentResponse = ''; let currentResponse = '';
complianceChat( complianceChat(
activeChunkId, activeChunkId,
chatInput, chatInput,
segmentContext,
(data: unknown) => { (data: unknown) => {
const sseData = data as { type: string; text?: string }; const sseData = data as { type: string; text?: string };
if (sseData.type === 'chunk' && sseData.text) { if (sseData.type === 'chunk' && sseData.text) {

View File

@@ -2,7 +2,7 @@ import React, { useEffect, useRef, useState } from 'react';
import { useTheme } from '../../contexts'; import { useTheme } from '../../contexts';
import { Content } from '../../components/layout/Content'; import { Content } from '../../components/layout/Content';
import { TPattern } from '../../components/common/TPattern'; import { TPattern } from '../../components/common/TPattern';
import { getDocumentList, searchRegulations, uploadDocument, type RegulationSearchItem } from '../../api/docs'; import { getDocumentList, getDocumentStatus, searchRegulations, uploadDocument, deleteDocument, retryDocument, type RegulationSearchItem } from '../../api/docs';
import type { Doc } from '../../types'; import type { Doc } from '../../types';
type PipelineStatus = 'idle' | 'running' | 'completed' | 'error'; type PipelineStatus = 'idle' | 'running' | 'completed' | 'error';
@@ -15,13 +15,13 @@ const PIPELINE_STEPS = [
{ name: 'STORE' }, { name: 'STORE' },
]; ];
const REGULATION_TYPES = ['', '国家标准', '行业标准', '地方标准', '企业标准', '法律法规', '监管规定'];
const STEP_DURATION_MS = 700; const STEP_DURATION_MS = 700;
const INITIAL_SEARCH_QUERY = '新能源汽车电池安全要求'; const INITIAL_SEARCH_QUERY = '新能源汽车电池安全要求';
function wait(ms: number) { function wait(ms: number) {
return new Promise<void>((resolve) => { return new Promise<void>((resolve) => { window.setTimeout(resolve, ms); });
window.setTimeout(resolve, ms);
});
} }
export const DocsPage: React.FC = () => { export const DocsPage: React.FC = () => {
@@ -41,6 +41,13 @@ export const DocsPage: React.FC = () => {
const [searchLoading, setSearchLoading] = useState(false); const [searchLoading, setSearchLoading] = useState(false);
const [searchError, setSearchError] = useState(''); const [searchError, setSearchError] = useState('');
// Upload metadata
const [regulationType, setRegulationType] = useState('');
const [version, setVersion] = useState('');
// Batch queue: files waiting to be uploaded after the current one finishes
const batchQueueRef = useRef<File[]>([]);
async function loadDocuments() { async function loadDocuments() {
setLoading(true); setLoading(true);
try { try {
@@ -54,6 +61,9 @@ export const DocsPage: React.FC = () => {
docId: doc.id, docId: doc.id,
downloadUrl: doc.download_url, downloadUrl: doc.download_url,
updatedAt: doc.updated_at, updatedAt: doc.updated_at,
summary: doc.summary,
regulationType: doc.regulation_type,
version: doc.version,
})); }));
setDocs(apiDocs); setDocs(apiDocs);
} catch (error) { } catch (error) {
@@ -81,62 +91,71 @@ export const DocsPage: React.FC = () => {
} }
useEffect(() => { useEffect(() => {
const timerId = window.setTimeout(() => { const timerId = window.setTimeout(() => { void loadDocuments(); }, 0);
void loadDocuments();
}, 0);
return () => window.clearTimeout(timerId); return () => window.clearTimeout(timerId);
}, []); }, []);
useEffect(() => { useEffect(() => {
const timerId = window.setTimeout(() => { const timerId = window.setTimeout(() => { void runSearch(INITIAL_SEARCH_QUERY); }, 0);
void runSearch(INITIAL_SEARCH_QUERY);
}, 0);
return () => window.clearTimeout(timerId); return () => window.clearTimeout(timerId);
}, []); }, []);
useEffect(() => { useEffect(() => {
return () => { return () => { pipelineRunIdRef.current += 1; };
pipelineRunIdRef.current += 1;
};
}, []); }, []);
useEffect(() => {
const parsingDocs = docs.filter(
(doc) => doc.status === 'parsing' && doc.docId && !doc.docId.startsWith('pending-')
);
if (parsingDocs.length === 0) return;
const timerId = window.setInterval(() => {
parsingDocs.forEach((doc) => {
void getDocumentStatus(doc.docId!).then((res) => {
if (res.status === 'indexed' || res.status === 'failed') {
setDocs((prev) =>
prev.map((d) =>
d.docId === doc.docId
? {
...d,
status: res.status === 'indexed' ? 'indexed' : 'failed',
chunks: res.num_chunks ?? d.chunks,
summary: res.summary ?? d.summary,
regulationType: res.regulation_type ?? d.regulationType,
version: res.version ?? d.version,
}
: d
)
);
}
}).catch(() => {});
});
}, 5000);
return () => window.clearInterval(timerId);
}, [docs]);
const runPipelineFlow = async (runId: number, uploadPromise: Promise<Awaited<ReturnType<typeof uploadDocument>>>) => { const runPipelineFlow = async (runId: number, uploadPromise: Promise<Awaited<ReturnType<typeof uploadDocument>>>) => {
const guardedSetActiveStep = (step: number) => { const guard = (fn: () => void) => { if (pipelineRunIdRef.current !== runId) return false; fn(); return true; };
if (pipelineRunIdRef.current !== runId) return false;
setActiveStep(step);
return true;
};
const guardedCompleteStep = (step: number) => { for (let i = 0; i < PIPELINE_STEPS.length - 1; i++) {
if (pipelineRunIdRef.current !== runId) return false; if (!guard(() => setActiveStep(i))) return;
setCompletedSteps((prev) => (prev.includes(step) ? prev : [...prev, step]));
return true;
};
for (let index = 0; index < PIPELINE_STEPS.length - 1; index += 1) {
if (!guardedSetActiveStep(index)) return;
await wait(STEP_DURATION_MS); await wait(STEP_DURATION_MS);
if (!guardedCompleteStep(index)) return; if (!guard(() => setCompletedSteps((p) => p.includes(i) ? p : [...p, i]))) return;
} }
if (!guardedSetActiveStep(PIPELINE_STEPS.length - 1)) return; if (!guard(() => setActiveStep(PIPELINE_STEPS.length - 1))) return;
await uploadPromise; await uploadPromise;
if (!guardedCompleteStep(PIPELINE_STEPS.length - 1)) return; if (!guard(() => setCompletedSteps((p) => { const last = PIPELINE_STEPS.length - 1; return p.includes(last) ? p : [...p, last]; }))) return;
await wait(240); await wait(240);
if (pipelineRunIdRef.current !== runId) return; if (pipelineRunIdRef.current !== runId) return;
setActiveStep(-1); setActiveStep(-1);
setPipelineStatus('completed'); setPipelineStatus('completed');
}; };
const handleFileSelect = async (event: React.ChangeEvent<HTMLInputElement>) => { const uploadSingleFile = async (file: File, runId: number) => {
const file = event.target.files?.[0];
if (!file || uploading) return;
const runId = pipelineRunIdRef.current + 1;
pipelineRunIdRef.current = runId;
setUploading(true); setUploading(true);
setUploadFileName(file.name); setUploadFileName(file.name);
setActiveStep(-1); setActiveStep(-1);
@@ -152,11 +171,16 @@ export const DocsPage: React.FC = () => {
size: `${fileSizeMB}MB`, size: `${fileSizeMB}MB`,
status: 'parsing', status: 'parsing',
docId: tempDocId, docId: tempDocId,
regulationType: regulationType || undefined,
version: version || undefined,
}; };
setDocs((prev) => [newDoc, ...prev]); setDocs((prev) => [newDoc, ...prev]);
const uploadPromise = uploadDocument(file); const uploadPromise = uploadDocument(file, {
regulationType: regulationType || undefined,
version: version || undefined,
});
void runPipelineFlow(runId, uploadPromise); void runPipelineFlow(runId, uploadPromise);
try { try {
@@ -166,143 +190,123 @@ export const DocsPage: React.FC = () => {
setDocs((prev) => setDocs((prev) =>
prev.map((doc) => prev.map((doc) =>
doc.id === newDoc.id doc.id === newDoc.id
? { ? { ...doc, status: 'indexed', docId: uploadRes.doc_id, chunks: uploadRes.num_chunks || doc.chunks, summary: uploadRes.summary }
...doc,
status: 'indexed',
docId: uploadRes.doc_id,
chunks: uploadRes.num_chunks || doc.chunks,
summary: uploadRes.summary,
}
: doc : doc
) )
); );
setUploading(false);
setUploadFileName('');
void loadDocuments(); void loadDocuments();
} catch (error) { } catch (error) {
console.error('Upload failed:', error); console.error('Upload failed:', error);
if (pipelineRunIdRef.current !== runId) return; if (pipelineRunIdRef.current !== runId) return;
setUploading(false);
setUploadFileName('');
setDocs((prev) => prev.filter((doc) => doc.id !== newDoc.id)); setDocs((prev) => prev.filter((doc) => doc.id !== newDoc.id));
setPipelineStatus('error'); setPipelineStatus('error');
setActiveStep(-1); setActiveStep(-1);
setCompletedSteps([]); setCompletedSteps([]);
} finally { } finally {
if (fileInputRef.current) { setUploading(false);
fileInputRef.current.value = ''; setUploadFileName('');
if (fileInputRef.current) fileInputRef.current.value = '';
// Process next file in batch queue
const next = batchQueueRef.current.shift();
if (next) {
const nextRunId = pipelineRunIdRef.current + 1;
pipelineRunIdRef.current = nextRunId;
void uploadSingleFile(next, nextRunId);
} }
} }
}; };
const triggerFileUpload = () => { const handleFileSelect = async (event: React.ChangeEvent<HTMLInputElement>) => {
if (uploading) return; const files = Array.from(event.target.files ?? []);
fileInputRef.current?.click(); if (files.length === 0 || uploading) return;
const [first, ...rest] = files;
batchQueueRef.current = rest;
const runId = pipelineRunIdRef.current + 1;
pipelineRunIdRef.current = runId;
await uploadSingleFile(first, runId);
}; };
const handleDragOver = (event: React.DragEvent) => { const handleDelete = async (docId: string) => {
event.preventDefault(); try {
event.stopPropagation(); await deleteDocument(docId);
setDocs((prev) => prev.filter((doc) => doc.docId !== docId));
} catch (error) {
console.error('Delete failed:', error);
}
}; };
const handleRetry = async (docId: string) => {
setDocs((prev) => prev.map((doc) => doc.docId === docId ? { ...doc, status: 'parsing' } : doc));
try {
const result = await retryDocument(docId);
setDocs((prev) =>
prev.map((doc) => doc.docId === docId ? { ...doc, status: 'indexed', chunks: result.num_chunks || doc.chunks } : doc)
);
} catch (error) {
console.error('Retry failed:', error);
setDocs((prev) => prev.map((doc) => doc.docId === docId ? { ...doc, status: 'failed' } : doc));
}
};
const triggerFileUpload = () => { if (uploading) return; fileInputRef.current?.click(); };
const handleDragOver = (event: React.DragEvent) => { event.preventDefault(); event.stopPropagation(); };
const handleDrop = (event: React.DragEvent) => { const handleDrop = (event: React.DragEvent) => {
event.preventDefault(); event.preventDefault();
event.stopPropagation(); event.stopPropagation();
const files = Array.from(event.dataTransfer.files);
const files = event.dataTransfer.files;
if (files.length === 0 || uploading) return; if (files.length === 0 || uploading) return;
const droppedFile = files[0]; const [first, ...rest] = files;
if (fileInputRef.current) { batchQueueRef.current = rest;
const dataTransfer = new DataTransfer(); const runId = pipelineRunIdRef.current + 1;
dataTransfer.items.add(droppedFile); pipelineRunIdRef.current = runId;
fileInputRef.current.files = dataTransfer.files; void uploadSingleFile(first, runId);
}
void handleFileSelect({
target: { files: [droppedFile] as unknown as FileList },
} as React.ChangeEvent<HTMLInputElement>);
}; };
const getStepStyle = (index: number) => { const getStepStyle = (index: number) => {
const isActive = activeStep === index; if (activeStep === index) return { background: theme.bgCard, border: `2px solid ${theme.accent}`, boxShadow: `0 0 12px ${theme.accent}40` };
const isCompleted = completedSteps.includes(index); if (completedSteps.includes(index)) return { background: theme.bgCard, border: `1px solid ${theme.green}` };
return { background: theme.bgCard, border: `1px solid ${theme.border}` };
if (isActive) {
return {
background: theme.bgCard,
border: `2px solid ${theme.accent}`,
boxShadow: `0 0 12px ${theme.accent}40`,
};
}
if (isCompleted) {
return {
background: theme.bgCard,
border: `1px solid ${theme.green}`,
};
}
return {
background: theme.bgCard,
border: `1px solid ${theme.border}`,
};
}; };
const getCheckStyle = (index: number) => { const getCheckStyle = (index: number) => {
const isActive = activeStep === index; if (activeStep === index) return { background: theme.gradientAccent, color: '#fff', animation: 'pulse 0.6s infinite' };
const isCompleted = completedSteps.includes(index); if (completedSteps.includes(index)) return { background: theme.green, color: '#fff' };
return { background: theme.bgHover, color: theme.text3 };
if (isActive) {
return {
background: theme.gradientAccent,
color: '#fff',
animation: 'pulse 0.6s infinite',
};
}
if (isCompleted) {
return {
background: theme.green,
color: '#fff',
};
}
return {
background: theme.bgHover,
color: theme.text3,
};
}; };
const getPipelineHint = () => { const getPipelineHint = () => {
if (pipelineStatus === 'running') { if (pipelineStatus === 'running') {
return activeStep >= 0 ? `${PIPELINE_STEPS[activeStep].name} · ${uploadFileName}` : `LOAD · ${uploadFileName}`; const queueLen = batchQueueRef.current.length;
} const suffix = queueLen > 0 ? ` (+${queueLen} 待上传)` : '';
if (pipelineStatus === 'completed') { return `${activeStep >= 0 ? PIPELINE_STEPS[activeStep].name : 'LOAD'} · ${uploadFileName}${suffix}`;
return 'PIPELINE COMPLETE';
}
if (pipelineStatus === 'error') {
return 'PIPELINE FAILED';
} }
if (pipelineStatus === 'completed') return 'PIPELINE COMPLETE';
if (pipelineStatus === 'error') return 'PIPELINE FAILED';
return 'WAITING FOR UPLOAD'; return 'WAITING FOR UPLOAD';
}; };
const inputStyle: React.CSSProperties = {
padding: '8px 12px',
fontSize: 13,
background: theme.bgCard,
border: `1px solid ${theme.border}`,
borderRadius: 8,
color: theme.text,
outline: 'none',
};
return ( return (
<Content> <Content>
<TPattern /> <TPattern />
<section style={{ marginBottom: 56 }}> <section style={{ marginBottom: 56 }}>
<h2 <h2 style={{ fontSize: 14, fontWeight: 600, color: theme.accent, marginBottom: 20, letterSpacing: '1px' }}>
style={{
fontSize: 14,
fontWeight: 600,
color: theme.accent,
marginBottom: 20,
letterSpacing: '1px',
}}
>
UPLOAD UPLOAD
</h2> </h2>
@@ -310,10 +314,30 @@ export const DocsPage: React.FC = () => {
ref={fileInputRef} ref={fileInputRef}
type="file" type="file"
accept=".pdf,.docx,.doc" accept=".pdf,.docx,.doc"
multiple
onChange={handleFileSelect} onChange={handleFileSelect}
style={{ display: 'none' }} style={{ display: 'none' }}
/> />
{/* Metadata row */}
<div style={{ display: 'flex', gap: 12, marginBottom: 12 }}>
<select
value={regulationType}
onChange={(e) => setRegulationType(e.target.value)}
style={{ ...inputStyle, flex: 1 }}
>
{REGULATION_TYPES.map((t) => (
<option key={t} value={t}>{t || '法规类型(可选)'}</option>
))}
</select>
<input
value={version}
onChange={(e) => setVersion(e.target.value)}
placeholder="版本号(可选,如 2024"
style={{ ...inputStyle, flex: 1 }}
/>
</div>
<div <div
onClick={triggerFileUpload} onClick={triggerFileUpload}
onDragOver={handleDragOver} onDragOver={handleDragOver}
@@ -330,30 +354,11 @@ export const DocsPage: React.FC = () => {
opacity: uploading ? 0.78 : 1, opacity: uploading ? 0.78 : 1,
}} }}
> >
<div <div style={{ width: 80, height: 80, borderRadius: 20, background: theme.bgHover, display: 'flex', alignItems: 'center', justifyContent: 'center', margin: '0 auto 20px' }}>
style={{
width: 80,
height: 80,
borderRadius: 20,
background: theme.bgHover,
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
margin: '0 auto 20px',
}}
>
{uploading ? ( {uploading ? (
<div style={{ animation: 'spin 1s linear infinite' }}> <div style={{ animation: 'spin 1s linear infinite' }}>
<svg width="36" height="36" viewBox="0 0 24 24" fill="none"> <svg width="36" height="36" viewBox="0 0 24 24" fill="none">
<circle <circle cx="12" cy="12" r="10" stroke={theme.accent} strokeWidth="2" strokeDasharray="60" strokeDashoffset="20" />
cx="12"
cy="12"
r="10"
stroke={theme.accent}
strokeWidth="2"
strokeDasharray="60"
strokeDashoffset="20"
/>
</svg> </svg>
</div> </div>
) : ( ) : (
@@ -363,26 +368,17 @@ export const DocsPage: React.FC = () => {
</svg> </svg>
)} )}
</div> </div>
<div style={{ fontSize: 16, fontWeight: 500, marginBottom: 8 }}> <div style={{ fontSize: 16, fontWeight: 500, marginBottom: 8 }}>
{uploading ? '正在上传并启动处理链路...' : '拖拽文件或点击上传'} {uploading ? '正在上传并启动处理链路...' : '拖拽文件或点击上传(支持多选)'}
</div> </div>
<div className="mono" style={{ fontSize: 12, color: theme.text3 }}> <div className="mono" style={{ fontSize: 12, color: theme.text3 }}>
{uploading ? uploadFileName : 'PDF · DOCX · DOC · MAX 100MB'} {uploading ? uploadFileName : 'PDF · DOCX · DOC · MAX 100MB · 支持批量'}
</div> </div>
</div> </div>
</section> </section>
<section style={{ marginBottom: 40 }}> <section style={{ marginBottom: 40 }}>
<h2 <h2 style={{ fontSize: 14, fontWeight: 600, color: theme.accent, marginBottom: 20, letterSpacing: '1px' }}>
style={{
fontSize: 14,
fontWeight: 600,
color: theme.accent,
marginBottom: 20,
letterSpacing: '1px',
}}
>
PROCESSING PIPELINE PROCESSING PIPELINE
</h2> </h2>
@@ -390,12 +386,7 @@ export const DocsPage: React.FC = () => {
className="mono" className="mono"
style={{ style={{
fontSize: 11, fontSize: 11,
color: color: pipelineStatus === 'error' ? '#d64545' : pipelineStatus === 'completed' ? theme.green : theme.text3,
pipelineStatus === 'error'
? '#d64545'
: pipelineStatus === 'completed'
? theme.green
: theme.text3,
letterSpacing: '1px', letterSpacing: '1px',
marginBottom: 12, marginBottom: 12,
}} }}
@@ -405,63 +396,24 @@ export const DocsPage: React.FC = () => {
<div style={{ display: 'flex', gap: 16 }}> <div style={{ display: 'flex', gap: 16 }}>
{PIPELINE_STEPS.map((step, index) => { {PIPELINE_STEPS.map((step, index) => {
const stepStyle = getStepStyle(index);
const checkStyle = getCheckStyle(index);
const arrowActive = activeStep > index || completedSteps.includes(index);
const isCompleted = completedSteps.includes(index); const isCompleted = completedSteps.includes(index);
const isActive = activeStep === index; const isActive = activeStep === index;
const arrowActive = activeStep > index || isCompleted;
return ( return (
<div <div
key={step.name} key={step.name}
style={{ style={{ flex: 1, padding: 20, textAlign: 'center', borderRadius: 12, position: 'relative', boxShadow: !isDark ? '0 2px 8px rgba(226,0,116,0.04)' : 'none', transition: 'all 0.3s ease', ...getStepStyle(index) }}
flex: 1,
padding: 20,
textAlign: 'center',
borderRadius: 12,
position: 'relative',
boxShadow: !isDark ? '0 2px 8px rgba(226,0,116,0.04)' : 'none',
transition: 'all 0.3s ease',
...stepStyle,
}}
>
<div
style={{
width: 36,
height: 36,
borderRadius: 8,
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
margin: '0 auto 12px',
fontSize: 16,
transition: 'all 0.3s ease',
...checkStyle,
}}
> >
<div style={{ width: 36, height: 36, borderRadius: 8, display: 'flex', alignItems: 'center', justifyContent: 'center', margin: '0 auto 12px', fontSize: 16, transition: 'all 0.3s ease', ...getCheckStyle(index) }}>
{isActive ? step.name : isCompleted ? '✓' : step.name} {isActive ? step.name : isCompleted ? '✓' : step.name}
</div> </div>
<div className="mono" style={{ fontSize: 12, fontWeight: 600 }}>{step.name}</div>
<div className="mono" style={{ fontSize: 12, fontWeight: 600 }}>
{step.name}
</div>
<div className="mono" style={{ fontSize: 10, color: theme.text3, marginTop: 8 }}> <div className="mono" style={{ fontSize: 10, color: theme.text3, marginTop: 8 }}>
{isCompleted ? 'DONE' : isActive ? 'RUNNING' : 'PENDING'} {isCompleted ? 'DONE' : isActive ? 'RUNNING' : 'PENDING'}
</div> </div>
{index < PIPELINE_STEPS.length - 1 && ( {index < PIPELINE_STEPS.length - 1 && (
<div <div style={{ position: 'absolute', right: -8, top: '50%', transform: 'translateY(-50%)', color: arrowActive ? theme.green : theme.borderLight, fontWeight: arrowActive ? 700 : 400, opacity: arrowActive ? 1 : 0.45, transition: 'all 0.3s ease' }}>
style={{
position: 'absolute',
right: -8,
top: '50%',
transform: 'translateY(-50%)',
color: arrowActive ? theme.green : theme.borderLight,
fontWeight: arrowActive ? 700 : 400,
opacity: arrowActive ? 1 : 0.45,
transition: 'all 0.3s ease',
}}
>
</div> </div>
)} )}
@@ -473,15 +425,7 @@ export const DocsPage: React.FC = () => {
<section style={{ display: 'grid', gridTemplateColumns: '1fr 1fr', gap: 24, marginBottom: 56 }}> <section style={{ display: 'grid', gridTemplateColumns: '1fr 1fr', gap: 24, marginBottom: 56 }}>
<div> <div>
<h2 <h2 style={{ fontSize: 14, fontWeight: 600, color: theme.accent, marginBottom: 20, letterSpacing: '1px' }}>
style={{
fontSize: 14,
fontWeight: 600,
color: theme.accent,
marginBottom: 20,
letterSpacing: '1px',
}}
>
({loading ? '...' : docs.length}) ({loading ? '...' : docs.length})
</h2> </h2>
@@ -489,36 +433,12 @@ export const DocsPage: React.FC = () => {
{docs.map((doc) => ( {docs.map((doc) => (
<div <div
key={doc.id} key={doc.id}
style={{ style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', padding: 20, background: theme.bgCard, borderRadius: 12, border: `1px solid ${doc.status === 'parsing' ? theme.accent : theme.border}`, transition: 'all 0.2s ease', boxShadow: !isDark ? '0 2px 8px rgba(226,0,116,0.04)' : 'none' }}
display: 'flex',
alignItems: 'center',
justifyContent: 'space-between',
padding: 20,
background: theme.bgCard,
borderRadius: 12,
border: `1px solid ${doc.status === 'parsing' ? theme.accent : theme.border}`,
transition: 'all 0.2s ease',
boxShadow: !isDark ? '0 2px 8px rgba(226,0,116,0.04)' : 'none',
}}
>
<div style={{ display: 'flex', alignItems: 'center', gap: 16 }}>
<div
style={{
width: 44,
height: 44,
borderRadius: 10,
background: theme.bgHover,
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
}}
> >
<div style={{ display: 'flex', alignItems: 'flex-start', gap: 16 }}>
<div style={{ width: 44, height: 44, borderRadius: 10, background: theme.bgHover, display: 'flex', alignItems: 'center', justifyContent: 'center', flexShrink: 0 }}>
<svg width="20" height="20" viewBox="0 0 24 24" fill="none"> <svg width="20" height="20" viewBox="0 0 24 24" fill="none">
<path <path d="M14 2H6C5 2 4 3 4 4V20C4 21 5 22 6 22H18C19 22 20 21 20 20V8L14 2Z" stroke={theme.accent} strokeWidth="1.5" />
d="M14 2H6C5 2 4 3 4 4V20C4 21 5 22 6 22H18C19 22 20 21 20 20V8L14 2Z"
stroke={theme.accent}
strokeWidth="1.5"
/>
<path d="M14 2V8H20" stroke={theme.accent} strokeWidth="1.5" /> <path d="M14 2V8H20" stroke={theme.accent} strokeWidth="1.5" />
</svg> </svg>
</div> </div>
@@ -529,36 +449,48 @@ export const DocsPage: React.FC = () => {
{doc.updatedAt ? new Date(doc.updatedAt).toLocaleString() : doc.size} {doc.updatedAt ? new Date(doc.updatedAt).toLocaleString() : doc.size}
{doc.docId ? ` · ${doc.docId}` : ''} {doc.docId ? ` · ${doc.docId}` : ''}
</div> </div>
{/* Tags row */}
{(doc.regulationType || doc.version) && (
<div style={{ display: 'flex', gap: 6, marginTop: 5, flexWrap: 'wrap' }}>
{doc.regulationType && (
<span style={{ fontSize: 11, padding: '2px 8px', borderRadius: 4, background: `${theme.accent}18`, color: theme.accent, fontWeight: 500 }}>
{doc.regulationType}
</span>
)}
{doc.version && (
<span style={{ fontSize: 11, padding: '2px 8px', borderRadius: 4, background: theme.bgHover, color: theme.text2 }}>
v{doc.version}
</span>
)}
</div>
)}
{doc.summary && (
<div style={{ fontSize: 12, color: theme.text2, marginTop: 6, lineHeight: 1.5, maxWidth: 320, display: '-webkit-box', WebkitLineClamp: 2, WebkitBoxOrient: 'vertical', overflow: 'hidden' }}>
{doc.summary}
</div>
)}
</div> </div>
</div> </div>
<div style={{ display: 'flex', alignItems: 'center', gap: 20 }}> <div style={{ display: 'flex', alignItems: 'flex-start', gap: 12, flexShrink: 0 }}>
{doc.downloadUrl && ( {doc.status === 'failed' && doc.docId && !doc.docId.startsWith('pending-') && (
<a <button onClick={() => void handleRetry(doc.docId!)} style={{ fontSize: 12, padding: '6px 12px', background: 'transparent', border: `1px solid ${theme.accent}`, borderRadius: 6, color: theme.accent, cursor: 'pointer' }}>
href={doc.downloadUrl}
target="_blank" </button>
rel="noreferrer" )}
style={{ fontSize: 12, color: theme.accent, textDecoration: 'none' }} {doc.downloadUrl && doc.status === 'indexed' && (
> <a href={doc.downloadUrl} target="_blank" rel="noreferrer" style={{ fontSize: 12, color: theme.accent, textDecoration: 'none' }}>
</a> </a>
)} )}
<div <div className="mono" style={{ fontSize: 12, padding: '6px 12px', background: theme.bgHover, borderRadius: 6, color: doc.status === 'failed' ? '#d64545' : theme.text2 }}>
className="mono" {doc.status === 'parsing' ? '处理中...' : doc.status === 'failed' ? '处理失败' : `${doc.chunks} chunks`}
style={{
fontSize: 12,
padding: '6px 12px',
background: theme.bgHover,
borderRadius: 6,
color: doc.status === 'failed' ? '#d64545' : theme.text2,
}}
>
{doc.status === 'parsing'
? '处理中...'
: doc.status === 'failed'
? '处理失败'
: `${doc.chunks} chunks`}
</div> </div>
{doc.docId && !doc.docId.startsWith('pending-') && (
<button onClick={() => void handleDelete(doc.docId!)} style={{ fontSize: 12, padding: '6px 10px', background: 'transparent', border: `1px solid ${theme.border}`, borderRadius: 6, color: theme.text3, cursor: 'pointer' }}>
</button>
)}
</div> </div>
</div> </div>
))} ))}
@@ -566,15 +498,7 @@ export const DocsPage: React.FC = () => {
</div> </div>
<div> <div>
<h2 <h2 style={{ fontSize: 14, fontWeight: 600, color: theme.accent, marginBottom: 20, letterSpacing: '1px' }}>
style={{
fontSize: 14,
fontWeight: 600,
color: theme.accent,
marginBottom: 20,
letterSpacing: '1px',
}}
>
</h2> </h2>
@@ -582,86 +506,37 @@ export const DocsPage: React.FC = () => {
<input <input
value={searchQuery} value={searchQuery}
onChange={(event) => setSearchQuery(event.target.value)} onChange={(event) => setSearchQuery(event.target.value)}
onKeyDown={(event) => { onKeyDown={(event) => { if (event.key === 'Enter') void runSearch(searchQuery); }}
if (event.key === 'Enter') {
void runSearch(searchQuery);
}
}}
placeholder="输入法规关键词、条款或制度主题" placeholder="输入法规关键词、条款或制度主题"
style={{ style={{ flex: 1, padding: 12, fontSize: 14, background: theme.bgCard, border: `1px solid ${theme.border}`, borderRadius: 8, color: theme.text, outline: 'none' }}
flex: 1,
padding: 12,
fontSize: 14,
background: theme.bgCard,
border: `1px solid ${theme.border}`,
borderRadius: 8,
color: theme.text,
outline: 'none',
}}
/> />
<button <button
onClick={() => void runSearch(searchQuery)} onClick={() => void runSearch(searchQuery)}
disabled={searchLoading || !searchQuery.trim()} disabled={searchLoading || !searchQuery.trim()}
style={{ style={{ padding: '12px 18px', fontSize: 13, fontWeight: 600, background: searchLoading || !searchQuery.trim() ? theme.bgHover : theme.gradientAccent, color: searchLoading || !searchQuery.trim() ? theme.text3 : '#fff', border: 'none', borderRadius: 8, cursor: searchLoading || !searchQuery.trim() ? 'not-allowed' : 'pointer' }}
padding: '12px 18px',
fontSize: 13,
fontWeight: 600,
background: searchLoading || !searchQuery.trim() ? theme.bgHover : theme.gradientAccent,
color: searchLoading || !searchQuery.trim() ? theme.text3 : '#fff',
border: 'none',
borderRadius: 8,
cursor: searchLoading || !searchQuery.trim() ? 'not-allowed' : 'pointer',
}}
> >
</button> </button>
</div> </div>
{searchError && ( {searchError && <div style={{ marginBottom: 12, fontSize: 13, color: '#d64545' }}>{searchError}</div>}
<div style={{ marginBottom: 12, fontSize: 13, color: '#d64545' }}>
{searchError}
</div>
)}
<div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}> <div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}>
{searchResults.map((item) => ( {searchResults.map((item) => (
<div <div key={`${item.id}-${item.file}`} style={{ padding: 18, background: theme.bgCard, borderRadius: 12, border: `1px solid ${theme.border}`, boxShadow: !isDark ? '0 2px 8px rgba(226,0,116,0.04)' : 'none' }}>
key={`${item.id}-${item.file}`}
style={{
padding: 18,
background: theme.bgCard,
borderRadius: 12,
border: `1px solid ${theme.border}`,
boxShadow: !isDark ? '0 2px 8px rgba(226,0,116,0.04)' : 'none',
}}
>
<div style={{ display: 'flex', justifyContent: 'space-between', gap: 12, marginBottom: 6 }}> <div style={{ display: 'flex', justifyContent: 'space-between', gap: 12, marginBottom: 6 }}>
<div style={{ fontSize: 14, fontWeight: 600, color: theme.text }}>{item.file}</div> <div style={{ fontSize: 14, fontWeight: 600, color: theme.text }}>{item.file}</div>
<div className="mono" style={{ fontSize: 11, color: theme.accent }}> <div className="mono" style={{ fontSize: 11, color: theme.accent }}>{(item.score * 100).toFixed(1)}%</div>
{(item.score * 100).toFixed(1)}%
</div> </div>
</div>
<div className="mono" style={{ fontSize: 11, color: theme.text3, marginBottom: 8 }}> <div className="mono" style={{ fontSize: 11, color: theme.text3, marginBottom: 8 }}>
{item.clause} {item.clause}{item.tags.length > 0 ? ` · ${item.tags.join(' · ')}` : ''}
{item.tags.length > 0 ? ` · ${item.tags.join(' · ')}` : ''}
</div> </div>
<div style={{ fontSize: 12, color: theme.text2, lineHeight: 1.6 }}>{item.content}</div> <div style={{ fontSize: 12, color: theme.text2, lineHeight: 1.6 }}>{item.content}</div>
</div> </div>
))} ))}
{!searchLoading && searchResults.length === 0 && ( {!searchLoading && searchResults.length === 0 && (
<div <div style={{ padding: 24, borderRadius: 12, background: theme.bgCard, border: `1px solid ${theme.border}`, textAlign: 'center', color: theme.text3 }}>
style={{
padding: 24,
borderRadius: 12,
background: theme.bgCard,
border: `1px solid ${theme.border}`,
textAlign: 'center',
color: theme.text3,
}}
>
</div> </div>
)} )}

View File

@@ -0,0 +1,58 @@
import React, { useRef } from 'react';
import { useTheme } from '../../contexts';
import type { RetrievalData } from '../../types';
interface CitedAnswerProps {
text: string;
sources: RetrievalData[];
onCiteClick: (index: number) => void;
}
export const CitedAnswer: React.FC<CitedAnswerProps> = ({ text, sources, onCiteClick }) => {
const { theme } = useTheme();
if (!text) return null;
// Split on [N] patterns, preserving delimiters
const parts = text.split(/(\[\d+\])/g);
return (
<span style={{ whiteSpace: 'pre-wrap', lineHeight: 1.6 }}>
{parts.map((part, i) => {
const match = part.match(/^\[(\d+)\]$/);
if (!match) return <React.Fragment key={i}>{part}</React.Fragment>;
const idx = parseInt(match[1], 10);
const source = sources[idx - 1];
return (
<sup
key={i}
title={source ? `${source.file} · ${source.clause}` : `引用 ${idx}`}
onClick={() => onCiteClick(idx)}
style={{
display: 'inline-flex',
alignItems: 'center',
justifyContent: 'center',
minWidth: 18,
height: 18,
padding: '0 4px',
marginLeft: 1,
fontSize: 10,
fontWeight: 700,
background: theme.gradientAccent,
color: '#fff',
borderRadius: 4,
cursor: source ? 'pointer' : 'default',
verticalAlign: 'super',
lineHeight: 1,
userSelect: 'none',
}}
>
{idx}
</sup>
);
})}
</span>
);
};

View File

@@ -2,6 +2,7 @@ import React, { useEffect, useRef, useState } from 'react';
import { useTheme } from '../../contexts'; import { useTheme } from '../../contexts';
import type { ChatMessage, RetrievalData } from '../../types'; import type { ChatMessage, RetrievalData } from '../../types';
import { getQuickQuestions, ragChat } from '../../api/rag'; import { getQuickQuestions, ragChat } from '../../api/rag';
import { CitedAnswer } from './CitedAnswer';
const ragQuickQuestionsDefault = [ const ragQuickQuestionsDefault = [
'电动自行车上路需要什么条件?', '电动自行车上路需要什么条件?',
@@ -24,6 +25,8 @@ export const RagChatPage: React.FC = () => {
const [showClearConfirm, setShowClearConfirm] = useState<boolean>(false); const [showClearConfirm, setShowClearConfirm] = useState<boolean>(false);
const [selectedRetrieval, setSelectedRetrieval] = useState<RetrievalData | null>(null); const [selectedRetrieval, setSelectedRetrieval] = useState<RetrievalData | null>(null);
const [quickQuestions, setQuickQuestions] = useState<string[]>(ragQuickQuestionsDefault); const [quickQuestions, setQuickQuestions] = useState<string[]>(ragQuickQuestionsDefault);
const [filterRegulationType, setFilterRegulationType] = useState<string>('');
const [highlightedSourceIdx, setHighlightedSourceIdx] = useState<number | null>(null);
function nextMessageId() { function nextMessageId() {
const currentId = nextMessageIdRef.current; const currentId = nextMessageIdRef.current;
@@ -55,8 +58,10 @@ export const RagChatPage: React.FC = () => {
setInput(''); setInput('');
setLoading(true); setLoading(true);
setRetrievals([]); setRetrievals([]);
setHighlightedSourceIdx(null);
let currentResponse = ''; let currentResponse = '';
const activeFilters = filterRegulationType.trim() || undefined;
void ragChat( void ragChat(
text, text,
@@ -112,7 +117,8 @@ export const RagChatPage: React.FC = () => {
}, },
() => { () => {
setLoading(false); setLoading(false);
} },
activeFilters
); );
}; };
@@ -130,8 +136,10 @@ export const RagChatPage: React.FC = () => {
setLoading(true); setLoading(true);
setMessages((prev) => [...prev.slice(0, -1)]); setMessages((prev) => [...prev.slice(0, -1)]);
setRetrievals([]); setRetrievals([]);
setHighlightedSourceIdx(null);
let currentResponse = ''; let currentResponse = '';
const activeFilters = filterRegulationType.trim() || undefined;
void ragChat( void ragChat(
lastUserMsg.content, lastUserMsg.content,
@@ -185,7 +193,8 @@ export const RagChatPage: React.FC = () => {
}, },
() => { () => {
setLoading(false); setLoading(false);
} },
activeFilters
); );
}; };
@@ -267,7 +276,17 @@ export const RagChatPage: React.FC = () => {
whiteSpace: 'pre-wrap', whiteSpace: 'pre-wrap',
border: msg.role === 'assistant' ? `1px solid ${theme.border}` : 'none', border: msg.role === 'assistant' ? `1px solid ${theme.border}` : 'none',
}}> }}>
{msg.content} {msg.role === 'assistant' ? (
<CitedAnswer
text={msg.content}
sources={retrievals}
onCiteClick={(idx) => {
setHighlightedSourceIdx(idx);
const el = document.getElementById(`source-${idx}`);
if (el) el.scrollIntoView({ behavior: 'smooth', block: 'center' });
}}
/>
) : msg.content}
{msg.role === 'assistant' && msg.retrievalIds && msg.retrievalIds.length > 0 && ( {msg.role === 'assistant' && msg.retrievalIds && msg.retrievalIds.length > 0 && (
<div style={{ <div style={{
marginTop: 10, marginTop: 10,
@@ -331,6 +350,31 @@ export const RagChatPage: React.FC = () => {
background: theme.bg, background: theme.bg,
borderTop: `1px solid ${theme.border}`, borderTop: `1px solid ${theme.border}`,
}}> }}>
<div style={{
display: 'flex',
gap: 8,
marginBottom: 10,
alignItems: 'center',
}}>
<span className="mono" style={{ fontSize: 11, color: theme.text3, whiteSpace: 'nowrap' }}></span>
<input
value={filterRegulationType}
onChange={(e) => setFilterRegulationType(e.target.value)}
placeholder="如: GB / UN-ECE / IATF留空不过滤"
style={{
flex: 1,
maxWidth: 280,
padding: '5px 10px',
fontSize: 12,
background: theme.bgHover,
border: `1px solid ${theme.border}`,
borderRadius: 6,
color: theme.text,
outline: 'none',
}}
/>
</div>
<div style={{ <div style={{
display: 'flex', display: 'flex',
gap: 8, gap: 8,
@@ -468,14 +512,16 @@ export const RagChatPage: React.FC = () => {
{retrievals.map((r, i) => ( {retrievals.map((r, i) => (
<div <div
key={r.id} key={r.id}
id={`source-${i + 1}`}
onClick={() => setSelectedRetrieval(r)} onClick={() => setSelectedRetrieval(r)}
style={{ style={{
padding: 16, padding: 16,
background: theme.bgHover, background: highlightedSourceIdx === i + 1 ? theme.bgElevated : theme.bgHover,
borderRadius: 10, borderRadius: 10,
border: `1px solid ${theme.border}`, border: `1px solid ${highlightedSourceIdx === i + 1 ? theme.accent : theme.border}`,
cursor: 'pointer', cursor: 'pointer',
position: 'relative', position: 'relative',
transition: 'border-color 0.2s, background 0.2s',
}} }}
> >
<div style={{ <div style={{

View File

@@ -8,6 +8,8 @@ export interface Doc {
downloadUrl?: string; downloadUrl?: string;
summary?: string; summary?: string;
updatedAt?: string; updatedAt?: string;
regulationType?: string;
version?: string;
} }
export interface SearchResult { export interface SearchResult {

View File

@@ -4,7 +4,7 @@ import react from '@vitejs/plugin-react'
// https://vite.dev/config/ // https://vite.dev/config/
export default defineConfig(({ mode }) => { export default defineConfig(({ mode }) => {
const env = loadEnv(mode, process.cwd(), '') const env = loadEnv(mode, process.cwd(), '')
const apiHost = env.API_HOST || 'localhost' const apiHost = env.API_HOST || '6.86.80.8'
const apiPort = env.API_PORT || '8000' const apiPort = env.API_PORT || '8000'
const proxyTarget = env.VITE_API_PROXY_TARGET || `http://${apiHost}:${apiPort}` const proxyTarget = env.VITE_API_PROXY_TARGET || `http://${apiHost}:${apiPort}`