fix 文档管理模块 & 法规对话模块
This commit is contained in:
@@ -23,6 +23,8 @@ class DocumentUploadResponse(BaseModel):
|
||||
num_chunks: int = Field(default=0, description="分块数量")
|
||||
summary: str = Field(default="", description="LLM生成的文档摘要")
|
||||
summary_latency_ms: int = Field(default=0, description="摘要生成耗时(ms)")
|
||||
regulation_type: str = Field(default="", description="法规类型")
|
||||
version: str = Field(default="", description="文档版本号")
|
||||
timestamp: datetime = Field(default_factory=datetime.now, description="时间戳")
|
||||
|
||||
|
||||
|
||||
@@ -14,30 +14,22 @@ from app.schemas.compliance import (
|
||||
AnalyzeResponse,
|
||||
ComplianceChatRequest,
|
||||
)
|
||||
from app.services.mock_data import (
|
||||
generate_task_id,
|
||||
get_mock_compliance_result,
|
||||
get_mock_compliance_chat_response,
|
||||
)
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
from app.services.mock_data import generate_task_id, get_mock_compliance_result
|
||||
from app.shared.bootstrap import get_agent_conversation_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/compliance", tags=["合规分析"])
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store: dict[str, dict] = {}
|
||||
|
||||
# Store uploaded compliance files inside the local backend data directory.
|
||||
RAW_DATA_DIR = Path(__file__).resolve().parents[3] / "data" / "raw"
|
||||
|
||||
|
||||
@router.post("/analyze", response_model=AnalyzeResponse)
|
||||
async def analyze_document(file: UploadFile = File(...)):
|
||||
"""Handle analyze document."""
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
task_id = generate_task_id()
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
RAW_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
file_path = RAW_DATA_DIR / f"compliance_{task_id}_{file.filename}"
|
||||
|
||||
@@ -45,7 +37,6 @@ async def analyze_document(file: UploadFile = File(...)):
|
||||
with file_path.open("wb") as f:
|
||||
f.write(content)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store[task_id] = {
|
||||
"task_id": task_id,
|
||||
"file_path": str(file_path),
|
||||
@@ -53,8 +44,6 @@ async def analyze_document(file: UploadFile = File(...)):
|
||||
"result": None,
|
||||
}
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
tasks_store[task_id]["status"] = "completed"
|
||||
tasks_store[task_id]["result"] = get_mock_compliance_result(task_id)
|
||||
|
||||
@@ -65,47 +54,44 @@ async def analyze_document(file: UploadFile = File(...)):
|
||||
async def get_result(task_id: str):
|
||||
"""Return result."""
|
||||
if task_id not in tasks_store:
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
return get_mock_compliance_result(task_id)
|
||||
|
||||
task = tasks_store[task_id]
|
||||
|
||||
if task["status"] == "processing":
|
||||
return {"status": "processing", "message": "分析进行中"}
|
||||
|
||||
return task["result"]
|
||||
|
||||
|
||||
@router.post("/chat/{segment_id}")
|
||||
async def compliance_chat(segment_id: int, request: ComplianceChatRequest):
|
||||
"""Handle compliance chat."""
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
intent_map = {
|
||||
1: "车身结构设计",
|
||||
2: "动力系统配置",
|
||||
3: "安全配置设计",
|
||||
}
|
||||
intent = intent_map.get(segment_id, "车身结构设计")
|
||||
"""Stream compliance Q&A grounded in real vector retrieval."""
|
||||
query = request.query
|
||||
if request.segment_context:
|
||||
query = f"[段落分析上下文]\n{request.segment_context}\n\n用户问题:{request.query}"
|
||||
|
||||
_, event_stream = get_agent_conversation_service().stream_chat(
|
||||
query=query,
|
||||
top_k=5,
|
||||
prompt_template="compliance_qa",
|
||||
)
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
"""Handle generate."""
|
||||
response = get_mock_compliance_chat_response(intent, request.query)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
sentences = response.split("\n\n")
|
||||
for sentence in sentences:
|
||||
if sentence.strip():
|
||||
chunks = sentence.split("\n")
|
||||
for chunk in chunks:
|
||||
if chunk.strip():
|
||||
await asyncio.sleep(0.05)
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
"""Translate agent SSE events to compliance chunk/done format."""
|
||||
for event in event_stream:
|
||||
event_type = event.get("event", "")
|
||||
if event_type == "content":
|
||||
text = event.get("data", "")
|
||||
if text:
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': text}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
elif event_type == "done":
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
|
||||
@@ -80,6 +80,8 @@ async def get_document_status(doc_id: str):
|
||||
num_chunks=document.chunk_count,
|
||||
summary=document.summary,
|
||||
summary_latency_ms=document.summary_latency_ms,
|
||||
regulation_type=document.regulation_type,
|
||||
version=document.version,
|
||||
)
|
||||
|
||||
|
||||
@@ -123,7 +125,7 @@ async def list_documents():
|
||||
@router.get("/management-list")
|
||||
async def get_document_management_list():
|
||||
"""Return document management list."""
|
||||
documents = get_document_query_service().list_documents(limit=10)
|
||||
documents = get_document_query_service().list_documents()
|
||||
return {
|
||||
"documents": [
|
||||
{
|
||||
@@ -131,10 +133,37 @@ async def get_document_management_list():
|
||||
"doc_name": item.doc_name,
|
||||
"status": item.status.value,
|
||||
"chunk_count": item.chunk_count,
|
||||
"size_bytes": item.size_bytes,
|
||||
"summary": item.summary,
|
||||
"updated_at": item.updated_at.isoformat(),
|
||||
"regulation_type": item.regulation_type,
|
||||
"version": item.version,
|
||||
}
|
||||
for item in documents
|
||||
],
|
||||
"total": len(documents),
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
"""Delete a document and its associated data."""
|
||||
deleted = get_document_command_service().delete(doc_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="文档不存在")
|
||||
return {"doc_id": doc_id, "deleted": True}
|
||||
|
||||
|
||||
@router.post("/{doc_id}/retry", response_model=DocumentUploadResponse)
|
||||
async def retry_document(doc_id: str):
|
||||
"""Re-process a failed document."""
|
||||
try:
|
||||
result = get_document_command_service().retry(doc_id)
|
||||
if result.status == "failed":
|
||||
raise HTTPException(status_code=500, detail=result.message)
|
||||
return _document_response(result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("文档重试失败")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
@@ -9,73 +9,74 @@ from typing import AsyncGenerator
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
|
||||
from app.services.mock_data import (
|
||||
get_mock_quick_questions,
|
||||
get_mock_retrieval,
|
||||
get_mock_rag_answer,
|
||||
)
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
from app.shared.bootstrap import get_agent_conversation_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/rag", tags=["RAG问答"])
|
||||
|
||||
_DEFAULT_QUICK_QUESTIONS = [
|
||||
{"id": "1", "question": "请总结最新入库法规对电池安全的核心要求", "category": "法规解读"},
|
||||
{"id": "2", "question": "我上传的制度文档与新能源法规有哪些潜在冲突?", "category": "差距分析"},
|
||||
{"id": "3", "question": "请给出法规依据,并按条款列出整改建议", "category": "整改建议"},
|
||||
{"id": "4", "question": "请解释 UN-ECE 与 GB 标准在网络安全方面的差异", "category": "标准对比"},
|
||||
{"id": "5", "question": "IATF 16949 对供应商质量管理有哪些强制要求?", "category": "法规解读"},
|
||||
{"id": "6", "question": "ISO 45001 与 AQ 标准在职业健康安全方面的主要差异是什么?", "category": "标准对比"},
|
||||
]
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def rag_chat(request: RagChatRequest):
|
||||
"""Handle rag chat."""
|
||||
"""Stream RAG Q&A using the real agent service."""
|
||||
_, event_stream = get_agent_conversation_service().stream_chat(
|
||||
query=request.query,
|
||||
session_id=request.session_id,
|
||||
filters=request.filters,
|
||||
top_k=request.top_k or settings.rag_top_k,
|
||||
)
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
"""Handle generate."""
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'retrieving'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
docs = get_mock_retrieval(request.query, top_k=request.top_k)
|
||||
|
||||
retrieved_data = [
|
||||
{
|
||||
"id": d["id"],
|
||||
"score": d["score"],
|
||||
"preview": d["preview"],
|
||||
"doc_name": d.get("doc_name", ""),
|
||||
"clause": d.get("clause", ""),
|
||||
}
|
||||
for d in docs
|
||||
]
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'retrieved', 'docs': retrieved_data}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield (
|
||||
f"event: message\ndata: "
|
||||
f"{json.dumps({'type': 'generating', 'text': '正在生成答案...'}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
answer = get_mock_rag_answer(request.query)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
sentences = answer.split("\n\n")
|
||||
for sentence in sentences:
|
||||
if sentence.strip():
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
chunks = sentence.split("\n")
|
||||
for chunk in chunks:
|
||||
if chunk.strip():
|
||||
await asyncio.sleep(0.05) # Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': chunk + chr(10)}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
yield f"event: message\ndata: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
"""Translate agent SSE events to rag format."""
|
||||
for event in event_stream:
|
||||
event_type = event.get("event", "")
|
||||
data = event.get("data", "")
|
||||
if event_type == "sources":
|
||||
docs = [
|
||||
{
|
||||
"id": str(s.get("chunk_id") or s.get("doc_id") or idx + 1),
|
||||
"score": s.get("score", 0),
|
||||
"preview": s.get("content", "")[:200],
|
||||
"doc_name": s.get("doc_name", ""),
|
||||
"clause": s.get("section_title", "法规片段"),
|
||||
"doc_id": s.get("doc_id"),
|
||||
"download_url": (
|
||||
f"/api/v1/documents/download/{s['doc_id']}" if s.get("doc_id") else None
|
||||
),
|
||||
}
|
||||
for idx, s in enumerate(data if isinstance(data, list) else [])
|
||||
]
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'retrieved', 'docs': docs}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
elif event_type == "content":
|
||||
if data:
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'chunk', 'text': data}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
elif event_type == "done":
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
elif event_type == "status":
|
||||
yield (
|
||||
"event: message\n"
|
||||
f"data: {json.dumps({'type': 'status', 'text': data}, ensure_ascii=False)}\n\n"
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
@@ -86,9 +87,13 @@ async def rag_chat(request: RagChatRequest):
|
||||
|
||||
@router.get("/quick-questions", response_model=QuickQuestionsResponse)
|
||||
async def get_quick_questions():
|
||||
"""Return quick questions."""
|
||||
questions = [
|
||||
QuickQuestion(id=q["id"], question=q["question"], category=q["category"])
|
||||
for q in get_mock_quick_questions()
|
||||
]
|
||||
"""Return configurable quick questions from settings or defaults."""
|
||||
raw = getattr(settings, "rag_quick_questions", None)
|
||||
if raw and isinstance(raw, list):
|
||||
questions = [
|
||||
QuickQuestion(id=str(i + 1), question=q if isinstance(q, str) else q.get("question", ""), category=q.get("category", "法规问答") if isinstance(q, dict) else "法规问答")
|
||||
for i, q in enumerate(raw)
|
||||
]
|
||||
else:
|
||||
questions = [QuickQuestion(**q) for q in _DEFAULT_QUICK_QUESTIONS]
|
||||
return QuickQuestionsResponse(questions=questions)
|
||||
|
||||
Reference in New Issue
Block a user