fix 文档管理模块 & 法规对话模块
This commit is contained in:
@@ -23,6 +23,8 @@ class DocumentUploadResponse(BaseModel):
|
||||
num_chunks: int = Field(default=0, description="分块数量")
|
||||
summary: str = Field(default="", description="LLM生成的文档摘要")
|
||||
summary_latency_ms: int = Field(default=0, description="摘要生成耗时(ms)")
|
||||
regulation_type: str = Field(default="", description="法规类型")
|
||||
version: str = Field(default="", description="文档版本号")
|
||||
timestamp: datetime = Field(default_factory=datetime.now, description="时间戳")
|
||||
|
||||
|
||||
|
||||
@@ -14,30 +14,22 @@ from app.schemas.compliance import (
|
||||
AnalyzeResponse,
|
||||
ComplianceChatRequest,
|
||||
)
|
||||
from app.services.mock_data import (
|
||||
generate_task_id,
|
||||
get_mock_compliance_result,
|
||||
get_mock_compliance_chat_response,
|
||||
)
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
from app.services.mock_data import generate_task_id, get_mock_compliance_result
|
||||
from app.shared.bootstrap import get_agent_conversation_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/compliance", tags=["合规分析"])
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store: dict[str, dict] = {}
|
||||
|
||||
# Store uploaded compliance files inside the local backend data directory.
|
||||
RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw"
|
||||
|
||||
|
||||
@router.post("/analyze", response_model=AnalyzeResponse)
|
||||
async def analyze_document(file: UploadFile = File(...)):
|
||||
"""Handle analyze document."""
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
task_id = generate_task_id()
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}"
|
||||
|
||||
@@ -45,7 +37,6 @@ async def analyze_document(file: UploadFile = File(...)):
|
||||
with file_path.open("wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store[task_id] = {
|
||||
"task_id": task_id,
|
||||
"file_path": str(file_path),
|
||||
@@ -53,8 +44,6 @@ async def analyze_document(file: UploadFile = File(...)):
|
||||
"result": None,
|
||||
}
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store[task_id]["status"] = "completed"
|
||||
tasks_store[task_id]["result"] = get_mock_compliance_result(task_id)
|
||||
|
||||
@@ -65,47 +54,44 @@ async def analyze_document(file: UploadFile = File(...)):
|
||||
async def get_result(task_id: str):
|
||||
"""Return result."""
|
||||
if task_id not in tasks_store:
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
return get_mock_compliance_result(task_id)
|
||||
|
||||
task = tasks_store[task_id]
|
||||
|
||||
if task["status"] == "processing":
|
||||
return {"status": "processing", "message": "分析进行中"}
|
||||
|
||||
return task["result"]
|
||||
|
||||
|
||||
@router.post("/chat/{segment_id}")
|
||||
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
||||
"""Handle compliance chat."""
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
intent_map = {
|
||||
1: "车身结构设计",
|
||||
2: "动力系统配置",
|
||||
3: "安全配置设计",
|
||||
}
|
||||
intent = intent_map.get(segment_id, "车身结构设计")
|
||||
"""Stream compliance Q&A grounded in real vector retrieval."""
|
||||
query = request.query
|
||||
if request.segment_context:
|
||||
query = f"[段落分析上下文]\n{request.segment_context}\n\n用户问题:{request.query}"
|
||||
|
||||
_, event_stream = get_agent_conversation_service().stream_chat(
|
||||
query=query,
|
||||
top_k=5,
|
||||
prompt_template="compliance_qa",
|
||||
)
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
"""Handle generate."""
|
||||
response = get_mock_compliance_chat_response(intent, request.query)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
sentences = response.split("\n\n")
|
||||
for sentence in sentences:
|
||||
if sentence.strip():
|
||||
chunks = sentence.split("\n")
|
||||
for chunk in chunks:
|
||||
if chunk.strip():
|
||||
await asyncio.sleep(0.05)
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
"""Translate agent SSE events to compliance chunk/done format."""
|
||||
for event in event_stream:
|
||||
event_type = event.get("event", "")
|
||||
if event_type == "content":
|
||||
text = event.get("data", "")
|
||||
if text:
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': text}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
elif event_type == "done":
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
|
||||
@@ -80,6 +80,8 @@ async def get_document_status(doc_id: str):
|
||||
num_chunks=document.chunk_count,
|
||||
summary=document.summary,
|
||||
summary_latency_ms=document.summary_latency_ms,
|
||||
regulation_type=document.regulation_type,
|
||||
version=document.version,
|
||||
)
|
||||
|
||||
|
||||
@@ -123,7 +125,7 @@ async def list_documents():
|
||||
@router.get("/management-list")
|
||||
async def get_document_management_list():
|
||||
"""Return document management list."""
|
||||
documents = get_document_query_service().list_documents(limit=10)
|
||||
documents = get_document_query_service().list_documents()
|
||||
return {
|
||||
"documents": [
|
||||
{
|
||||
@@ -131,10 +133,37 @@ async def get_document_management_list():
|
||||
"doc_name": item.doc_name,
|
||||
"status": item.status.value,
|
||||
"chunk_count": item.chunk_count,
|
||||
"size_bytes": item.size_bytes,
|
||||
"summary": item.summary,
|
||||
"updated_at": item.updated_at.isoformat(),
|
||||
"regulation_type": item.regulation_type,
|
||||
"version": item.version,
|
||||
}
|
||||
for item in documents
|
||||
],
|
||||
"total": len(documents),
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
"""Delete a document and its associated data."""
|
||||
deleted = get_document_command_service().delete(doc_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="文档不存在")
|
||||
return {"doc_id": doc_id, "deleted": True}
|
||||
|
||||
|
||||
@router.post("/{doc_id}/retry", response_model=DocumentUploadResponse)
|
||||
async def retry_document(doc_id: str):
|
||||
"""Re-process a failed document."""
|
||||
try:
|
||||
result = get_document_command_service().retry(doc_id)
|
||||
if result.status == "failed":
|
||||
raise HTTPException(status_code=500, detail=result.message)
|
||||
return _document_response(result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("文档重试失败")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
@@ -9,73 +9,74 @@ from typing import AsyncGenerator
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
|
||||
from app.services.mock_data import (
|
||||
get_mock_quick_questions,
|
||||
get_mock_retrieval,
|
||||
get_mock_rag_answer,
|
||||
)
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
from app.shared.bootstrap import get_agent_conversation_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/rag", tags=["RAG问答"])
|
||||
|
||||
_DEFAULT_QUICK_QUESTIONS = [
|
||||
{"id": "1", "question": "请总结最新入库法规对电池安全的核心要求", "category": "法规解读"},
|
||||
{"id": "2", "question": "我上传的制度文档与新能源法规有哪些潜在冲突?", "category": "差距分析"},
|
||||
{"id": "3", "question": "请给出法规依据,并按条款列出整改建议", "category": "整改建议"},
|
||||
{"id": "4", "question": "请解释 UN-ECE 与 GB 标准在网络安全方面的差异", "category": "标准对比"},
|
||||
{"id": "5", "question": "IATF 16949 对供应商质量管理有哪些强制要求?", "category": "法规解读"},
|
||||
{"id": "6", "question": "ISO 45001 与 AQ 标准在职业健康安全方面的主要差异是什么?", "category": "标准对比"},
|
||||
]
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def rag_chat(request: RagChatRequest):
|
||||
"""Handle rag chat."""
|
||||
"""Stream RAG Q&A using the real agent service."""
|
||||
_, event_stream = get_agent_conversation_service().stream_chat(
|
||||
query=request.query,
|
||||
session_id=request.session_id,
|
||||
filters=request.filters,
|
||||
top_k=request.top_k or settings.rag_top_k,
|
||||
)
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
"""Handle generate."""
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'retrieving'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
docs = get_mock_retrieval(request.query, top_k=request.top_k)
|
||||
|
||||
retrieved_data = [
|
||||
{
|
||||
"id": d["id"],
|
||||
"score": d["score"],
|
||||
"preview": d["preview"],
|
||||
"doc_name": d.get("doc_name", ""),
|
||||
"clause": d.get("clause", ""),
|
||||
}
|
||||
for d in docs
|
||||
]
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'retrieved', 'docs': retrieved_data}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield (
|
||||
f"event: message\ndata: "
|
||||
f"{json.dumps({'type': 'generating', 'text': '正在生成答案...'}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
answer = get_mock_rag_answer(request.query)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
sentences = answer.split("\n\n")
|
||||
for sentence in sentences:
|
||||
if sentence.strip():
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
chunks = sentence.split("\n")
|
||||
for chunk in chunks:
|
||||
if chunk.strip():
|
||||
await asyncio.sleep(0.05) # Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
"""Translate agent SSE events to rag format."""
|
||||
for event in event_stream:
|
||||
event_type = event.get("event", "")
|
||||
data = event.get("data", "")
|
||||
if event_type == "sources":
|
||||
docs = [
|
||||
{
|
||||
"id": str(s.get("chunk_id") or s.get("doc_id") or idx + 1),
|
||||
"score": s.get("score", 0),
|
||||
"preview": s.get("content", "")[:200],
|
||||
"doc_name": s.get("doc_name", ""),
|
||||
"clause": s.get("section_title", "法规片段"),
|
||||
"doc_id": s.get("doc_id"),
|
||||
"download_url": (
|
||||
f"/api/v1/documents/download/{s['doc_id']}" if s.get("doc_id") else None
|
||||
),
|
||||
}
|
||||
for idx, s in enumerate(data if isinstance(data, list) else [])
|
||||
]
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'retrieved', 'docs': docs}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
elif event_type == "content":
|
||||
if data:
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': data}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
elif event_type == "done":
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
elif event_type == "status":
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'status', 'text': data}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
@@ -86,9 +87,13 @@ async def rag_chat(request: RagChatRequest):
|
||||
|
||||
@router.get("/quick-questions", response_model=QuickQuestionsResponse)
|
||||
async def get_quick_questions():
|
||||
"""Return quick questions."""
|
||||
questions = [
|
||||
QuickQuestion(id=q["id"], question=q["question"], category=q["category"])
|
||||
for q in get_mock_quick_questions()
|
||||
]
|
||||
"""Return configurable quick questions from settings or defaults."""
|
||||
raw = getattr(settings, "rag_quick_questions", None)
|
||||
if raw and isinstance(raw, list):
|
||||
questions = [
|
||||
QuickQuestion(id=str(i + 1), question=q if isinstance(q, str) else q.get("question", ""), category=q.get("category", "法规问答") if isinstance(q, dict) else "法规问答")
|
||||
for i, q in enumerate(raw)
|
||||
]
|
||||
else:
|
||||
questions = [QuickQuestion(**q) for q in _DEFAULT_QUICK_QUESTIONS]
|
||||
return QuickQuestionsResponse(questions=questions)
|
||||
|
||||
@@ -17,6 +17,7 @@ from app.domain.documents import (
|
||||
DocumentParser,
|
||||
DocumentRepository,
|
||||
DocumentStatus,
|
||||
ParseArtifactStore,
|
||||
ParsedDocument,
|
||||
)
|
||||
from app.domain.retrieval import EmbeddingProvider, VectorIndex
|
||||
@@ -47,6 +48,7 @@ class DocumentCommandService:
|
||||
chunk_builder: ChunkBuilder,
|
||||
embedding_provider: EmbeddingProvider,
|
||||
vector_index: VectorIndex,
|
||||
parse_artifact_store: ParseArtifactStore | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Document Command Service instance."""
|
||||
self.document_repository = document_repository
|
||||
@@ -55,6 +57,7 @@ class DocumentCommandService:
|
||||
self.chunk_builder = chunk_builder
|
||||
self.embedding_provider = embedding_provider
|
||||
self.vector_index = vector_index
|
||||
self.parse_artifact_store = parse_artifact_store
|
||||
|
||||
def _save_parse_artifacts(self, *, doc_id: str, parsed_document: ParsedDocument) -> dict[str, str]:
|
||||
"""Persist parse artifacts so troubleshooting does not depend on provider retention windows."""
|
||||
@@ -143,6 +146,15 @@ class DocumentCommandService:
|
||||
"processing_stage": "parsed",
|
||||
},
|
||||
)
|
||||
if self.parse_artifact_store:
|
||||
try:
|
||||
self.parse_artifact_store.save(
|
||||
doc_id,
|
||||
parsed_document.structure_nodes,
|
||||
parsed_document.semantic_blocks,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id)
|
||||
|
||||
chunks = self.chunk_builder.build(
|
||||
parsed_document=parsed_document,
|
||||
@@ -205,20 +217,120 @@ class DocumentCommandService:
|
||||
logger.warning("临时文件清理失败: {}", temp_path)
|
||||
|
||||
|
||||
def delete(self, doc_id: str) -> bool:
|
||||
"""Delete document record, binary file, and vector chunks."""
|
||||
document = self.document_repository.get(doc_id)
|
||||
if not document:
|
||||
return False
|
||||
try:
|
||||
self.binary_store.delete(document.object_name)
|
||||
except Exception:
|
||||
logger.warning("Binary delete failed for doc_id={}", doc_id)
|
||||
try:
|
||||
self.vector_index.delete_by_document(doc_id)
|
||||
except Exception:
|
||||
logger.warning("Vector delete failed for doc_id={}", doc_id)
|
||||
if self.parse_artifact_store:
|
||||
try:
|
||||
self.parse_artifact_store.delete(doc_id)
|
||||
except Exception:
|
||||
logger.warning("ParseArtifactStore delete failed for doc_id={}", doc_id)
|
||||
self.document_repository.delete(doc_id)
|
||||
return True
|
||||
|
||||
def retry(self, doc_id: str) -> DocumentProcessResult:
|
||||
"""Re-process a failed document from its stored binary."""
|
||||
document = self.document_repository.get(doc_id)
|
||||
if not document:
|
||||
return DocumentProcessResult(doc_id=doc_id, doc_name="", status="failed", message="文档不存在")
|
||||
content = self.binary_store.read(document.object_name)
|
||||
return self.upload_and_process(
|
||||
doc_id=doc_id,
|
||||
file_name=document.file_name,
|
||||
content=content,
|
||||
content_type=document.content_type,
|
||||
doc_name=document.doc_name,
|
||||
regulation_type=document.regulation_type,
|
||||
version=document.version,
|
||||
generate_summary=bool(document.metadata.get("generate_summary", False)),
|
||||
)
|
||||
|
||||
|
||||
class DocumentQueryService:
|
||||
"""Provide the Document Query Service service."""
|
||||
def __init__(self, *, document_repository: DocumentRepository, binary_store: DocumentBinaryStore) -> None:
|
||||
def __init__(self, *, document_repository: DocumentRepository, binary_store: DocumentBinaryStore, vector_index: VectorIndex) -> None:
|
||||
"""Initialize the Document Query Service instance."""
|
||||
self.document_repository = document_repository
|
||||
self.binary_store = binary_store
|
||||
self.vector_index = vector_index
|
||||
|
||||
def get(self, doc_id: str) -> Document | None:
|
||||
"""Handle get for the Document Query Service instance."""
|
||||
return self.document_repository.get(doc_id)
|
||||
|
||||
def list_documents(self, limit: int | None = None) -> list[Document]:
|
||||
"""List documents for the Document Query Service instance."""
|
||||
return self.document_repository.list(limit=limit)
|
||||
"""Return documents with real-time state from Milvus as the authoritative source.
|
||||
|
||||
Algorithm:
|
||||
1. Query Milvus for all doc metadata (doc_id, doc_name, chunk_count, …).
|
||||
2. Load JSON/PG metadata records and index them by doc_id.
|
||||
3. Merge: Milvus-present docs get status=INDEXED and live chunk_count;
|
||||
metadata-only docs with status=INDEXED are demoted to FAILED.
|
||||
4. Milvus-only docs (no metadata record) are surfaced as synthetic INDEXED
|
||||
entries so they are never invisible to the management list.
|
||||
"""
|
||||
# Fetch live Milvus state first.
|
||||
try:
|
||||
milvus_rows = self.vector_index.list_document_metadata()
|
||||
except Exception:
|
||||
milvus_rows = []
|
||||
|
||||
milvus_by_id: dict[str, dict] = {r["doc_id"]: r for r in milvus_rows}
|
||||
|
||||
# Load metadata store records.
|
||||
meta_docs = self.document_repository.list(limit=limit)
|
||||
meta_by_id: dict[str, Document] = {d.doc_id: d for d in meta_docs}
|
||||
|
||||
result: list[Document] = []
|
||||
|
||||
# Reconcile metadata records against Milvus.
|
||||
for doc in meta_docs:
|
||||
if doc.doc_id in milvus_by_id:
|
||||
row = milvus_by_id[doc.doc_id]
|
||||
doc.chunk_count = row["chunk_count"]
|
||||
doc.status = DocumentStatus.INDEXED
|
||||
# Backfill fields that may be missing from older JSON records.
|
||||
if not doc.doc_name and row.get("doc_name"):
|
||||
doc.doc_name = row["doc_name"]
|
||||
if not doc.regulation_type and row.get("regulation_type"):
|
||||
doc.regulation_type = row["regulation_type"]
|
||||
if not doc.version and row.get("version"):
|
||||
doc.version = row["version"]
|
||||
elif doc.status == DocumentStatus.INDEXED:
|
||||
# Metadata says indexed but Milvus has no chunks.
|
||||
doc.status = DocumentStatus.FAILED
|
||||
doc.error_message = "向量数据库中未找到对应数据"
|
||||
result.append(doc)
|
||||
|
||||
# Surface Milvus-only docs that have no metadata record at all.
|
||||
for doc_id, row in milvus_by_id.items():
|
||||
if doc_id not in meta_by_id:
|
||||
synthetic = Document(
|
||||
doc_id=doc_id,
|
||||
doc_name=row.get("doc_name", doc_id),
|
||||
file_name=row.get("doc_name", doc_id),
|
||||
object_name="",
|
||||
content_type="",
|
||||
size_bytes=0,
|
||||
status=DocumentStatus.INDEXED,
|
||||
regulation_type=row.get("regulation_type", ""),
|
||||
version=row.get("version", ""),
|
||||
chunk_count=row["chunk_count"],
|
||||
)
|
||||
result.append(synthetic)
|
||||
|
||||
result.sort(key=lambda d: d.updated_at, reverse=True)
|
||||
return result[:limit] if limit is not None else result
|
||||
|
||||
def download(self, doc_id: str) -> tuple[Document, bytes]:
|
||||
"""Handle download for the Document Query Service instance."""
|
||||
|
||||
@@ -3,17 +3,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.domain.retrieval import RetrievalQuery, Retriever, RetrievedChunk
|
||||
from app.domain.retrieval.ports import Reranker
|
||||
# Keep orchestration logic centralized so use-case flow stays easy to trace.
|
||||
|
||||
|
||||
|
||||
class KnowledgeRetrievalService:
|
||||
"""Provide the Knowledge Retrieval Service service."""
|
||||
def __init__(self, *, retriever: Retriever) -> None:
|
||||
|
||||
def __init__(self, *, retriever: Retriever, reranker: Reranker | None = None, reranker_top_k: int = 5) -> None:
|
||||
"""Initialize the Knowledge Retrieval Service instance."""
|
||||
self.retriever = retriever
|
||||
self.reranker = reranker
|
||||
self.reranker_top_k = reranker_top_k
|
||||
|
||||
def retrieve(self, *, query: str, top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||
"""Handle retrieve for the Knowledge Retrieval Service instance."""
|
||||
retrieval_query = RetrievalQuery(query=query, top_k=top_k, filters=filters)
|
||||
return self.retriever.retrieve(retrieval_query)
|
||||
"""Retrieve and optionally rerank chunks for a query."""
|
||||
candidate_k = top_k if self.reranker is None else max(top_k * 4, 20)
|
||||
retrieval_query = RetrievalQuery(query=query, top_k=candidate_k, filters=filters)
|
||||
candidates = self.retriever.retrieve(retrieval_query)
|
||||
if self.reranker and candidates:
|
||||
return self.reranker.rerank(query, candidates, top_k=self.reranker_top_k)
|
||||
return candidates[:top_k]
|
||||
|
||||
@@ -31,7 +31,7 @@ class Settings(BaseSettings):
|
||||
debug: bool = Field(default=False, description="调试模式")
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
milvus_host: str = Field(default="localhost", description="Milvus服务地址")
|
||||
milvus_host: str = Field(default="6.86.80.8", description="Milvus服务地址")
|
||||
milvus_port: int = Field(default=19530, description="Milvus服务端口")
|
||||
milvus_collection: str = Field(default="regulations_dense_1024_v1", description="法规向量集合名称")
|
||||
milvus_db_name: str = Field(default="default", description="Milvus数据库名称")
|
||||
@@ -54,20 +54,20 @@ class Settings(BaseSettings):
|
||||
parser_failure_mode: str = Field(default="fail", description="解析失败策略")
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
minio_endpoint: str = Field(default="localhost:9000", description="MinIO服务地址")
|
||||
minio_endpoint: str = Field(default="6.86.80.8:9000", description="MinIO服务地址")
|
||||
minio_access_key: str = Field(default="minioadmin", description="MinIO访问密钥")
|
||||
minio_secret_key: str = Field(default="minioadmin123", description="MinIO秘密密钥")
|
||||
minio_bucket: str = Field(default="upload-files", description="文档存储桶名称")
|
||||
minio_secure: bool = Field(default=False, description="是否使用HTTPS")
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
redis_host: str = Field(default="localhost", description="Redis服务地址")
|
||||
redis_host: str = Field(default="6.86.80.8", description="Redis服务地址")
|
||||
redis_port: int = Field(default=6379, description="Redis服务端口")
|
||||
redis_password: str = Field(default="", description="Redis密码")
|
||||
redis_db: int = Field(default=0, description="Redis数据库编号")
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
postgres_host: str = Field(default="localhost", description="PostgreSQL服务地址")
|
||||
postgres_host: str = Field(default="6.86.80.8", description="PostgreSQL服务地址")
|
||||
postgres_port: int = Field(default=5432, description="PostgreSQL服务端口")
|
||||
postgres_user: str = Field(default="compliance", description="PostgreSQL用户名")
|
||||
postgres_password: str = Field(default="compliance123", description="PostgreSQL密码")
|
||||
@@ -80,6 +80,7 @@ class Settings(BaseSettings):
|
||||
document_metadata_path: str = Field(default="backend/data/documents.json", description="文档元数据存储路径")
|
||||
parser_backend: str = Field(default="aliyun", description="解析后端(local/aliyun)")
|
||||
chunk_backend: str = Field(default="aliyun", description="分块后端(local/aliyun)")
|
||||
document_repository_backend: str = Field(default="json", description="文档元数据存储后端 (json/postgres)")
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
api_host: str = Field(default="0.0.0.0", description="API服务地址")
|
||||
@@ -104,9 +105,16 @@ class Settings(BaseSettings):
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
rag_top_k: int = Field(default=5, description="检索召回数量")
|
||||
rag_retrieval_top_k: int = Field(default=20, description="精排前召回候选数量(reranker 启用时生效)")
|
||||
rag_max_context_tokens: int = Field(default=2000, description="RAG最大上下文token数")
|
||||
rag_summary_max_tokens: int = Field(default=10240, description="文档摘要最大token数")
|
||||
|
||||
reranker_enabled: bool = Field(default=False, description="是否启用 Cross-Encoder 精排")
|
||||
reranker_base_url: str = Field(default="", description="Reranker API 地址")
|
||||
reranker_model: str = Field(default="BAAI/bge-reranker-v2-m3", description="Reranker 模型名称")
|
||||
reranker_api_key: str = Field(default="", description="Reranker API 密钥")
|
||||
reranker_top_k: int = Field(default=5, description="精排后保留的最终结果数量")
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
milvus_index_type: str = Field(default="IVF_FLAT", description="Milvus索引类型")
|
||||
milvus_nlist: int = Field(default=128, description="Milvus nlist参数")
|
||||
|
||||
@@ -25,7 +25,7 @@ class Settings(BaseSettings):
|
||||
dashscope_api_key: str = ""
|
||||
|
||||
# Milvus
|
||||
milvus_host: str = "localhost"
|
||||
milvus_host: str = "6.86.80.8"
|
||||
milvus_port: int = 19530
|
||||
milvus_collection: str = "regulations_dense_1024_v1"
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Initialize the app.domain.documents package."""
|
||||
|
||||
from .models import Chunk, Document, DocumentStatus, ParsedDocument
|
||||
from .ports import ChunkBuilder, DocumentBinaryStore, DocumentParser, DocumentRepository
|
||||
from .ports import ChunkBuilder, DocumentBinaryStore, DocumentParser, DocumentRepository, ParseArtifactStore
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
@@ -14,4 +14,5 @@ __all__ = [
|
||||
"DocumentBinaryStore",
|
||||
"DocumentParser",
|
||||
"DocumentRepository",
|
||||
"ParseArtifactStore",
|
||||
]
|
||||
|
||||
@@ -31,6 +31,11 @@ class DocumentRepository(ABC):
|
||||
"""Handle list for the Document Repository instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, doc_id: str) -> bool:
|
||||
"""Delete a document record. Returns True if deleted, False if not found."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_status(
|
||||
self,
|
||||
@@ -94,3 +99,32 @@ class ChunkBuilder(ABC):
|
||||
) -> list[Chunk]:
|
||||
"""Handle build for the Chunk Builder instance."""
|
||||
pass
|
||||
|
||||
|
||||
class ParseArtifactStore(ABC):
|
||||
"""Persist parse artifacts (structure nodes and semantic blocks) for relational queries."""
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
doc_id: str,
|
||||
structure_nodes: list[dict],
|
||||
semantic_blocks: list[dict],
|
||||
) -> None:
|
||||
"""Persist structure nodes and semantic blocks for a document."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, doc_id: str) -> None:
|
||||
"""Remove all parse artifacts for a document."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_semantic_blocks(self, doc_id: str) -> list[dict]:
|
||||
"""Return all semantic blocks for a document."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_structure_nodes(self, doc_id: str) -> list[dict]:
|
||||
"""Return all structure nodes for a document."""
|
||||
pass
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Initialize the app.domain.retrieval package."""
|
||||
|
||||
from .models import RetrievalQuery, RetrievedChunk
|
||||
from .ports import EmbeddingProvider, Retriever, VectorIndex
|
||||
from .ports import EmbeddingProvider, Reranker, Retriever, VectorIndex
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = ["RetrievalQuery", "RetrievedChunk", "EmbeddingProvider", "Retriever", "VectorIndex"]
|
||||
__all__ = ["RetrievalQuery", "RetrievedChunk", "EmbeddingProvider", "Reranker", "Retriever", "VectorIndex"]
|
||||
|
||||
@@ -10,7 +10,6 @@ from .models import RetrievalQuery, RetrievedChunk
|
||||
# Keep domain contracts explicit so adapters can swap implementations cleanly.
|
||||
|
||||
|
||||
|
||||
class EmbeddingProvider(ABC):
|
||||
"""Provide the Embedding Provider provider."""
|
||||
@abstractmethod
|
||||
@@ -41,12 +40,35 @@ class VectorIndex(ABC):
|
||||
"""Handle search for the Vector Index instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def count_by_document(self) -> dict[str, int]:
|
||||
"""Return a mapping of doc_id -> chunk count from the vector store."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_document_metadata(self) -> list[dict]:
|
||||
"""Return per-document metadata rows from the vector store.
|
||||
|
||||
Each row contains at minimum: doc_id, doc_name, chunk_count.
|
||||
Optional fields: regulation_type, version.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def health(self) -> dict:
|
||||
"""Handle health for the Vector Index instance."""
|
||||
pass
|
||||
|
||||
|
||||
class Reranker(ABC):
|
||||
"""Re-score and re-order a candidate list using a cross-encoder model."""
|
||||
|
||||
@abstractmethod
|
||||
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
|
||||
"""Return top_k chunks sorted by cross-encoder score (descending)."""
|
||||
pass
|
||||
|
||||
|
||||
class Retriever(ABC):
|
||||
"""Provide the Retriever retriever."""
|
||||
@abstractmethod
|
||||
|
||||
@@ -289,7 +289,7 @@ def build_vector_chunks(
|
||||
{
|
||||
"doc_id": doc_id,
|
||||
"doc_title": doc_title,
|
||||
"chunk_id": f"chunk-{chunk_index}",
|
||||
"chunk_id": f"{doc_id}-chunk-{chunk_index}",
|
||||
"chunk_index": chunk_index,
|
||||
"semantic_id": block["semantic_id"],
|
||||
"chunk_type": block["block_type"],
|
||||
|
||||
@@ -75,6 +75,15 @@ class JsonDocumentRepository(DocumentRepository):
|
||||
documents.sort(key=lambda item: item.updated_at, reverse=True)
|
||||
return documents[:limit] if limit is not None else documents
|
||||
|
||||
def delete(self, doc_id: str) -> bool:
|
||||
"""Delete a document record."""
|
||||
payload = self._load()
|
||||
if doc_id not in payload:
|
||||
return False
|
||||
del payload[doc_id]
|
||||
self._save(payload)
|
||||
return True
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
doc_id: str,
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
"""Implement infrastructure support for postgres document repository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.documents import Document, DocumentRepository, DocumentStatus
|
||||
|
||||
_CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
doc_id VARCHAR(128) PRIMARY KEY,
|
||||
doc_name VARCHAR(512) NOT NULL DEFAULT '',
|
||||
file_name VARCHAR(512) NOT NULL DEFAULT '',
|
||||
object_name VARCHAR(1024) NOT NULL DEFAULT '',
|
||||
content_type VARCHAR(128) NOT NULL DEFAULT '',
|
||||
size_bytes BIGINT NOT NULL DEFAULT 0,
|
||||
status VARCHAR(32) NOT NULL DEFAULT 'pending',
|
||||
regulation_type VARCHAR(128) NOT NULL DEFAULT '',
|
||||
version VARCHAR(64) NOT NULL DEFAULT '',
|
||||
summary TEXT NOT NULL DEFAULT '',
|
||||
summary_latency_ms INTEGER NOT NULL DEFAULT 0,
|
||||
chunk_count INTEGER NOT NULL DEFAULT 0,
|
||||
parser_name VARCHAR(128) NOT NULL DEFAULT '',
|
||||
index_name VARCHAR(128) NOT NULL DEFAULT '',
|
||||
error_message TEXT NOT NULL DEFAULT '',
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
|
||||
_COLUMNS = (
|
||||
"doc_id", "doc_name", "file_name", "object_name", "content_type",
|
||||
"size_bytes", "status", "regulation_type", "version", "summary",
|
||||
"summary_latency_ms", "chunk_count", "parser_name", "index_name",
|
||||
"error_message", "metadata", "created_at", "updated_at",
|
||||
)
|
||||
|
||||
|
||||
class PostgresDocumentRepository(DocumentRepository):
|
||||
"""DocumentRepository implementation backed by PostgreSQL."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pool = ThreadedConnectionPool(
|
||||
minconn=1,
|
||||
maxconn=5,
|
||||
host=settings.postgres_host,
|
||||
port=settings.postgres_port,
|
||||
user=settings.postgres_user,
|
||||
password=settings.postgres_password,
|
||||
dbname=settings.postgres_db,
|
||||
)
|
||||
self._ensure_schema()
|
||||
|
||||
def _ensure_schema(self) -> None:
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(_CREATE_TABLE)
|
||||
conn.commit()
|
||||
|
||||
@contextmanager
|
||||
def _conn(self):
|
||||
conn = self._pool.getconn()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
self._pool.putconn(conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Serialization helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _row_to_document(self, row: dict[str, Any]) -> Document:
|
||||
return Document(
|
||||
doc_id=row["doc_id"],
|
||||
doc_name=row["doc_name"],
|
||||
file_name=row["file_name"],
|
||||
object_name=row["object_name"],
|
||||
content_type=row["content_type"],
|
||||
size_bytes=row["size_bytes"],
|
||||
status=DocumentStatus(row["status"]),
|
||||
regulation_type=row["regulation_type"],
|
||||
version=row["version"],
|
||||
summary=row["summary"],
|
||||
summary_latency_ms=row["summary_latency_ms"],
|
||||
chunk_count=row["chunk_count"],
|
||||
parser_name=row["parser_name"],
|
||||
index_name=row["index_name"],
|
||||
error_message=row["error_message"],
|
||||
metadata=row["metadata"] if isinstance(row["metadata"], dict) else json.loads(row["metadata"] or "{}"),
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# DocumentRepository interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create(self, document: Document) -> Document:
|
||||
sql = """
|
||||
INSERT INTO documents
|
||||
(doc_id, doc_name, file_name, object_name, content_type, size_bytes,
|
||||
status, regulation_type, version, summary, summary_latency_ms,
|
||||
chunk_count, parser_name, index_name, error_message, metadata,
|
||||
created_at, updated_at)
|
||||
VALUES
|
||||
(%(doc_id)s, %(doc_name)s, %(file_name)s, %(object_name)s, %(content_type)s,
|
||||
%(size_bytes)s, %(status)s, %(regulation_type)s, %(version)s, %(summary)s,
|
||||
%(summary_latency_ms)s, %(chunk_count)s, %(parser_name)s, %(index_name)s,
|
||||
%(error_message)s, %(metadata)s, %(created_at)s, %(updated_at)s)
|
||||
ON CONFLICT (doc_id) DO NOTHING
|
||||
"""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, self._to_params(document))
|
||||
conn.commit()
|
||||
return document
|
||||
|
||||
def update(self, document: Document) -> Document:
|
||||
document.updated_at = datetime.now(UTC)
|
||||
sql = """
|
||||
UPDATE documents SET
|
||||
doc_name=%(doc_name)s, file_name=%(file_name)s, object_name=%(object_name)s,
|
||||
content_type=%(content_type)s, size_bytes=%(size_bytes)s, status=%(status)s,
|
||||
regulation_type=%(regulation_type)s, version=%(version)s, summary=%(summary)s,
|
||||
summary_latency_ms=%(summary_latency_ms)s, chunk_count=%(chunk_count)s,
|
||||
parser_name=%(parser_name)s, index_name=%(index_name)s,
|
||||
error_message=%(error_message)s, metadata=%(metadata)s, updated_at=%(updated_at)s
|
||||
WHERE doc_id=%(doc_id)s
|
||||
"""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, self._to_params(document))
|
||||
conn.commit()
|
||||
return document
|
||||
|
||||
def get(self, doc_id: str) -> Document | None:
|
||||
sql = "SELECT * FROM documents WHERE doc_id = %s"
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql, (doc_id,))
|
||||
row = cur.fetchone()
|
||||
return self._row_to_document(dict(row)) if row else None
|
||||
|
||||
def list(self, limit: int | None = None) -> list[Document]:
|
||||
sql = "SELECT * FROM documents ORDER BY updated_at DESC"
|
||||
if limit is not None:
|
||||
sql += f" LIMIT {int(limit)}"
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql)
|
||||
rows = cur.fetchall()
|
||||
return [self._row_to_document(dict(r)) for r in rows]
|
||||
|
||||
def delete(self, doc_id: str) -> bool:
|
||||
sql = "DELETE FROM documents WHERE doc_id = %s"
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, (doc_id,))
|
||||
deleted = cur.rowcount > 0
|
||||
conn.commit()
|
||||
return deleted
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
doc_id: str,
|
||||
status: DocumentStatus,
|
||||
*,
|
||||
error_message: str = "",
|
||||
chunk_count: int | None = None,
|
||||
summary: str | None = None,
|
||||
summary_latency_ms: int | None = None,
|
||||
parser_name: str | None = None,
|
||||
index_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> Document | None:
|
||||
document = self.get(doc_id)
|
||||
if not document:
|
||||
return None
|
||||
document.status = status
|
||||
document.error_message = error_message
|
||||
if chunk_count is not None:
|
||||
document.chunk_count = chunk_count
|
||||
if summary is not None:
|
||||
document.summary = summary
|
||||
if summary_latency_ms is not None:
|
||||
document.summary_latency_ms = summary_latency_ms
|
||||
if parser_name is not None:
|
||||
document.parser_name = parser_name
|
||||
if index_name is not None:
|
||||
document.index_name = index_name
|
||||
if metadata:
|
||||
document.metadata.update(metadata)
|
||||
return self.update(document)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _to_params(self, document: Document) -> dict[str, Any]:
|
||||
return {
|
||||
"doc_id": document.doc_id,
|
||||
"doc_name": document.doc_name,
|
||||
"file_name": document.file_name,
|
||||
"object_name": document.object_name,
|
||||
"content_type": document.content_type,
|
||||
"size_bytes": document.size_bytes,
|
||||
"status": document.status.value,
|
||||
"regulation_type": document.regulation_type,
|
||||
"version": document.version,
|
||||
"summary": document.summary,
|
||||
"summary_latency_ms": document.summary_latency_ms,
|
||||
"chunk_count": document.chunk_count,
|
||||
"parser_name": document.parser_name,
|
||||
"index_name": document.index_name,
|
||||
"error_message": document.error_message,
|
||||
"metadata": json.dumps(document.metadata, ensure_ascii=False),
|
||||
"created_at": document.created_at,
|
||||
"updated_at": document.updated_at,
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
"""Implement infrastructure support for postgres parse artifact store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from psycopg2.pool import ThreadedConnectionPool
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.documents import ParseArtifactStore
|
||||
|
||||
_CREATE_STRUCTURE_NODES = """
|
||||
CREATE TABLE IF NOT EXISTS structure_nodes (
|
||||
id SERIAL PRIMARY KEY,
|
||||
doc_id VARCHAR(128) NOT NULL,
|
||||
unique_id VARCHAR(128),
|
||||
page INTEGER NOT NULL DEFAULT 0,
|
||||
idx INTEGER NOT NULL DEFAULT 0,
|
||||
level INTEGER NOT NULL DEFAULT 0,
|
||||
title TEXT NOT NULL DEFAULT '',
|
||||
type VARCHAR(64),
|
||||
sub_type VARCHAR(64),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT fk_sn_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_structure_nodes_doc_id ON structure_nodes(doc_id);
|
||||
"""
|
||||
|
||||
_CREATE_SEMANTIC_BLOCKS = """
|
||||
CREATE TABLE IF NOT EXISTS semantic_blocks (
|
||||
id SERIAL PRIMARY KEY,
|
||||
doc_id VARCHAR(128) NOT NULL,
|
||||
semantic_id VARCHAR(128) NOT NULL,
|
||||
block_type VARCHAR(64) NOT NULL DEFAULT '',
|
||||
page_start INTEGER NOT NULL DEFAULT 0,
|
||||
page_end INTEGER NOT NULL DEFAULT 0,
|
||||
section_path JSONB NOT NULL DEFAULT '[]',
|
||||
section_level INTEGER NOT NULL DEFAULT 0,
|
||||
section_title VARCHAR(512) NOT NULL DEFAULT '',
|
||||
source_ids JSONB NOT NULL DEFAULT '[]',
|
||||
text TEXT NOT NULL DEFAULT '',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
CONSTRAINT fk_sb_doc FOREIGN KEY (doc_id) REFERENCES documents(doc_id) ON DELETE CASCADE,
|
||||
CONSTRAINT uq_semantic_blocks UNIQUE (doc_id, semantic_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_semantic_blocks_doc_id ON semantic_blocks(doc_id);
|
||||
"""
|
||||
|
||||
|
||||
class PostgresParseArtifactStore(ParseArtifactStore):
|
||||
"""ParseArtifactStore implementation backed by PostgreSQL.
|
||||
|
||||
Requires the `documents` table to exist first (created by PostgresDocumentRepository).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pool = ThreadedConnectionPool(
|
||||
minconn=1,
|
||||
maxconn=5,
|
||||
host=settings.postgres_host,
|
||||
port=settings.postgres_port,
|
||||
user=settings.postgres_user,
|
||||
password=settings.postgres_password,
|
||||
dbname=settings.postgres_db,
|
||||
)
|
||||
self._ensure_schema()
|
||||
|
||||
def _ensure_schema(self) -> None:
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(_CREATE_STRUCTURE_NODES)
|
||||
cur.execute(_CREATE_SEMANTIC_BLOCKS)
|
||||
conn.commit()
|
||||
|
||||
@contextmanager
|
||||
def _conn(self):
|
||||
conn = self._pool.getconn()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
self._pool.putconn(conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# ParseArtifactStore interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def save(
|
||||
self,
|
||||
doc_id: str,
|
||||
structure_nodes: list[dict],
|
||||
semantic_blocks: list[dict],
|
||||
) -> None:
|
||||
"""Persist structure nodes and semantic blocks, replacing any existing records."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
# Delete existing records first to keep save idempotent.
|
||||
cur.execute("DELETE FROM structure_nodes WHERE doc_id = %s", (doc_id,))
|
||||
cur.execute("DELETE FROM semantic_blocks WHERE doc_id = %s", (doc_id,))
|
||||
|
||||
if structure_nodes:
|
||||
psycopg2.extras.execute_values(
|
||||
cur,
|
||||
"""
|
||||
INSERT INTO structure_nodes
|
||||
(doc_id, unique_id, page, idx, level, title, type, sub_type)
|
||||
VALUES %s
|
||||
""",
|
||||
[
|
||||
(
|
||||
doc_id,
|
||||
node.get("unique_id"),
|
||||
int(node.get("page", 0) or 0),
|
||||
int(node.get("index", 0) or 0),
|
||||
int(node.get("level", 0) or 0),
|
||||
str(node.get("title", "")),
|
||||
node.get("type"),
|
||||
node.get("sub_type"),
|
||||
)
|
||||
for node in structure_nodes
|
||||
],
|
||||
)
|
||||
|
||||
if semantic_blocks:
|
||||
psycopg2.extras.execute_values(
|
||||
cur,
|
||||
"""
|
||||
INSERT INTO semantic_blocks
|
||||
(doc_id, semantic_id, block_type, page_start, page_end,
|
||||
section_path, section_level, section_title, source_ids, text)
|
||||
VALUES %s
|
||||
""",
|
||||
[
|
||||
(
|
||||
doc_id,
|
||||
block.get("semantic_id", ""),
|
||||
block.get("block_type", ""),
|
||||
int(block.get("page_start", 0) or 0),
|
||||
int(block.get("page_end", 0) or 0),
|
||||
json.dumps(block.get("section_path", []), ensure_ascii=False),
|
||||
int(block.get("section_level", 0) or 0),
|
||||
str(block.get("section_title", "")),
|
||||
json.dumps(block.get("source_ids", []), ensure_ascii=False),
|
||||
str(block.get("text", "")),
|
||||
)
|
||||
for block in semantic_blocks
|
||||
],
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def delete(self, doc_id: str) -> None:
|
||||
"""Remove all parse artifacts for a document (ON DELETE CASCADE handles child rows)."""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("DELETE FROM structure_nodes WHERE doc_id = %s", (doc_id,))
|
||||
cur.execute("DELETE FROM semantic_blocks WHERE doc_id = %s", (doc_id,))
|
||||
conn.commit()
|
||||
|
||||
def get_semantic_blocks(self, doc_id: str) -> list[dict[str, Any]]:
|
||||
"""Return all semantic blocks for a document ordered by id."""
|
||||
sql = """
|
||||
SELECT semantic_id, block_type, page_start, page_end,
|
||||
section_path, section_level, section_title, source_ids, text
|
||||
FROM semantic_blocks
|
||||
WHERE doc_id = %s
|
||||
ORDER BY id
|
||||
"""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql, (doc_id,))
|
||||
rows = cur.fetchall()
|
||||
results = []
|
||||
for row in rows:
|
||||
item = dict(row)
|
||||
for key in ("section_path", "source_ids"):
|
||||
if isinstance(item[key], str):
|
||||
item[key] = json.loads(item[key])
|
||||
results.append(item)
|
||||
return results
|
||||
|
||||
def get_structure_nodes(self, doc_id: str) -> list[dict[str, Any]]:
|
||||
"""Return all structure nodes for a document ordered by idx."""
|
||||
sql = """
|
||||
SELECT unique_id, page, idx, level, title, type, sub_type
|
||||
FROM structure_nodes
|
||||
WHERE doc_id = %s
|
||||
ORDER BY idx
|
||||
"""
|
||||
with self._conn() as conn:
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
|
||||
cur.execute(sql, (doc_id,))
|
||||
rows = cur.fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Implement cross-encoder reranking via an OpenAI-compatible reranker API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.retrieval import Reranker, RetrievedChunk
|
||||
|
||||
|
||||
class OpenAICompatibleReranker(Reranker):
|
||||
"""Call a TEI / Cohere-style reranker endpoint to re-score retrieved chunks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
api_key: str | None = None,
|
||||
timeout: int = 30,
|
||||
) -> None:
|
||||
self._base_url = (base_url or settings.reranker_base_url).rstrip("/")
|
||||
self._model = model or settings.reranker_model
|
||||
self._api_key = api_key or settings.reranker_api_key
|
||||
self._timeout = timeout
|
||||
|
||||
def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int) -> list[RetrievedChunk]:
|
||||
"""Return up to top_k chunks re-sorted by cross-encoder score."""
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
texts = [chunk.content for chunk in chunks]
|
||||
start = time.time()
|
||||
try:
|
||||
scores = self._call_reranker(query, texts)
|
||||
except Exception as exc:
|
||||
logger.warning("Reranker call failed ({}), falling back to original order: {}", type(exc).__name__, exc)
|
||||
return chunks[:top_k]
|
||||
|
||||
elapsed_ms = int((time.time() - start) * 1000)
|
||||
logger.debug("Reranker scored {} chunks in {}ms", len(chunks), elapsed_ms)
|
||||
|
||||
ranked = sorted(
|
||||
[(score, chunk) for score, chunk in zip(scores, chunks)],
|
||||
key=lambda x: x[0],
|
||||
reverse=True,
|
||||
)
|
||||
result = []
|
||||
for score, chunk in ranked[:top_k]:
|
||||
chunk.score = float(score)
|
||||
result.append(chunk)
|
||||
return result
|
||||
|
||||
def _call_reranker(self, query: str, texts: list[str]) -> list[float]:
|
||||
"""Call the reranker API and return a score per text."""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self._api_key:
|
||||
headers["Authorization"] = f"Bearer {self._api_key}"
|
||||
|
||||
# Try TEI format first: POST /rerank
|
||||
payload = {"query": query, "texts": texts, "raw_scores": False, "return_text": False}
|
||||
url = f"{self._base_url}/rerank"
|
||||
resp = requests.post(url, json=payload, headers=headers, timeout=self._timeout)
|
||||
|
||||
if resp.status_code == 404:
|
||||
# Fall back to Cohere / OpenAI-style: POST /v1/rerank
|
||||
payload_v1 = {"model": self._model, "query": query, "documents": texts}
|
||||
url = f"{self._base_url}/v1/rerank"
|
||||
resp = requests.post(url, json=payload_v1, headers=headers, timeout=self._timeout)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# TEI response: list of {"index": N, "score": F}
|
||||
if isinstance(data, list):
|
||||
ordered = sorted(data, key=lambda x: x["index"])
|
||||
return [float(item["score"]) for item in ordered]
|
||||
|
||||
# Cohere/OpenAI response: {"results": [{"index": N, "relevance_score": F}]}
|
||||
results = data.get("results", [])
|
||||
ordered = sorted(results, key=lambda x: x["index"])
|
||||
return [float(item.get("relevance_score", item.get("score", 0))) for item in ordered]
|
||||
@@ -100,14 +100,42 @@ class MilvusVectorIndex(VectorIndex):
|
||||
result = self.collection.delete(f'doc_id == "{doc_id}"')
|
||||
return len(result.primary_keys)
|
||||
|
||||
def _parse_filters(self, filters: str | None) -> str | None:
|
||||
"""Parse filter string into Milvus expression."""
|
||||
if not filters or not filters.strip():
|
||||
return None
|
||||
|
||||
filters = filters.strip()
|
||||
|
||||
# Check if already a Milvus expression (contains operators)
|
||||
if any(op in filters for op in ["==", "!=", "in", "not in", ">", "<", ">=", "<=", "and", "or"]):
|
||||
return filters
|
||||
|
||||
# Parse simple regulation_type filter
|
||||
# Support: "GB" or "GB,UN-ECE" or "GB, UN-ECE"
|
||||
types = [t.strip() for t in filters.split(",") if t.strip()]
|
||||
|
||||
if not types:
|
||||
return None
|
||||
|
||||
if len(types) == 1:
|
||||
# Single value: regulation_type == "GB"
|
||||
return f'regulation_type == "{types[0]}"'
|
||||
else:
|
||||
# Multiple values: regulation_type in ["GB", "UN-ECE"]
|
||||
quoted_types = [f'"{t}"' for t in types]
|
||||
return f'regulation_type in [{", ".join(quoted_types)}]'
|
||||
|
||||
def search(self, query_vector: list[float], top_k: int, filters: str | None = None) -> list[RetrievedChunk]:
|
||||
"""Handle search for the Milvus Vector Index instance."""
|
||||
milvus_expr = self._parse_filters(filters)
|
||||
|
||||
results = self.collection.search(
|
||||
data=[query_vector],
|
||||
anns_field="embedding",
|
||||
param={"metric_type": "COSINE", "params": {"nprobe": settings.milvus_nprobe}},
|
||||
limit=top_k,
|
||||
filter=filters,
|
||||
expr=milvus_expr,
|
||||
output_fields=[
|
||||
"doc_id",
|
||||
"doc_name",
|
||||
@@ -145,6 +173,49 @@ class MilvusVectorIndex(VectorIndex):
|
||||
)
|
||||
return payload
|
||||
|
||||
def count_by_document(self) -> dict[str, int]:
|
||||
"""Return doc_id -> chunk count from Milvus."""
|
||||
try:
|
||||
rows = self.collection.query(expr="doc_id != \"\"", output_fields=["doc_id"])
|
||||
except Exception:
|
||||
return {}
|
||||
counts: dict[str, int] = {}
|
||||
for row in rows:
|
||||
doc_id = row.get("doc_id", "")
|
||||
if doc_id:
|
||||
counts[doc_id] = counts.get(doc_id, 0) + 1
|
||||
return counts
|
||||
|
||||
def list_document_metadata(self) -> list[dict]:
|
||||
"""Return one metadata row per document from Milvus (single query, no embeddings)."""
|
||||
try:
|
||||
rows = self.collection.query(
|
||||
expr="doc_id != \"\"",
|
||||
output_fields=["doc_id", "doc_name", "regulation_type", "version"],
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
seen: dict[str, dict] = {}
|
||||
counts: dict[str, int] = {}
|
||||
for row in rows:
|
||||
doc_id = row.get("doc_id", "")
|
||||
if not doc_id:
|
||||
continue
|
||||
counts[doc_id] = counts.get(doc_id, 0) + 1
|
||||
if doc_id not in seen:
|
||||
seen[doc_id] = {
|
||||
"doc_id": doc_id,
|
||||
"doc_name": row.get("doc_name", ""),
|
||||
"regulation_type": row.get("regulation_type", ""),
|
||||
"version": row.get("version", ""),
|
||||
}
|
||||
|
||||
return [
|
||||
{**meta, "chunk_count": counts[meta["doc_id"]]}
|
||||
for meta in seen.values()
|
||||
]
|
||||
|
||||
def health(self) -> dict:
|
||||
"""Handle health for the Milvus Vector Index instance."""
|
||||
return {
|
||||
|
||||
@@ -74,6 +74,7 @@ class ComplianceResult(BaseModel):
|
||||
class ComplianceChatRequest(BaseModel):
|
||||
"""Define the Compliance Chat Request API model."""
|
||||
query: str
|
||||
segment_context: Optional[str] = None
|
||||
|
||||
|
||||
class AnalyzeResponse(BaseModel):
|
||||
|
||||
@@ -10,6 +10,8 @@ class RagChatRequest(BaseModel):
|
||||
"""Define the Rag Chat Request API model."""
|
||||
query: str
|
||||
top_k: int = 5
|
||||
session_id: Optional[str] = None
|
||||
filters: Optional[str] = None
|
||||
|
||||
|
||||
class RetrievedDoc(BaseModel):
|
||||
|
||||
@@ -95,6 +95,57 @@ class DeepSeekClient(BaseLLMClient):
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""Stream chat for the Deep Seek Client instance."""
|
||||
try:
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens or self.config.max_tokens,
|
||||
"temperature": temperature or self.config.temperature,
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": True
|
||||
}
|
||||
|
||||
with self._client.stream("POST", "/chat/completions", json=payload) as response:
|
||||
response.raise_for_status()
|
||||
for line in response.iter_lines():
|
||||
if not line or line.startswith(b":"):
|
||||
continue
|
||||
|
||||
line_str = line.decode("utf-8").strip()
|
||||
if not line_str.startswith("data: "):
|
||||
continue
|
||||
|
||||
data_str = line_str[6:]
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
import json
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices:
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"DeepSeek Stream API错误: {e.response.status_code}")
|
||||
yield ""
|
||||
except Exception as e:
|
||||
logger.error(f"DeepSeek Stream调用失败: {e}")
|
||||
yield ""
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""Return available models for the Deep Seek Client instance."""
|
||||
return self.SUPPORTED_MODELS
|
||||
|
||||
@@ -17,18 +17,31 @@ from app.infrastructure.parser.vector_chunk_builder import AliyunVectorChunkBuil
|
||||
from app.infrastructure.session.in_memory_conversation_store import InMemoryConversationStore
|
||||
from app.infrastructure.storage.json_document_repository import JsonDocumentRepository
|
||||
from app.infrastructure.storage.minio_binary_store import MinioDocumentBinaryStore
|
||||
from app.infrastructure.storage.postgres_document_repository import PostgresDocumentRepository
|
||||
from app.infrastructure.storage.postgres_parse_artifact_store import PostgresParseArtifactStore
|
||||
from app.infrastructure.vectorstore.dense_retriever import DenseRetriever
|
||||
from app.infrastructure.vectorstore.milvus_vector_index import MilvusVectorIndex
|
||||
from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker
|
||||
# Keep shared wiring centralized so dependency construction remains consistent.
|
||||
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_document_repository() -> JsonDocumentRepository:
|
||||
"""Return document repository."""
|
||||
def get_document_repository():
|
||||
"""Return document repository (json or postgres, controlled by settings)."""
|
||||
if settings.document_repository_backend == "postgres":
|
||||
return PostgresDocumentRepository()
|
||||
return JsonDocumentRepository(settings.document_metadata_path)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_parse_artifact_store():
|
||||
"""Return parse artifact store, or None when postgres backend is not enabled."""
|
||||
if settings.document_repository_backend == "postgres":
|
||||
return PostgresParseArtifactStore()
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_binary_store() -> MinioDocumentBinaryStore:
|
||||
"""Return binary store."""
|
||||
@@ -66,6 +79,14 @@ def get_vector_index() -> MilvusVectorIndex:
|
||||
return MilvusVectorIndex()
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_reranker():
|
||||
"""Return reranker if enabled, else None."""
|
||||
if settings.reranker_enabled and settings.reranker_base_url:
|
||||
return OpenAICompatibleReranker()
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_retrieval_service() -> KnowledgeRetrievalService:
|
||||
"""Return retrieval service."""
|
||||
@@ -73,7 +94,11 @@ def get_retrieval_service() -> KnowledgeRetrievalService:
|
||||
embedding_provider=get_embedding_provider(),
|
||||
vector_index=get_vector_index(),
|
||||
)
|
||||
return KnowledgeRetrievalService(retriever=retriever)
|
||||
return KnowledgeRetrievalService(
|
||||
retriever=retriever,
|
||||
reranker=get_reranker(),
|
||||
reranker_top_k=settings.reranker_top_k,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
@@ -86,6 +111,7 @@ def get_document_command_service() -> DocumentCommandService:
|
||||
chunk_builder=get_chunk_builder(),
|
||||
embedding_provider=get_embedding_provider(),
|
||||
vector_index=get_vector_index(),
|
||||
parse_artifact_store=get_parse_artifact_store(),
|
||||
)
|
||||
|
||||
|
||||
@@ -95,6 +121,7 @@ def get_document_query_service() -> DocumentQueryService:
|
||||
return DocumentQueryService(
|
||||
document_repository=get_document_repository(),
|
||||
binary_store=get_binary_store(),
|
||||
vector_index=get_vector_index(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user