1. Add 登陆功能
2. 调整字体大小 3. 新增部分功能
This commit is contained in:
5
backend/app/api/dependencies/__init__.py
Normal file
5
backend/app/api/dependencies/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""FastAPI dependency functions for authentication and authorisation.
|
||||
|
||||
Import `get_current_user` or `require_role` into route modules to protect
|
||||
endpoints. Both use the shared JWTHandler wired through bootstrap.
|
||||
"""
|
||||
72
backend/app/api/dependencies/auth.py
Normal file
72
backend/app/api/dependencies/auth.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""FastAPI dependencies for JWT authentication.
|
||||
|
||||
Usage in a route:
|
||||
from app.api.dependencies.auth import get_current_user, require_role
|
||||
from app.domain.auth.models import UserRole
|
||||
|
||||
@router.get("/protected")
|
||||
async def protected(user: UserClaims = Depends(get_current_user)):
|
||||
return {"user": user.username}
|
||||
|
||||
@router.delete("/admin-only")
|
||||
async def admin_only(user: UserClaims = Depends(require_role(UserRole.ADMIN))):
|
||||
...
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.domain.auth.models import UserClaims, UserRole
|
||||
from app.shared.bootstrap import get_jwt_handler
|
||||
|
||||
# Use Bearer token scheme — client sends `Authorization: Bearer <token>`.
|
||||
_bearer = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(_bearer),
|
||||
) -> UserClaims:
|
||||
"""Extract and validate the JWT from the Authorization header.
|
||||
|
||||
Returns the decoded UserClaims on success.
|
||||
Raises HTTP 401 when the token is missing, expired, or invalid.
|
||||
When auth_enabled=False (development), returns a synthetic admin user.
|
||||
"""
|
||||
if not settings.auth_enabled:
|
||||
# Development bypass — never enable this in production.
|
||||
return UserClaims(user_id="dev", username="dev-admin", role=UserRole.ADMIN)
|
||||
|
||||
if credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing authentication token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
return get_jwt_handler().decode_token(credentials.credentials)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(exc),
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
|
||||
def require_role(*roles: UserRole):
|
||||
"""Return a dependency that enforces one of the given roles.
|
||||
|
||||
Example:
|
||||
Depends(require_role(UserRole.ADMIN, UserRole.LEGAL))
|
||||
"""
|
||||
async def _check(user: UserClaims = Depends(get_current_user)) -> UserClaims:
|
||||
"""Verify the user holds one of the required roles."""
|
||||
if user.role not in roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Role '{user.role}' is not permitted. Required: {[r.value for r in roles]}",
|
||||
)
|
||||
return user
|
||||
return _check
|
||||
@@ -8,6 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
|
||||
from app.api.middleware.audit import AuditMiddleware
|
||||
from app.api.models import ErrorResponse
|
||||
from app.api.routes import api_router
|
||||
from app.config.logging import setup_logging
|
||||
@@ -46,14 +47,23 @@ app = FastAPI(
|
||||
redoc_url="/redoc",
|
||||
)
|
||||
|
||||
# Tighten CORS — only allow configured origins.
|
||||
# Set CORS_ALLOW_ORIGINS in .env to the real frontend URL in production.
|
||||
_ORIGINS = [o.strip() for o in settings.cors_allow_origins.split(",") if o.strip()]
|
||||
if not _ORIGINS:
|
||||
_ORIGINS = ["http://localhost:5173"]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Audit middleware logs every authenticated API call for compliance traceability.
|
||||
app.add_middleware(AuditMiddleware)
|
||||
|
||||
app.include_router(api_router, prefix="/api/v1")
|
||||
|
||||
|
||||
|
||||
1
backend/app/api/middleware/__init__.py
Normal file
1
backend/app/api/middleware/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""HTTP middleware for cross-cutting concerns: audit logging."""
|
||||
56
backend/app/api/middleware/audit.py
Normal file
56
backend/app/api/middleware/audit.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Audit logging middleware.
|
||||
|
||||
Logs every API request with method, path, status code, response time,
|
||||
and the authenticated user identity (extracted from the JWT when present).
|
||||
Log lines are structured so they can be ingested by ELK / Loki.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
from fastapi import Request, Response
|
||||
from loguru import logger
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
|
||||
class AuditMiddleware(BaseHTTPMiddleware):
|
||||
"""Log all API calls. Skips health/docs paths to reduce noise."""
|
||||
|
||||
# Paths that produce no audit log entry.
|
||||
_SKIP_PREFIXES = ("/health", "/docs", "/redoc", "/openapi.json")
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""Intercept the request, call the handler, and log the outcome."""
|
||||
path = request.url.path
|
||||
if path == "/" or any(path == p or path.startswith(p + "/") for p in self._SKIP_PREFIXES):
|
||||
return await call_next(request)
|
||||
|
||||
start = time.perf_counter()
|
||||
response = await call_next(request)
|
||||
elapsed_ms = int((time.perf_counter() - start) * 1000)
|
||||
|
||||
# Extract user identity from JWT header for structured audit records.
|
||||
# The token is not re-validated here — auth dependencies do that upstream.
|
||||
user_id = "anonymous"
|
||||
username = "anonymous"
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
try:
|
||||
from app.shared.bootstrap import get_jwt_handler
|
||||
claims = get_jwt_handler().decode_token(auth_header[7:])
|
||||
user_id = claims.user_id
|
||||
username = claims.username
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
"AUDIT method={} path={} status={} elapsed_ms={} user_id={} username={}",
|
||||
request.method,
|
||||
path,
|
||||
response.status_code,
|
||||
elapsed_ms,
|
||||
user_id,
|
||||
username,
|
||||
)
|
||||
return response
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Initialize the app.api.routes package."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from .auth import router as auth_router
|
||||
from .compliance import router as compliance_router
|
||||
from .documents import router as documents_router
|
||||
from .knowledge import router as knowledge_router
|
||||
@@ -14,7 +15,8 @@ from .rag import router as rag_router
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
api_router = APIRouter()
|
||||
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
# Auth routes first so /auth/token is easy to discover.
|
||||
api_router.include_router(auth_router)
|
||||
api_router.include_router(documents_router)
|
||||
api_router.include_router(knowledge_router)
|
||||
api_router.include_router(agent_router)
|
||||
@@ -25,6 +27,7 @@ api_router.include_router(rag_router)
|
||||
|
||||
__all__ = [
|
||||
"api_router",
|
||||
"auth_router",
|
||||
"documents_router",
|
||||
"knowledge_router",
|
||||
"agent_router",
|
||||
|
||||
63
backend/app/api/routes/auth.py
Normal file
63
backend/app/api/routes/auth.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Authentication routes — token issuance only.
|
||||
|
||||
POST /auth/token — exchange username + password for a JWT.
|
||||
GET /auth/me — return the current user identity (requires token).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.config.settings import settings
|
||||
from app.domain.auth.models import UserClaims
|
||||
from app.shared.bootstrap import get_jwt_handler, get_user_store
|
||||
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["认证"])
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""JWT token response body."""
|
||||
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
|
||||
|
||||
@router.post("/token", response_model=TokenResponse)
|
||||
async def login(form: OAuth2PasswordRequestForm = Depends()):
|
||||
"""Issue a JWT for valid username + password credentials.
|
||||
|
||||
Uses standard OAuth2 password grant form fields — compatible with
|
||||
Swagger UI Authorize button.
|
||||
"""
|
||||
user = get_user_store().authenticate(form.username, form.password)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
token = get_jwt_handler().create_access_token(
|
||||
user_id=user.id,
|
||||
username=user.username,
|
||||
role=user.role,
|
||||
)
|
||||
return TokenResponse(
|
||||
access_token=token,
|
||||
token_type="bearer",
|
||||
expires_in=settings.auth_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def get_me(current_user: UserClaims = Depends(get_current_user)):
|
||||
"""Return the identity of the currently authenticated user."""
|
||||
return {
|
||||
"user_id": current_user.user_id,
|
||||
"username": current_user.username,
|
||||
"role": current_user.role.value,
|
||||
}
|
||||
@@ -7,10 +7,12 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from fastapi import APIRouter, File, Form, UploadFile
|
||||
from fastapi import APIRouter, Depends, File, Form, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.domain.auth.models import UserClaims
|
||||
from app.schemas.compliance import (
|
||||
AnalyzeResponse,
|
||||
ComplianceChatRequest,
|
||||
@@ -75,6 +77,7 @@ async def analyze_stream(
|
||||
file: Optional[UploadFile] = File(None),
|
||||
domains: Optional[str] = Form(None),
|
||||
title: Optional[str] = Form(None),
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Stream compliance analysis as SSE events.
|
||||
|
||||
|
||||
@@ -5,12 +5,15 @@ from __future__ import annotations
|
||||
from io import BytesIO
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.api.models import DocumentUploadResponse
|
||||
from app.application.documents import DocumentProcessResult
|
||||
from app.config.settings import settings
|
||||
from app.domain.auth.models import UserClaims
|
||||
from app.shared.bootstrap import get_document_command_service, get_document_query_service
|
||||
# Keep route handlers close to their transport-layer wiring for easier auditing.
|
||||
|
||||
@@ -31,16 +34,60 @@ def _document_response(result: DocumentProcessResult) -> DocumentUploadResponse:
|
||||
)
|
||||
|
||||
|
||||
def _run_process_in_background(
|
||||
*,
|
||||
doc_id: str,
|
||||
file_name: str,
|
||||
final_doc_name: str,
|
||||
content: bytes,
|
||||
regulation_type: str,
|
||||
version: str,
|
||||
generate_summary: bool,
|
||||
run_id: str | None,
|
||||
) -> None:
|
||||
"""Run document processing synchronously inside a FastAPI BackgroundTask thread.
|
||||
|
||||
FastAPI executes BackgroundTasks in a threadpool executor, so blocking I/O
|
||||
(parser API calls, embedding, Milvus upsert) is safe here.
|
||||
"""
|
||||
try:
|
||||
svc = get_document_command_service()
|
||||
svc._process_document(
|
||||
doc_id=doc_id,
|
||||
file_name=file_name,
|
||||
final_doc_name=final_doc_name,
|
||||
content=content,
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
generate_summary=generate_summary,
|
||||
run_id=run_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("BackgroundTask document processing failed: doc_id={}", doc_id)
|
||||
|
||||
|
||||
@router.post("/upload", response_model=DocumentUploadResponse)
|
||||
async def upload_document(
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(..., description="上传的文档文件"),
|
||||
doc_id: str | None = Form(None, description="客户端预分配的文档ID,不传则自动生成"),
|
||||
doc_name: str | None = Form(None, description="文档名称"),
|
||||
regulation_type: str | None = Form(None, description="法规类型"),
|
||||
version: str | None = Form(None, description="文档版本"),
|
||||
generate_summary: bool = Form(False, description="是否生成摘要"),
|
||||
sync: bool = Form(False, description="同步处理(演示/测试用,默认异步处理)"),
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Handle upload document."""
|
||||
"""Upload a document and process it asynchronously.
|
||||
|
||||
Default path (sync=false):
|
||||
1. Store binary to MinIO immediately — returns within seconds.
|
||||
2. Schedule parse→embed→index as a FastAPI BackgroundTask (same process,
|
||||
threadpool) OR enqueue to Celery workers when USE_CELERY_WORKER=true.
|
||||
3. Poll GET /documents/status/{doc_id} for progress.
|
||||
|
||||
sync=true path: full inline processing, blocks until complete (demo / CI use).
|
||||
"""
|
||||
content = await file.read()
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名不能为空")
|
||||
@@ -48,19 +95,73 @@ async def upload_document(
|
||||
raise HTTPException(status_code=400, detail="上传文件为空")
|
||||
|
||||
try:
|
||||
result = get_document_command_service().upload_and_process(
|
||||
doc_id=doc_id,
|
||||
file_name=file.filename,
|
||||
content=content,
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
doc_name=doc_name,
|
||||
regulation_type=regulation_type or "",
|
||||
version=version or "",
|
||||
generate_summary=generate_summary,
|
||||
)
|
||||
svc = get_document_command_service()
|
||||
|
||||
if sync:
|
||||
# Synchronous fallback: full inline processing.
|
||||
result = svc.upload_and_process(
|
||||
doc_id=doc_id,
|
||||
file_name=file.filename,
|
||||
content=content,
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
doc_name=doc_name,
|
||||
regulation_type=regulation_type or "",
|
||||
version=version or "",
|
||||
generate_summary=generate_summary,
|
||||
)
|
||||
else:
|
||||
# Step 1: store binary and create the document record (fast, sync).
|
||||
stored_doc_id, run_id = svc.store_document(
|
||||
doc_id=doc_id,
|
||||
file_name=file.filename,
|
||||
content=content,
|
||||
content_type=file.content_type or "application/octet-stream",
|
||||
doc_name=doc_name,
|
||||
regulation_type=regulation_type or "",
|
||||
version=version or "",
|
||||
generate_summary=generate_summary,
|
||||
)
|
||||
final_doc_name = doc_name or file.filename
|
||||
|
||||
# Step 2: schedule processing via Celery worker OR FastAPI BackgroundTask.
|
||||
if settings.use_celery_worker:
|
||||
from app.infrastructure.tasks.document_tasks import process_document_task
|
||||
process_document_task.delay(
|
||||
doc_id=stored_doc_id,
|
||||
file_name=file.filename,
|
||||
doc_name=final_doc_name,
|
||||
regulation_type=regulation_type or "",
|
||||
version=version or "",
|
||||
generate_summary=generate_summary,
|
||||
run_id=run_id,
|
||||
)
|
||||
processing_note = "已入 Celery 队列,由 Worker 处理。"
|
||||
else:
|
||||
# Default: run in FastAPI's threadpool — no external worker needed.
|
||||
background_tasks.add_task(
|
||||
_run_process_in_background,
|
||||
doc_id=stored_doc_id,
|
||||
file_name=file.filename,
|
||||
final_doc_name=final_doc_name,
|
||||
content=content,
|
||||
regulation_type=regulation_type or "",
|
||||
version=version or "",
|
||||
generate_summary=generate_summary,
|
||||
run_id=run_id,
|
||||
)
|
||||
processing_note = "正在后台处理。"
|
||||
|
||||
result = DocumentProcessResult(
|
||||
doc_id=stored_doc_id,
|
||||
doc_name=final_doc_name,
|
||||
status="stored",
|
||||
message=f"文件已存储,{processing_note}请轮询 GET /documents/status/{{doc_id}} 查看进度。",
|
||||
)
|
||||
|
||||
if result.status == "failed":
|
||||
raise HTTPException(status_code=500, detail=result.message)
|
||||
return _document_response(result)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
@@ -106,7 +207,7 @@ async def download_document(doc_id: str):
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def list_documents():
|
||||
async def list_documents(current_user: UserClaims = Depends(get_current_user)):
|
||||
"""List documents."""
|
||||
documents = get_document_query_service().list_documents()
|
||||
return {
|
||||
@@ -148,7 +249,7 @@ async def get_document_management_list():
|
||||
|
||||
|
||||
@router.delete("/{doc_id}")
|
||||
async def delete_document(doc_id: str):
|
||||
async def delete_document(doc_id: str, current_user: UserClaims = Depends(get_current_user)):
|
||||
"""Delete a document and its associated data."""
|
||||
deleted = get_document_command_service().delete(doc_id)
|
||||
if not deleted:
|
||||
|
||||
@@ -5,10 +5,12 @@ from __future__ import annotations
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.config.settings import settings
|
||||
from app.domain.auth.models import UserClaims
|
||||
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
|
||||
from app.shared.async_utils import iter_in_thread
|
||||
from app.shared.bootstrap import get_agent_conversation_service
|
||||
@@ -27,7 +29,10 @@ _DEFAULT_QUICK_QUESTIONS = [
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def rag_chat(request: RagChatRequest):
|
||||
async def rag_chat(
|
||||
request: RagChatRequest,
|
||||
current_user: UserClaims = Depends(get_current_user),
|
||||
):
|
||||
"""Stream RAG Q&A using the real agent service."""
|
||||
session_id, event_stream = get_agent_conversation_service().stream_chat(
|
||||
query=request.query,
|
||||
|
||||
@@ -277,7 +277,6 @@ class DocumentCommandService:
|
||||
message="Document record created",
|
||||
)
|
||||
|
||||
temp_path = ""
|
||||
try:
|
||||
self.binary_store.save(
|
||||
object_name=object_name,
|
||||
@@ -297,117 +296,20 @@ class DocumentCommandService:
|
||||
stage="store",
|
||||
message="Source file stored",
|
||||
)
|
||||
|
||||
suffix = os.path.splitext(file_name)[1]
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||
temp_file.write(content)
|
||||
temp_path = temp_file.name
|
||||
|
||||
parsed_document = self.parser.parse(
|
||||
file_path=temp_path,
|
||||
# Delegate parse → embed → index to the shared processing method.
|
||||
# This same method is invoked by the Celery worker for async processing.
|
||||
return self._process_document(
|
||||
doc_id=doc_id,
|
||||
doc_name=final_doc_name,
|
||||
)
|
||||
self._safe_mark_run_parsed(doc_id=doc_id, run_id=run_id, parsed_document=parsed_document)
|
||||
|
||||
artifact_keys: dict[str, str] = {}
|
||||
try:
|
||||
artifact_keys = self._save_parse_artifacts(doc_id=doc_id, parsed_document=parsed_document)
|
||||
except Exception:
|
||||
logger.warning("Parse artifact binary persistence failed for doc_id={}", doc_id)
|
||||
self.document_repository.update_status(
|
||||
doc_id,
|
||||
DocumentStatus.PARSED,
|
||||
parser_name=parsed_document.parser_name,
|
||||
metadata={
|
||||
"parser_backend": parsed_document.parser_name,
|
||||
"parse_task_id": parsed_document.metadata.get("task_id", ""),
|
||||
"layout_count": parsed_document.metadata.get("layout_count", len(parsed_document.raw_layouts)),
|
||||
"structure_node_count": len(parsed_document.structure_nodes),
|
||||
"semantic_block_count": len(parsed_document.semantic_blocks),
|
||||
"vector_chunk_count": len(parsed_document.vector_chunks),
|
||||
"artifact_keys": artifact_keys,
|
||||
"processing_stage": "parsed",
|
||||
},
|
||||
)
|
||||
current_status = DocumentStatus.PARSED
|
||||
current_stage = "embed"
|
||||
self._safe_replace_processing_artifacts(doc_id=doc_id, run_id=run_id, artifact_keys=artifact_keys)
|
||||
self._safe_append_status_event(
|
||||
doc_id=doc_id,
|
||||
run_id=run_id,
|
||||
from_status=DocumentStatus.STORED.value,
|
||||
to_status=DocumentStatus.PARSED.value,
|
||||
stage="parse",
|
||||
message="Document parsed",
|
||||
metadata={"artifact_count": len(artifact_keys)},
|
||||
)
|
||||
if self.parse_artifact_store:
|
||||
try:
|
||||
self.parse_artifact_store.save(
|
||||
doc_id,
|
||||
parsed_document.structure_nodes,
|
||||
parsed_document.semantic_blocks,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id)
|
||||
|
||||
chunks = self.chunk_builder.build(
|
||||
parsed_document=parsed_document,
|
||||
file_name=file_name,
|
||||
final_doc_name=final_doc_name,
|
||||
content=content,
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
)
|
||||
if not chunks:
|
||||
raise ValueError("解析完成但没有生成可入库的 chunks")
|
||||
|
||||
vectors = self.embedding_provider.embed_texts([chunk.embedding_text for chunk in chunks])
|
||||
current_stage = "index"
|
||||
inserted = self.vector_index.upsert(chunks, vectors)
|
||||
if inserted != len(chunks):
|
||||
logger.warning("Milvus upsert count mismatched: inserted={}, chunks={}", inserted, len(chunks))
|
||||
|
||||
health = self.vector_index.health()
|
||||
self.document_repository.update_status(
|
||||
doc_id,
|
||||
DocumentStatus.INDEXED,
|
||||
chunk_count=len(chunks),
|
||||
summary="",
|
||||
summary_latency_ms=0,
|
||||
index_name=health.get("collection_name", ""),
|
||||
metadata={
|
||||
"index_collection": health.get("collection_name", ""),
|
||||
"processing_stage": "indexed",
|
||||
},
|
||||
)
|
||||
current_status = DocumentStatus.INDEXED
|
||||
index_name = health.get("collection_name", "")
|
||||
self._safe_mark_run_indexed(
|
||||
doc_id=doc_id,
|
||||
generate_summary=generate_summary,
|
||||
run_id=run_id,
|
||||
chunk_count=len(chunks),
|
||||
index_name=index_name,
|
||||
)
|
||||
self._safe_append_status_event(
|
||||
doc_id=doc_id,
|
||||
run_id=run_id,
|
||||
from_status=DocumentStatus.PARSED.value,
|
||||
to_status=DocumentStatus.INDEXED.value,
|
||||
stage="index",
|
||||
message="Document indexed",
|
||||
metadata={"chunk_count": len(chunks), "index_name": index_name},
|
||||
)
|
||||
stored = self.document_repository.get(doc_id)
|
||||
return DocumentProcessResult(
|
||||
doc_id=doc_id,
|
||||
doc_name=final_doc_name,
|
||||
status=(stored.status.value if stored else DocumentStatus.INDEXED.value),
|
||||
message="处理成功",
|
||||
num_chunks=len(chunks),
|
||||
summary=stored.summary if stored else "",
|
||||
summary_latency_ms=stored.summary_latency_ms if stored else 0,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("文档处理失败: doc_id={}", doc_id)
|
||||
logger.exception("文档存储失败: doc_id={}", doc_id)
|
||||
failure_stage = current_stage
|
||||
self.document_repository.update_status(
|
||||
doc_id,
|
||||
@@ -439,6 +341,183 @@ class DocumentCommandService:
|
||||
status=DocumentStatus.FAILED.value,
|
||||
message=f"文档处理失败: {exc}",
|
||||
)
|
||||
|
||||
def store_document(
|
||||
self,
|
||||
*,
|
||||
doc_id: str | None = None,
|
||||
file_name: str,
|
||||
content: bytes,
|
||||
content_type: str,
|
||||
doc_name: str | None,
|
||||
regulation_type: str,
|
||||
version: str,
|
||||
generate_summary: bool,
|
||||
) -> tuple[str, str | None]:
|
||||
"""Store the binary file and create the Document record.
|
||||
|
||||
Returns (doc_id, run_id). Does NOT parse, embed, or index.
|
||||
This is the fast synchronous first step; processing is enqueued separately.
|
||||
The caller is responsible for enqueuing the follow-up process_document_task.
|
||||
"""
|
||||
doc_id = doc_id or str(uuid.uuid4())[:8]
|
||||
final_doc_name = doc_name or file_name
|
||||
object_name = f"{doc_id}/{file_name}"
|
||||
|
||||
document = Document(
|
||||
doc_id=doc_id,
|
||||
doc_name=final_doc_name,
|
||||
file_name=file_name,
|
||||
object_name=object_name,
|
||||
content_type=content_type,
|
||||
size_bytes=len(content),
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
metadata={"generate_summary": generate_summary},
|
||||
)
|
||||
self.document_repository.create(document)
|
||||
run_id = self._safe_create_processing_run(
|
||||
doc_id=doc_id, trigger_type="upload", generate_summary=generate_summary
|
||||
)
|
||||
self.binary_store.save(
|
||||
object_name=object_name, data=content,
|
||||
content_type=content_type, metadata={"doc_id": doc_id},
|
||||
)
|
||||
self.document_repository.update_status(doc_id, DocumentStatus.STORED)
|
||||
self._safe_mark_run_stored(doc_id=doc_id, run_id=run_id)
|
||||
self._safe_append_status_event(
|
||||
doc_id=doc_id, run_id=run_id,
|
||||
from_status=DocumentStatus.PENDING.value, to_status=DocumentStatus.STORED.value,
|
||||
stage="store", message="Source file stored",
|
||||
)
|
||||
return doc_id, run_id
|
||||
|
||||
def _process_document(
|
||||
self,
|
||||
*,
|
||||
doc_id: str,
|
||||
file_name: str,
|
||||
final_doc_name: str,
|
||||
content: bytes,
|
||||
regulation_type: str,
|
||||
version: str,
|
||||
generate_summary: bool,
|
||||
run_id: str | None = None,
|
||||
) -> DocumentProcessResult:
|
||||
"""Run parse → chunk → embed → index for a document that is already stored.
|
||||
|
||||
Called both synchronously (from upload_and_process) and asynchronously
|
||||
(from the Celery process_document_task worker). All side-effects write
|
||||
through DocumentProcessingStore so callers can poll progress.
|
||||
"""
|
||||
current_status = DocumentStatus.STORED
|
||||
current_stage = "parse"
|
||||
temp_path = ""
|
||||
try:
|
||||
suffix = os.path.splitext(file_name)[1]
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
||||
temp_file.write(content)
|
||||
temp_path = temp_file.name
|
||||
|
||||
parsed_document = self.parser.parse(
|
||||
file_path=temp_path,
|
||||
doc_id=doc_id,
|
||||
doc_name=final_doc_name,
|
||||
)
|
||||
self._safe_mark_run_parsed(doc_id=doc_id, run_id=run_id, parsed_document=parsed_document)
|
||||
|
||||
artifact_keys: dict[str, str] = {}
|
||||
try:
|
||||
artifact_keys = self._save_parse_artifacts(doc_id=doc_id, parsed_document=parsed_document)
|
||||
except Exception:
|
||||
logger.warning("Parse artifact binary persistence failed for doc_id={}", doc_id)
|
||||
|
||||
self.document_repository.update_status(
|
||||
doc_id,
|
||||
DocumentStatus.PARSED,
|
||||
parser_name=parsed_document.parser_name,
|
||||
metadata={
|
||||
"parser_backend": parsed_document.parser_name,
|
||||
"parse_task_id": parsed_document.metadata.get("task_id", ""),
|
||||
"layout_count": parsed_document.metadata.get("layout_count", len(parsed_document.raw_layouts)),
|
||||
"structure_node_count": len(parsed_document.structure_nodes),
|
||||
"semantic_block_count": len(parsed_document.semantic_blocks),
|
||||
"vector_chunk_count": len(parsed_document.vector_chunks),
|
||||
"artifact_keys": artifact_keys,
|
||||
"processing_stage": "parsed",
|
||||
},
|
||||
)
|
||||
current_status = DocumentStatus.PARSED
|
||||
current_stage = "embed"
|
||||
self._safe_replace_processing_artifacts(doc_id=doc_id, run_id=run_id, artifact_keys=artifact_keys)
|
||||
self._safe_append_status_event(
|
||||
doc_id=doc_id, run_id=run_id,
|
||||
from_status=DocumentStatus.STORED.value, to_status=DocumentStatus.PARSED.value,
|
||||
stage="parse", message="Document parsed", metadata={"artifact_count": len(artifact_keys)},
|
||||
)
|
||||
if self.parse_artifact_store:
|
||||
try:
|
||||
self.parse_artifact_store.save(
|
||||
doc_id, parsed_document.structure_nodes, parsed_document.semantic_blocks,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id)
|
||||
|
||||
chunks = self.chunk_builder.build(
|
||||
parsed_document=parsed_document,
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
)
|
||||
if not chunks:
|
||||
raise ValueError("解析完成但没有生成可入库的 chunks")
|
||||
|
||||
vectors = self.embedding_provider.embed_texts([chunk.embedding_text for chunk in chunks])
|
||||
current_stage = "index"
|
||||
inserted = self.vector_index.upsert(chunks, vectors)
|
||||
if inserted != len(chunks):
|
||||
logger.warning("Milvus upsert count mismatched: inserted={}, chunks={}", inserted, len(chunks))
|
||||
|
||||
health = self.vector_index.health()
|
||||
index_name = health.get("collection_name", "")
|
||||
self.document_repository.update_status(
|
||||
doc_id, DocumentStatus.INDEXED,
|
||||
chunk_count=len(chunks), summary="", summary_latency_ms=0,
|
||||
index_name=index_name,
|
||||
metadata={"index_collection": index_name, "processing_stage": "indexed"},
|
||||
)
|
||||
self._safe_mark_run_indexed(doc_id=doc_id, run_id=run_id, chunk_count=len(chunks), index_name=index_name)
|
||||
self._safe_append_status_event(
|
||||
doc_id=doc_id, run_id=run_id,
|
||||
from_status=DocumentStatus.PARSED.value, to_status=DocumentStatus.INDEXED.value,
|
||||
stage="index", message="Document indexed",
|
||||
metadata={"chunk_count": len(chunks), "index_name": index_name},
|
||||
)
|
||||
stored = self.document_repository.get(doc_id)
|
||||
return DocumentProcessResult(
|
||||
doc_id=doc_id, doc_name=final_doc_name,
|
||||
status=(stored.status.value if stored else DocumentStatus.INDEXED.value),
|
||||
message="处理成功", num_chunks=len(chunks),
|
||||
summary=stored.summary if stored else "",
|
||||
summary_latency_ms=stored.summary_latency_ms if stored else 0,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("文档处理失败: doc_id={}", doc_id)
|
||||
self.document_repository.update_status(
|
||||
doc_id, DocumentStatus.FAILED, error_message=str(exc),
|
||||
metadata={"failure_reason": str(exc), "processing_stage": "failed", "failure_stage": current_stage},
|
||||
)
|
||||
self._safe_mark_run_failed(
|
||||
doc_id=doc_id, run_id=run_id, failure_stage=current_stage, error_message=str(exc)
|
||||
)
|
||||
self._safe_append_status_event(
|
||||
doc_id=doc_id, run_id=run_id,
|
||||
from_status=current_status.value, to_status=DocumentStatus.FAILED.value,
|
||||
stage=current_stage, message=str(exc),
|
||||
)
|
||||
return DocumentProcessResult(
|
||||
doc_id=doc_id, doc_name=final_doc_name,
|
||||
status=DocumentStatus.FAILED.value, message=f"文档处理失败: {exc}",
|
||||
)
|
||||
finally:
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
try:
|
||||
@@ -446,7 +525,6 @@ class DocumentCommandService:
|
||||
except OSError:
|
||||
logger.warning("临时文件清理失败: {}", temp_path)
|
||||
|
||||
|
||||
def delete(self, doc_id: str) -> bool:
|
||||
"""Delete document record, binary file, and vector chunks."""
|
||||
document = self.document_repository.get(doc_id)
|
||||
|
||||
@@ -82,6 +82,10 @@ class Settings(BaseSettings):
|
||||
parser_backend: str = Field(default="aliyun", description="解析后端(local/aliyun)")
|
||||
chunk_backend: str = Field(default="aliyun", description="分块后端(local/aliyun)")
|
||||
document_repository_backend: str = Field(default="json", description="文档元数据存储后端 (json/postgres)")
|
||||
# When True, document processing is enqueued to Celery workers via Redis.
|
||||
# When False (default), processing runs in a FastAPI BackgroundTask in the same process —
|
||||
# no external worker needed. Switch to True only when a Celery worker is running.
|
||||
use_celery_worker: bool = Field(default=False, description="使用 Celery Worker 异步处理文档 (需要 Worker 运行中)")
|
||||
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
api_host: str = Field(default="0.0.0.0", description="API服务地址")
|
||||
@@ -109,6 +113,7 @@ class Settings(BaseSettings):
|
||||
rag_retrieval_top_k: int = Field(default=20, description="精排前召回候选数量(reranker 启用时生效)")
|
||||
rag_max_context_tokens: int = Field(default=2000, description="RAG最大上下文token数")
|
||||
rag_summary_max_tokens: int = Field(default=10240, description="文档摘要最大token数")
|
||||
rag_skills_max_tokens: int = Field(default=2048, description="技能类 RAG 最大 token 数")
|
||||
|
||||
reranker_enabled: bool = Field(default=False, description="是否启用 Cross-Encoder 精排")
|
||||
reranker_base_url: str = Field(default="", description="Reranker API 地址")
|
||||
@@ -124,6 +129,26 @@ class Settings(BaseSettings):
|
||||
# Keep configuration setup explicit so runtime behavior is easy to reason about.
|
||||
session_max_sessions: int = Field(default=100, description="最大会话数量")
|
||||
session_timeout_minutes: int = Field(default=30, description="会话超时时间(分钟)")
|
||||
session_backend: str = Field(
|
||||
default="memory",
|
||||
description="会话存储后端 (memory | redis)。redis 需要 Redis 可用。",
|
||||
)
|
||||
|
||||
# ── Auth ──────────────────────────────────────────────────────────────────
|
||||
# Generate a strong secret: python -c "import secrets; print(secrets.token_hex(32))"
|
||||
auth_secret_key: str = Field(
|
||||
default="change-me-in-production-must-be-32-or-more-characters-long",
|
||||
description="JWT signing secret. MUST be changed in production.",
|
||||
)
|
||||
auth_algorithm: str = Field(default="HS256", description="JWT signing algorithm.")
|
||||
auth_token_expire_minutes: int = Field(default=480, description="JWT TTL in minutes (default 8 hours).")
|
||||
auth_enabled: bool = Field(default=True, description="Set False to bypass auth (development only).")
|
||||
|
||||
# ── CORS ──────────────────────────────────────────────────────────────────
|
||||
cors_allow_origins: str = Field(
|
||||
default="http://localhost:5173",
|
||||
description="Comma-separated allowed CORS origins. Never use * in production.",
|
||||
)
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
|
||||
10
backend/app/domain/auth/__init__.py
Normal file
10
backend/app/domain/auth/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Auth domain: role definitions and token claim models.
|
||||
|
||||
The domain layer defines what a user identity looks like (UserClaims) and
|
||||
what roles exist (UserRole). Infrastructure details (JWT, bcrypt, PostgreSQL)
|
||||
live under infrastructure/auth and never leak into this package.
|
||||
"""
|
||||
|
||||
from .models import UserClaims, UserRole
|
||||
|
||||
__all__ = ["UserClaims", "UserRole"]
|
||||
42
backend/app/domain/auth/models.py
Normal file
42
backend/app/domain/auth/models.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Auth domain models: roles and token claims.
|
||||
|
||||
UserRole defines the four roles from PPT Slide 12.
|
||||
UserClaims is what the JWT decodes to — it is the identity object passed
|
||||
through FastAPI dependency injection to route handlers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class UserRole(str, enum.Enum):
|
||||
"""Access roles mirroring the four-role RBAC matrix from the product spec.
|
||||
|
||||
ADMIN — full platform access including system management.
|
||||
LEGAL — knowledge query, document review, compliance checks.
|
||||
EHS — knowledge query, perception/regulatory signals.
|
||||
READONLY — knowledge query only.
|
||||
"""
|
||||
|
||||
ADMIN = "admin"
|
||||
LEGAL = "legal"
|
||||
EHS = "ehs"
|
||||
READONLY = "readonly"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserClaims:
|
||||
"""Decoded JWT payload representing an authenticated user.
|
||||
|
||||
Instances are created by JWTHandler.decode_token() and injected into
|
||||
route handlers via the get_current_user FastAPI dependency.
|
||||
"""
|
||||
|
||||
# Unique user identifier (UUID string stored in PostgreSQL users table).
|
||||
user_id: str
|
||||
# Display name used for audit log entries.
|
||||
username: str
|
||||
# Role determines which resources the user may access.
|
||||
role: UserRole
|
||||
5
backend/app/infrastructure/auth/__init__.py
Normal file
5
backend/app/infrastructure/auth/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""JWT token creation and validation infrastructure.
|
||||
|
||||
JWTHandler is the only component in this package. It is wired through
|
||||
shared/bootstrap.py and injected into FastAPI dependencies.
|
||||
"""
|
||||
82
backend/app/infrastructure/auth/jwt_handler.py
Normal file
82
backend/app/infrastructure/auth/jwt_handler.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""JWT access token creation and decoding.
|
||||
|
||||
Uses python-jose for HS256 token signing. Token expiry is enforced at
|
||||
decode time so expired tokens are rejected even if the signature is valid.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from loguru import logger
|
||||
|
||||
from app.domain.auth.models import UserClaims, UserRole
|
||||
|
||||
|
||||
class JWTHandler:
|
||||
"""Create and validate HS256 JWT access tokens.
|
||||
|
||||
A single shared instance is wired by bootstrap.py. Use
|
||||
get_jwt_handler() from shared.bootstrap for all token operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
secret_key: str,
|
||||
algorithm: str = "HS256",
|
||||
expire_minutes: int = 480,
|
||||
) -> None:
|
||||
"""Initialise the handler with signing credentials and token lifetime."""
|
||||
self._secret = secret_key
|
||||
self._algorithm = algorithm
|
||||
self._expire_minutes = expire_minutes
|
||||
|
||||
def create_access_token(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
username: str,
|
||||
role: str,
|
||||
) -> str:
|
||||
"""Return a signed JWT containing user identity and role claims."""
|
||||
now = datetime.now(UTC)
|
||||
payload: dict[str, Any] = {
|
||||
"sub": user_id,
|
||||
"username": username,
|
||||
"role": role,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(minutes=self._expire_minutes),
|
||||
}
|
||||
return jwt.encode(payload, self._secret, algorithm=self._algorithm)
|
||||
|
||||
def decode_token(self, token: str) -> UserClaims:
|
||||
"""Decode and validate a JWT, returning UserClaims.
|
||||
|
||||
Raises ValueError with a descriptive message on expiry, tampering,
|
||||
or any other validation failure so callers do not need to know jose.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, self._secret, algorithms=[self._algorithm])
|
||||
except JWTError as exc:
|
||||
msg = str(exc).lower()
|
||||
if "expired" in msg:
|
||||
raise ValueError("Token expired") from exc
|
||||
raise ValueError(f"Invalid token: {exc}") from exc
|
||||
|
||||
user_id = payload.get("sub")
|
||||
username = payload.get("username", "")
|
||||
role_str = payload.get("role", UserRole.READONLY.value)
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("Token missing subject claim")
|
||||
|
||||
try:
|
||||
role = UserRole(role_str)
|
||||
except ValueError:
|
||||
logger.warning("Unknown role in token: {}, defaulting to readonly", role_str)
|
||||
role = UserRole.READONLY
|
||||
|
||||
return UserClaims(user_id=user_id, username=username, role=role)
|
||||
113
backend/app/infrastructure/auth/user_store.py
Normal file
113
backend/app/infrastructure/auth/user_store.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""PostgreSQL-backed user store for authentication.
|
||||
|
||||
Manages a `users` table with hashed passwords and roles.
|
||||
Provides lookup by username for the login flow.
|
||||
Table DDL is auto-applied on first connection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from loguru import logger
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
# bcrypt context — work factor 12 is a good production default.
|
||||
_PWD_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# DDL executed once to ensure the table exists.
|
||||
_CREATE_TABLE_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
username VARCHAR(100) UNIQUE NOT NULL,
|
||||
hashed_pw TEXT NOT NULL,
|
||||
role VARCHAR(50) NOT NULL DEFAULT 'readonly',
|
||||
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserRecord:
|
||||
"""A single row from the users table."""
|
||||
|
||||
id: str
|
||||
username: str
|
||||
hashed_pw: str
|
||||
role: str
|
||||
is_active: bool
|
||||
|
||||
|
||||
class PostgresUserStore:
|
||||
"""Read and verify users stored in the PostgreSQL users table.
|
||||
|
||||
The connection is opened on first use and shared for the lifetime
|
||||
of the singleton instance wired by bootstrap.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise the store and ensure the users table exists."""
|
||||
self._conn = psycopg2.connect(
|
||||
host=settings.postgres_host,
|
||||
port=settings.postgres_port,
|
||||
user=settings.postgres_user,
|
||||
password=settings.postgres_password,
|
||||
dbname=settings.postgres_db,
|
||||
cursor_factory=psycopg2.extras.RealDictCursor,
|
||||
)
|
||||
self._conn.autocommit = True
|
||||
self._ensure_table()
|
||||
|
||||
def _ensure_table(self) -> None:
|
||||
"""Create the users table if it does not already exist."""
|
||||
with self._conn.cursor() as cur:
|
||||
# Enable pgcrypto so gen_random_uuid() is available for UUID primary keys.
|
||||
try:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;")
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
cur.execute(_CREATE_TABLE_SQL)
|
||||
|
||||
def get_by_username(self, username: str) -> Optional[UserRecord]:
|
||||
"""Return a UserRecord for the given username, or None if not found."""
|
||||
with self._conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT id, username, hashed_pw, role, is_active "
|
||||
"FROM users WHERE username = %s",
|
||||
(username,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return UserRecord(
|
||||
id=str(row["id"]),
|
||||
username=row["username"],
|
||||
hashed_pw=row["hashed_pw"],
|
||||
role=row["role"],
|
||||
is_active=row["is_active"],
|
||||
)
|
||||
|
||||
def verify_password(self, plain: str, hashed: str) -> bool:
|
||||
"""Return True if `plain` matches the stored bcrypt hash."""
|
||||
return _PWD_CTX.verify(plain, hashed)
|
||||
|
||||
def authenticate(self, username: str, password: str) -> Optional[UserRecord]:
|
||||
"""Return the UserRecord if credentials are valid, else None."""
|
||||
user = self.get_by_username(username)
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
if not self.verify_password(password, user.hashed_pw):
|
||||
return None
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def hash_password(plain: str) -> str:
|
||||
"""Hash a plain-text password with bcrypt."""
|
||||
return _PWD_CTX.hash(plain)
|
||||
169
backend/app/infrastructure/session/redis_conversation_store.py
Normal file
169
backend/app/infrastructure/session/redis_conversation_store.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Redis-backed conversation store for persistent chat sessions.
|
||||
|
||||
Sessions are stored as JSON strings under the key `session:{session_id}`.
|
||||
The Redis TTL is refreshed on every write so active sessions stay alive.
|
||||
On expiry, `get_session` returns None — callers should create a new session.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.domain.conversation import ConversationMessage, ConversationSession, ConversationStore
|
||||
|
||||
|
||||
class RedisConversationStore(ConversationStore):
|
||||
"""Store conversation sessions in Redis with automatic TTL expiry.
|
||||
|
||||
Each session is serialised as a JSON object at key ``session:{session_id}``.
|
||||
The TTL is reset on every write so sessions stay alive as long as they are active.
|
||||
"""
|
||||
|
||||
# Prefix for all session keys to avoid collisions with other Redis consumers.
|
||||
_PREFIX = "session:"
|
||||
|
||||
def __init__(self, *, redis_client: Any, timeout_seconds: int = 1800) -> None:
|
||||
"""Initialise the store with an existing Redis client and a TTL in seconds."""
|
||||
self._redis = redis_client
|
||||
self._ttl = timeout_seconds
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _key(self, session_id: str) -> str:
|
||||
"""Build the Redis key for a session."""
|
||||
return f"{self._PREFIX}{session_id}"
|
||||
|
||||
def _serialise(self, session: ConversationSession) -> str:
|
||||
"""Serialise a ConversationSession to a JSON string."""
|
||||
return json.dumps(
|
||||
{
|
||||
"session_id": session.session_id,
|
||||
"created_at": session.created_at,
|
||||
"updated_at": session.updated_at,
|
||||
"metadata": session.metadata,
|
||||
"messages": [
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"timestamp": msg.timestamp,
|
||||
"sources": msg.sources,
|
||||
}
|
||||
for msg in session.messages
|
||||
],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
def _deserialise(self, raw: bytes | str) -> ConversationSession:
|
||||
"""Deserialise a JSON string back into a ConversationSession."""
|
||||
data = json.loads(raw)
|
||||
messages = [
|
||||
ConversationMessage(
|
||||
role=m["role"],
|
||||
content=m["content"],
|
||||
timestamp=m["timestamp"],
|
||||
sources=m.get("sources", []),
|
||||
)
|
||||
for m in data.get("messages", [])
|
||||
]
|
||||
session = ConversationSession(
|
||||
session_id=data["session_id"],
|
||||
created_at=data.get("created_at", 0),
|
||||
updated_at=data.get("updated_at", 0),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
session.messages = messages
|
||||
return session
|
||||
|
||||
def _save(self, session: ConversationSession) -> None:
|
||||
"""Persist a session to Redis and refresh its TTL."""
|
||||
self._redis.setex(self._key(session.session_id), self._ttl, self._serialise(session))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# ConversationStore protocol
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_session(self, metadata: dict | None = None) -> ConversationSession:
|
||||
"""Create a new empty session and persist it immediately."""
|
||||
now = int(time.time())
|
||||
session = ConversationSession(
|
||||
session_id=str(uuid.uuid4())[:8],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self._save(session)
|
||||
return session
|
||||
|
||||
def get_session(self, session_id: str) -> ConversationSession | None:
|
||||
"""Return a session by ID, or None if it does not exist or has expired."""
|
||||
raw = self._redis.get(self._key(session_id))
|
||||
if raw is None:
|
||||
return None
|
||||
try:
|
||||
return self._deserialise(raw)
|
||||
except Exception:
|
||||
logger.warning("Failed to deserialise session: {}", session_id)
|
||||
return None
|
||||
|
||||
def save_message(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
role: str,
|
||||
content: str,
|
||||
sources: list[dict] | None = None,
|
||||
) -> ConversationSession | None:
|
||||
"""Append a message to a session and refresh its TTL."""
|
||||
session = self.get_session(session_id)
|
||||
if session is None:
|
||||
return None
|
||||
session.messages.append(
|
||||
ConversationMessage(
|
||||
role=role,
|
||||
content=content,
|
||||
timestamp=int(time.time()),
|
||||
sources=sources or [],
|
||||
)
|
||||
)
|
||||
session.updated_at = int(time.time())
|
||||
self._save(session)
|
||||
return session
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""Delete a session. Returns True if it existed, False otherwise."""
|
||||
deleted = self._redis.delete(self._key(session_id))
|
||||
return bool(deleted)
|
||||
|
||||
def list_sessions(self) -> list[dict]:
|
||||
"""Return summary dicts for all live sessions visible in this Redis DB.
|
||||
|
||||
Note: KEYS is used for simplicity; replace with SCAN for large deployments.
|
||||
"""
|
||||
pattern = f"{self._PREFIX}*"
|
||||
keys = self._redis.keys(pattern)
|
||||
result = []
|
||||
for key in keys:
|
||||
raw = self._redis.get(key)
|
||||
if raw is None:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
result.append(
|
||||
{
|
||||
"session_id": data["session_id"],
|
||||
"message_count": len(data.get("messages", [])),
|
||||
"created_at": data.get("created_at", 0),
|
||||
"updated_at": data.get("updated_at", 0),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
return result
|
||||
5
backend/app/infrastructure/tasks/__init__.py
Normal file
5
backend/app/infrastructure/tasks/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Celery task definitions for background processing.
|
||||
|
||||
This package exposes the shared Celery application instance and all
|
||||
registered task functions used by API routes to enqueue work.
|
||||
"""
|
||||
45
backend/app/infrastructure/tasks/celery_app.py
Normal file
45
backend/app/infrastructure/tasks/celery_app.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Shared Celery application instance for background task processing.
|
||||
|
||||
All workers and enqueueing call sites import `celery_app` from this module
|
||||
so the broker/backend configuration stays in one place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
def _redis_url() -> str:
|
||||
"""Return a Redis connection URL from application settings."""
|
||||
if settings.redis_password:
|
||||
return (
|
||||
f"redis://:{settings.redis_password}@"
|
||||
f"{settings.redis_host}:{settings.redis_port}/{settings.redis_db}"
|
||||
)
|
||||
return f"redis://{settings.redis_host}:{settings.redis_port}/{settings.redis_db}"
|
||||
|
||||
|
||||
_BROKER = _redis_url()
|
||||
_BACKEND = _redis_url()
|
||||
|
||||
celery_app = Celery(
|
||||
"compliance_hub",
|
||||
broker=_BROKER,
|
||||
backend=_BACKEND,
|
||||
include=["app.infrastructure.tasks.document_tasks"],
|
||||
)
|
||||
|
||||
celery_app.conf.update(
|
||||
task_serializer="json",
|
||||
result_serializer="json",
|
||||
accept_content=["json"],
|
||||
timezone="UTC",
|
||||
enable_utc=True,
|
||||
# Acknowledge task only after successful execution to avoid data loss.
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
# Keep results for 1 hour for status polling.
|
||||
result_expires=3600,
|
||||
)
|
||||
73
backend/app/infrastructure/tasks/document_tasks.py
Normal file
73
backend/app/infrastructure/tasks/document_tasks.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Celery tasks for document processing.
|
||||
|
||||
Each task is a thin wrapper that retrieves the already-stored document
|
||||
binary and delegates to DocumentCommandService._process_document.
|
||||
The task does not accept raw file bytes — it reads them from the binary
|
||||
store using the doc_id, so the Celery message payload stays small.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.infrastructure.tasks.celery_app import celery_app
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="app.infrastructure.tasks.document_tasks.process_document_task",
|
||||
bind=True,
|
||||
max_retries=3,
|
||||
default_retry_delay=30,
|
||||
acks_late=True,
|
||||
)
|
||||
def process_document_task(
|
||||
self,
|
||||
doc_id: str,
|
||||
file_name: str,
|
||||
doc_name: str,
|
||||
regulation_type: str,
|
||||
version: str,
|
||||
generate_summary: bool,
|
||||
run_id: str | None = None,
|
||||
) -> dict:
|
||||
"""Parse, embed, and index a document that has already been stored.
|
||||
|
||||
The task reads the file binary from MinIO using doc_id so the Celery
|
||||
message stays small. Retries up to 3 times with a 30-second delay on
|
||||
transient infrastructure errors.
|
||||
"""
|
||||
# Import inside the task function to avoid pickling issues and to ensure
|
||||
# that each worker process initialises its own bootstrap singletons.
|
||||
from app.shared.bootstrap import get_document_command_service, get_document_query_service
|
||||
|
||||
logger.info("process_document_task started: doc_id={}", doc_id)
|
||||
try:
|
||||
svc = get_document_command_service()
|
||||
doc = get_document_query_service().get(doc_id)
|
||||
if not doc:
|
||||
raise ValueError(f"Document record not found: {doc_id}")
|
||||
|
||||
# Read the stored binary from MinIO — avoids passing raw bytes in the task message.
|
||||
content = svc.binary_store.read(doc.object_name)
|
||||
|
||||
result = svc._process_document(
|
||||
doc_id=doc_id,
|
||||
file_name=file_name,
|
||||
final_doc_name=doc_name,
|
||||
content=content,
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
generate_summary=generate_summary,
|
||||
run_id=run_id,
|
||||
)
|
||||
logger.info(
|
||||
"process_document_task completed: doc_id={} status={} chunks={}",
|
||||
doc_id, result.status, result.num_chunks,
|
||||
)
|
||||
return {"doc_id": result.doc_id, "status": result.status, "num_chunks": result.num_chunks}
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("process_document_task failed: doc_id={}", doc_id)
|
||||
# Retry on transient errors; permanent errors (bad file, parse failure)
|
||||
# will exhaust retries and leave the document in FAILED state.
|
||||
raise self.retry(exc=exc)
|
||||
@@ -252,7 +252,31 @@ def get_document_query_service() -> DocumentQueryService:
|
||||
|
||||
@lru_cache
|
||||
def get_conversation_store() -> InMemoryConversationStore:
|
||||
"""Return conversation store."""
|
||||
"""Return the active conversation store based on settings.
|
||||
|
||||
When session_backend='redis', sessions survive backend restarts and scale
|
||||
across multiple API worker processes. When session_backend='memory' (default),
|
||||
sessions are process-local and lost on restart.
|
||||
"""
|
||||
if settings.session_backend == "redis":
|
||||
import redis as redis_lib
|
||||
from app.infrastructure.session.redis_conversation_store import RedisConversationStore
|
||||
|
||||
# Build the Redis client from the same connection settings used by Celery.
|
||||
kwargs: dict = {
|
||||
"host": settings.redis_host,
|
||||
"port": settings.redis_port,
|
||||
"db": settings.redis_db,
|
||||
"decode_responses": False,
|
||||
}
|
||||
if settings.redis_password:
|
||||
kwargs["password"] = settings.redis_password
|
||||
|
||||
redis_client = redis_lib.Redis(**kwargs)
|
||||
return RedisConversationStore( # type: ignore[return-value]
|
||||
redis_client=redis_client,
|
||||
timeout_seconds=settings.session_timeout_minutes * 60,
|
||||
)
|
||||
return InMemoryConversationStore(
|
||||
max_sessions=settings.session_max_sessions,
|
||||
timeout_minutes=settings.session_timeout_minutes,
|
||||
@@ -284,6 +308,35 @@ def get_agent_session_service() -> AgentSessionService:
|
||||
return AgentSessionService(conversation_store=get_conversation_store())
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_celery_app():
|
||||
"""Return the shared Celery application instance.
|
||||
|
||||
Imported lazily so Celery is not required when running without workers
|
||||
(e.g., tests that mock bootstrap or dev without Redis).
|
||||
"""
|
||||
from app.infrastructure.tasks.celery_app import celery_app
|
||||
return celery_app
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_jwt_handler():
|
||||
"""Return the shared JWTHandler instance for token creation and validation."""
|
||||
from app.infrastructure.auth.jwt_handler import JWTHandler
|
||||
return JWTHandler(
|
||||
secret_key=settings.auth_secret_key,
|
||||
algorithm=settings.auth_algorithm,
|
||||
expire_minutes=settings.auth_token_expire_minutes,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_user_store():
|
||||
"""Return the PostgreSQL user store (lazy-connects on first call)."""
|
||||
from app.infrastructure.auth.user_store import PostgresUserStore
|
||||
return PostgresUserStore()
|
||||
|
||||
|
||||
def preload_runtime_dependencies() -> None:
|
||||
"""Warm dependencies that are safe and useful to preload during startup."""
|
||||
LLMFactory.preload_clients(["qwen", "deepseek"])
|
||||
|
||||
Reference in New Issue
Block a user