1. Add 登陆功能

2. 调整字体大小
3. 新增部分功能
This commit is contained in:
2026-06-05 18:00:31 +08:00
parent 06e0967128
commit 9fea9c6a53
58 changed files with 5028 additions and 322 deletions

View File

@@ -0,0 +1,5 @@
"""FastAPI dependency functions for authentication and authorisation.
Import `get_current_user` or `require_role` into route modules to protect
endpoints. Both use the shared JWTHandler wired through bootstrap.
"""

View File

@@ -0,0 +1,72 @@
"""FastAPI dependencies for JWT authentication.
Usage in a route:
from app.api.dependencies.auth import get_current_user, require_role
from app.domain.auth.models import UserRole
@router.get("/protected")
async def protected(user: UserClaims = Depends(get_current_user)):
return {"user": user.username}
@router.delete("/admin-only")
async def admin_only(user: UserClaims = Depends(require_role(UserRole.ADMIN))):
...
"""
from __future__ import annotations
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.config.settings import settings
from app.domain.auth.models import UserClaims, UserRole
from app.shared.bootstrap import get_jwt_handler
# Use Bearer token scheme — client sends `Authorization: Bearer <token>`.
_bearer = HTTPBearer(auto_error=False)
async def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(_bearer),
) -> UserClaims:
"""Extract and validate the JWT from the Authorization header.
Returns the decoded UserClaims on success.
Raises HTTP 401 when the token is missing, expired, or invalid.
When auth_enabled=False (development), returns a synthetic admin user.
"""
if not settings.auth_enabled:
# Development bypass — never enable this in production.
return UserClaims(user_id="dev", username="dev-admin", role=UserRole.ADMIN)
if credentials is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
return get_jwt_handler().decode_token(credentials.credentials)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(exc),
headers={"WWW-Authenticate": "Bearer"},
) from exc
def require_role(*roles: UserRole):
"""Return a dependency that enforces one of the given roles.
Example:
Depends(require_role(UserRole.ADMIN, UserRole.LEGAL))
"""
async def _check(user: UserClaims = Depends(get_current_user)) -> UserClaims:
"""Verify the user holds one of the required roles."""
if user.role not in roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Role '{user.role}' is not permitted. Required: {[r.value for r in roles]}",
)
return user
return _check

View File

@@ -8,6 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from loguru import logger
from app.api.middleware.audit import AuditMiddleware
from app.api.models import ErrorResponse
from app.api.routes import api_router
from app.config.logging import setup_logging
@@ -46,14 +47,23 @@ app = FastAPI(
redoc_url="/redoc",
)
# Tighten CORS — only allow configured origins.
# Set CORS_ALLOW_ORIGINS in .env to the real frontend URL in production.
_ORIGINS = [o.strip() for o in settings.cors_allow_origins.split(",") if o.strip()]
if not _ORIGINS:
_ORIGINS = ["http://localhost:5173"]
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_origins=_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Audit middleware logs every authenticated API call for compliance traceability.
app.add_middleware(AuditMiddleware)
app.include_router(api_router, prefix="/api/v1")

View File

@@ -0,0 +1 @@
"""HTTP middleware for cross-cutting concerns: audit logging."""

View File

@@ -0,0 +1,56 @@
"""Audit logging middleware.
Logs every API request with method, path, status code, response time,
and the authenticated user identity (extracted from the JWT when present).
Log lines are structured so they can be ingested by ELK / Loki.
"""
from __future__ import annotations
import time
from fastapi import Request, Response
from loguru import logger
from starlette.middleware.base import BaseHTTPMiddleware
class AuditMiddleware(BaseHTTPMiddleware):
"""Log all API calls. Skips health/docs paths to reduce noise."""
# Paths that produce no audit log entry.
_SKIP_PREFIXES = ("/health", "/docs", "/redoc", "/openapi.json")
async def dispatch(self, request: Request, call_next) -> Response:
"""Intercept the request, call the handler, and log the outcome."""
path = request.url.path
if path == "/" or any(path == p or path.startswith(p + "/") for p in self._SKIP_PREFIXES):
return await call_next(request)
start = time.perf_counter()
response = await call_next(request)
elapsed_ms = int((time.perf_counter() - start) * 1000)
# Extract user identity from JWT header for structured audit records.
# The token is not re-validated here — auth dependencies do that upstream.
user_id = "anonymous"
username = "anonymous"
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
try:
from app.shared.bootstrap import get_jwt_handler
claims = get_jwt_handler().decode_token(auth_header[7:])
user_id = claims.user_id
username = claims.username
except Exception:
pass
logger.info(
"AUDIT method={} path={} status={} elapsed_ms={} user_id={} username={}",
request.method,
path,
response.status_code,
elapsed_ms,
user_id,
username,
)
return response

View File

@@ -1,6 +1,7 @@
"""Initialize the app.api.routes package."""
from fastapi import APIRouter
from .auth import router as auth_router
from .compliance import router as compliance_router
from .documents import router as documents_router
from .knowledge import router as knowledge_router
@@ -14,7 +15,8 @@ from .rag import router as rag_router
# Keep package boundaries explicit so backend imports stay predictable.
api_router = APIRouter()
# Keep package boundaries explicit so backend imports stay predictable.
# Auth routes first so /auth/token is easy to discover.
api_router.include_router(auth_router)
api_router.include_router(documents_router)
api_router.include_router(knowledge_router)
api_router.include_router(agent_router)
@@ -25,6 +27,7 @@ api_router.include_router(rag_router)
__all__ = [
"api_router",
"auth_router",
"documents_router",
"knowledge_router",
"agent_router",

View File

@@ -0,0 +1,63 @@
"""Authentication routes — token issuance only.
POST /auth/token — exchange username + password for a JWT.
GET /auth/me — return the current user identity (requires token).
"""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel
from app.api.dependencies.auth import get_current_user
from app.config.settings import settings
from app.domain.auth.models import UserClaims
from app.shared.bootstrap import get_jwt_handler, get_user_store
router = APIRouter(prefix="/auth", tags=["认证"])
class TokenResponse(BaseModel):
"""JWT token response body."""
access_token: str
token_type: str = "bearer"
expires_in: int
@router.post("/token", response_model=TokenResponse)
async def login(form: OAuth2PasswordRequestForm = Depends()):
"""Issue a JWT for valid username + password credentials.
Uses standard OAuth2 password grant form fields — compatible with
Swagger UI Authorize button.
"""
user = get_user_store().authenticate(form.username, form.password)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
token = get_jwt_handler().create_access_token(
user_id=user.id,
username=user.username,
role=user.role,
)
return TokenResponse(
access_token=token,
token_type="bearer",
expires_in=settings.auth_token_expire_minutes * 60,
)
@router.get("/me")
async def get_me(current_user: UserClaims = Depends(get_current_user)):
"""Return the identity of the currently authenticated user."""
return {
"user_id": current_user.user_id,
"username": current_user.username,
"role": current_user.role.value,
}

View File

@@ -7,10 +7,12 @@ import json
from pathlib import Path
from typing import AsyncGenerator, Optional
from fastapi import APIRouter, File, Form, UploadFile
from fastapi import APIRouter, Depends, File, Form, UploadFile
from fastapi.responses import StreamingResponse
from loguru import logger
from app.api.dependencies.auth import get_current_user
from app.domain.auth.models import UserClaims
from app.schemas.compliance import (
AnalyzeResponse,
ComplianceChatRequest,
@@ -75,6 +77,7 @@ async def analyze_stream(
file: Optional[UploadFile] = File(None),
domains: Optional[str] = Form(None),
title: Optional[str] = Form(None),
current_user: UserClaims = Depends(get_current_user),
):
"""Stream compliance analysis as SSE events.

View File

@@ -5,12 +5,15 @@ from __future__ import annotations
from io import BytesIO
from urllib.parse import quote
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, UploadFile
from fastapi.responses import StreamingResponse
from loguru import logger
from app.api.dependencies.auth import get_current_user
from app.api.models import DocumentUploadResponse
from app.application.documents import DocumentProcessResult
from app.config.settings import settings
from app.domain.auth.models import UserClaims
from app.shared.bootstrap import get_document_command_service, get_document_query_service
# Keep route handlers close to their transport-layer wiring for easier auditing.
@@ -31,16 +34,60 @@ def _document_response(result: DocumentProcessResult) -> DocumentUploadResponse:
)
def _run_process_in_background(
*,
doc_id: str,
file_name: str,
final_doc_name: str,
content: bytes,
regulation_type: str,
version: str,
generate_summary: bool,
run_id: str | None,
) -> None:
"""Run document processing synchronously inside a FastAPI BackgroundTask thread.
FastAPI executes BackgroundTasks in a threadpool executor, so blocking I/O
(parser API calls, embedding, Milvus upsert) is safe here.
"""
try:
svc = get_document_command_service()
svc._process_document(
doc_id=doc_id,
file_name=file_name,
final_doc_name=final_doc_name,
content=content,
regulation_type=regulation_type,
version=version,
generate_summary=generate_summary,
run_id=run_id,
)
except Exception:
logger.exception("BackgroundTask document processing failed: doc_id={}", doc_id)
@router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="上传的文档文件"),
doc_id: str | None = Form(None, description="客户端预分配的文档ID不传则自动生成"),
doc_name: str | None = Form(None, description="文档名称"),
regulation_type: str | None = Form(None, description="法规类型"),
version: str | None = Form(None, description="文档版本"),
generate_summary: bool = Form(False, description="是否生成摘要"),
sync: bool = Form(False, description="同步处理(演示/测试用,默认异步处理)"),
current_user: UserClaims = Depends(get_current_user),
):
"""Handle upload document."""
"""Upload a document and process it asynchronously.
Default path (sync=false):
1. Store binary to MinIO immediately — returns within seconds.
2. Schedule parse→embed→index as a FastAPI BackgroundTask (same process,
threadpool) OR enqueue to Celery workers when USE_CELERY_WORKER=true.
3. Poll GET /documents/status/{doc_id} for progress.
sync=true path: full inline processing, blocks until complete (demo / CI use).
"""
content = await file.read()
if not file.filename:
raise HTTPException(status_code=400, detail="文件名不能为空")
@@ -48,19 +95,73 @@ async def upload_document(
raise HTTPException(status_code=400, detail="上传文件为空")
try:
result = get_document_command_service().upload_and_process(
doc_id=doc_id,
file_name=file.filename,
content=content,
content_type=file.content_type or "application/octet-stream",
doc_name=doc_name,
regulation_type=regulation_type or "",
version=version or "",
generate_summary=generate_summary,
)
svc = get_document_command_service()
if sync:
# Synchronous fallback: full inline processing.
result = svc.upload_and_process(
doc_id=doc_id,
file_name=file.filename,
content=content,
content_type=file.content_type or "application/octet-stream",
doc_name=doc_name,
regulation_type=regulation_type or "",
version=version or "",
generate_summary=generate_summary,
)
else:
# Step 1: store binary and create the document record (fast, sync).
stored_doc_id, run_id = svc.store_document(
doc_id=doc_id,
file_name=file.filename,
content=content,
content_type=file.content_type or "application/octet-stream",
doc_name=doc_name,
regulation_type=regulation_type or "",
version=version or "",
generate_summary=generate_summary,
)
final_doc_name = doc_name or file.filename
# Step 2: schedule processing via Celery worker OR FastAPI BackgroundTask.
if settings.use_celery_worker:
from app.infrastructure.tasks.document_tasks import process_document_task
process_document_task.delay(
doc_id=stored_doc_id,
file_name=file.filename,
doc_name=final_doc_name,
regulation_type=regulation_type or "",
version=version or "",
generate_summary=generate_summary,
run_id=run_id,
)
processing_note = "已入 Celery 队列,由 Worker 处理。"
else:
# Default: run in FastAPI's threadpool — no external worker needed.
background_tasks.add_task(
_run_process_in_background,
doc_id=stored_doc_id,
file_name=file.filename,
final_doc_name=final_doc_name,
content=content,
regulation_type=regulation_type or "",
version=version or "",
generate_summary=generate_summary,
run_id=run_id,
)
processing_note = "正在后台处理。"
result = DocumentProcessResult(
doc_id=stored_doc_id,
doc_name=final_doc_name,
status="stored",
message=f"文件已存储,{processing_note}请轮询 GET /documents/status/{{doc_id}} 查看进度。",
)
if result.status == "failed":
raise HTTPException(status_code=500, detail=result.message)
return _document_response(result)
except HTTPException:
raise
except Exception as exc:
@@ -106,7 +207,7 @@ async def download_document(doc_id: str):
@router.get("/list")
async def list_documents():
async def list_documents(current_user: UserClaims = Depends(get_current_user)):
"""List documents."""
documents = get_document_query_service().list_documents()
return {
@@ -148,7 +249,7 @@ async def get_document_management_list():
@router.delete("/{doc_id}")
async def delete_document(doc_id: str):
async def delete_document(doc_id: str, current_user: UserClaims = Depends(get_current_user)):
"""Delete a document and its associated data."""
deleted = get_document_command_service().delete(doc_id)
if not deleted:

View File

@@ -5,10 +5,12 @@ from __future__ import annotations
import json
from typing import AsyncGenerator
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from app.api.dependencies.auth import get_current_user
from app.config.settings import settings
from app.domain.auth.models import UserClaims
from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion
from app.shared.async_utils import iter_in_thread
from app.shared.bootstrap import get_agent_conversation_service
@@ -27,7 +29,10 @@ _DEFAULT_QUICK_QUESTIONS = [
@router.post("/chat")
async def rag_chat(request: RagChatRequest):
async def rag_chat(
request: RagChatRequest,
current_user: UserClaims = Depends(get_current_user),
):
"""Stream RAG Q&A using the real agent service."""
session_id, event_stream = get_agent_conversation_service().stream_chat(
query=request.query,

View File

@@ -277,7 +277,6 @@ class DocumentCommandService:
message="Document record created",
)
temp_path = ""
try:
self.binary_store.save(
object_name=object_name,
@@ -297,117 +296,20 @@ class DocumentCommandService:
stage="store",
message="Source file stored",
)
suffix = os.path.splitext(file_name)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
temp_file.write(content)
temp_path = temp_file.name
parsed_document = self.parser.parse(
file_path=temp_path,
# Delegate parse → embed → index to the shared processing method.
# This same method is invoked by the Celery worker for async processing.
return self._process_document(
doc_id=doc_id,
doc_name=final_doc_name,
)
self._safe_mark_run_parsed(doc_id=doc_id, run_id=run_id, parsed_document=parsed_document)
artifact_keys: dict[str, str] = {}
try:
artifact_keys = self._save_parse_artifacts(doc_id=doc_id, parsed_document=parsed_document)
except Exception:
logger.warning("Parse artifact binary persistence failed for doc_id={}", doc_id)
self.document_repository.update_status(
doc_id,
DocumentStatus.PARSED,
parser_name=parsed_document.parser_name,
metadata={
"parser_backend": parsed_document.parser_name,
"parse_task_id": parsed_document.metadata.get("task_id", ""),
"layout_count": parsed_document.metadata.get("layout_count", len(parsed_document.raw_layouts)),
"structure_node_count": len(parsed_document.structure_nodes),
"semantic_block_count": len(parsed_document.semantic_blocks),
"vector_chunk_count": len(parsed_document.vector_chunks),
"artifact_keys": artifact_keys,
"processing_stage": "parsed",
},
)
current_status = DocumentStatus.PARSED
current_stage = "embed"
self._safe_replace_processing_artifacts(doc_id=doc_id, run_id=run_id, artifact_keys=artifact_keys)
self._safe_append_status_event(
doc_id=doc_id,
run_id=run_id,
from_status=DocumentStatus.STORED.value,
to_status=DocumentStatus.PARSED.value,
stage="parse",
message="Document parsed",
metadata={"artifact_count": len(artifact_keys)},
)
if self.parse_artifact_store:
try:
self.parse_artifact_store.save(
doc_id,
parsed_document.structure_nodes,
parsed_document.semantic_blocks,
)
except Exception:
logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id)
chunks = self.chunk_builder.build(
parsed_document=parsed_document,
file_name=file_name,
final_doc_name=final_doc_name,
content=content,
regulation_type=regulation_type,
version=version,
)
if not chunks:
raise ValueError("解析完成但没有生成可入库的 chunks")
vectors = self.embedding_provider.embed_texts([chunk.embedding_text for chunk in chunks])
current_stage = "index"
inserted = self.vector_index.upsert(chunks, vectors)
if inserted != len(chunks):
logger.warning("Milvus upsert count mismatched: inserted={}, chunks={}", inserted, len(chunks))
health = self.vector_index.health()
self.document_repository.update_status(
doc_id,
DocumentStatus.INDEXED,
chunk_count=len(chunks),
summary="",
summary_latency_ms=0,
index_name=health.get("collection_name", ""),
metadata={
"index_collection": health.get("collection_name", ""),
"processing_stage": "indexed",
},
)
current_status = DocumentStatus.INDEXED
index_name = health.get("collection_name", "")
self._safe_mark_run_indexed(
doc_id=doc_id,
generate_summary=generate_summary,
run_id=run_id,
chunk_count=len(chunks),
index_name=index_name,
)
self._safe_append_status_event(
doc_id=doc_id,
run_id=run_id,
from_status=DocumentStatus.PARSED.value,
to_status=DocumentStatus.INDEXED.value,
stage="index",
message="Document indexed",
metadata={"chunk_count": len(chunks), "index_name": index_name},
)
stored = self.document_repository.get(doc_id)
return DocumentProcessResult(
doc_id=doc_id,
doc_name=final_doc_name,
status=(stored.status.value if stored else DocumentStatus.INDEXED.value),
message="处理成功",
num_chunks=len(chunks),
summary=stored.summary if stored else "",
summary_latency_ms=stored.summary_latency_ms if stored else 0,
)
except Exception as exc:
logger.exception("文档处理失败: doc_id={}", doc_id)
logger.exception("文档存储失败: doc_id={}", doc_id)
failure_stage = current_stage
self.document_repository.update_status(
doc_id,
@@ -439,6 +341,183 @@ class DocumentCommandService:
status=DocumentStatus.FAILED.value,
message=f"文档处理失败: {exc}",
)
def store_document(
self,
*,
doc_id: str | None = None,
file_name: str,
content: bytes,
content_type: str,
doc_name: str | None,
regulation_type: str,
version: str,
generate_summary: bool,
) -> tuple[str, str | None]:
"""Store the binary file and create the Document record.
Returns (doc_id, run_id). Does NOT parse, embed, or index.
This is the fast synchronous first step; processing is enqueued separately.
The caller is responsible for enqueuing the follow-up process_document_task.
"""
doc_id = doc_id or str(uuid.uuid4())[:8]
final_doc_name = doc_name or file_name
object_name = f"{doc_id}/{file_name}"
document = Document(
doc_id=doc_id,
doc_name=final_doc_name,
file_name=file_name,
object_name=object_name,
content_type=content_type,
size_bytes=len(content),
regulation_type=regulation_type,
version=version,
metadata={"generate_summary": generate_summary},
)
self.document_repository.create(document)
run_id = self._safe_create_processing_run(
doc_id=doc_id, trigger_type="upload", generate_summary=generate_summary
)
self.binary_store.save(
object_name=object_name, data=content,
content_type=content_type, metadata={"doc_id": doc_id},
)
self.document_repository.update_status(doc_id, DocumentStatus.STORED)
self._safe_mark_run_stored(doc_id=doc_id, run_id=run_id)
self._safe_append_status_event(
doc_id=doc_id, run_id=run_id,
from_status=DocumentStatus.PENDING.value, to_status=DocumentStatus.STORED.value,
stage="store", message="Source file stored",
)
return doc_id, run_id
def _process_document(
self,
*,
doc_id: str,
file_name: str,
final_doc_name: str,
content: bytes,
regulation_type: str,
version: str,
generate_summary: bool,
run_id: str | None = None,
) -> DocumentProcessResult:
"""Run parse → chunk → embed → index for a document that is already stored.
Called both synchronously (from upload_and_process) and asynchronously
(from the Celery process_document_task worker). All side-effects write
through DocumentProcessingStore so callers can poll progress.
"""
current_status = DocumentStatus.STORED
current_stage = "parse"
temp_path = ""
try:
suffix = os.path.splitext(file_name)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
temp_file.write(content)
temp_path = temp_file.name
parsed_document = self.parser.parse(
file_path=temp_path,
doc_id=doc_id,
doc_name=final_doc_name,
)
self._safe_mark_run_parsed(doc_id=doc_id, run_id=run_id, parsed_document=parsed_document)
artifact_keys: dict[str, str] = {}
try:
artifact_keys = self._save_parse_artifacts(doc_id=doc_id, parsed_document=parsed_document)
except Exception:
logger.warning("Parse artifact binary persistence failed for doc_id={}", doc_id)
self.document_repository.update_status(
doc_id,
DocumentStatus.PARSED,
parser_name=parsed_document.parser_name,
metadata={
"parser_backend": parsed_document.parser_name,
"parse_task_id": parsed_document.metadata.get("task_id", ""),
"layout_count": parsed_document.metadata.get("layout_count", len(parsed_document.raw_layouts)),
"structure_node_count": len(parsed_document.structure_nodes),
"semantic_block_count": len(parsed_document.semantic_blocks),
"vector_chunk_count": len(parsed_document.vector_chunks),
"artifact_keys": artifact_keys,
"processing_stage": "parsed",
},
)
current_status = DocumentStatus.PARSED
current_stage = "embed"
self._safe_replace_processing_artifacts(doc_id=doc_id, run_id=run_id, artifact_keys=artifact_keys)
self._safe_append_status_event(
doc_id=doc_id, run_id=run_id,
from_status=DocumentStatus.STORED.value, to_status=DocumentStatus.PARSED.value,
stage="parse", message="Document parsed", metadata={"artifact_count": len(artifact_keys)},
)
if self.parse_artifact_store:
try:
self.parse_artifact_store.save(
doc_id, parsed_document.structure_nodes, parsed_document.semantic_blocks,
)
except Exception:
logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id)
chunks = self.chunk_builder.build(
parsed_document=parsed_document,
regulation_type=regulation_type,
version=version,
)
if not chunks:
raise ValueError("解析完成但没有生成可入库的 chunks")
vectors = self.embedding_provider.embed_texts([chunk.embedding_text for chunk in chunks])
current_stage = "index"
inserted = self.vector_index.upsert(chunks, vectors)
if inserted != len(chunks):
logger.warning("Milvus upsert count mismatched: inserted={}, chunks={}", inserted, len(chunks))
health = self.vector_index.health()
index_name = health.get("collection_name", "")
self.document_repository.update_status(
doc_id, DocumentStatus.INDEXED,
chunk_count=len(chunks), summary="", summary_latency_ms=0,
index_name=index_name,
metadata={"index_collection": index_name, "processing_stage": "indexed"},
)
self._safe_mark_run_indexed(doc_id=doc_id, run_id=run_id, chunk_count=len(chunks), index_name=index_name)
self._safe_append_status_event(
doc_id=doc_id, run_id=run_id,
from_status=DocumentStatus.PARSED.value, to_status=DocumentStatus.INDEXED.value,
stage="index", message="Document indexed",
metadata={"chunk_count": len(chunks), "index_name": index_name},
)
stored = self.document_repository.get(doc_id)
return DocumentProcessResult(
doc_id=doc_id, doc_name=final_doc_name,
status=(stored.status.value if stored else DocumentStatus.INDEXED.value),
message="处理成功", num_chunks=len(chunks),
summary=stored.summary if stored else "",
summary_latency_ms=stored.summary_latency_ms if stored else 0,
)
except Exception as exc:
logger.exception("文档处理失败: doc_id={}", doc_id)
self.document_repository.update_status(
doc_id, DocumentStatus.FAILED, error_message=str(exc),
metadata={"failure_reason": str(exc), "processing_stage": "failed", "failure_stage": current_stage},
)
self._safe_mark_run_failed(
doc_id=doc_id, run_id=run_id, failure_stage=current_stage, error_message=str(exc)
)
self._safe_append_status_event(
doc_id=doc_id, run_id=run_id,
from_status=current_status.value, to_status=DocumentStatus.FAILED.value,
stage=current_stage, message=str(exc),
)
return DocumentProcessResult(
doc_id=doc_id, doc_name=final_doc_name,
status=DocumentStatus.FAILED.value, message=f"文档处理失败: {exc}",
)
finally:
if temp_path and os.path.exists(temp_path):
try:
@@ -446,7 +525,6 @@ class DocumentCommandService:
except OSError:
logger.warning("临时文件清理失败: {}", temp_path)
def delete(self, doc_id: str) -> bool:
"""Delete document record, binary file, and vector chunks."""
document = self.document_repository.get(doc_id)

View File

@@ -82,6 +82,10 @@ class Settings(BaseSettings):
parser_backend: str = Field(default="aliyun", description="解析后端(local/aliyun)")
chunk_backend: str = Field(default="aliyun", description="分块后端(local/aliyun)")
document_repository_backend: str = Field(default="json", description="文档元数据存储后端 (json/postgres)")
# When True, document processing is enqueued to Celery workers via Redis.
# When False (default), processing runs in a FastAPI BackgroundTask in the same process —
# no external worker needed. Switch to True only when a Celery worker is running.
use_celery_worker: bool = Field(default=False, description="使用 Celery Worker 异步处理文档 (需要 Worker 运行中)")
# Keep configuration setup explicit so runtime behavior is easy to reason about.
api_host: str = Field(default="0.0.0.0", description="API服务地址")
@@ -109,6 +113,7 @@ class Settings(BaseSettings):
rag_retrieval_top_k: int = Field(default=20, description="精排前召回候选数量reranker 启用时生效)")
rag_max_context_tokens: int = Field(default=2000, description="RAG最大上下文token数")
rag_summary_max_tokens: int = Field(default=10240, description="文档摘要最大token数")
rag_skills_max_tokens: int = Field(default=2048, description="技能类 RAG 最大 token 数")
reranker_enabled: bool = Field(default=False, description="是否启用 Cross-Encoder 精排")
reranker_base_url: str = Field(default="", description="Reranker API 地址")
@@ -124,6 +129,26 @@ class Settings(BaseSettings):
# Keep configuration setup explicit so runtime behavior is easy to reason about.
session_max_sessions: int = Field(default=100, description="最大会话数量")
session_timeout_minutes: int = Field(default=30, description="会话超时时间(分钟)")
session_backend: str = Field(
default="memory",
description="会话存储后端 (memory | redis)。redis 需要 Redis 可用。",
)
# ── Auth ──────────────────────────────────────────────────────────────────
# Generate a strong secret: python -c "import secrets; print(secrets.token_hex(32))"
auth_secret_key: str = Field(
default="change-me-in-production-must-be-32-or-more-characters-long",
description="JWT signing secret. MUST be changed in production.",
)
auth_algorithm: str = Field(default="HS256", description="JWT signing algorithm.")
auth_token_expire_minutes: int = Field(default=480, description="JWT TTL in minutes (default 8 hours).")
auth_enabled: bool = Field(default=True, description="Set False to bypass auth (development only).")
# ── CORS ──────────────────────────────────────────────────────────────────
cors_allow_origins: str = Field(
default="http://localhost:5173",
description="Comma-separated allowed CORS origins. Never use * in production.",
)
@lru_cache
def get_settings() -> Settings:

View File

@@ -0,0 +1,10 @@
"""Auth domain: role definitions and token claim models.
The domain layer defines what a user identity looks like (UserClaims) and
what roles exist (UserRole). Infrastructure details (JWT, bcrypt, PostgreSQL)
live under infrastructure/auth and never leak into this package.
"""
from .models import UserClaims, UserRole
__all__ = ["UserClaims", "UserRole"]

View File

@@ -0,0 +1,42 @@
"""Auth domain models: roles and token claims.
UserRole defines the four roles from PPT Slide 12.
UserClaims is what the JWT decodes to — it is the identity object passed
through FastAPI dependency injection to route handlers.
"""
from __future__ import annotations
import enum
from dataclasses import dataclass
class UserRole(str, enum.Enum):
"""Access roles mirroring the four-role RBAC matrix from the product spec.
ADMIN — full platform access including system management.
LEGAL — knowledge query, document review, compliance checks.
EHS — knowledge query, perception/regulatory signals.
READONLY — knowledge query only.
"""
ADMIN = "admin"
LEGAL = "legal"
EHS = "ehs"
READONLY = "readonly"
@dataclass
class UserClaims:
"""Decoded JWT payload representing an authenticated user.
Instances are created by JWTHandler.decode_token() and injected into
route handlers via the get_current_user FastAPI dependency.
"""
# Unique user identifier (UUID string stored in PostgreSQL users table).
user_id: str
# Display name used for audit log entries.
username: str
# Role determines which resources the user may access.
role: UserRole

View File

@@ -0,0 +1,5 @@
"""JWT token creation and validation infrastructure.
JWTHandler is the only component in this package. It is wired through
shared/bootstrap.py and injected into FastAPI dependencies.
"""

View File

@@ -0,0 +1,82 @@
"""JWT access token creation and decoding.
Uses python-jose for HS256 token signing. Token expiry is enforced at
decode time so expired tokens are rejected even if the signature is valid.
"""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from typing import Any
from jose import JWTError, jwt
from loguru import logger
from app.domain.auth.models import UserClaims, UserRole
class JWTHandler:
"""Create and validate HS256 JWT access tokens.
A single shared instance is wired by bootstrap.py. Use
get_jwt_handler() from shared.bootstrap for all token operations.
"""
def __init__(
self,
*,
secret_key: str,
algorithm: str = "HS256",
expire_minutes: int = 480,
) -> None:
"""Initialise the handler with signing credentials and token lifetime."""
self._secret = secret_key
self._algorithm = algorithm
self._expire_minutes = expire_minutes
def create_access_token(
self,
*,
user_id: str,
username: str,
role: str,
) -> str:
"""Return a signed JWT containing user identity and role claims."""
now = datetime.now(UTC)
payload: dict[str, Any] = {
"sub": user_id,
"username": username,
"role": role,
"iat": now,
"exp": now + timedelta(minutes=self._expire_minutes),
}
return jwt.encode(payload, self._secret, algorithm=self._algorithm)
def decode_token(self, token: str) -> UserClaims:
"""Decode and validate a JWT, returning UserClaims.
Raises ValueError with a descriptive message on expiry, tampering,
or any other validation failure so callers do not need to know jose.
"""
try:
payload = jwt.decode(token, self._secret, algorithms=[self._algorithm])
except JWTError as exc:
msg = str(exc).lower()
if "expired" in msg:
raise ValueError("Token expired") from exc
raise ValueError(f"Invalid token: {exc}") from exc
user_id = payload.get("sub")
username = payload.get("username", "")
role_str = payload.get("role", UserRole.READONLY.value)
if not user_id:
raise ValueError("Token missing subject claim")
try:
role = UserRole(role_str)
except ValueError:
logger.warning("Unknown role in token: {}, defaulting to readonly", role_str)
role = UserRole.READONLY
return UserClaims(user_id=user_id, username=username, role=role)

View File

@@ -0,0 +1,113 @@
"""PostgreSQL-backed user store for authentication.
Manages a `users` table with hashed passwords and roles.
Provides lookup by username for the login flow.
Table DDL is auto-applied on first connection.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import psycopg2
import psycopg2.extras
from loguru import logger
from passlib.context import CryptContext
from app.config.settings import settings
# bcrypt context — work factor 12 is a good production default.
_PWD_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto")
# DDL executed once to ensure the table exists.
_CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
username VARCHAR(100) UNIQUE NOT NULL,
hashed_pw TEXT NOT NULL,
role VARCHAR(50) NOT NULL DEFAULT 'readonly',
is_active BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
@dataclass
class UserRecord:
"""A single row from the users table."""
id: str
username: str
hashed_pw: str
role: str
is_active: bool
class PostgresUserStore:
"""Read and verify users stored in the PostgreSQL users table.
The connection is opened on first use and shared for the lifetime
of the singleton instance wired by bootstrap.
"""
def __init__(self) -> None:
"""Initialise the store and ensure the users table exists."""
self._conn = psycopg2.connect(
host=settings.postgres_host,
port=settings.postgres_port,
user=settings.postgres_user,
password=settings.postgres_password,
dbname=settings.postgres_db,
cursor_factory=psycopg2.extras.RealDictCursor,
)
self._conn.autocommit = True
self._ensure_table()
def _ensure_table(self) -> None:
"""Create the users table if it does not already exist."""
with self._conn.cursor() as cur:
# Enable pgcrypto so gen_random_uuid() is available for UUID primary keys.
try:
cur.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;")
except Exception:
self._conn.rollback()
cur.execute(_CREATE_TABLE_SQL)
def get_by_username(self, username: str) -> Optional[UserRecord]:
"""Return a UserRecord for the given username, or None if not found."""
with self._conn.cursor() as cur:
cur.execute(
"SELECT id, username, hashed_pw, role, is_active "
"FROM users WHERE username = %s",
(username,),
)
row = cur.fetchone()
if row is None:
return None
return UserRecord(
id=str(row["id"]),
username=row["username"],
hashed_pw=row["hashed_pw"],
role=row["role"],
is_active=row["is_active"],
)
def verify_password(self, plain: str, hashed: str) -> bool:
"""Return True if `plain` matches the stored bcrypt hash."""
return _PWD_CTX.verify(plain, hashed)
def authenticate(self, username: str, password: str) -> Optional[UserRecord]:
"""Return the UserRecord if credentials are valid, else None."""
user = self.get_by_username(username)
if user is None or not user.is_active:
return None
if not self.verify_password(password, user.hashed_pw):
return None
return user
@staticmethod
def hash_password(plain: str) -> str:
"""Hash a plain-text password with bcrypt."""
return _PWD_CTX.hash(plain)

View File

@@ -0,0 +1,169 @@
"""Redis-backed conversation store for persistent chat sessions.
Sessions are stored as JSON strings under the key `session:{session_id}`.
The Redis TTL is refreshed on every write so active sessions stay alive.
On expiry, `get_session` returns None — callers should create a new session.
"""
from __future__ import annotations
import json
import time
import uuid
from typing import Any
from loguru import logger
from app.domain.conversation import ConversationMessage, ConversationSession, ConversationStore
class RedisConversationStore(ConversationStore):
"""Store conversation sessions in Redis with automatic TTL expiry.
Each session is serialised as a JSON object at key ``session:{session_id}``.
The TTL is reset on every write so sessions stay alive as long as they are active.
"""
# Prefix for all session keys to avoid collisions with other Redis consumers.
_PREFIX = "session:"
def __init__(self, *, redis_client: Any, timeout_seconds: int = 1800) -> None:
"""Initialise the store with an existing Redis client and a TTL in seconds."""
self._redis = redis_client
self._ttl = timeout_seconds
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _key(self, session_id: str) -> str:
"""Build the Redis key for a session."""
return f"{self._PREFIX}{session_id}"
def _serialise(self, session: ConversationSession) -> str:
"""Serialise a ConversationSession to a JSON string."""
return json.dumps(
{
"session_id": session.session_id,
"created_at": session.created_at,
"updated_at": session.updated_at,
"metadata": session.metadata,
"messages": [
{
"role": msg.role,
"content": msg.content,
"timestamp": msg.timestamp,
"sources": msg.sources,
}
for msg in session.messages
],
},
ensure_ascii=False,
)
def _deserialise(self, raw: bytes | str) -> ConversationSession:
"""Deserialise a JSON string back into a ConversationSession."""
data = json.loads(raw)
messages = [
ConversationMessage(
role=m["role"],
content=m["content"],
timestamp=m["timestamp"],
sources=m.get("sources", []),
)
for m in data.get("messages", [])
]
session = ConversationSession(
session_id=data["session_id"],
created_at=data.get("created_at", 0),
updated_at=data.get("updated_at", 0),
metadata=data.get("metadata", {}),
)
session.messages = messages
return session
def _save(self, session: ConversationSession) -> None:
"""Persist a session to Redis and refresh its TTL."""
self._redis.setex(self._key(session.session_id), self._ttl, self._serialise(session))
# ------------------------------------------------------------------
# ConversationStore protocol
# ------------------------------------------------------------------
def create_session(self, metadata: dict | None = None) -> ConversationSession:
"""Create a new empty session and persist it immediately."""
now = int(time.time())
session = ConversationSession(
session_id=str(uuid.uuid4())[:8],
created_at=now,
updated_at=now,
metadata=metadata or {},
)
self._save(session)
return session
def get_session(self, session_id: str) -> ConversationSession | None:
"""Return a session by ID, or None if it does not exist or has expired."""
raw = self._redis.get(self._key(session_id))
if raw is None:
return None
try:
return self._deserialise(raw)
except Exception:
logger.warning("Failed to deserialise session: {}", session_id)
return None
def save_message(
self,
session_id: str,
*,
role: str,
content: str,
sources: list[dict] | None = None,
) -> ConversationSession | None:
"""Append a message to a session and refresh its TTL."""
session = self.get_session(session_id)
if session is None:
return None
session.messages.append(
ConversationMessage(
role=role,
content=content,
timestamp=int(time.time()),
sources=sources or [],
)
)
session.updated_at = int(time.time())
self._save(session)
return session
def delete_session(self, session_id: str) -> bool:
"""Delete a session. Returns True if it existed, False otherwise."""
deleted = self._redis.delete(self._key(session_id))
return bool(deleted)
def list_sessions(self) -> list[dict]:
"""Return summary dicts for all live sessions visible in this Redis DB.
Note: KEYS is used for simplicity; replace with SCAN for large deployments.
"""
pattern = f"{self._PREFIX}*"
keys = self._redis.keys(pattern)
result = []
for key in keys:
raw = self._redis.get(key)
if raw is None:
continue
try:
data = json.loads(raw)
result.append(
{
"session_id": data["session_id"],
"message_count": len(data.get("messages", [])),
"created_at": data.get("created_at", 0),
"updated_at": data.get("updated_at", 0),
}
)
except Exception:
continue
return result

View File

@@ -0,0 +1,5 @@
"""Celery task definitions for background processing.
This package exposes the shared Celery application instance and all
registered task functions used by API routes to enqueue work.
"""

View File

@@ -0,0 +1,45 @@
"""Shared Celery application instance for background task processing.
All workers and enqueueing call sites import `celery_app` from this module
so the broker/backend configuration stays in one place.
"""
from __future__ import annotations
from celery import Celery
from app.config.settings import settings
def _redis_url() -> str:
"""Return a Redis connection URL from application settings."""
if settings.redis_password:
return (
f"redis://:{settings.redis_password}@"
f"{settings.redis_host}:{settings.redis_port}/{settings.redis_db}"
)
return f"redis://{settings.redis_host}:{settings.redis_port}/{settings.redis_db}"
_BROKER = _redis_url()
_BACKEND = _redis_url()
celery_app = Celery(
"compliance_hub",
broker=_BROKER,
backend=_BACKEND,
include=["app.infrastructure.tasks.document_tasks"],
)
celery_app.conf.update(
task_serializer="json",
result_serializer="json",
accept_content=["json"],
timezone="UTC",
enable_utc=True,
# Acknowledge task only after successful execution to avoid data loss.
task_acks_late=True,
task_reject_on_worker_lost=True,
# Keep results for 1 hour for status polling.
result_expires=3600,
)

View File

@@ -0,0 +1,73 @@
"""Celery tasks for document processing.
Each task is a thin wrapper that retrieves the already-stored document
binary and delegates to DocumentCommandService._process_document.
The task does not accept raw file bytes — it reads them from the binary
store using the doc_id, so the Celery message payload stays small.
"""
from __future__ import annotations
from loguru import logger
from app.infrastructure.tasks.celery_app import celery_app
@celery_app.task(
name="app.infrastructure.tasks.document_tasks.process_document_task",
bind=True,
max_retries=3,
default_retry_delay=30,
acks_late=True,
)
def process_document_task(
self,
doc_id: str,
file_name: str,
doc_name: str,
regulation_type: str,
version: str,
generate_summary: bool,
run_id: str | None = None,
) -> dict:
"""Parse, embed, and index a document that has already been stored.
The task reads the file binary from MinIO using doc_id so the Celery
message stays small. Retries up to 3 times with a 30-second delay on
transient infrastructure errors.
"""
# Import inside the task function to avoid pickling issues and to ensure
# that each worker process initialises its own bootstrap singletons.
from app.shared.bootstrap import get_document_command_service, get_document_query_service
logger.info("process_document_task started: doc_id={}", doc_id)
try:
svc = get_document_command_service()
doc = get_document_query_service().get(doc_id)
if not doc:
raise ValueError(f"Document record not found: {doc_id}")
# Read the stored binary from MinIO — avoids passing raw bytes in the task message.
content = svc.binary_store.read(doc.object_name)
result = svc._process_document(
doc_id=doc_id,
file_name=file_name,
final_doc_name=doc_name,
content=content,
regulation_type=regulation_type,
version=version,
generate_summary=generate_summary,
run_id=run_id,
)
logger.info(
"process_document_task completed: doc_id={} status={} chunks={}",
doc_id, result.status, result.num_chunks,
)
return {"doc_id": result.doc_id, "status": result.status, "num_chunks": result.num_chunks}
except Exception as exc:
logger.exception("process_document_task failed: doc_id={}", doc_id)
# Retry on transient errors; permanent errors (bad file, parse failure)
# will exhaust retries and leave the document in FAILED state.
raise self.retry(exc=exc)

View File

@@ -252,7 +252,31 @@ def get_document_query_service() -> DocumentQueryService:
@lru_cache
def get_conversation_store() -> InMemoryConversationStore:
"""Return conversation store."""
"""Return the active conversation store based on settings.
When session_backend='redis', sessions survive backend restarts and scale
across multiple API worker processes. When session_backend='memory' (default),
sessions are process-local and lost on restart.
"""
if settings.session_backend == "redis":
import redis as redis_lib
from app.infrastructure.session.redis_conversation_store import RedisConversationStore
# Build the Redis client from the same connection settings used by Celery.
kwargs: dict = {
"host": settings.redis_host,
"port": settings.redis_port,
"db": settings.redis_db,
"decode_responses": False,
}
if settings.redis_password:
kwargs["password"] = settings.redis_password
redis_client = redis_lib.Redis(**kwargs)
return RedisConversationStore( # type: ignore[return-value]
redis_client=redis_client,
timeout_seconds=settings.session_timeout_minutes * 60,
)
return InMemoryConversationStore(
max_sessions=settings.session_max_sessions,
timeout_minutes=settings.session_timeout_minutes,
@@ -284,6 +308,35 @@ def get_agent_session_service() -> AgentSessionService:
return AgentSessionService(conversation_store=get_conversation_store())
@lru_cache
def get_celery_app():
"""Return the shared Celery application instance.
Imported lazily so Celery is not required when running without workers
(e.g., tests that mock bootstrap or dev without Redis).
"""
from app.infrastructure.tasks.celery_app import celery_app
return celery_app
@lru_cache
def get_jwt_handler():
"""Return the shared JWTHandler instance for token creation and validation."""
from app.infrastructure.auth.jwt_handler import JWTHandler
return JWTHandler(
secret_key=settings.auth_secret_key,
algorithm=settings.auth_algorithm,
expire_minutes=settings.auth_token_expire_minutes,
)
@lru_cache
def get_user_store():
"""Return the PostgreSQL user store (lazy-connects on first call)."""
from app.infrastructure.auth.user_store import PostgresUserStore
return PostgresUserStore()
def preload_runtime_dependencies() -> None:
"""Warm dependencies that are safe and useful to preload during startup."""
LLMFactory.preload_clients(["qwen", "deepseek"])

View File

@@ -1,30 +1,46 @@
# ── Web framework ─────────────────────────────────────────────────────────────
fastapi>=0.110.0
uvicorn[standard]>=0.27.0
python-multipart>=0.0.9
# ── Config & utilities ────────────────────────────────────────────────────────
pydantic>=2.0.0
pydantic-settings>=2.0.0
python-dotenv>=1.0.0
loguru>=0.7.0
httpx>=0.25.0
tiktoken>=0.5.0
tenacity>=8.2.0
# ── Auth ──────────────────────────────────────────────────────────────────────
python-jose[cryptography]>=3.3.0
# passlib is incompatible with bcrypt>=4.0 (removed __about__, strict 72-byte limit).
# Pin bcrypt to 3.x until passlib ships a fix.
passlib[bcrypt]>=1.7.4
bcrypt>=3.2.0,<4.0.0
# ── Async task queue ──────────────────────────────────────────────────────────
celery>=5.3.0
redis>=4.5.0
# ── Storage & databases ───────────────────────────────────────────────────────
pymilvus>=2.4.0
minio>=7.1.0
psycopg2-binary>=2.9.0
# ── Document parsing ─────────────────────────────────────────────────────────
pymupdf>=1.24.0
python-docx>=1.1.0
numpy>=1.24.0
alibabacloud-docmind-api20220711>=1.0.6
alibabacloud-tea-openapi>=0.3.11
alibabacloud-tea-util>=0.3.13
# ── RAG / LangChain ───────────────────────────────────────────────────────────
langchain>=0.1.0
langchain-milvus>=0.1.0
numpy>=1.24.0
# ── Testing ───────────────────────────────────────────────────────────────────
pytest>=7.4.0
pytest-asyncio>=0.21.0
fakeredis>=2.0.0