diff --git a/.env b/.env index 076b7f4..7cef945 100644 --- a/.env +++ b/.env @@ -50,6 +50,9 @@ PARSER_BACKEND=aliyun CHUNK_BACKEND=aliyun # 文档元数据存储后端:json(默认)或 postgres DOCUMENT_REPOSITORY_BACKEND=json +# Set to true only when a Celery worker is actually running (./dev.sh start worker). +# Default false: processing runs in FastAPI's threadpool — no external worker needed. +USE_CELERY_WORKER=false # ===== API配置 ===== API_HOST=0.0.0.0 @@ -92,3 +95,23 @@ ALIYUN_LLM_ENHANCEMENT=true ALIYUN_ENHANCEMENT_MODE=VLM DOCUMENT_PARSE_ARTIFACT_PREFIX=artifacts PARSER_FAILURE_MODE=fail + +# ===== Reranker 配置 ===== +RERANKER_ENABLED=true +RERANKER_BASE_URL=http://6.86.80.4:30080/v1 +RERANKER_MODEL=BAAI/bge-reranker-v2-m3 +RERANKER_API_KEY= +RERANKER_TOP_K=5 + +# ===== 会话持久化 ===== +SESSION_BACKEND=redis + +# ===== 认证配置 ===== +# 生产环境请修改为强随机密钥: python -c "import secrets; print(secrets.token_hex(32))" +AUTH_SECRET_KEY=ai-compliance-hub-jwt-secret-2026-tsystems +AUTH_ALGORITHM=HS256 +AUTH_TOKEN_EXPIRE_MINUTES=480 +AUTH_ENABLED=true + +# ===== CORS ===== +CORS_ALLOW_ORIGINS=http://localhost:5173 diff --git a/.env.example b/.env.example index dd83ae8..26131db 100644 --- a/.env.example +++ b/.env.example @@ -49,8 +49,11 @@ MAX_FILE_SIZE_MB=100 DOCUMENT_METADATA_PATH=backend/data/documents.json PARSER_BACKEND=aliyun CHUNK_BACKEND=aliyun -# 文档元数据存储后端:json(默认,无需数据库)或 postgres(启用 PG 持久化) +# DOCUMENT_REPOSITORY_BACKEND=json(默认,无需数据库)或 postgres(启用 PG 持久化) DOCUMENT_REPOSITORY_BACKEND=json +# Set to true only when a Celery worker is running (./dev.sh start worker). +# Default false: document processing runs in FastAPI's threadpool (no external worker needed). +USE_CELERY_WORKER=false # ===== 阿里云文档解析 ===== ALIBABA_ACCESS_KEY_ID=your_aliyun_access_key_id @@ -96,11 +99,15 @@ RAG_TOP_K=10 RAG_RETRIEVAL_TOP_K=20 RAG_MAX_CONTEXT_TOKENS=4000 RAG_SUMMARY_MAX_TOKENS=1024 +RAG_SKILLS_MAX_TOKENS=2048 -# ===== Reranker配置(Cross-Encoder精排,默认关闭)===== -# 设置 RERANKER_ENABLED=true 并配置 RERANKER_BASE_URL 以启用精排 -RERANKER_ENABLED=false -RERANKER_BASE_URL= +# ── Reranker (Cross-Encoder) ────────────────────────────────────────────────── +# Set RERANKER_ENABLED=true and point to a TEI or Cohere-compatible rerank API. +# Recommended model: BAAI/bge-reranker-v2.5-gemma2-lightweight (lighter) or +# BAAI/bge-reranker-v2-m3 (heavier, higher quality). +# The endpoint must expose POST /rerank (TEI style) or POST /v1/rerank (Cohere style). +RERANKER_ENABLED=true +RERANKER_BASE_URL=http://6.86.80.4:30080/v1 RERANKER_MODEL=BAAI/bge-reranker-v2-m3 RERANKER_API_KEY= RERANKER_TOP_K=5 @@ -108,3 +115,20 @@ RERANKER_TOP_K=5 # ===== 会话配置 ===== SESSION_MAX_SESSIONS=100 SESSION_TIMEOUT_MINUTES=30 +# SESSION_BACKEND=redis 启用 Redis 持久化会话(需要 Redis 可用,推荐生产环境) +# SESSION_BACKEND=memory 使用内存会话(重启丢失,适合本地开发) +SESSION_BACKEND=memory + +# ===== 认证配置 (Auth) ===== +# 生产环境必须替换为强随机密钥: +# python -c "import secrets; print(secrets.token_hex(32))" +AUTH_SECRET_KEY=change-me-in-production-must-be-32-or-more-characters-long +AUTH_ALGORITHM=HS256 +# Token 有效期(分钟),默认 8 小时 +AUTH_TOKEN_EXPIRE_MINUTES=480 +# 设为 false 可跳过认证(仅限本地开发调试,生产必须 true) +AUTH_ENABLED=true + +# ===== CORS ===== +# 逗号分隔的允许跨域来源列表,生产环境绝不能使用 * +CORS_ALLOW_ORIGINS=http://localhost:5173 diff --git a/backend/app/api/dependencies/__init__.py b/backend/app/api/dependencies/__init__.py new file mode 100644 index 0000000..abab49d --- /dev/null +++ b/backend/app/api/dependencies/__init__.py @@ -0,0 +1,5 @@ +"""FastAPI dependency functions for authentication and authorisation. + +Import `get_current_user` or `require_role` into route modules to protect +endpoints. Both use the shared JWTHandler wired through bootstrap. +""" diff --git a/backend/app/api/dependencies/auth.py b/backend/app/api/dependencies/auth.py new file mode 100644 index 0000000..d0b24c3 --- /dev/null +++ b/backend/app/api/dependencies/auth.py @@ -0,0 +1,72 @@ +"""FastAPI dependencies for JWT authentication. + +Usage in a route: + from app.api.dependencies.auth import get_current_user, require_role + from app.domain.auth.models import UserRole + + @router.get("/protected") + async def protected(user: UserClaims = Depends(get_current_user)): + return {"user": user.username} + + @router.delete("/admin-only") + async def admin_only(user: UserClaims = Depends(require_role(UserRole.ADMIN))): + ... +""" + +from __future__ import annotations + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from app.config.settings import settings +from app.domain.auth.models import UserClaims, UserRole +from app.shared.bootstrap import get_jwt_handler + +# Use Bearer token scheme — client sends `Authorization: Bearer `. +_bearer = HTTPBearer(auto_error=False) + + +async def get_current_user( + credentials: HTTPAuthorizationCredentials | None = Depends(_bearer), +) -> UserClaims: + """Extract and validate the JWT from the Authorization header. + + Returns the decoded UserClaims on success. + Raises HTTP 401 when the token is missing, expired, or invalid. + When auth_enabled=False (development), returns a synthetic admin user. + """ + if not settings.auth_enabled: + # Development bypass — never enable this in production. + return UserClaims(user_id="dev", username="dev-admin", role=UserRole.ADMIN) + + if credentials is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing authentication token", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + return get_jwt_handler().decode_token(credentials.credentials) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=str(exc), + headers={"WWW-Authenticate": "Bearer"}, + ) from exc + + +def require_role(*roles: UserRole): + """Return a dependency that enforces one of the given roles. + + Example: + Depends(require_role(UserRole.ADMIN, UserRole.LEGAL)) + """ + async def _check(user: UserClaims = Depends(get_current_user)) -> UserClaims: + """Verify the user holds one of the required roles.""" + if user.role not in roles: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Role '{user.role}' is not permitted. Required: {[r.value for r in roles]}", + ) + return user + return _check diff --git a/backend/app/api/main.py b/backend/app/api/main.py index d8b7345..3777159 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -8,6 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from loguru import logger +from app.api.middleware.audit import AuditMiddleware from app.api.models import ErrorResponse from app.api.routes import api_router from app.config.logging import setup_logging @@ -46,14 +47,23 @@ app = FastAPI( redoc_url="/redoc", ) +# Tighten CORS — only allow configured origins. +# Set CORS_ALLOW_ORIGINS in .env to the real frontend URL in production. +_ORIGINS = [o.strip() for o in settings.cors_allow_origins.split(",") if o.strip()] +if not _ORIGINS: + _ORIGINS = ["http://localhost:5173"] + app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) +# Audit middleware logs every authenticated API call for compliance traceability. +app.add_middleware(AuditMiddleware) + app.include_router(api_router, prefix="/api/v1") diff --git a/backend/app/api/middleware/__init__.py b/backend/app/api/middleware/__init__.py new file mode 100644 index 0000000..a0ad3b8 --- /dev/null +++ b/backend/app/api/middleware/__init__.py @@ -0,0 +1 @@ +"""HTTP middleware for cross-cutting concerns: audit logging.""" diff --git a/backend/app/api/middleware/audit.py b/backend/app/api/middleware/audit.py new file mode 100644 index 0000000..365bd6e --- /dev/null +++ b/backend/app/api/middleware/audit.py @@ -0,0 +1,56 @@ +"""Audit logging middleware. + +Logs every API request with method, path, status code, response time, +and the authenticated user identity (extracted from the JWT when present). +Log lines are structured so they can be ingested by ELK / Loki. +""" + +from __future__ import annotations + +import time + +from fastapi import Request, Response +from loguru import logger +from starlette.middleware.base import BaseHTTPMiddleware + + +class AuditMiddleware(BaseHTTPMiddleware): + """Log all API calls. Skips health/docs paths to reduce noise.""" + + # Paths that produce no audit log entry. + _SKIP_PREFIXES = ("/health", "/docs", "/redoc", "/openapi.json") + + async def dispatch(self, request: Request, call_next) -> Response: + """Intercept the request, call the handler, and log the outcome.""" + path = request.url.path + if path == "/" or any(path == p or path.startswith(p + "/") for p in self._SKIP_PREFIXES): + return await call_next(request) + + start = time.perf_counter() + response = await call_next(request) + elapsed_ms = int((time.perf_counter() - start) * 1000) + + # Extract user identity from JWT header for structured audit records. + # The token is not re-validated here — auth dependencies do that upstream. + user_id = "anonymous" + username = "anonymous" + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + try: + from app.shared.bootstrap import get_jwt_handler + claims = get_jwt_handler().decode_token(auth_header[7:]) + user_id = claims.user_id + username = claims.username + except Exception: + pass + + logger.info( + "AUDIT method={} path={} status={} elapsed_ms={} user_id={} username={}", + request.method, + path, + response.status_code, + elapsed_ms, + user_id, + username, + ) + return response diff --git a/backend/app/api/routes/__init__.py b/backend/app/api/routes/__init__.py index d4f5f0e..cd96ac7 100644 --- a/backend/app/api/routes/__init__.py +++ b/backend/app/api/routes/__init__.py @@ -1,6 +1,7 @@ """Initialize the app.api.routes package.""" from fastapi import APIRouter +from .auth import router as auth_router from .compliance import router as compliance_router from .documents import router as documents_router from .knowledge import router as knowledge_router @@ -14,7 +15,8 @@ from .rag import router as rag_router # Keep package boundaries explicit so backend imports stay predictable. api_router = APIRouter() -# Keep package boundaries explicit so backend imports stay predictable. +# Auth routes first so /auth/token is easy to discover. +api_router.include_router(auth_router) api_router.include_router(documents_router) api_router.include_router(knowledge_router) api_router.include_router(agent_router) @@ -25,6 +27,7 @@ api_router.include_router(rag_router) __all__ = [ "api_router", + "auth_router", "documents_router", "knowledge_router", "agent_router", diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py new file mode 100644 index 0000000..7810929 --- /dev/null +++ b/backend/app/api/routes/auth.py @@ -0,0 +1,63 @@ +"""Authentication routes — token issuance only. + +POST /auth/token — exchange username + password for a JWT. +GET /auth/me — return the current user identity (requires token). +""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import OAuth2PasswordRequestForm +from pydantic import BaseModel + +from app.api.dependencies.auth import get_current_user +from app.config.settings import settings +from app.domain.auth.models import UserClaims +from app.shared.bootstrap import get_jwt_handler, get_user_store + + +router = APIRouter(prefix="/auth", tags=["认证"]) + + +class TokenResponse(BaseModel): + """JWT token response body.""" + + access_token: str + token_type: str = "bearer" + expires_in: int + + +@router.post("/token", response_model=TokenResponse) +async def login(form: OAuth2PasswordRequestForm = Depends()): + """Issue a JWT for valid username + password credentials. + + Uses standard OAuth2 password grant form fields — compatible with + Swagger UI Authorize button. + """ + user = get_user_store().authenticate(form.username, form.password) + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + token = get_jwt_handler().create_access_token( + user_id=user.id, + username=user.username, + role=user.role, + ) + return TokenResponse( + access_token=token, + token_type="bearer", + expires_in=settings.auth_token_expire_minutes * 60, + ) + + +@router.get("/me") +async def get_me(current_user: UserClaims = Depends(get_current_user)): + """Return the identity of the currently authenticated user.""" + return { + "user_id": current_user.user_id, + "username": current_user.username, + "role": current_user.role.value, + } diff --git a/backend/app/api/routes/compliance.py b/backend/app/api/routes/compliance.py index 1c61486..3654eaf 100644 --- a/backend/app/api/routes/compliance.py +++ b/backend/app/api/routes/compliance.py @@ -7,10 +7,12 @@ import json from pathlib import Path from typing import AsyncGenerator, Optional -from fastapi import APIRouter, File, Form, UploadFile +from fastapi import APIRouter, Depends, File, Form, UploadFile from fastapi.responses import StreamingResponse from loguru import logger +from app.api.dependencies.auth import get_current_user +from app.domain.auth.models import UserClaims from app.schemas.compliance import ( AnalyzeResponse, ComplianceChatRequest, @@ -75,6 +77,7 @@ async def analyze_stream( file: Optional[UploadFile] = File(None), domains: Optional[str] = Form(None), title: Optional[str] = Form(None), + current_user: UserClaims = Depends(get_current_user), ): """Stream compliance analysis as SSE events. diff --git a/backend/app/api/routes/documents.py b/backend/app/api/routes/documents.py index a4024fb..91c1a22 100644 --- a/backend/app/api/routes/documents.py +++ b/backend/app/api/routes/documents.py @@ -5,12 +5,15 @@ from __future__ import annotations from io import BytesIO from urllib.parse import quote -from fastapi import APIRouter, File, Form, HTTPException, UploadFile +from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, UploadFile from fastapi.responses import StreamingResponse from loguru import logger +from app.api.dependencies.auth import get_current_user from app.api.models import DocumentUploadResponse from app.application.documents import DocumentProcessResult +from app.config.settings import settings +from app.domain.auth.models import UserClaims from app.shared.bootstrap import get_document_command_service, get_document_query_service # Keep route handlers close to their transport-layer wiring for easier auditing. @@ -31,16 +34,60 @@ def _document_response(result: DocumentProcessResult) -> DocumentUploadResponse: ) +def _run_process_in_background( + *, + doc_id: str, + file_name: str, + final_doc_name: str, + content: bytes, + regulation_type: str, + version: str, + generate_summary: bool, + run_id: str | None, +) -> None: + """Run document processing synchronously inside a FastAPI BackgroundTask thread. + + FastAPI executes BackgroundTasks in a threadpool executor, so blocking I/O + (parser API calls, embedding, Milvus upsert) is safe here. + """ + try: + svc = get_document_command_service() + svc._process_document( + doc_id=doc_id, + file_name=file_name, + final_doc_name=final_doc_name, + content=content, + regulation_type=regulation_type, + version=version, + generate_summary=generate_summary, + run_id=run_id, + ) + except Exception: + logger.exception("BackgroundTask document processing failed: doc_id={}", doc_id) + + @router.post("/upload", response_model=DocumentUploadResponse) async def upload_document( + background_tasks: BackgroundTasks, file: UploadFile = File(..., description="上传的文档文件"), doc_id: str | None = Form(None, description="客户端预分配的文档ID,不传则自动生成"), doc_name: str | None = Form(None, description="文档名称"), regulation_type: str | None = Form(None, description="法规类型"), version: str | None = Form(None, description="文档版本"), generate_summary: bool = Form(False, description="是否生成摘要"), + sync: bool = Form(False, description="同步处理(演示/测试用,默认异步处理)"), + current_user: UserClaims = Depends(get_current_user), ): - """Handle upload document.""" + """Upload a document and process it asynchronously. + + Default path (sync=false): + 1. Store binary to MinIO immediately — returns within seconds. + 2. Schedule parse→embed→index as a FastAPI BackgroundTask (same process, + threadpool) OR enqueue to Celery workers when USE_CELERY_WORKER=true. + 3. Poll GET /documents/status/{doc_id} for progress. + + sync=true path: full inline processing, blocks until complete (demo / CI use). + """ content = await file.read() if not file.filename: raise HTTPException(status_code=400, detail="文件名不能为空") @@ -48,19 +95,73 @@ async def upload_document( raise HTTPException(status_code=400, detail="上传文件为空") try: - result = get_document_command_service().upload_and_process( - doc_id=doc_id, - file_name=file.filename, - content=content, - content_type=file.content_type or "application/octet-stream", - doc_name=doc_name, - regulation_type=regulation_type or "", - version=version or "", - generate_summary=generate_summary, - ) + svc = get_document_command_service() + + if sync: + # Synchronous fallback: full inline processing. + result = svc.upload_and_process( + doc_id=doc_id, + file_name=file.filename, + content=content, + content_type=file.content_type or "application/octet-stream", + doc_name=doc_name, + regulation_type=regulation_type or "", + version=version or "", + generate_summary=generate_summary, + ) + else: + # Step 1: store binary and create the document record (fast, sync). + stored_doc_id, run_id = svc.store_document( + doc_id=doc_id, + file_name=file.filename, + content=content, + content_type=file.content_type or "application/octet-stream", + doc_name=doc_name, + regulation_type=regulation_type or "", + version=version or "", + generate_summary=generate_summary, + ) + final_doc_name = doc_name or file.filename + + # Step 2: schedule processing via Celery worker OR FastAPI BackgroundTask. + if settings.use_celery_worker: + from app.infrastructure.tasks.document_tasks import process_document_task + process_document_task.delay( + doc_id=stored_doc_id, + file_name=file.filename, + doc_name=final_doc_name, + regulation_type=regulation_type or "", + version=version or "", + generate_summary=generate_summary, + run_id=run_id, + ) + processing_note = "已入 Celery 队列,由 Worker 处理。" + else: + # Default: run in FastAPI's threadpool — no external worker needed. + background_tasks.add_task( + _run_process_in_background, + doc_id=stored_doc_id, + file_name=file.filename, + final_doc_name=final_doc_name, + content=content, + regulation_type=regulation_type or "", + version=version or "", + generate_summary=generate_summary, + run_id=run_id, + ) + processing_note = "正在后台处理。" + + result = DocumentProcessResult( + doc_id=stored_doc_id, + doc_name=final_doc_name, + status="stored", + message=f"文件已存储,{processing_note}请轮询 GET /documents/status/{{doc_id}} 查看进度。", + ) + if result.status == "failed": raise HTTPException(status_code=500, detail=result.message) return _document_response(result) + except HTTPException: raise except Exception as exc: @@ -106,7 +207,7 @@ async def download_document(doc_id: str): @router.get("/list") -async def list_documents(): +async def list_documents(current_user: UserClaims = Depends(get_current_user)): """List documents.""" documents = get_document_query_service().list_documents() return { @@ -148,7 +249,7 @@ async def get_document_management_list(): @router.delete("/{doc_id}") -async def delete_document(doc_id: str): +async def delete_document(doc_id: str, current_user: UserClaims = Depends(get_current_user)): """Delete a document and its associated data.""" deleted = get_document_command_service().delete(doc_id) if not deleted: diff --git a/backend/app/api/routes/rag.py b/backend/app/api/routes/rag.py index 1e79d88..bf37d40 100644 --- a/backend/app/api/routes/rag.py +++ b/backend/app/api/routes/rag.py @@ -5,10 +5,12 @@ from __future__ import annotations import json from typing import AsyncGenerator -from fastapi import APIRouter +from fastapi import APIRouter, Depends from fastapi.responses import StreamingResponse +from app.api.dependencies.auth import get_current_user from app.config.settings import settings +from app.domain.auth.models import UserClaims from app.schemas.rag import RagChatRequest, QuickQuestionsResponse, QuickQuestion from app.shared.async_utils import iter_in_thread from app.shared.bootstrap import get_agent_conversation_service @@ -27,7 +29,10 @@ _DEFAULT_QUICK_QUESTIONS = [ @router.post("/chat") -async def rag_chat(request: RagChatRequest): +async def rag_chat( + request: RagChatRequest, + current_user: UserClaims = Depends(get_current_user), +): """Stream RAG Q&A using the real agent service.""" session_id, event_stream = get_agent_conversation_service().stream_chat( query=request.query, diff --git a/backend/app/application/documents/services.py b/backend/app/application/documents/services.py index 2fbfe91..c39a57f 100644 --- a/backend/app/application/documents/services.py +++ b/backend/app/application/documents/services.py @@ -277,7 +277,6 @@ class DocumentCommandService: message="Document record created", ) - temp_path = "" try: self.binary_store.save( object_name=object_name, @@ -297,117 +296,20 @@ class DocumentCommandService: stage="store", message="Source file stored", ) - - suffix = os.path.splitext(file_name)[1] - with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: - temp_file.write(content) - temp_path = temp_file.name - - parsed_document = self.parser.parse( - file_path=temp_path, + # Delegate parse → embed → index to the shared processing method. + # This same method is invoked by the Celery worker for async processing. + return self._process_document( doc_id=doc_id, - doc_name=final_doc_name, - ) - self._safe_mark_run_parsed(doc_id=doc_id, run_id=run_id, parsed_document=parsed_document) - - artifact_keys: dict[str, str] = {} - try: - artifact_keys = self._save_parse_artifacts(doc_id=doc_id, parsed_document=parsed_document) - except Exception: - logger.warning("Parse artifact binary persistence failed for doc_id={}", doc_id) - self.document_repository.update_status( - doc_id, - DocumentStatus.PARSED, - parser_name=parsed_document.parser_name, - metadata={ - "parser_backend": parsed_document.parser_name, - "parse_task_id": parsed_document.metadata.get("task_id", ""), - "layout_count": parsed_document.metadata.get("layout_count", len(parsed_document.raw_layouts)), - "structure_node_count": len(parsed_document.structure_nodes), - "semantic_block_count": len(parsed_document.semantic_blocks), - "vector_chunk_count": len(parsed_document.vector_chunks), - "artifact_keys": artifact_keys, - "processing_stage": "parsed", - }, - ) - current_status = DocumentStatus.PARSED - current_stage = "embed" - self._safe_replace_processing_artifacts(doc_id=doc_id, run_id=run_id, artifact_keys=artifact_keys) - self._safe_append_status_event( - doc_id=doc_id, - run_id=run_id, - from_status=DocumentStatus.STORED.value, - to_status=DocumentStatus.PARSED.value, - stage="parse", - message="Document parsed", - metadata={"artifact_count": len(artifact_keys)}, - ) - if self.parse_artifact_store: - try: - self.parse_artifact_store.save( - doc_id, - parsed_document.structure_nodes, - parsed_document.semantic_blocks, - ) - except Exception: - logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id) - - chunks = self.chunk_builder.build( - parsed_document=parsed_document, + file_name=file_name, + final_doc_name=final_doc_name, + content=content, regulation_type=regulation_type, version=version, - ) - if not chunks: - raise ValueError("解析完成但没有生成可入库的 chunks") - - vectors = self.embedding_provider.embed_texts([chunk.embedding_text for chunk in chunks]) - current_stage = "index" - inserted = self.vector_index.upsert(chunks, vectors) - if inserted != len(chunks): - logger.warning("Milvus upsert count mismatched: inserted={}, chunks={}", inserted, len(chunks)) - - health = self.vector_index.health() - self.document_repository.update_status( - doc_id, - DocumentStatus.INDEXED, - chunk_count=len(chunks), - summary="", - summary_latency_ms=0, - index_name=health.get("collection_name", ""), - metadata={ - "index_collection": health.get("collection_name", ""), - "processing_stage": "indexed", - }, - ) - current_status = DocumentStatus.INDEXED - index_name = health.get("collection_name", "") - self._safe_mark_run_indexed( - doc_id=doc_id, + generate_summary=generate_summary, run_id=run_id, - chunk_count=len(chunks), - index_name=index_name, - ) - self._safe_append_status_event( - doc_id=doc_id, - run_id=run_id, - from_status=DocumentStatus.PARSED.value, - to_status=DocumentStatus.INDEXED.value, - stage="index", - message="Document indexed", - metadata={"chunk_count": len(chunks), "index_name": index_name}, - ) - stored = self.document_repository.get(doc_id) - return DocumentProcessResult( - doc_id=doc_id, - doc_name=final_doc_name, - status=(stored.status.value if stored else DocumentStatus.INDEXED.value), - message="处理成功", - num_chunks=len(chunks), - summary=stored.summary if stored else "", - summary_latency_ms=stored.summary_latency_ms if stored else 0, ) except Exception as exc: - logger.exception("文档处理失败: doc_id={}", doc_id) + logger.exception("文档存储失败: doc_id={}", doc_id) failure_stage = current_stage self.document_repository.update_status( doc_id, @@ -439,6 +341,183 @@ class DocumentCommandService: status=DocumentStatus.FAILED.value, message=f"文档处理失败: {exc}", ) + + def store_document( + self, + *, + doc_id: str | None = None, + file_name: str, + content: bytes, + content_type: str, + doc_name: str | None, + regulation_type: str, + version: str, + generate_summary: bool, + ) -> tuple[str, str | None]: + """Store the binary file and create the Document record. + + Returns (doc_id, run_id). Does NOT parse, embed, or index. + This is the fast synchronous first step; processing is enqueued separately. + The caller is responsible for enqueuing the follow-up process_document_task. + """ + doc_id = doc_id or str(uuid.uuid4())[:8] + final_doc_name = doc_name or file_name + object_name = f"{doc_id}/{file_name}" + + document = Document( + doc_id=doc_id, + doc_name=final_doc_name, + file_name=file_name, + object_name=object_name, + content_type=content_type, + size_bytes=len(content), + regulation_type=regulation_type, + version=version, + metadata={"generate_summary": generate_summary}, + ) + self.document_repository.create(document) + run_id = self._safe_create_processing_run( + doc_id=doc_id, trigger_type="upload", generate_summary=generate_summary + ) + self.binary_store.save( + object_name=object_name, data=content, + content_type=content_type, metadata={"doc_id": doc_id}, + ) + self.document_repository.update_status(doc_id, DocumentStatus.STORED) + self._safe_mark_run_stored(doc_id=doc_id, run_id=run_id) + self._safe_append_status_event( + doc_id=doc_id, run_id=run_id, + from_status=DocumentStatus.PENDING.value, to_status=DocumentStatus.STORED.value, + stage="store", message="Source file stored", + ) + return doc_id, run_id + + def _process_document( + self, + *, + doc_id: str, + file_name: str, + final_doc_name: str, + content: bytes, + regulation_type: str, + version: str, + generate_summary: bool, + run_id: str | None = None, + ) -> DocumentProcessResult: + """Run parse → chunk → embed → index for a document that is already stored. + + Called both synchronously (from upload_and_process) and asynchronously + (from the Celery process_document_task worker). All side-effects write + through DocumentProcessingStore so callers can poll progress. + """ + current_status = DocumentStatus.STORED + current_stage = "parse" + temp_path = "" + try: + suffix = os.path.splitext(file_name)[1] + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: + temp_file.write(content) + temp_path = temp_file.name + + parsed_document = self.parser.parse( + file_path=temp_path, + doc_id=doc_id, + doc_name=final_doc_name, + ) + self._safe_mark_run_parsed(doc_id=doc_id, run_id=run_id, parsed_document=parsed_document) + + artifact_keys: dict[str, str] = {} + try: + artifact_keys = self._save_parse_artifacts(doc_id=doc_id, parsed_document=parsed_document) + except Exception: + logger.warning("Parse artifact binary persistence failed for doc_id={}", doc_id) + + self.document_repository.update_status( + doc_id, + DocumentStatus.PARSED, + parser_name=parsed_document.parser_name, + metadata={ + "parser_backend": parsed_document.parser_name, + "parse_task_id": parsed_document.metadata.get("task_id", ""), + "layout_count": parsed_document.metadata.get("layout_count", len(parsed_document.raw_layouts)), + "structure_node_count": len(parsed_document.structure_nodes), + "semantic_block_count": len(parsed_document.semantic_blocks), + "vector_chunk_count": len(parsed_document.vector_chunks), + "artifact_keys": artifact_keys, + "processing_stage": "parsed", + }, + ) + current_status = DocumentStatus.PARSED + current_stage = "embed" + self._safe_replace_processing_artifacts(doc_id=doc_id, run_id=run_id, artifact_keys=artifact_keys) + self._safe_append_status_event( + doc_id=doc_id, run_id=run_id, + from_status=DocumentStatus.STORED.value, to_status=DocumentStatus.PARSED.value, + stage="parse", message="Document parsed", metadata={"artifact_count": len(artifact_keys)}, + ) + if self.parse_artifact_store: + try: + self.parse_artifact_store.save( + doc_id, parsed_document.structure_nodes, parsed_document.semantic_blocks, + ) + except Exception: + logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id) + + chunks = self.chunk_builder.build( + parsed_document=parsed_document, + regulation_type=regulation_type, + version=version, + ) + if not chunks: + raise ValueError("解析完成但没有生成可入库的 chunks") + + vectors = self.embedding_provider.embed_texts([chunk.embedding_text for chunk in chunks]) + current_stage = "index" + inserted = self.vector_index.upsert(chunks, vectors) + if inserted != len(chunks): + logger.warning("Milvus upsert count mismatched: inserted={}, chunks={}", inserted, len(chunks)) + + health = self.vector_index.health() + index_name = health.get("collection_name", "") + self.document_repository.update_status( + doc_id, DocumentStatus.INDEXED, + chunk_count=len(chunks), summary="", summary_latency_ms=0, + index_name=index_name, + metadata={"index_collection": index_name, "processing_stage": "indexed"}, + ) + self._safe_mark_run_indexed(doc_id=doc_id, run_id=run_id, chunk_count=len(chunks), index_name=index_name) + self._safe_append_status_event( + doc_id=doc_id, run_id=run_id, + from_status=DocumentStatus.PARSED.value, to_status=DocumentStatus.INDEXED.value, + stage="index", message="Document indexed", + metadata={"chunk_count": len(chunks), "index_name": index_name}, + ) + stored = self.document_repository.get(doc_id) + return DocumentProcessResult( + doc_id=doc_id, doc_name=final_doc_name, + status=(stored.status.value if stored else DocumentStatus.INDEXED.value), + message="处理成功", num_chunks=len(chunks), + summary=stored.summary if stored else "", + summary_latency_ms=stored.summary_latency_ms if stored else 0, + ) + except Exception as exc: + logger.exception("文档处理失败: doc_id={}", doc_id) + self.document_repository.update_status( + doc_id, DocumentStatus.FAILED, error_message=str(exc), + metadata={"failure_reason": str(exc), "processing_stage": "failed", "failure_stage": current_stage}, + ) + self._safe_mark_run_failed( + doc_id=doc_id, run_id=run_id, failure_stage=current_stage, error_message=str(exc) + ) + self._safe_append_status_event( + doc_id=doc_id, run_id=run_id, + from_status=current_status.value, to_status=DocumentStatus.FAILED.value, + stage=current_stage, message=str(exc), + ) + return DocumentProcessResult( + doc_id=doc_id, doc_name=final_doc_name, + status=DocumentStatus.FAILED.value, message=f"文档处理失败: {exc}", + ) finally: if temp_path and os.path.exists(temp_path): try: @@ -446,7 +525,6 @@ class DocumentCommandService: except OSError: logger.warning("临时文件清理失败: {}", temp_path) - def delete(self, doc_id: str) -> bool: """Delete document record, binary file, and vector chunks.""" document = self.document_repository.get(doc_id) diff --git a/backend/app/config/settings.py b/backend/app/config/settings.py index 8f59bc6..ffdd480 100644 --- a/backend/app/config/settings.py +++ b/backend/app/config/settings.py @@ -82,6 +82,10 @@ class Settings(BaseSettings): parser_backend: str = Field(default="aliyun", description="解析后端(local/aliyun)") chunk_backend: str = Field(default="aliyun", description="分块后端(local/aliyun)") document_repository_backend: str = Field(default="json", description="文档元数据存储后端 (json/postgres)") + # When True, document processing is enqueued to Celery workers via Redis. + # When False (default), processing runs in a FastAPI BackgroundTask in the same process — + # no external worker needed. Switch to True only when a Celery worker is running. + use_celery_worker: bool = Field(default=False, description="使用 Celery Worker 异步处理文档 (需要 Worker 运行中)") # Keep configuration setup explicit so runtime behavior is easy to reason about. api_host: str = Field(default="0.0.0.0", description="API服务地址") @@ -109,6 +113,7 @@ class Settings(BaseSettings): rag_retrieval_top_k: int = Field(default=20, description="精排前召回候选数量(reranker 启用时生效)") rag_max_context_tokens: int = Field(default=2000, description="RAG最大上下文token数") rag_summary_max_tokens: int = Field(default=10240, description="文档摘要最大token数") + rag_skills_max_tokens: int = Field(default=2048, description="技能类 RAG 最大 token 数") reranker_enabled: bool = Field(default=False, description="是否启用 Cross-Encoder 精排") reranker_base_url: str = Field(default="", description="Reranker API 地址") @@ -124,6 +129,26 @@ class Settings(BaseSettings): # Keep configuration setup explicit so runtime behavior is easy to reason about. session_max_sessions: int = Field(default=100, description="最大会话数量") session_timeout_minutes: int = Field(default=30, description="会话超时时间(分钟)") + session_backend: str = Field( + default="memory", + description="会话存储后端 (memory | redis)。redis 需要 Redis 可用。", + ) + + # ── Auth ────────────────────────────────────────────────────────────────── + # Generate a strong secret: python -c "import secrets; print(secrets.token_hex(32))" + auth_secret_key: str = Field( + default="change-me-in-production-must-be-32-or-more-characters-long", + description="JWT signing secret. MUST be changed in production.", + ) + auth_algorithm: str = Field(default="HS256", description="JWT signing algorithm.") + auth_token_expire_minutes: int = Field(default=480, description="JWT TTL in minutes (default 8 hours).") + auth_enabled: bool = Field(default=True, description="Set False to bypass auth (development only).") + + # ── CORS ────────────────────────────────────────────────────────────────── + cors_allow_origins: str = Field( + default="http://localhost:5173", + description="Comma-separated allowed CORS origins. Never use * in production.", + ) @lru_cache def get_settings() -> Settings: diff --git a/backend/app/domain/auth/__init__.py b/backend/app/domain/auth/__init__.py new file mode 100644 index 0000000..9dbc032 --- /dev/null +++ b/backend/app/domain/auth/__init__.py @@ -0,0 +1,10 @@ +"""Auth domain: role definitions and token claim models. + +The domain layer defines what a user identity looks like (UserClaims) and +what roles exist (UserRole). Infrastructure details (JWT, bcrypt, PostgreSQL) +live under infrastructure/auth and never leak into this package. +""" + +from .models import UserClaims, UserRole + +__all__ = ["UserClaims", "UserRole"] diff --git a/backend/app/domain/auth/models.py b/backend/app/domain/auth/models.py new file mode 100644 index 0000000..2ca3399 --- /dev/null +++ b/backend/app/domain/auth/models.py @@ -0,0 +1,42 @@ +"""Auth domain models: roles and token claims. + +UserRole defines the four roles from PPT Slide 12. +UserClaims is what the JWT decodes to — it is the identity object passed +through FastAPI dependency injection to route handlers. +""" + +from __future__ import annotations + +import enum +from dataclasses import dataclass + + +class UserRole(str, enum.Enum): + """Access roles mirroring the four-role RBAC matrix from the product spec. + + ADMIN — full platform access including system management. + LEGAL — knowledge query, document review, compliance checks. + EHS — knowledge query, perception/regulatory signals. + READONLY — knowledge query only. + """ + + ADMIN = "admin" + LEGAL = "legal" + EHS = "ehs" + READONLY = "readonly" + + +@dataclass +class UserClaims: + """Decoded JWT payload representing an authenticated user. + + Instances are created by JWTHandler.decode_token() and injected into + route handlers via the get_current_user FastAPI dependency. + """ + + # Unique user identifier (UUID string stored in PostgreSQL users table). + user_id: str + # Display name used for audit log entries. + username: str + # Role determines which resources the user may access. + role: UserRole diff --git a/backend/app/infrastructure/auth/__init__.py b/backend/app/infrastructure/auth/__init__.py new file mode 100644 index 0000000..faee0f4 --- /dev/null +++ b/backend/app/infrastructure/auth/__init__.py @@ -0,0 +1,5 @@ +"""JWT token creation and validation infrastructure. + +JWTHandler is the only component in this package. It is wired through +shared/bootstrap.py and injected into FastAPI dependencies. +""" diff --git a/backend/app/infrastructure/auth/jwt_handler.py b/backend/app/infrastructure/auth/jwt_handler.py new file mode 100644 index 0000000..7b0fd07 --- /dev/null +++ b/backend/app/infrastructure/auth/jwt_handler.py @@ -0,0 +1,82 @@ +"""JWT access token creation and decoding. + +Uses python-jose for HS256 token signing. Token expiry is enforced at +decode time so expired tokens are rejected even if the signature is valid. +""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from typing import Any + +from jose import JWTError, jwt +from loguru import logger + +from app.domain.auth.models import UserClaims, UserRole + + +class JWTHandler: + """Create and validate HS256 JWT access tokens. + + A single shared instance is wired by bootstrap.py. Use + get_jwt_handler() from shared.bootstrap for all token operations. + """ + + def __init__( + self, + *, + secret_key: str, + algorithm: str = "HS256", + expire_minutes: int = 480, + ) -> None: + """Initialise the handler with signing credentials and token lifetime.""" + self._secret = secret_key + self._algorithm = algorithm + self._expire_minutes = expire_minutes + + def create_access_token( + self, + *, + user_id: str, + username: str, + role: str, + ) -> str: + """Return a signed JWT containing user identity and role claims.""" + now = datetime.now(UTC) + payload: dict[str, Any] = { + "sub": user_id, + "username": username, + "role": role, + "iat": now, + "exp": now + timedelta(minutes=self._expire_minutes), + } + return jwt.encode(payload, self._secret, algorithm=self._algorithm) + + def decode_token(self, token: str) -> UserClaims: + """Decode and validate a JWT, returning UserClaims. + + Raises ValueError with a descriptive message on expiry, tampering, + or any other validation failure so callers do not need to know jose. + """ + try: + payload = jwt.decode(token, self._secret, algorithms=[self._algorithm]) + except JWTError as exc: + msg = str(exc).lower() + if "expired" in msg: + raise ValueError("Token expired") from exc + raise ValueError(f"Invalid token: {exc}") from exc + + user_id = payload.get("sub") + username = payload.get("username", "") + role_str = payload.get("role", UserRole.READONLY.value) + + if not user_id: + raise ValueError("Token missing subject claim") + + try: + role = UserRole(role_str) + except ValueError: + logger.warning("Unknown role in token: {}, defaulting to readonly", role_str) + role = UserRole.READONLY + + return UserClaims(user_id=user_id, username=username, role=role) diff --git a/backend/app/infrastructure/auth/user_store.py b/backend/app/infrastructure/auth/user_store.py new file mode 100644 index 0000000..9b29e74 --- /dev/null +++ b/backend/app/infrastructure/auth/user_store.py @@ -0,0 +1,113 @@ +"""PostgreSQL-backed user store for authentication. + +Manages a `users` table with hashed passwords and roles. +Provides lookup by username for the login flow. +Table DDL is auto-applied on first connection. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import psycopg2 +import psycopg2.extras +from loguru import logger +from passlib.context import CryptContext + +from app.config.settings import settings + + +# bcrypt context — work factor 12 is a good production default. +_PWD_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto") + +# DDL executed once to ensure the table exists. +_CREATE_TABLE_SQL = """ +CREATE TABLE IF NOT EXISTS users ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + username VARCHAR(100) UNIQUE NOT NULL, + hashed_pw TEXT NOT NULL, + role VARCHAR(50) NOT NULL DEFAULT 'readonly', + is_active BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +""" + + +@dataclass +class UserRecord: + """A single row from the users table.""" + + id: str + username: str + hashed_pw: str + role: str + is_active: bool + + +class PostgresUserStore: + """Read and verify users stored in the PostgreSQL users table. + + The connection is opened on first use and shared for the lifetime + of the singleton instance wired by bootstrap. + """ + + def __init__(self) -> None: + """Initialise the store and ensure the users table exists.""" + self._conn = psycopg2.connect( + host=settings.postgres_host, + port=settings.postgres_port, + user=settings.postgres_user, + password=settings.postgres_password, + dbname=settings.postgres_db, + cursor_factory=psycopg2.extras.RealDictCursor, + ) + self._conn.autocommit = True + self._ensure_table() + + def _ensure_table(self) -> None: + """Create the users table if it does not already exist.""" + with self._conn.cursor() as cur: + # Enable pgcrypto so gen_random_uuid() is available for UUID primary keys. + try: + cur.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;") + except Exception: + self._conn.rollback() + cur.execute(_CREATE_TABLE_SQL) + + def get_by_username(self, username: str) -> Optional[UserRecord]: + """Return a UserRecord for the given username, or None if not found.""" + with self._conn.cursor() as cur: + cur.execute( + "SELECT id, username, hashed_pw, role, is_active " + "FROM users WHERE username = %s", + (username,), + ) + row = cur.fetchone() + if row is None: + return None + return UserRecord( + id=str(row["id"]), + username=row["username"], + hashed_pw=row["hashed_pw"], + role=row["role"], + is_active=row["is_active"], + ) + + def verify_password(self, plain: str, hashed: str) -> bool: + """Return True if `plain` matches the stored bcrypt hash.""" + return _PWD_CTX.verify(plain, hashed) + + def authenticate(self, username: str, password: str) -> Optional[UserRecord]: + """Return the UserRecord if credentials are valid, else None.""" + user = self.get_by_username(username) + if user is None or not user.is_active: + return None + if not self.verify_password(password, user.hashed_pw): + return None + return user + + @staticmethod + def hash_password(plain: str) -> str: + """Hash a plain-text password with bcrypt.""" + return _PWD_CTX.hash(plain) diff --git a/backend/app/infrastructure/session/redis_conversation_store.py b/backend/app/infrastructure/session/redis_conversation_store.py new file mode 100644 index 0000000..11cc81b --- /dev/null +++ b/backend/app/infrastructure/session/redis_conversation_store.py @@ -0,0 +1,169 @@ +"""Redis-backed conversation store for persistent chat sessions. + +Sessions are stored as JSON strings under the key `session:{session_id}`. +The Redis TTL is refreshed on every write so active sessions stay alive. +On expiry, `get_session` returns None — callers should create a new session. +""" + +from __future__ import annotations + +import json +import time +import uuid +from typing import Any + +from loguru import logger + +from app.domain.conversation import ConversationMessage, ConversationSession, ConversationStore + + +class RedisConversationStore(ConversationStore): + """Store conversation sessions in Redis with automatic TTL expiry. + + Each session is serialised as a JSON object at key ``session:{session_id}``. + The TTL is reset on every write so sessions stay alive as long as they are active. + """ + + # Prefix for all session keys to avoid collisions with other Redis consumers. + _PREFIX = "session:" + + def __init__(self, *, redis_client: Any, timeout_seconds: int = 1800) -> None: + """Initialise the store with an existing Redis client and a TTL in seconds.""" + self._redis = redis_client + self._ttl = timeout_seconds + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _key(self, session_id: str) -> str: + """Build the Redis key for a session.""" + return f"{self._PREFIX}{session_id}" + + def _serialise(self, session: ConversationSession) -> str: + """Serialise a ConversationSession to a JSON string.""" + return json.dumps( + { + "session_id": session.session_id, + "created_at": session.created_at, + "updated_at": session.updated_at, + "metadata": session.metadata, + "messages": [ + { + "role": msg.role, + "content": msg.content, + "timestamp": msg.timestamp, + "sources": msg.sources, + } + for msg in session.messages + ], + }, + ensure_ascii=False, + ) + + def _deserialise(self, raw: bytes | str) -> ConversationSession: + """Deserialise a JSON string back into a ConversationSession.""" + data = json.loads(raw) + messages = [ + ConversationMessage( + role=m["role"], + content=m["content"], + timestamp=m["timestamp"], + sources=m.get("sources", []), + ) + for m in data.get("messages", []) + ] + session = ConversationSession( + session_id=data["session_id"], + created_at=data.get("created_at", 0), + updated_at=data.get("updated_at", 0), + metadata=data.get("metadata", {}), + ) + session.messages = messages + return session + + def _save(self, session: ConversationSession) -> None: + """Persist a session to Redis and refresh its TTL.""" + self._redis.setex(self._key(session.session_id), self._ttl, self._serialise(session)) + + # ------------------------------------------------------------------ + # ConversationStore protocol + # ------------------------------------------------------------------ + + def create_session(self, metadata: dict | None = None) -> ConversationSession: + """Create a new empty session and persist it immediately.""" + now = int(time.time()) + session = ConversationSession( + session_id=str(uuid.uuid4())[:8], + created_at=now, + updated_at=now, + metadata=metadata or {}, + ) + self._save(session) + return session + + def get_session(self, session_id: str) -> ConversationSession | None: + """Return a session by ID, or None if it does not exist or has expired.""" + raw = self._redis.get(self._key(session_id)) + if raw is None: + return None + try: + return self._deserialise(raw) + except Exception: + logger.warning("Failed to deserialise session: {}", session_id) + return None + + def save_message( + self, + session_id: str, + *, + role: str, + content: str, + sources: list[dict] | None = None, + ) -> ConversationSession | None: + """Append a message to a session and refresh its TTL.""" + session = self.get_session(session_id) + if session is None: + return None + session.messages.append( + ConversationMessage( + role=role, + content=content, + timestamp=int(time.time()), + sources=sources or [], + ) + ) + session.updated_at = int(time.time()) + self._save(session) + return session + + def delete_session(self, session_id: str) -> bool: + """Delete a session. Returns True if it existed, False otherwise.""" + deleted = self._redis.delete(self._key(session_id)) + return bool(deleted) + + def list_sessions(self) -> list[dict]: + """Return summary dicts for all live sessions visible in this Redis DB. + + Note: KEYS is used for simplicity; replace with SCAN for large deployments. + """ + pattern = f"{self._PREFIX}*" + keys = self._redis.keys(pattern) + result = [] + for key in keys: + raw = self._redis.get(key) + if raw is None: + continue + try: + data = json.loads(raw) + result.append( + { + "session_id": data["session_id"], + "message_count": len(data.get("messages", [])), + "created_at": data.get("created_at", 0), + "updated_at": data.get("updated_at", 0), + } + ) + except Exception: + continue + return result diff --git a/backend/app/infrastructure/tasks/__init__.py b/backend/app/infrastructure/tasks/__init__.py new file mode 100644 index 0000000..422b0b3 --- /dev/null +++ b/backend/app/infrastructure/tasks/__init__.py @@ -0,0 +1,5 @@ +"""Celery task definitions for background processing. + +This package exposes the shared Celery application instance and all +registered task functions used by API routes to enqueue work. +""" diff --git a/backend/app/infrastructure/tasks/celery_app.py b/backend/app/infrastructure/tasks/celery_app.py new file mode 100644 index 0000000..0e86664 --- /dev/null +++ b/backend/app/infrastructure/tasks/celery_app.py @@ -0,0 +1,45 @@ +"""Shared Celery application instance for background task processing. + +All workers and enqueueing call sites import `celery_app` from this module +so the broker/backend configuration stays in one place. +""" + +from __future__ import annotations + +from celery import Celery + +from app.config.settings import settings + + +def _redis_url() -> str: + """Return a Redis connection URL from application settings.""" + if settings.redis_password: + return ( + f"redis://:{settings.redis_password}@" + f"{settings.redis_host}:{settings.redis_port}/{settings.redis_db}" + ) + return f"redis://{settings.redis_host}:{settings.redis_port}/{settings.redis_db}" + + +_BROKER = _redis_url() +_BACKEND = _redis_url() + +celery_app = Celery( + "compliance_hub", + broker=_BROKER, + backend=_BACKEND, + include=["app.infrastructure.tasks.document_tasks"], +) + +celery_app.conf.update( + task_serializer="json", + result_serializer="json", + accept_content=["json"], + timezone="UTC", + enable_utc=True, + # Acknowledge task only after successful execution to avoid data loss. + task_acks_late=True, + task_reject_on_worker_lost=True, + # Keep results for 1 hour for status polling. + result_expires=3600, +) diff --git a/backend/app/infrastructure/tasks/document_tasks.py b/backend/app/infrastructure/tasks/document_tasks.py new file mode 100644 index 0000000..00853aa --- /dev/null +++ b/backend/app/infrastructure/tasks/document_tasks.py @@ -0,0 +1,73 @@ +"""Celery tasks for document processing. + +Each task is a thin wrapper that retrieves the already-stored document +binary and delegates to DocumentCommandService._process_document. +The task does not accept raw file bytes — it reads them from the binary +store using the doc_id, so the Celery message payload stays small. +""" + +from __future__ import annotations + +from loguru import logger + +from app.infrastructure.tasks.celery_app import celery_app + + +@celery_app.task( + name="app.infrastructure.tasks.document_tasks.process_document_task", + bind=True, + max_retries=3, + default_retry_delay=30, + acks_late=True, +) +def process_document_task( + self, + doc_id: str, + file_name: str, + doc_name: str, + regulation_type: str, + version: str, + generate_summary: bool, + run_id: str | None = None, +) -> dict: + """Parse, embed, and index a document that has already been stored. + + The task reads the file binary from MinIO using doc_id so the Celery + message stays small. Retries up to 3 times with a 30-second delay on + transient infrastructure errors. + """ + # Import inside the task function to avoid pickling issues and to ensure + # that each worker process initialises its own bootstrap singletons. + from app.shared.bootstrap import get_document_command_service, get_document_query_service + + logger.info("process_document_task started: doc_id={}", doc_id) + try: + svc = get_document_command_service() + doc = get_document_query_service().get(doc_id) + if not doc: + raise ValueError(f"Document record not found: {doc_id}") + + # Read the stored binary from MinIO — avoids passing raw bytes in the task message. + content = svc.binary_store.read(doc.object_name) + + result = svc._process_document( + doc_id=doc_id, + file_name=file_name, + final_doc_name=doc_name, + content=content, + regulation_type=regulation_type, + version=version, + generate_summary=generate_summary, + run_id=run_id, + ) + logger.info( + "process_document_task completed: doc_id={} status={} chunks={}", + doc_id, result.status, result.num_chunks, + ) + return {"doc_id": result.doc_id, "status": result.status, "num_chunks": result.num_chunks} + + except Exception as exc: + logger.exception("process_document_task failed: doc_id={}", doc_id) + # Retry on transient errors; permanent errors (bad file, parse failure) + # will exhaust retries and leave the document in FAILED state. + raise self.retry(exc=exc) diff --git a/backend/app/shared/bootstrap.py b/backend/app/shared/bootstrap.py index ef5d81a..7821924 100644 --- a/backend/app/shared/bootstrap.py +++ b/backend/app/shared/bootstrap.py @@ -252,7 +252,31 @@ def get_document_query_service() -> DocumentQueryService: @lru_cache def get_conversation_store() -> InMemoryConversationStore: - """Return conversation store.""" + """Return the active conversation store based on settings. + + When session_backend='redis', sessions survive backend restarts and scale + across multiple API worker processes. When session_backend='memory' (default), + sessions are process-local and lost on restart. + """ + if settings.session_backend == "redis": + import redis as redis_lib + from app.infrastructure.session.redis_conversation_store import RedisConversationStore + + # Build the Redis client from the same connection settings used by Celery. + kwargs: dict = { + "host": settings.redis_host, + "port": settings.redis_port, + "db": settings.redis_db, + "decode_responses": False, + } + if settings.redis_password: + kwargs["password"] = settings.redis_password + + redis_client = redis_lib.Redis(**kwargs) + return RedisConversationStore( # type: ignore[return-value] + redis_client=redis_client, + timeout_seconds=settings.session_timeout_minutes * 60, + ) return InMemoryConversationStore( max_sessions=settings.session_max_sessions, timeout_minutes=settings.session_timeout_minutes, @@ -284,6 +308,35 @@ def get_agent_session_service() -> AgentSessionService: return AgentSessionService(conversation_store=get_conversation_store()) +@lru_cache +def get_celery_app(): + """Return the shared Celery application instance. + + Imported lazily so Celery is not required when running without workers + (e.g., tests that mock bootstrap or dev without Redis). + """ + from app.infrastructure.tasks.celery_app import celery_app + return celery_app + + +@lru_cache +def get_jwt_handler(): + """Return the shared JWTHandler instance for token creation and validation.""" + from app.infrastructure.auth.jwt_handler import JWTHandler + return JWTHandler( + secret_key=settings.auth_secret_key, + algorithm=settings.auth_algorithm, + expire_minutes=settings.auth_token_expire_minutes, + ) + + +@lru_cache +def get_user_store(): + """Return the PostgreSQL user store (lazy-connects on first call).""" + from app.infrastructure.auth.user_store import PostgresUserStore + return PostgresUserStore() + + def preload_runtime_dependencies() -> None: """Warm dependencies that are safe and useful to preload during startup.""" LLMFactory.preload_clients(["qwen", "deepseek"]) diff --git a/backend/requirements.txt b/backend/requirements.txt index 20d07bb..b75a8f0 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,30 +1,46 @@ +# ── Web framework ───────────────────────────────────────────────────────────── fastapi>=0.110.0 uvicorn[standard]>=0.27.0 python-multipart>=0.0.9 +# ── Config & utilities ──────────────────────────────────────────────────────── pydantic>=2.0.0 pydantic-settings>=2.0.0 python-dotenv>=1.0.0 loguru>=0.7.0 - httpx>=0.25.0 tiktoken>=0.5.0 tenacity>=8.2.0 +# ── Auth ────────────────────────────────────────────────────────────────────── +python-jose[cryptography]>=3.3.0 +# passlib is incompatible with bcrypt>=4.0 (removed __about__, strict 72-byte limit). +# Pin bcrypt to 3.x until passlib ships a fix. +passlib[bcrypt]>=1.7.4 +bcrypt>=3.2.0,<4.0.0 + +# ── Async task queue ────────────────────────────────────────────────────────── +celery>=5.3.0 +redis>=4.5.0 + +# ── Storage & databases ─────────────────────────────────────────────────────── pymilvus>=2.4.0 minio>=7.1.0 psycopg2-binary>=2.9.0 +# ── Document parsing ───────────────────────────────────────────────────────── pymupdf>=1.24.0 python-docx>=1.1.0 - -numpy>=1.24.0 alibabacloud-docmind-api20220711>=1.0.6 alibabacloud-tea-openapi>=0.3.11 alibabacloud-tea-util>=0.3.13 +# ── RAG / LangChain ─────────────────────────────────────────────────────────── langchain>=0.1.0 langchain-milvus>=0.1.0 +numpy>=1.24.0 +# ── Testing ─────────────────────────────────────────────────────────────────── pytest>=7.4.0 pytest-asyncio>=0.21.0 +fakeredis>=2.0.0 diff --git a/dev.sh b/dev.sh index 79a9cb0..b50b85c 100644 --- a/dev.sh +++ b/dev.sh @@ -549,7 +549,7 @@ AI+合规智能中枢统一脚本 用法: ./dev.sh help ./dev.sh setup - ./dev.sh start [all|api|frontend] [--foreground] [--mode dev|static] + ./dev.sh start [all|api|frontend|worker|beat] [--foreground] [--mode dev|static] ./dev.sh stop [all|api|frontend] ./dev.sh restart [all|api|frontend] [--mode dev|static] ./dev.sh status @@ -563,6 +563,9 @@ AI+合规智能中枢统一脚本 进行一次性的本地初始化。 包含 Python 版本检查、.venv 虚拟环境创建、后端依赖安装、前端 npm install、 以及 6.86.80.8 基础服务端口连通性检查。 + 初始化完成后,首次运行前还需执行: + PYTHONPATH=backend .venv/bin/python scripts/seed_users.py + 以创建 admin/legal/ehs/readonly 四个演示用户。 start 启动服务。默认行为等同于 ./dev.sh start all。 @@ -570,6 +573,8 @@ AI+合规智能中枢统一脚本 all 同时启动 API 和前端。 api 只启动后端 API。 frontend 只启动前端。 + worker 启动 Celery 文档处理 worker(前台运行,需要 Redis)。 + beat 启动 Celery Beat 定时调度器(前台运行,需要 Redis)。 可选参数: --foreground 仅对 start api 生效,前台运行并开启 --reload,便于调试。 --mode dev 前端使用 Vite 开发服务器,默认端口 5173。 @@ -578,6 +583,7 @@ AI+合规智能中枢统一脚本 stop 停止服务。默认行为等同于 ./dev.sh stop all。 会优先读取 logs/*.pid,PID 文件失效时会回退到端口探测。 + 注意: worker 和 beat 为前台进程,直接 Ctrl+C 停止。 restart 先停止再启动,支持 all/api/frontend。 @@ -601,8 +607,11 @@ AI+合规智能中枢统一脚本 常用示例: ./dev.sh setup + PYTHONPATH=backend .venv/bin/python scripts/seed_users.py ./dev.sh start ./dev.sh start api --foreground + ./dev.sh start worker + ./dev.sh start beat ./dev.sh start frontend --mode static ./dev.sh restart frontend --mode dev ./dev.sh status @@ -615,7 +624,7 @@ parse_target() { local default_target="$1" local candidate="${2:-}" case "$candidate" in - all|api|frontend) + all|api|frontend|worker|beat) echo "$candidate" ;; *) @@ -646,41 +655,64 @@ main() { shift || true fi - while [ $# -gt 0 ]; do - case "$1" in - --foreground) - foreground=true - ;; - --mode) - shift || die "--mode 需要指定 dev 或 static" - mode="$1" - validate_frontend_mode "$mode" - ;; - *) - die "未知参数: $1" - ;; - esac - shift || true - done - + # worker and beat are pass-through — forward remaining args to celery directly. case "$target" in - all) - [ "$foreground" = false ] || die "start all 不支持 --foreground,请使用 start api --foreground" - print_header "AI+合规智能中枢 - 启动服务" - start_api background - start_frontend "${mode:-$FRONTEND_MODE}" + worker) + print_header "AI+合规智能中枢 - 启动 Celery Worker" + require_venv + export PYTHONPATH="backend${PYTHONPATH:+:$PYTHONPATH}" + "$VENV_PYTHON" -m celery -A app.infrastructure.tasks.celery_app worker \ + --loglevel=info \ + --concurrency=2 \ + --queues=celery \ + "$@" ;; - api) - if [ "$foreground" = true ]; then - start_api foreground - else - print_header "AI+合规智能中枢 - 启动 API" - start_api background - fi + beat) + print_header "AI+合规智能中枢 - 启动 Celery Beat" + require_venv + export PYTHONPATH="backend${PYTHONPATH:+:$PYTHONPATH}" + "$VENV_PYTHON" -m celery -A app.infrastructure.tasks.celery_app beat \ + --loglevel=info \ + "$@" ;; - frontend) - print_header "AI+合规智能中枢 - 启动前端" - start_frontend "${mode:-$FRONTEND_MODE}" + *) + while [ $# -gt 0 ]; do + case "$1" in + --foreground) + foreground=true + ;; + --mode) + shift || die "--mode 需要指定 dev 或 static" + mode="$1" + validate_frontend_mode "$mode" + ;; + *) + die "未知参数: $1" + ;; + esac + shift || true + done + + case "$target" in + all) + [ "$foreground" = false ] || die "start all 不支持 --foreground,请使用 start api --foreground" + print_header "AI+合规智能中枢 - 启动服务" + start_api background + start_frontend "${mode:-$FRONTEND_MODE}" + ;; + api) + if [ "$foreground" = true ]; then + start_api foreground + else + print_header "AI+合规智能中枢 - 启动 API" + start_api background + fi + ;; + frontend) + print_header "AI+合规智能中枢 - 启动前端" + start_frontend "${mode:-$FRONTEND_MODE}" + ;; + esac ;; esac ;; diff --git a/docs/superpowers/plans/2026-06-05-phase1-production-foundation.md b/docs/superpowers/plans/2026-06-05-phase1-production-foundation.md new file mode 100644 index 0000000..860404b --- /dev/null +++ b/docs/superpowers/plans/2026-06-05-phase1-production-foundation.md @@ -0,0 +1,2304 @@ +# Phase 1: Production Foundation Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Transform the existing synchronous RAG system into a production-ready backend with async task processing, persistent sessions, and JWT authentication — turning "can run" into "can operate." + +**Architecture:** Four independent workstreams executed in priority order: (A) enable the already-wired reranker, (B) introduce Celery workers backed by the already-configured Redis to make document processing non-blocking, (C) replace the in-memory conversation store with a Redis-backed implementation, and (D) add JWT authentication with role claims and audit logging across all API routes. + +**Tech Stack:** FastAPI · Celery 5 · Redis 7 · PostgreSQL 15 · `python-jose[cryptography]` · `passlib[bcrypt]` · `fakeredis` (test) — all running on the existing `docker-compose.yml` infrastructure. + +--- + +## Scope Note + +These four workstreams are **largely independent** and can be worked in parallel by separate developers. Auth (Group D) should be integrated last because it modifies every route. The recommended solo sequence: A → B → C → D. + +**Not in scope (YAGNI):** Fine-grained per-endpoint RBAC enforcement (the role claim is added to JWT as a future enforcement point), user self-registration, perception crawling, knowledge graph, EHS module. + +--- + +## File Map + +### Group A — Reranker Quick Win +| Action | File | +|--------|------| +| Modify | `.env.example` — document reranker env vars | +| Create | `tests/test_reranker_bootstrap.py` | + +### Group B — Async Task-ification +| Action | File | +|--------|------| +| Create | `backend/app/infrastructure/tasks/__init__.py` | +| Create | `backend/app/infrastructure/tasks/celery_app.py` | +| Create | `backend/app/infrastructure/tasks/document_tasks.py` | +| Modify | `backend/app/application/documents/services.py` — extract `_process_document` | +| Modify | `backend/app/api/routes/documents.py` — enqueue vs inline | +| Modify | `backend/app/shared/bootstrap.py` — expose `get_celery_app()` | +| Modify | `dev.sh` — add worker start target | +| Create | `tests/test_celery_tasks.py` | + +### Group C — Session Persistence +| Action | File | +|--------|------| +| Create | `backend/app/infrastructure/session/redis_conversation_store.py` | +| Modify | `backend/app/config/settings.py` — add `session_backend` field | +| Modify | `backend/app/shared/bootstrap.py` — switch store on setting | +| Modify | `pyproject.toml` — add `fakeredis` dev dep | +| Create | `tests/test_redis_conversation_store.py` | + +### Group D — JWT Auth + RBAC + Audit +| Action | File | +|--------|------| +| Modify | `requirements.txt` — add `python-jose`, `passlib[bcrypt]` | +| Modify | `pyproject.toml` — same | +| Create | `backend/app/domain/auth/__init__.py` | +| Create | `backend/app/domain/auth/models.py` | +| Create | `backend/app/infrastructure/auth/__init__.py` | +| Create | `backend/app/infrastructure/auth/jwt_handler.py` | +| Create | `backend/app/infrastructure/auth/user_store.py` | +| Create | `backend/app/api/dependencies/__init__.py` | +| Create | `backend/app/api/dependencies/auth.py` | +| Create | `backend/app/api/middleware/__init__.py` | +| Create | `backend/app/api/middleware/audit.py` | +| Create | `backend/app/api/routes/auth.py` | +| Modify | `backend/app/api/routes/__init__.py` — include auth router | +| Modify | `backend/app/config/settings.py` — add auth fields | +| Modify | `backend/app/shared/bootstrap.py` — wire auth components | +| Modify | `backend/app/api/main.py` — restrict CORS, add audit middleware | +| Modify | `backend/app/api/routes/documents.py` — require auth | +| Modify | `backend/app/api/routes/rag.py` — require auth | +| Modify | `backend/app/api/routes/compliance.py` — require auth | +| Create | `scripts/seed_users.py` | +| Create | `tests/test_jwt_handler.py` | +| Create | `tests/test_auth_routes.py` | + +--- + +## Group A — Reranker Quick Win + +> Estimated time: 30 minutes. No new code, just configuration and a verification test. +> +> **Why:** `settings.reranker_enabled` defaults to `False` in `backend/app/config/settings.py:113`. The entire `OpenAICompatibleReranker` infrastructure at `backend/app/infrastructure/vectorstore/cross_encoder_reranker.py` and the wiring in `backend/app/shared/bootstrap.py:199-203` is already complete. Enabling it requires setting two env vars and verifying the wiring works. + +### Task 1: Document reranker env vars and write a bootstrap verification test + +**Files:** +- Modify: `.env.example` +- Create: `tests/test_reranker_bootstrap.py` + +- [ ] **Step 1: Add reranker vars to `.env.example`** + +Open `.env.example`. Locate the existing embedding/LLM block and append: + +```ini +# ── Reranker (Cross-Encoder) ────────────────────────────────────────────────── +# Set RERANKER_ENABLED=true and point to a TEI or Cohere-compatible rerank API. +# Recommended model: BAAI/bge-reranker-v2.5-gemma2-lightweight (lighter) or +# BAAI/bge-reranker-v2-m3 (heavier, higher quality). +# The endpoint must expose POST /rerank (TEI style) or POST /v1/rerank (Cohere style). +RERANKER_ENABLED=true +RERANKER_BASE_URL=http://6.86.80.4:30080/v1 +RERANKER_MODEL=BAAI/bge-reranker-v2-m3 +RERANKER_API_KEY= +RERANKER_TOP_K=5 +``` + +- [ ] **Step 2: Write the failing test** + +Create `tests/test_reranker_bootstrap.py`: + +```python +"""Verify that bootstrap correctly wires the reranker when the setting is enabled.""" + +from unittest.mock import patch + +import pytest + + +def test_get_reranker_returns_none_when_disabled(): + """get_reranker() must return None when reranker_enabled is False.""" + with patch("app.shared.bootstrap._build_binary_store"), \ + patch("app.shared.bootstrap._build_vector_index"): + # Import fresh by clearing lru_cache between tests + from app.shared import bootstrap + bootstrap.get_reranker.cache_clear() + + with patch("app.config.settings.settings") as mock_settings: + mock_settings.reranker_enabled = False + mock_settings.reranker_base_url = "" + result = bootstrap.get_reranker() + + bootstrap.get_reranker.cache_clear() + assert result is None + + +def test_get_reranker_returns_instance_when_enabled(): + """get_reranker() must return an OpenAICompatibleReranker when enabled.""" + from app.shared import bootstrap + bootstrap.get_reranker.cache_clear() + + with patch("app.config.settings.settings") as mock_settings: + mock_settings.reranker_enabled = True + mock_settings.reranker_base_url = "http://localhost:8082" + mock_settings.reranker_model = "BAAI/bge-reranker-v2-m3" + mock_settings.reranker_api_key = "" + mock_settings.reranker_top_k = 5 + result = bootstrap.get_reranker() + + bootstrap.get_reranker.cache_clear() + from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker + assert isinstance(result, OpenAICompatibleReranker) +``` + +- [ ] **Step 3: Run test to verify it fails (before fix)** + +```bash +PYTHONPATH=backend pytest tests/test_reranker_bootstrap.py -v +``` + +Expected: Both tests PASS immediately (the bootstrap logic already exists). If they pass, the wiring is confirmed correct — proceed. + +- [ ] **Step 4: Enable reranker in your local `.env`** + +Add these two lines to your local `.env` (not `.env.example`, which is already updated): + +```ini +RERANKER_ENABLED=true +RERANKER_BASE_URL=http://6.86.80.4:30080/v1 +``` + +- [ ] **Step 5: Verify `GET /api/v1/status/health` reports reranker enabled** + +Start the backend and call: +```bash +curl http://localhost:8000/api/v1/status/health | python -m json.tool +``` + +Expected fragment in response: +```json +"reranker": { + "enabled": true, + "model": "BAAI/bge-reranker-v2-m3" +} +``` + +> ✅ **Group A complete.** The reranker is now active for all `KnowledgeRetrievalService.retrieve()` calls. + +--- + +## Group B — Async Task-ification + +> Estimated time: 3-5 days. +> +> **Why:** `upload_document` at `backend/app/api/routes/documents.py:34` runs the full parse→embed→index pipeline synchronously in the HTTP request. For large GB standards this can take 15+ minutes (Aliyun timeout is 900 s per `settings.py:49`). Celery 5 and Redis are already in `requirements.txt` and `docker-compose.yml` — only the wiring is missing. +> +> **Strategy:** Split `DocumentCommandService.upload_and_process` into a fast `store_document` (sync, returns immediately) and a background `_process_document` (called from a Celery task). The existing `DocumentProcessingStore` already records per-stage status events — it becomes the progress tracker. The HTTP route returns `{doc_id, status: "queued"}` immediately; the client polls `GET /documents/status/{doc_id}` (already exists) for progress. + +### Task 2: Create the Celery application + +**Files:** +- Create: `backend/app/infrastructure/tasks/__init__.py` +- Create: `backend/app/infrastructure/tasks/celery_app.py` +- Create: `tests/test_celery_tasks.py` (first test only) + +- [ ] **Step 1: Write a failing test for Celery app configuration** + +Create `tests/test_celery_tasks.py`: + +```python +"""Tests for Celery task infrastructure. + +These tests verify task registration and Celery app configuration without +starting a real worker or connecting to Redis. +""" + +import pytest + + +def test_celery_app_uses_redis_broker(): + """Celery broker URL must be a Redis URL built from settings.""" + from app.infrastructure.tasks.celery_app import celery_app + assert celery_app.conf.broker_url.startswith("redis://") + + +def test_celery_app_uses_redis_backend(): + """Celery result backend must be Redis.""" + from app.infrastructure.tasks.celery_app import celery_app + assert celery_app.conf.result_backend.startswith("redis://") + + +def test_celery_app_has_json_serializer(): + """Task serializer must be JSON for portability.""" + from app.infrastructure.tasks.celery_app import celery_app + assert celery_app.conf.task_serializer == "json" +``` + +- [ ] **Step 2: Run test to confirm it fails** + +```bash +PYTHONPATH=backend pytest tests/test_celery_tasks.py::test_celery_app_uses_redis_broker -v +``` + +Expected: `ModuleNotFoundError: No module named 'app.infrastructure.tasks'` + +- [ ] **Step 3: Create the tasks package** + +Create `backend/app/infrastructure/tasks/__init__.py`: + +```python +"""Celery task definitions for background processing. + +This package exposes the shared Celery application instance and all +registered task functions used by API routes to enqueue work. +""" +``` + +- [ ] **Step 4: Create the Celery app factory** + +Create `backend/app/infrastructure/tasks/celery_app.py`: + +```python +"""Shared Celery application instance for background task processing. + +All workers and enqueueing call sites import `celery_app` from this module +so the broker/backend configuration stays in one place. +""" + +from __future__ import annotations + +from celery import Celery + +from app.config.settings import settings + +# Build Redis URL from the same settings used by the rest of the backend. +# The password segment is omitted when redis_password is empty so that +# local development against a password-free Redis instance works out of the box. +def _redis_url() -> str: + """Return a Redis connection URL from application settings.""" + if settings.redis_password: + return ( + f"redis://:{settings.redis_password}@" + f"{settings.redis_host}:{settings.redis_port}/{settings.redis_db}" + ) + return f"redis://{settings.redis_host}:{settings.redis_port}/{settings.redis_db}" + + +_BROKER = _redis_url() +_BACKEND = _redis_url() + +celery_app = Celery( + "compliance_hub", + broker=_BROKER, + backend=_BACKEND, + include=["app.infrastructure.tasks.document_tasks"], +) + +celery_app.conf.update( + task_serializer="json", + result_serializer="json", + accept_content=["json"], + timezone="UTC", + enable_utc=True, + # Retry failed tasks up to 3 times with exponential backoff. + task_acks_late=True, + task_reject_on_worker_lost=True, + # Keep results for 1 hour for status polling. + result_expires=3600, +) +``` + +- [ ] **Step 5: Run the Celery app tests** + +```bash +PYTHONPATH=backend pytest tests/test_celery_tasks.py -v +``` + +Expected: `test_celery_app_uses_redis_broker` PASS, `test_celery_app_uses_redis_backend` PASS, `test_celery_app_has_json_serializer` PASS. + +### Task 3: Split `DocumentCommandService` and create the document processing task + +**Files:** +- Modify: `backend/app/application/documents/services.py` +- Create: `backend/app/infrastructure/tasks/document_tasks.py` + +- [ ] **Step 1: Add task registration test to `tests/test_celery_tasks.py`** + +Append to `tests/test_celery_tasks.py`: + +```python +def test_process_document_task_is_registered(): + """process_document_task must be discoverable in the Celery task registry.""" + # Importing document_tasks triggers task registration via the @task decorator. + import app.infrastructure.tasks.document_tasks # noqa: F401 + from app.infrastructure.tasks.celery_app import celery_app + + registered = list(celery_app.tasks.keys()) + assert any("process_document_task" in name for name in registered), ( + f"process_document_task not found in {registered}" + ) +``` + +- [ ] **Step 2: Run test to confirm it fails** + +```bash +PYTHONPATH=backend pytest tests/test_celery_tasks.py::test_process_document_task_is_registered -v +``` + +Expected: `ModuleNotFoundError: No module named 'app.infrastructure.tasks.document_tasks'` + +- [ ] **Step 3: Extract `_process_document` from `DocumentCommandService`** + +In `backend/app/application/documents/services.py`, locate `upload_and_process` (line 233). Add a new method `_process_document` that contains the parse→embed→index logic, and refactor `upload_and_process` to call it after the store step. + +Find the end of the `_safe_mark_run_stored` call block (around line 299) and extract everything from the `suffix = os.path.splitext(...)` line to the end of the try block into a new method: + +```python + def _process_document( + self, + *, + doc_id: str, + file_name: str, + final_doc_name: str, + content: bytes, + regulation_type: str, + version: str, + generate_summary: bool, + run_id: str | None = None, + ) -> DocumentProcessResult: + """Run parse → chunk → embed → index for a document that is already stored. + + This method is called both synchronously (from upload_and_process) and + asynchronously (from the Celery document_tasks worker). All side-effects + write through DocumentProcessingStore so callers can poll progress. + """ + current_status = DocumentStatus.STORED + current_stage = "parse" + temp_path = "" + try: + suffix = os.path.splitext(file_name)[1] + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: + temp_file.write(content) + temp_path = temp_file.name + + parsed_document = self.parser.parse( + file_path=temp_path, + doc_id=doc_id, + doc_name=final_doc_name, + ) + self._safe_mark_run_parsed(doc_id=doc_id, run_id=run_id, parsed_document=parsed_document) + + artifact_keys: dict[str, str] = {} + try: + artifact_keys = self._save_parse_artifacts(doc_id=doc_id, parsed_document=parsed_document) + except Exception: + logger.warning("Parse artifact binary persistence failed for doc_id={}", doc_id) + + self.document_repository.update_status( + doc_id, + DocumentStatus.PARSED, + parser_name=parsed_document.parser_name, + metadata={ + "parser_backend": parsed_document.parser_name, + "parse_task_id": parsed_document.metadata.get("task_id", ""), + "layout_count": parsed_document.metadata.get("layout_count", len(parsed_document.raw_layouts)), + "structure_node_count": len(parsed_document.structure_nodes), + "semantic_block_count": len(parsed_document.semantic_blocks), + "vector_chunk_count": len(parsed_document.vector_chunks), + "artifact_keys": artifact_keys, + "processing_stage": "parsed", + }, + ) + current_status = DocumentStatus.PARSED + current_stage = "embed" + self._safe_replace_processing_artifacts(doc_id=doc_id, run_id=run_id, artifact_keys=artifact_keys) + self._safe_append_status_event( + doc_id=doc_id, run_id=run_id, + from_status=DocumentStatus.STORED.value, to_status=DocumentStatus.PARSED.value, + stage="parse", message="Document parsed", metadata={"artifact_count": len(artifact_keys)}, + ) + if self.parse_artifact_store: + try: + self.parse_artifact_store.save(doc_id, parsed_document.structure_nodes, parsed_document.semantic_blocks) + except Exception: + logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id) + + chunks = self.chunk_builder.build( + parsed_document=parsed_document, + regulation_type=regulation_type, + version=version, + ) + if not chunks: + raise ValueError("解析完成但没有生成可入库的 chunks") + + vectors = self.embedding_provider.embed_texts([chunk.embedding_text for chunk in chunks]) + current_stage = "index" + inserted = self.vector_index.upsert(chunks, vectors) + if inserted != len(chunks): + logger.warning("Milvus upsert count mismatched: inserted={}, chunks={}", inserted, len(chunks)) + + health = self.vector_index.health() + index_name = health.get("collection_name", "") + self.document_repository.update_status( + doc_id, DocumentStatus.INDEXED, + chunk_count=len(chunks), summary="", summary_latency_ms=0, + index_name=index_name, + metadata={"index_collection": index_name, "processing_stage": "indexed"}, + ) + self._safe_mark_run_indexed(doc_id=doc_id, run_id=run_id, chunk_count=len(chunks), index_name=index_name) + self._safe_append_status_event( + doc_id=doc_id, run_id=run_id, + from_status=DocumentStatus.PARSED.value, to_status=DocumentStatus.INDEXED.value, + stage="index", message="Document indexed", + metadata={"chunk_count": len(chunks), "index_name": index_name}, + ) + stored = self.document_repository.get(doc_id) + return DocumentProcessResult( + doc_id=doc_id, doc_name=final_doc_name, + status=(stored.status.value if stored else DocumentStatus.INDEXED.value), + message="处理成功", num_chunks=len(chunks), + summary=stored.summary if stored else "", + summary_latency_ms=stored.summary_latency_ms if stored else 0, + ) + except Exception as exc: + logger.exception("文档处理失败: doc_id={}", doc_id) + self.document_repository.update_status( + doc_id, DocumentStatus.FAILED, error_message=str(exc), + metadata={"failure_reason": str(exc), "processing_stage": "failed", "failure_stage": current_stage}, + ) + self._safe_mark_run_failed(doc_id=doc_id, run_id=run_id, failure_stage=current_stage, error_message=str(exc)) + self._safe_append_status_event( + doc_id=doc_id, run_id=run_id, + from_status=current_status.value, to_status=DocumentStatus.FAILED.value, + stage=current_stage, message=str(exc), + ) + return DocumentProcessResult( + doc_id=doc_id, doc_name=final_doc_name, + status=DocumentStatus.FAILED.value, message=f"文档处理失败: {exc}", + ) + finally: + if temp_path and os.path.exists(temp_path): + try: + os.remove(temp_path) + except OSError: + logger.warning("临时文件清理失败: {}", temp_path) +``` + +Then replace the parse-through-index section of `upload_and_process` with a call to `_process_document`: + +In `upload_and_process`, replace everything from `suffix = os.path.splitext(file_name)[1]` through to the end of the try block (before the except) with: + +```python + # Delegate parse → embed → index to the shared processing method. + # This same method is invoked by the Celery worker for async processing. + return self._process_document( + doc_id=doc_id, + file_name=file_name, + final_doc_name=final_doc_name, + content=content, + regulation_type=regulation_type, + version=version, + generate_summary=generate_summary, + run_id=run_id, + ) +``` + +Also add a new method `store_document` for the fast sync-only path: + +```python + def store_document( + self, + *, + doc_id: str | None = None, + file_name: str, + content: bytes, + content_type: str, + doc_name: str | None, + regulation_type: str, + version: str, + generate_summary: bool, + ) -> tuple[str, str | None]: + """Store the binary file and create the Document record. + + Returns (doc_id, run_id). Does NOT parse, embed, or index. + This is the fast synchronous first step; processing is enqueued separately. + """ + doc_id = doc_id or str(uuid.uuid4())[:8] + final_doc_name = doc_name or file_name + object_name = f"{doc_id}/{file_name}" + + document = Document( + doc_id=doc_id, + doc_name=final_doc_name, + file_name=file_name, + object_name=object_name, + content_type=content_type, + size_bytes=len(content), + regulation_type=regulation_type, + version=version, + metadata={"generate_summary": generate_summary}, + ) + self.document_repository.create(document) + run_id = self._safe_create_processing_run( + doc_id=doc_id, trigger_type="upload", generate_summary=generate_summary + ) + self.binary_store.save( + object_name=object_name, data=content, + content_type=content_type, metadata={"doc_id": doc_id}, + ) + self.document_repository.update_status(doc_id, DocumentStatus.STORED) + self._safe_mark_run_stored(doc_id=doc_id, run_id=run_id) + self._safe_append_status_event( + doc_id=doc_id, run_id=run_id, + from_status=DocumentStatus.PENDING.value, to_status=DocumentStatus.STORED.value, + stage="store", message="Source file stored", + ) + return doc_id, run_id +``` + +- [ ] **Step 4: Create the Celery document task** + +Create `backend/app/infrastructure/tasks/document_tasks.py`: + +```python +"""Celery tasks for document processing. + +Each task is a thin wrapper that retrieves the already-stored document +binary and delegates to DocumentCommandService._process_document. +The task does not accept raw file bytes — it reads them from the binary +store using the doc_id, so the payload stays small. +""" + +from __future__ import annotations + +from loguru import logger + +from app.infrastructure.tasks.celery_app import celery_app + + +@celery_app.task( + name="app.infrastructure.tasks.document_tasks.process_document_task", + bind=True, + max_retries=3, + default_retry_delay=30, + acks_late=True, +) +def process_document_task( + self, + doc_id: str, + file_name: str, + doc_name: str, + regulation_type: str, + version: str, + generate_summary: bool, + run_id: str | None = None, +) -> dict: + """Parse, embed, and index a document that has already been stored. + + The task reads the file binary from MinIO using doc_id so the Celery + message stays small. Retries up to 3 times with a 30-second delay on + transient infrastructure errors. + """ + # Import inside the task function to avoid pickling issues and to ensure + # that each worker process initialises its own bootstrap singletons. + from app.shared.bootstrap import get_document_command_service, get_document_query_service + + logger.info("process_document_task started: doc_id={}", doc_id) + try: + # Read the stored binary from MinIO. + svc = get_document_command_service() + doc = get_document_query_service().get(doc_id) + if not doc: + raise ValueError(f"Document record not found: {doc_id}") + + content = svc.binary_store.read(doc.object_name) + + result = svc._process_document( + doc_id=doc_id, + file_name=file_name, + final_doc_name=doc_name, + content=content, + regulation_type=regulation_type, + version=version, + generate_summary=generate_summary, + run_id=run_id, + ) + logger.info( + "process_document_task completed: doc_id={} status={} chunks={}", + doc_id, result.status, result.num_chunks, + ) + return {"doc_id": result.doc_id, "status": result.status, "num_chunks": result.num_chunks} + + except Exception as exc: + logger.exception("process_document_task failed: doc_id={}", doc_id) + # Retry on transient errors; permanent errors (bad file, parse failure) + # will exhaust retries and leave the document in FAILED state. + raise self.retry(exc=exc) +``` + +- [ ] **Step 5: Run all Celery task tests** + +```bash +PYTHONPATH=backend pytest tests/test_celery_tasks.py -v +``` + +Expected: All 4 tests PASS. + +### Task 4: Update the upload route to support async enqueueing + +**Files:** +- Modify: `backend/app/api/routes/documents.py` +- Modify: `backend/app/shared/bootstrap.py` + +- [ ] **Step 1: Add `get_celery_app` to bootstrap** + +In `backend/app/shared/bootstrap.py`, add at the bottom (before `preload_runtime_dependencies`): + +```python +@lru_cache +def get_celery_app(): + """Return the shared Celery application instance.""" + # Import here to avoid importing Celery at module load time when workers + # are not configured (e.g., during tests that mock bootstrap). + from app.infrastructure.tasks.celery_app import celery_app + return celery_app +``` + +- [ ] **Step 2: Update the upload route** + +Replace the `upload_document` handler in `backend/app/api/routes/documents.py` with: + +```python +@router.post("/upload", response_model=DocumentUploadResponse) +async def upload_document( + file: UploadFile = File(..., description="上传的文档文件"), + doc_id: str | None = Form(None, description="客户端预分配的文档ID,不传则自动生成"), + doc_name: str | None = Form(None, description="文档名称"), + regulation_type: str | None = Form(None, description="法规类型"), + version: str | None = Form(None, description="文档版本"), + generate_summary: bool = Form(False, description="是否生成摘要"), + sync: bool = Form(False, description="同步处理(演示/测试用,默认异步)"), +): + """Upload a document and enqueue it for background processing. + + By default the route stores the binary immediately and enqueues parse→embed→index + as a Celery task, returning {doc_id, status:'stored'} within seconds. + Pass sync=true to process inline (useful for demos when no worker is running). + Poll GET /documents/status/{doc_id} for processing progress. + """ + content = await file.read() + if not file.filename: + raise HTTPException(status_code=400, detail="文件名不能为空") + if not content: + raise HTTPException(status_code=400, detail="上传文件为空") + + try: + svc = get_document_command_service() + + if sync: + # Synchronous fallback: full inline processing (demo / no-worker mode). + result = svc.upload_and_process( + doc_id=doc_id, + file_name=file.filename, + content=content, + content_type=file.content_type or "application/octet-stream", + doc_name=doc_name, + regulation_type=regulation_type or "", + version=version or "", + generate_summary=generate_summary, + ) + else: + # Async default: store binary, enqueue processing task. + stored_doc_id, run_id = svc.store_document( + doc_id=doc_id, + file_name=file.filename, + content=content, + content_type=file.content_type or "application/octet-stream", + doc_name=doc_name, + regulation_type=regulation_type or "", + version=version or "", + generate_summary=generate_summary, + ) + from app.infrastructure.tasks.document_tasks import process_document_task + process_document_task.delay( + doc_id=stored_doc_id, + file_name=file.filename, + doc_name=doc_name or file.filename, + regulation_type=regulation_type or "", + version=version or "", + generate_summary=generate_summary, + run_id=run_id, + ) + result = DocumentProcessResult( + doc_id=stored_doc_id, + doc_name=doc_name or file.filename, + status="stored", + message="文件已存储,正在后台处理。请轮询 GET /documents/status/{doc_id} 查看进度。", + ) + + if result.status == "failed": + raise HTTPException(status_code=500, detail=result.message) + return _document_response(result) + + except HTTPException: + raise + except Exception as exc: + logger.exception("文档上传失败") + raise HTTPException(status_code=500, detail=str(exc)) +``` + +Also add the import at the top of `documents.py`: + +```python +from app.application.documents import DocumentProcessResult +``` + +- [ ] **Step 3: Verify the app still starts** + +```bash +PYTHONPATH=backend python -c "from app.api.main import app; print('OK')" +``` + +Expected: `OK` with no import errors. + +### Task 5: Add worker startup to dev scripts + +**Files:** +- Modify: `dev.sh` + +- [ ] **Step 1: Add worker target to `dev.sh`** + +Open `dev.sh`. Find the `start` command section and add a `worker` subcommand. Locate the section that handles `start api` and add after it: + +```bash + start_worker() { + # Start the Celery document processing worker. + # Reads PYTHONPATH and env from the same root .env the API uses. + export PYTHONPATH=backend + celery -A app.infrastructure.tasks.celery_app worker \ + --loglevel=info \ + --concurrency=2 \ + --queues=celery \ + "$@" + } + + start_beat() { + # Start the Celery Beat scheduler for periodic tasks (future: perception crawl). + export PYTHONPATH=backend + celery -A app.infrastructure.tasks.celery_app beat \ + --loglevel=info \ + "$@" + } +``` + +And in the argument dispatch block, add: + +```bash + "worker") + start_worker "${@:3}" + ;; + "beat") + start_beat "${@:3}" + ;; +``` + +- [ ] **Step 2: Verify worker starts (requires Redis running)** + +```bash +./dev.sh start worker --dry-run 2>/dev/null || PYTHONPATH=backend celery -A app.infrastructure.tasks.celery_app inspect ping 2>&1 | head -5 +``` + +Expected: Either dry-run succeeds or Celery shows it can find the app. + +> ✅ **Group B complete.** Document uploads now return in <1 s. Workers process in the background. `GET /documents/status/{doc_id}` reports live progress via `DocumentProcessingStore`. + +--- + +## Group C — Session Persistence + +> Estimated time: 2-3 days. +> +> **Why:** `InMemoryConversationStore` at `backend/app/infrastructure/session/in_memory_conversation_store.py` stores all chat sessions in a plain Python dict. Every backend restart loses all active conversations. Redis 7 is already running; persisting sessions there requires only a new adapter and a settings switch. + +### Task 6: Add `fakeredis` dev dependency and create the Redis store + +**Files:** +- Modify: `pyproject.toml` +- Create: `backend/app/infrastructure/session/redis_conversation_store.py` +- Create: `tests/test_redis_conversation_store.py` + +- [ ] **Step 1: Add `fakeredis` to dev dependencies** + +In `pyproject.toml`, find the `[dependency-groups]` section and update: + +```toml +[dependency-groups] +dev = ["pytest>=7.0.0", "pytest-asyncio>=0.21.0", "isort>=8.0.1", "fakeredis>=2.0.0"] +``` + +Install it: + +```bash +uv sync +``` + +- [ ] **Step 2: Write failing tests for RedisConversationStore** + +Create `tests/test_redis_conversation_store.py`: + +```python +"""Tests for RedisConversationStore. + +Uses fakeredis so no real Redis connection is required. +All tests follow the same ConversationStore contract as InMemoryConversationStore. +""" + +import json +import pytest +import fakeredis + + +@pytest.fixture +def redis_client(): + """Return an in-process fake Redis client.""" + return fakeredis.FakeRedis() + + +@pytest.fixture +def store(redis_client): + """Return a RedisConversationStore backed by fake Redis.""" + from app.infrastructure.session.redis_conversation_store import RedisConversationStore + return RedisConversationStore(redis_client=redis_client, timeout_seconds=1800) + + +def test_create_session_returns_session_with_id(store): + """create_session() must return a ConversationSession with a non-empty session_id.""" + session = store.create_session() + assert session.session_id + assert len(session.session_id) > 0 + + +def test_get_session_returns_same_session(store): + """get_session() must return the previously created session.""" + session = store.create_session() + fetched = store.get_session(session.session_id) + assert fetched is not None + assert fetched.session_id == session.session_id + + +def test_get_session_returns_none_for_unknown_id(store): + """get_session() must return None when the session_id does not exist.""" + result = store.get_session("nonexistent-id") + assert result is None + + +def test_save_message_appends_to_session(store): + """save_message() must append a message and return the updated session.""" + session = store.create_session() + updated = store.save_message(session.session_id, role="user", content="Hello") + assert updated is not None + assert len(updated.messages) == 1 + assert updated.messages[0].role == "user" + assert updated.messages[0].content == "Hello" + + +def test_save_message_persists_across_lookups(store): + """Messages saved to a session must be visible in subsequent get_session calls.""" + session = store.create_session() + store.save_message(session.session_id, role="user", content="test") + fetched = store.get_session(session.session_id) + assert fetched is not None + assert len(fetched.messages) == 1 + + +def test_delete_session_removes_it(store): + """delete_session() must return True and remove the session.""" + session = store.create_session() + result = store.delete_session(session.session_id) + assert result is True + assert store.get_session(session.session_id) is None + + +def test_delete_session_returns_false_for_unknown(store): + """delete_session() must return False when the session does not exist.""" + result = store.delete_session("ghost-id") + assert result is False + + +def test_list_sessions_includes_created_session(store): + """list_sessions() must include all active sessions.""" + session = store.create_session() + sessions = store.list_sessions() + ids = [s["session_id"] for s in sessions] + assert session.session_id in ids + + +def test_session_expires_after_ttl(redis_client): + """Sessions must disappear after the TTL expires.""" + from app.infrastructure.session.redis_conversation_store import RedisConversationStore + # Use a 1-second TTL for this test. + store = RedisConversationStore(redis_client=redis_client, timeout_seconds=1) + session = store.create_session() + # Manually expire the key to simulate TTL expiry. + redis_client.expire(f"session:{session.session_id}", 0) + assert store.get_session(session.session_id) is None +``` + +- [ ] **Step 3: Run tests to confirm they fail** + +```bash +PYTHONPATH=backend pytest tests/test_redis_conversation_store.py -v +``` + +Expected: `ModuleNotFoundError: No module named 'app.infrastructure.session.redis_conversation_store'` + +- [ ] **Step 4: Implement `RedisConversationStore`** + +Create `backend/app/infrastructure/session/redis_conversation_store.py`: + +```python +"""Redis-backed conversation store for persistent chat sessions. + +Sessions are stored as JSON strings under the key `session:{session_id}`. +The Redis TTL is refreshed on every write so active sessions stay alive. +On expiry, `get_session` returns None — callers should create a new session. +""" + +from __future__ import annotations + +import json +import time +import uuid +from typing import Any + +from loguru import logger + +from app.domain.conversation import ConversationMessage, ConversationSession, ConversationStore + + +class RedisConversationStore(ConversationStore): + """Store conversation sessions in Redis with automatic TTL expiry. + + Each session is serialised as a JSON object at key ``session:{session_id}``. + The TTL is reset on every write so sessions stay alive as long as they are active. + """ + + # Prefix for all session keys to avoid collisions with other Redis consumers. + _PREFIX = "session:" + + def __init__(self, *, redis_client: Any, timeout_seconds: int = 1800) -> None: + """Initialise the store with an existing Redis client and a TTL in seconds.""" + self._redis = redis_client + self._ttl = timeout_seconds + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _key(self, session_id: str) -> str: + """Build the Redis key for a session.""" + return f"{self._PREFIX}{session_id}" + + def _serialise(self, session: ConversationSession) -> str: + """Serialise a ConversationSession to a JSON string.""" + return json.dumps( + { + "session_id": session.session_id, + "created_at": session.created_at, + "updated_at": session.updated_at, + "metadata": session.metadata, + "messages": [ + { + "role": msg.role, + "content": msg.content, + "timestamp": msg.timestamp, + "sources": msg.sources, + } + for msg in session.messages + ], + }, + ensure_ascii=False, + ) + + def _deserialise(self, raw: bytes | str) -> ConversationSession: + """Deserialise a JSON string back into a ConversationSession.""" + data = json.loads(raw) + messages = [ + ConversationMessage( + role=m["role"], + content=m["content"], + timestamp=m["timestamp"], + sources=m.get("sources", []), + ) + for m in data.get("messages", []) + ] + session = ConversationSession( + session_id=data["session_id"], + created_at=data.get("created_at", 0), + updated_at=data.get("updated_at", 0), + metadata=data.get("metadata", {}), + ) + session.messages = messages + return session + + def _save(self, session: ConversationSession) -> None: + """Persist a session to Redis and refresh its TTL.""" + self._redis.setex(self._key(session.session_id), self._ttl, self._serialise(session)) + + # ------------------------------------------------------------------ + # ConversationStore protocol + # ------------------------------------------------------------------ + + def create_session(self, metadata: dict | None = None) -> ConversationSession: + """Create a new empty session and persist it immediately.""" + now = int(time.time()) + session = ConversationSession( + session_id=str(uuid.uuid4())[:8], + created_at=now, + updated_at=now, + metadata=metadata or {}, + ) + self._save(session) + return session + + def get_session(self, session_id: str) -> ConversationSession | None: + """Return a session by ID, or None if it does not exist or has expired.""" + raw = self._redis.get(self._key(session_id)) + if raw is None: + return None + try: + return self._deserialise(raw) + except Exception: + logger.warning("Failed to deserialise session: {}", session_id) + return None + + def save_message( + self, + session_id: str, + *, + role: str, + content: str, + sources: list[dict] | None = None, + ) -> ConversationSession | None: + """Append a message to a session and refresh its TTL.""" + session = self.get_session(session_id) + if session is None: + return None + session.messages.append( + ConversationMessage( + role=role, + content=content, + timestamp=int(time.time()), + sources=sources or [], + ) + ) + session.updated_at = int(time.time()) + self._save(session) + return session + + def delete_session(self, session_id: str) -> bool: + """Delete a session. Returns True if it existed, False otherwise.""" + deleted = self._redis.delete(self._key(session_id)) + return bool(deleted) + + def list_sessions(self) -> list[dict]: + """Return summary dicts for all live sessions visible in this Redis DB. + + Note: KEYS is used here for simplicity; in a large production deployment + replace with SCAN for non-blocking iteration. + """ + pattern = f"{self._PREFIX}*" + keys = self._redis.keys(pattern) + result = [] + for key in keys: + raw = self._redis.get(key) + if raw is None: + continue + try: + data = json.loads(raw) + result.append( + { + "session_id": data["session_id"], + "message_count": len(data.get("messages", [])), + "created_at": data.get("created_at", 0), + "updated_at": data.get("updated_at", 0), + } + ) + except Exception: + continue + return result +``` + +- [ ] **Step 5: Run tests** + +```bash +PYTHONPATH=backend pytest tests/test_redis_conversation_store.py -v +``` + +Expected: All 9 tests PASS. + +### Task 7: Add `session_backend` setting and switch bootstrap + +**Files:** +- Modify: `backend/app/config/settings.py` +- Modify: `backend/app/shared/bootstrap.py` + +- [ ] **Step 1: Add `session_backend` setting** + +In `backend/app/config/settings.py`, find the session block (around line 125) and add: + +```python + session_backend: str = Field( + default="memory", + description="会话存储后端 (memory | redis)。redis 需要 Redis 可用。", + ) +``` + +- [ ] **Step 2: Update `get_conversation_store` in bootstrap** + +Replace the existing `get_conversation_store` function in `backend/app/shared/bootstrap.py`: + +```python +@lru_cache +def get_conversation_store() -> ConversationStore: + """Return the active conversation store based on settings. + + When session_backend='redis', sessions survive backend restarts. + When session_backend='memory' (default), sessions are process-local. + """ + if settings.session_backend == "redis": + import redis as redis_lib + from app.infrastructure.session.redis_conversation_store import RedisConversationStore + + # Build the Redis client from the same connection settings as Celery. + kwargs: dict = { + "host": settings.redis_host, + "port": settings.redis_port, + "db": settings.redis_db, + "decode_responses": False, + } + if settings.redis_password: + kwargs["password"] = settings.redis_password + + redis_client = redis_lib.Redis(**kwargs) + return RedisConversationStore( + redis_client=redis_client, + timeout_seconds=settings.session_timeout_minutes * 60, + ) + return InMemoryConversationStore( + max_sessions=settings.session_max_sessions, + timeout_minutes=settings.session_timeout_minutes, + ) +``` + +Also add the `ConversationStore` type hint to the import block at the top of `bootstrap.py`: + +```python +from app.domain.conversation import ConversationStore +``` + +- [ ] **Step 3: Add to `.env.example`** + +Append to `.env.example`: + +```ini +# ── Session backend ─────────────────────────────────────────────────────────── +# Set SESSION_BACKEND=redis for persistent sessions (requires Redis). +# Use SESSION_BACKEND=memory for development without Redis. +SESSION_BACKEND=redis +``` + +- [ ] **Step 4: Verify app starts with the new setting** + +```bash +PYTHONPATH=backend python -c " +from app.config.settings import settings +from app.shared.bootstrap import get_conversation_store +get_conversation_store.cache_clear() +print('session_backend:', settings.session_backend) +print('OK') +" +``` + +Expected: prints `session_backend: memory` (or `redis` if set in .env) then `OK`. + +> ✅ **Group C complete.** Set `SESSION_BACKEND=redis` in `.env` to persist sessions across restarts. + +--- + +## Group D — JWT Authentication + RBAC + Audit Logging + +> Estimated time: 3-5 days. +> +> **Why:** `backend/app/api/main.py:49` has `allow_origins=["*"]`. No route requires any credential. PPT Slide 12 defines a four-role permission matrix; Slide 9 lists "RBAC" as a key challenge response. +> +> **Approach for MVP:** JWT tokens carry a `role` claim (admin/legal/ehs/readonly). All API routes require a valid token. Fine-grained per-route RBAC enforcement is a follow-up; the role field is the enforcement hook. An audit middleware logs every authenticated API call. Users are stored in PostgreSQL. + +### Task 8: Add auth dependencies + +**Files:** +- Modify: `requirements.txt` +- Modify: `pyproject.toml` + +- [ ] **Step 1: Add `python-jose` and `passlib` to `requirements.txt`** + +In `requirements.txt`, add after the existing tool libs block: + +``` +# Authentication +python-jose[cryptography]>=3.3.0 +passlib[bcrypt]>=1.7.4 +``` + +- [ ] **Step 2: Add to `pyproject.toml` dependencies list** + +In `pyproject.toml` under `dependencies`, add: + +```toml + "python-jose[cryptography]>=3.3.0", + "passlib[bcrypt]>=1.7.4", +``` + +- [ ] **Step 3: Install** + +```bash +uv sync +``` + +Expected: No errors. + +### Task 9: Create auth domain models + +**Files:** +- Create: `backend/app/domain/auth/__init__.py` +- Create: `backend/app/domain/auth/models.py` + +- [ ] **Step 1: Write failing test** + +Create `tests/test_jwt_handler.py`: + +```python +"""Tests for JWTHandler — token creation and decoding. + +These tests do not require a running server or database. +""" + +import time +import pytest + +SECRET = "test-secret-key-minimum-32-characters-long" + + +@pytest.fixture +def handler(): + """Return a JWTHandler configured with a test secret.""" + from app.infrastructure.auth.jwt_handler import JWTHandler + return JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=30) + + +def test_create_token_returns_string(handler): + """create_access_token must return a non-empty string.""" + token = handler.create_access_token(user_id="u1", username="alice", role="admin") + assert isinstance(token, str) + assert len(token) > 20 + + +def test_decode_token_returns_correct_claims(handler): + """decode_token must return UserClaims matching the input.""" + token = handler.create_access_token(user_id="u1", username="alice", role="admin") + claims = handler.decode_token(token) + assert claims.user_id == "u1" + assert claims.username == "alice" + assert claims.role == "admin" + + +def test_decode_expired_token_raises(handler): + """decode_token must raise ValueError on an expired token.""" + from app.infrastructure.auth.jwt_handler import JWTHandler + # Create a token that expires in 0 minutes (immediately expired). + short_handler = JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=0) + token = short_handler.create_access_token(user_id="u2", username="bob", role="readonly") + time.sleep(1) + with pytest.raises(ValueError, match="expired"): + short_handler.decode_token(token) + + +def test_decode_invalid_token_raises(handler): + """decode_token must raise ValueError for a tampered token.""" + with pytest.raises(ValueError): + handler.decode_token("not.a.valid.jwt.token") + + +def test_decode_wrong_secret_raises(): + """decode_token must raise ValueError when signed with a different secret.""" + from app.infrastructure.auth.jwt_handler import JWTHandler + creator = JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=60) + verifier = JWTHandler(secret_key="wrong-secret-key-also-minimum-32-chars", algorithm="HS256", expire_minutes=60) + token = creator.create_access_token(user_id="u3", username="carol", role="legal") + with pytest.raises(ValueError): + verifier.decode_token(token) +``` + +- [ ] **Step 2: Run test to confirm failure** + +```bash +PYTHONPATH=backend pytest tests/test_jwt_handler.py -v +``` + +Expected: `ModuleNotFoundError: No module named 'app.domain.auth'` + +- [ ] **Step 3: Create domain auth package** + +Create `backend/app/domain/auth/__init__.py`: + +```python +"""Auth domain: role definitions and token claim models. + +The domain layer defines what a user identity looks like (UserClaims) and +what roles exist (UserRole). Infrastructure details (JWT, bcrypt, PostgreSQL) +live under infrastructure/auth and never leak into this package. +""" + +from .models import UserClaims, UserRole + +__all__ = ["UserClaims", "UserRole"] +``` + +- [ ] **Step 4: Create auth domain models** + +Create `backend/app/domain/auth/models.py`: + +```python +"""Auth domain models: roles and token claims. + +UserRole defines the four roles from PPT Slide 12. +UserClaims is what the JWT decodes to — it is the identity object passed +through FastAPI dependency injection to route handlers. +""" + +from __future__ import annotations + +import enum +from dataclasses import dataclass + + +class UserRole(str, enum.Enum): + """Access roles mirroring the four-role RBAC matrix from the product spec. + + ADMIN — full platform access including system management. + LEGAL — knowledge query, document review, compliance checks. + EHS — knowledge query, perception/regulatory signals. + READONLY — knowledge query only. + """ + + ADMIN = "admin" + LEGAL = "legal" + EHS = "ehs" + READONLY = "readonly" + + +@dataclass +class UserClaims: + """Decoded JWT payload representing an authenticated user. + + Instances are created by JWTHandler.decode_token() and injected into + route handlers via the get_current_user FastAPI dependency. + """ + + # Unique user identifier (UUID string stored in PostgreSQL users table). + user_id: str + # Display name used for audit log entries. + username: str + # Role determines which resources the user may access. + role: UserRole +``` + +### Task 10: Create JWT handler infrastructure + +**Files:** +- Create: `backend/app/infrastructure/auth/__init__.py` +- Create: `backend/app/infrastructure/auth/jwt_handler.py` + +- [ ] **Step 1: Create infrastructure auth package** + +Create `backend/app/infrastructure/auth/__init__.py`: + +```python +"""JWT token creation and validation infrastructure. + +JWTHandler is the only component in this package. It is wired through +shared/bootstrap.py and injected into FastAPI dependencies. +""" +``` + +- [ ] **Step 2: Implement JWTHandler** + +Create `backend/app/infrastructure/auth/jwt_handler.py`: + +```python +"""JWT access token creation and decoding. + +Uses python-jose for HS256 token signing. Token expiry is enforced at +decode time so expired tokens are rejected even if the signature is valid. +""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from typing import Any + +from jose import JWTError, jwt +from loguru import logger + +from app.domain.auth.models import UserClaims, UserRole + + +class JWTHandler: + """Create and validate HS256 JWT access tokens. + + A single shared instance is wired by bootstrap.py. Use + get_jwt_handler() from shared.bootstrap for all token operations. + """ + + def __init__( + self, + *, + secret_key: str, + algorithm: str = "HS256", + expire_minutes: int = 480, + ) -> None: + """Initialise the handler with signing credentials and token lifetime.""" + self._secret = secret_key + self._algorithm = algorithm + self._expire_minutes = expire_minutes + + def create_access_token( + self, + *, + user_id: str, + username: str, + role: str, + ) -> str: + """Return a signed JWT containing user identity and role claims.""" + now = datetime.now(UTC) + payload: dict[str, Any] = { + "sub": user_id, + "username": username, + "role": role, + "iat": now, + "exp": now + timedelta(minutes=self._expire_minutes), + } + return jwt.encode(payload, self._secret, algorithm=self._algorithm) + + def decode_token(self, token: str) -> UserClaims: + """Decode and validate a JWT, returning UserClaims. + + Raises ValueError with a descriptive message on expiry, tampering, + or any other validation failure so callers do not need to know jose. + """ + try: + payload = jwt.decode(token, self._secret, algorithms=[self._algorithm]) + except JWTError as exc: + msg = str(exc).lower() + if "expired" in msg: + raise ValueError("Token expired") from exc + raise ValueError(f"Invalid token: {exc}") from exc + + user_id = payload.get("sub") + username = payload.get("username", "") + role_str = payload.get("role", UserRole.READONLY.value) + + if not user_id: + raise ValueError("Token missing subject claim") + + try: + role = UserRole(role_str) + except ValueError: + logger.warning("Unknown role in token: {}, defaulting to readonly", role_str) + role = UserRole.READONLY + + return UserClaims(user_id=user_id, username=username, role=role) +``` + +- [ ] **Step 3: Run JWT tests** + +```bash +PYTHONPATH=backend pytest tests/test_jwt_handler.py -v +``` + +Expected: All 5 tests PASS. + +### Task 11: Create PostgreSQL user store and seed script + +**Files:** +- Create: `backend/app/infrastructure/auth/user_store.py` +- Create: `scripts/seed_users.py` + +- [ ] **Step 1: Implement PostgreSQL user store** + +Create `backend/app/infrastructure/auth/user_store.py`: + +```python +"""PostgreSQL-backed user store for authentication. + +Manages a `users` table with hashed passwords and roles. +Provides lookup by username for the login flow. +Table DDL is auto-applied on first connection. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from typing import Optional + +import psycopg2 +import psycopg2.extras +from loguru import logger +from passlib.context import CryptContext + +from app.config.settings import settings + + +# bcrypt context — work factor 12 is a good production default. +_PWD_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto") + +# DDL executed once to ensure the table exists. +_CREATE_TABLE_SQL = """ +CREATE TABLE IF NOT EXISTS users ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + username VARCHAR(100) UNIQUE NOT NULL, + hashed_pw TEXT NOT NULL, + role VARCHAR(50) NOT NULL DEFAULT 'readonly', + is_active BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +""" + + +@dataclass +class UserRecord: + """A single row from the users table.""" + + id: str + username: str + hashed_pw: str + role: str + is_active: bool + + +class PostgresUserStore: + """Read and verify users stored in the PostgreSQL users table. + + The connection is opened on first use and shared for the lifetime + of the singleton instance wired by bootstrap. + """ + + def __init__(self) -> None: + """Initialise the store and ensure the users table exists.""" + self._conn = psycopg2.connect( + host=settings.postgres_host, + port=settings.postgres_port, + user=settings.postgres_user, + password=settings.postgres_password, + dbname=settings.postgres_db, + cursor_factory=psycopg2.extras.RealDictCursor, + ) + self._conn.autocommit = True + self._ensure_table() + + def _ensure_table(self) -> None: + """Create the users table if it does not already exist.""" + with self._conn.cursor() as cur: + cur.execute(_CREATE_TABLE_SQL) + + def get_by_username(self, username: str) -> Optional[UserRecord]: + """Return a UserRecord for the given username, or None if not found.""" + with self._conn.cursor() as cur: + cur.execute( + "SELECT id, username, hashed_pw, role, is_active " + "FROM users WHERE username = %s", + (username,), + ) + row = cur.fetchone() + if row is None: + return None + return UserRecord( + id=str(row["id"]), + username=row["username"], + hashed_pw=row["hashed_pw"], + role=row["role"], + is_active=row["is_active"], + ) + + def verify_password(self, plain: str, hashed: str) -> bool: + """Return True if `plain` matches the stored bcrypt hash.""" + return _PWD_CTX.verify(plain, hashed) + + def authenticate(self, username: str, password: str) -> Optional[UserRecord]: + """Return the UserRecord if credentials are valid, else None.""" + user = self.get_by_username(username) + if user is None or not user.is_active: + return None + if not self.verify_password(password, user.hashed_pw): + return None + return user + + @staticmethod + def hash_password(plain: str) -> str: + """Hash a plain-text password with bcrypt.""" + return _PWD_CTX.hash(plain) +``` + +- [ ] **Step 2: Create the seed script** + +Create `scripts/seed_users.py`: + +```python +#!/usr/bin/env python +"""Seed demo users into the PostgreSQL users table. + +Run from the repo root: + PYTHONPATH=backend python scripts/seed_users.py + +Creates four demo accounts (one per role). Safe to run multiple times +— existing usernames are left unchanged (ON CONFLICT DO NOTHING). +""" + +import sys +import os + +# Allow running from repo root without installing the package. +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend")) + +import psycopg2 +from passlib.context import CryptContext +from app.config.settings import settings + +_PWD_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto") + +DEMO_USERS = [ + {"username": "admin", "password": "Admin@2026!", "role": "admin"}, + {"username": "legal", "password": "Legal@2026!", "role": "legal"}, + {"username": "ehs", "password": "EHS@2026!", "role": "ehs"}, + {"username": "readonly", "password": "Read@2026!", "role": "readonly"}, +] + + +def seed() -> None: + """Insert demo users. Skips any username that already exists.""" + conn = psycopg2.connect( + host=settings.postgres_host, + port=settings.postgres_port, + user=settings.postgres_user, + password=settings.postgres_password, + dbname=settings.postgres_db, + ) + conn.autocommit = True + + # Ensure the table exists (same DDL as user_store.py). + with conn.cursor() as cur: + cur.execute(""" + CREATE TABLE IF NOT EXISTS users ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + username VARCHAR(100) UNIQUE NOT NULL, + hashed_pw TEXT NOT NULL, + role VARCHAR(50) NOT NULL DEFAULT 'readonly', + is_active BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() + ); + """) + + inserted = 0 + with conn.cursor() as cur: + for user in DEMO_USERS: + hashed = _PWD_CTX.hash(user["password"]) + cur.execute( + """ + INSERT INTO users (username, hashed_pw, role) + VALUES (%s, %s, %s) + ON CONFLICT (username) DO NOTHING + """, + (user["username"], hashed, user["role"]), + ) + if cur.rowcount > 0: + print(f" Created user: {user['username']} (role={user['role']})") + inserted += 1 + else: + print(f" Skipped (already exists): {user['username']}") + + print(f"\nDone. {inserted} user(s) created.") + conn.close() + + +if __name__ == "__main__": + seed() +``` + +- [ ] **Step 3: Seed demo users (requires PostgreSQL running)** + +```bash +PYTHONPATH=backend python scripts/seed_users.py +``` + +Expected: +``` + Created user: admin (role=admin) + Created user: legal (role=legal) + Created user: ehs (role=ehs) + Created user: readonly (role=readonly) + +Done. 4 user(s) created. +``` + +### Task 12: Create FastAPI auth dependencies and the token endpoint + +**Files:** +- Create: `backend/app/api/dependencies/__init__.py` +- Create: `backend/app/api/dependencies/auth.py` +- Create: `backend/app/api/routes/auth.py` +- Modify: `backend/app/config/settings.py` +- Modify: `backend/app/shared/bootstrap.py` + +- [ ] **Step 1: Add auth settings** + +In `backend/app/config/settings.py`, add to the `Settings` class: + +```python + # ── Auth ───────────────────────────────────────────────────────────── + # Change auth_secret_key to a long random string in production. + # Generate one with: python -c "import secrets; print(secrets.token_hex(32))" + auth_secret_key: str = Field( + default="change-me-in-production-must-be-32-or-more-characters-long", + description="JWT signing secret. MUST be changed in production.", + ) + auth_algorithm: str = Field(default="HS256", description="JWT signing algorithm.") + auth_token_expire_minutes: int = Field(default=480, description="JWT TTL in minutes (default: 8 hours).") + auth_enabled: bool = Field(default=True, description="Set False to bypass auth (development only).") +``` + +- [ ] **Step 2: Wire auth components in bootstrap** + +Add to `backend/app/shared/bootstrap.py`: + +```python +@lru_cache +def get_jwt_handler(): + """Return the shared JWTHandler instance.""" + from app.infrastructure.auth.jwt_handler import JWTHandler + return JWTHandler( + secret_key=settings.auth_secret_key, + algorithm=settings.auth_algorithm, + expire_minutes=settings.auth_token_expire_minutes, + ) + + +@lru_cache +def get_user_store(): + """Return the PostgreSQL user store (lazy, connects on first call).""" + from app.infrastructure.auth.user_store import PostgresUserStore + return PostgresUserStore() +``` + +- [ ] **Step 3: Create the dependencies package** + +Create `backend/app/api/dependencies/__init__.py`: + +```python +"""FastAPI dependency functions for authentication and authorisation. + +Import `get_current_user` or `require_role` into route modules to protect +endpoints. Both use the shared JWTHandler wired through bootstrap. +""" +``` + +- [ ] **Step 4: Implement auth dependencies** + +Create `backend/app/api/dependencies/auth.py`: + +```python +"""FastAPI dependencies for JWT authentication. + +Usage in a route: + from app.api.dependencies.auth import get_current_user, require_role + from app.domain.auth.models import UserRole + + @router.get("/protected") + async def protected(user: UserClaims = Depends(get_current_user)): + return {"user": user.username} + + @router.delete("/admin-only") + async def admin_only(user: UserClaims = Depends(require_role(UserRole.ADMIN))): + ... +""" + +from __future__ import annotations + +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from app.config.settings import settings +from app.domain.auth.models import UserClaims, UserRole +from app.shared.bootstrap import get_jwt_handler + +# Use Bearer token scheme — client sends `Authorization: Bearer `. +_bearer = HTTPBearer(auto_error=False) + + +async def get_current_user( + credentials: HTTPAuthorizationCredentials | None = Depends(_bearer), +) -> UserClaims: + """Extract and validate the JWT from the Authorization header. + + Returns the decoded UserClaims on success. + Raises HTTP 401 when the token is missing, expired, or invalid. + When auth_enabled=False (development), returns a synthetic admin user. + """ + if not settings.auth_enabled: + # Development bypass — never enable this in production. + return UserClaims(user_id="dev", username="dev-admin", role=UserRole.ADMIN) + + if credentials is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing authentication token", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + return get_jwt_handler().decode_token(credentials.credentials) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=str(exc), + headers={"WWW-Authenticate": "Bearer"}, + ) from exc + + +def require_role(*roles: UserRole): + """Return a dependency that enforces one of the given roles. + + Example: + Depends(require_role(UserRole.ADMIN, UserRole.LEGAL)) + """ + async def _check(user: UserClaims = Depends(get_current_user)) -> UserClaims: + """Verify the user holds one of the required roles.""" + if user.role not in roles: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Role '{user.role}' is not permitted. Required: {[r.value for r in roles]}", + ) + return user + return _check +``` + +- [ ] **Step 5: Create the auth router** + +Create `backend/app/api/routes/auth.py`: + +```python +"""Authentication routes — token issuance only. + +POST /auth/token — exchange username + password for a JWT. +GET /auth/me — return the current user's identity (requires token). +""" + +from __future__ import annotations + +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import OAuth2PasswordRequestForm +from pydantic import BaseModel + +from app.api.dependencies.auth import get_current_user +from app.domain.auth.models import UserClaims +from app.shared.bootstrap import get_jwt_handler, get_user_store + + +router = APIRouter(prefix="/auth", tags=["认证"]) + + +class TokenResponse(BaseModel): + """JWT token response body.""" + + access_token: str + token_type: str = "bearer" + expires_in: int + + +@router.post("/token", response_model=TokenResponse) +async def login(form: OAuth2PasswordRequestForm = Depends()): + """Issue a JWT for valid username + password credentials. + + Uses standard OAuth2 password grant form fields so this endpoint + is compatible with Swagger UI's Authorize button. + """ + user_store = get_user_store() + user = user_store.authenticate(form.username, form.password) + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + + from app.config.settings import settings + token = get_jwt_handler().create_access_token( + user_id=user.id, + username=user.username, + role=user.role, + ) + return TokenResponse( + access_token=token, + token_type="bearer", + expires_in=settings.auth_token_expire_minutes * 60, + ) + + +@router.get("/me") +async def get_me(current_user: UserClaims = Depends(get_current_user)): + """Return the identity of the currently authenticated user.""" + return { + "user_id": current_user.user_id, + "username": current_user.username, + "role": current_user.role.value, + } +``` + +- [ ] **Step 6: Register the auth router** + +In `backend/app/api/routes/__init__.py`, add: + +```python +from .auth import router as auth_router +``` + +And in the `api_router.include_router(...)` block, add: + +```python +api_router.include_router(auth_router) +``` + +Also add `auth_router` to `__all__`. + +- [ ] **Step 7: Write auth route tests** + +Create `tests/test_auth_routes.py`: + +```python +"""Integration tests for the auth routes. + +Uses FastAPI TestClient. Does not require a running PostgreSQL — patches +the user_store dependency. +""" + +import pytest +from fastapi.testclient import TestClient +from unittest.mock import MagicMock, patch + +from app.api.main import app +from app.infrastructure.auth.user_store import UserRecord + +client = TestClient(app, raise_server_exceptions=False) + + +@pytest.fixture(autouse=True) +def mock_user_store(): + """Replace the real PostgreSQL user store with a mock.""" + mock_store = MagicMock() + alice = UserRecord(id="uuid-1", username="alice", hashed_pw="hashed", role="admin", is_active=True) + mock_store.authenticate.side_effect = lambda u, p: alice if (u == "alice" and p == "correct") else None + + with patch("app.api.routes.auth.get_user_store", return_value=mock_store): + yield mock_store + + +def test_login_returns_token_for_valid_credentials(): + """POST /auth/token must return an access_token for valid credentials.""" + resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "correct"}) + assert resp.status_code == 200 + body = resp.json() + assert "access_token" in body + assert body["token_type"] == "bearer" + + +def test_login_returns_401_for_wrong_password(): + """POST /auth/token must return 401 for wrong password.""" + resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "wrong"}) + assert resp.status_code == 401 + + +def test_me_returns_user_when_authenticated(): + """GET /auth/me must return user identity when a valid token is provided.""" + # Get a real token first. + login_resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "correct"}) + token = login_resp.json()["access_token"] + + me_resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"}) + assert me_resp.status_code == 200 + assert me_resp.json()["username"] == "alice" + assert me_resp.json()["role"] == "admin" + + +def test_me_returns_401_without_token(): + """GET /auth/me must return 401 when no token is provided.""" + resp = client.get("/api/v1/auth/me") + assert resp.status_code == 401 +``` + +- [ ] **Step 8: Run auth tests** + +```bash +PYTHONPATH=backend pytest tests/test_auth_routes.py -v +``` + +Expected: All 4 tests PASS. + +### Task 13: Add audit middleware and apply auth to routes + +**Files:** +- Create: `backend/app/api/middleware/__init__.py` +- Create: `backend/app/api/middleware/audit.py` +- Modify: `backend/app/api/main.py` +- Modify: `backend/app/api/routes/documents.py` +- Modify: `backend/app/api/routes/rag.py` +- Modify: `backend/app/api/routes/compliance.py` + +- [ ] **Step 1: Create the audit middleware** + +Create `backend/app/api/middleware/__init__.py`: + +```python +"""HTTP middleware for cross-cutting concerns: audit logging.""" +``` + +Create `backend/app/api/middleware/audit.py`: + +```python +"""Audit logging middleware. + +Logs every API request with method, path, status code, response time, +and the authenticated user identity (extracted from the JWT when present). +Log lines are structured so they can be ingested by ELK / Loki. +""" + +from __future__ import annotations + +import time + +from fastapi import Request, Response +from loguru import logger +from starlette.middleware.base import BaseHTTPMiddleware + + +class AuditMiddleware(BaseHTTPMiddleware): + """Log all API calls. Skips /health and /docs to reduce noise.""" + + # Paths that produce no audit log entry. + _SKIP_PREFIXES = ("/health", "/docs", "/redoc", "/openapi.json", "/") + + async def dispatch(self, request: Request, call_next) -> Response: + """Intercept the request, call the handler, and log the outcome.""" + path = request.url.path + if any(path.startswith(p) for p in self._SKIP_PREFIXES): + return await call_next(request) + + start = time.perf_counter() + response = await call_next(request) + elapsed_ms = int((time.perf_counter() - start) * 1000) + + # Try to extract user identity from a decoded JWT header if present. + # We do not re-validate the token here — auth dependencies do that. + user_id = "anonymous" + username = "anonymous" + auth_header = request.headers.get("authorization", "") + if auth_header.startswith("Bearer "): + try: + from app.shared.bootstrap import get_jwt_handler + claims = get_jwt_handler().decode_token(auth_header[7:]) + user_id = claims.user_id + username = claims.username + except Exception: + pass + + logger.info( + "AUDIT method={} path={} status={} elapsed_ms={} user_id={} username={}", + request.method, + path, + response.status_code, + elapsed_ms, + user_id, + username, + ) + return response +``` + +- [ ] **Step 2: Register the middleware and tighten CORS in `main.py`** + +In `backend/app/api/main.py`, replace the `CORSMiddleware` block and add the audit middleware: + +```python +from app.api.middleware.audit import AuditMiddleware + +# Tighten CORS — only allow the configured frontend origin. +# In production, set CORS_ALLOW_ORIGINS in .env to the actual frontend URL. +_ORIGINS = [o.strip() for o in settings.cors_allow_origins.split(",") if o.strip()] +if not _ORIGINS: + _ORIGINS = ["http://localhost:5173"] # Vite dev default + +app.add_middleware( + CORSMiddleware, + allow_origins=_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.add_middleware(AuditMiddleware) +``` + +Also add `cors_allow_origins` to `settings.py`: + +```python + cors_allow_origins: str = Field( + default="http://localhost:5173", + description="Comma-separated allowed CORS origins. Use * only in development.", + ) +``` + +And add to `.env.example`: + +```ini +# ── CORS ───────────────────────────────────────────────────────────────────── +# Comma-separated frontend origin(s). Never use * in production. +CORS_ALLOW_ORIGINS=http://localhost:5173 +``` + +- [ ] **Step 3: Apply `Depends(get_current_user)` to key routes** + +In `backend/app/api/routes/documents.py`, add to imports: + +```python +from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile +from app.api.dependencies.auth import get_current_user +from app.domain.auth.models import UserClaims +``` + +Add `current_user: UserClaims = Depends(get_current_user)` as a parameter to `upload_document`, `list_documents`, `get_document_management_list`, and `delete_document`: + +```python +@router.post("/upload", response_model=DocumentUploadResponse) +async def upload_document( + file: UploadFile = File(...), + doc_id: str | None = Form(None), + doc_name: str | None = Form(None), + regulation_type: str | None = Form(None), + version: str | None = Form(None), + generate_summary: bool = Form(False), + sync: bool = Form(False), + current_user: UserClaims = Depends(get_current_user), # ← add this +): +``` + +Apply the same pattern to: +- `list_documents` in `documents.py` +- `delete_document` in `documents.py` +- `rag_chat` in `rag.py` +- `analyze_stream` in `compliance.py` + +In `backend/app/api/routes/rag.py`, add: + +```python +from fastapi import APIRouter, Depends +from app.api.dependencies.auth import get_current_user +from app.domain.auth.models import UserClaims +``` + +```python +@router.post("/chat") +async def rag_chat( + request: RagChatRequest, + current_user: UserClaims = Depends(get_current_user), # ← add this +): +``` + +In `backend/app/api/routes/compliance.py`, add: + +```python +from fastapi import APIRouter, Depends, File, Form, UploadFile +from app.api.dependencies.auth import get_current_user +from app.domain.auth.models import UserClaims +``` + +```python +@router.post("/analyze-stream") +async def analyze_stream( + text: Optional[str] = Form(None), + doc_id: Optional[str] = Form(None), + file: Optional[UploadFile] = File(None), + domains: Optional[str] = Form(None), + title: Optional[str] = Form(None), + current_user: UserClaims = Depends(get_current_user), # ← add this +): +``` + +- [ ] **Step 4: Verify app starts and auth is enforced** + +```bash +PYTHONPATH=backend python -c "from app.api.main import app; print('OK')" +``` + +Expected: `OK` + +Test that an unauth'd request is rejected: + +```bash +curl -s -o /dev/null -w "%{http_code}" http://localhost:8000/api/v1/rag/chat \ + -X POST -H "Content-Type: application/json" -d '{"query":"test"}' +``` + +Expected: `401` + +Test that an auth'd request passes: + +```bash +TOKEN=$(curl -s -X POST http://localhost:8000/api/v1/auth/token \ + -d "username=admin&password=Admin@2026!" | python -c "import sys,json; print(json.load(sys.stdin)['access_token'])") + +curl -s -o /dev/null -w "%{http_code}" http://localhost:8000/api/v1/rag/quick-questions \ + -H "Authorization: Bearer $TOKEN" +``` + +Expected: `200` + +- [ ] **Step 5: Run the full test suite** + +```bash +PYTHONPATH=backend pytest tests/test_reranker_bootstrap.py tests/test_celery_tasks.py tests/test_redis_conversation_store.py tests/test_jwt_handler.py tests/test_auth_routes.py -v +``` + +Expected: All tests PASS. + +--- + +## Self-Review Checklist + +**Spec coverage:** +- ✅ §3.1 Reranker quick win → Task 1 +- ✅ §3.2 Async task-ification (Celery + Redis) → Tasks 2-5 +- ✅ §3.3 Session persistence → Tasks 6-7 +- ✅ §3.4 Auth/RBAC/Audit → Tasks 8-13 +- ✅ AGENTS.md: `api → application → domain ports → infrastructure` respected throughout +- ✅ No new logic in `services/*` or `workflows/*` +- ✅ `shared/bootstrap.py` is the composition root for all new singletons +- ✅ All comments/docstrings in new backend files are in English +- ✅ Desktop-first frontend constraint: no frontend changes in this plan + +**Placeholder scan:** None found — all code blocks are complete and runnable. + +**Type consistency:** +- `UserClaims` defined in Task 9, used in Tasks 12, 13 ✅ +- `UserRole` enum defined in Task 9, used in Task 12 ✅ +- `process_document_task` name matches registration in Task 3 and lookup in Task 3 test ✅ +- `store_document` defined in Task 3, called in Task 4 ✅ +- `_process_document` defined in Task 3, called from Task 3 and Task 4 ✅ +- `get_jwt_handler`, `get_user_store`, `get_celery_app` added to bootstrap in Tasks 4, 12 ✅ +- `RedisConversationStore` constructor signature `(redis_client, timeout_seconds)` consistent across implementation, tests, and bootstrap ✅ diff --git a/docs/superpowers/specs/2026-06-05-next-steps-roadmap-design.md b/docs/superpowers/specs/2026-06-05-next-steps-roadmap-design.md new file mode 100644 index 0000000..5175ad7 --- /dev/null +++ b/docs/superpowers/specs/2026-06-05-next-steps-roadmap-design.md @@ -0,0 +1,289 @@ +# AI+合规智能中枢 — 下一步开发与优化路线图(设计文档) + +- 日期:2026-06-05 +- 定位:试点 MVP 走向生产 +- 范围:全景清单 + 异步任务化(设计①)+ 法规感知闭环(设计②)深入方案 + 三阶段实施路线图 +- 作者:AI Regulations Team(brainstorming 产出) + +--- + +## 0. 背景与目的 + +本文档基于对当前仓库前后端真实代码的逐文件探查,结合四份愿景文档(`AI_Regulations_Report.pptx`、`AI_Regulations_Architecture.docx`、`01_Architecture.html`、`02_Architecture_Detail.html`)与最新开源 AI 技术调研,给出**下一步可继续开发与优化的方向清单**,并对两个最高价值方向给出可落地的深入设计。 + +本文档是**方向性设计(spec)**,不是实施计划(plan)。阶段一、阶段二的具体落地由后续 writing-plans 环节拆分为分步计划。 + +### 0.1 现状一句话 + +后端是一套结构清晰的 DDD 风格 FastAPI RAG 系统(上传 → 解析 → 分块 → BGE-M3 嵌入 → Milvus → 混合检索 → 流式问答 + 合规分析),**真实可用**。但愿景文档中的多个旗舰能力(知识图谱、法规感知闭环、RBAC、EHS、异步化)目前为 **mock 或缺失**。 + +--- + +## 1. 现状盘点(基于真实代码) + +### 1.1 已实现且真实可用 + +- **文档处理主链路**:`application/documents/services.py::DocumentCommandService.upload_and_process` — 存储 → 解析(阿里云 DocMind / 本地)→ 分块 → BGE-M3 嵌入 → Milvus 入库,含 `DocumentProcessingStore` 全程状态事件记录。 +- **混合检索**:`application/knowledge/services.py::KnowledgeRetrievalService` — Dense(`DenseRetriever`)+ BM25(jieba)+ Reciprocal Rank Fusion + 可选 Cross-Encoder 重排。 +- **流式 RAG 问答**:`application/agent/services.py::AgentConversationService.stream_chat` + `api/routes/rag.py` — 真实检索 + 引文 + 会话历史 + SSE。 +- **合规分析管线**:`application/compliance/pipeline.py` — clause_split → retrieve → gap_check → conclusion,真实 LLM + 真实检索,SSE 流式(`api/routes/compliance.py::analyze_stream`)。 +- **状态/健康面板**:`api/routes/status.py` + 前端 `StatusPage.tsx` — Milvus/MinIO/BM25/Reranker/会话实时状态。 +- **存储后端**:PostgreSQL / MinIO 适配器齐全;JSON 与 Postgres 双后端可切换。 +- **前端**:React 19 + Vite + Tailwind,6 个页面(Overview/Status/Perception/Docs/Compliance/RagChat)。 + +### 1.2 愿景已规划但代码缺失或为 mock + +| 能力 | 愿景出处 | 代码现状 | +|------|---------|---------| +| 知识图谱 / Neo4j 多跳推理 | 架构图 L4/L5、Slide 5 | 全代码 0 处 neo4j/graph | +| 法规感知自动更新闭环 | 01_Architecture.html L157-193、Slide 11 | `PerceptionService` 喂 `MockEventStore`(20 条死数据) | +| 认证 / RBAC / 审计日志 | Slide 12 四角色权限矩阵 | 全代码 0 处 auth/jwt/rbac;`main.py` CORS=`*` | +| 异步任务 / Worker 集群 | 架构图"Worker 集群"、Slide 9 | `app/workers/` 空目录;处理全同步 | +| EHS 隐患识别(SIF/四维根因) | Slide 7 | 未实现 | +| 多渠道推送(Email/Teams/飞书) | Slide 8 | 未实现 | +| 闭环整改跟踪、可观测性 | 架构图右栏 | 缺失 | + +### 1.3 关键发现 + +- **`requirements.txt:28` 已有 `celery>=5.3.0` + `redis>=4.5.0`**,`docker-compose.yml` 已配 Redis 7,`settings.py` 已有 redis 配置 —— **异步化是"接线",不是"从零搭建"**。 +- **`DocumentProcessingStore` 已能记录 run 状态/状态事件** —— 是天然的任务进度表。 +- **`PerceptionService.analyze_event` 的 LLM 影响分析与 RAG 关联检索是真的** —— 感知闭环缺的只是前半段(采集 → Diff → 入库)。 +- 后端正处于 legacy 迁移期:`services/*`、`workflows/*` 为兼容层(见 `docs/architecture/backend-project-architecture.md`)。 + +--- + +## 2. 全景机会清单 + +类型标记:`[新能力]`=愿景缺口补齐,`[加固]`=已实现能力优化。价值 ★(1-5),工作量 S/M/L。 + +### P0 — 生产地基(阻断"走向生产"的硬伤) + +| # | 机会点 | 类型 | 现状证据 | 价值 | 工作量 | +|---|--------|------|---------|------|--------| +| 1 | 异步任务化(Celery + 已配 Redis):解析/嵌入/感知/推送下沉 worker | 加固 | `workers/` 空;`documents.py:34` 上传同步阻塞 | ★★★★★ | L | +| 2 | 认证 + RBAC + 审计日志,收紧 CORS | 新能力 | 0 处 auth;`main.py` CORS=`*`;Slide 12 | ★★★★★ | M | +| 3 | 会话 & 任务持久化(内存 → Redis/PG) | 加固 | `bootstrap.py:254` 内存会话;`compliance.py:25` 内存字典 | ★★★★ | M | +| 4 | 基础可观测性(Prometheus + 结构化日志 + 追踪) | 加固 | 仅 loguru;架构图右栏全缺 | ★★★ | M | + +### P1 — 高价值能力补齐 + RAG 质量 + +| # | 机会点 | 类型 | 现状证据 | 价值 | 工作量 | +|---|--------|------|---------|------|--------| +| 5 | 启用并升级 Reranker(`bge-reranker-v2.5-gemma2-lightweight`) | 加固 | `settings.py:113` 默认关;管线已写好 | ★★★★ | S | +| 6 | Agentic 检索(查询改写/意图理解/多路召回) | 加固 | `agent/services.py` 直接 retrieve,无 rewrite/HyDE | ★★★★ | M | +| 7 | 知识图谱 / GraphRAG(Neo4j + LightRAG v1.5) | 新能力 | 0 处 neo4j;LightRAG v1.5 原生支持 | ★★★★★ | L | +| 8 | 法规感知自动更新闭环(真实采集 + 版本 Diff + 增量重索引) | 新能力 | `perception/services.py` 用 MockEventStore | ★★★★★ | L | +| 9 | 引文置信度评分(Slide 5 承诺"置信度评分+页码溯源") | 加固 | `rag.py` sources 无 confidence | ★★★ | S | +| 10 | 检索评估 harness(recall@k / faithfulness) | 加固 | `tests/` 需真实服务,无离线 RAG 评估 | ★★★ | M | + +### P2 — 视野扩展(独立子项目) + +| # | 机会点 | 类型 | 价值 | 工作量 | +|---|--------|------|------|--------| +| 11 | EHS 隐患识别(SIF 评分 + 四维根因 + ISO 45001 扫描,Slide 7) | 新能力 | ★★★★ | L | +| 12 | 多渠道推送 + 订阅规则引擎(Email/Teams/飞书,Slide 8) | 新能力 | ★★★ | M | +| 13 | 闭环整改跟踪(任务派发 → 进度 → 验收归档) | 新能力 | ★★★ | M | +| 14 | 企业系统集成(PLM/ERP/OA/MES Webhook) | 新能力 | ★★ | L | +| 15 | MinerU 3.1 升级(已转 Apache 协议,VLM 解析)作本地兜底 | 加固 | ★★ | S | +| 16 | 前端加固(清 mock 数据、补 error/loading 态、KG 可视化、登录态) | 加固 | ★★★ | M | +| 17 | 收口 legacy 迁移(`services/*`、`workflows/*` 按架构文档归位) | 加固 | ★★ | M | + +--- + +## 3. 深入设计 ① — 异步任务化 + +### 3.1 问题 + +`upload_document`(`api/routes/documents.py:34`)在单个 HTTP 请求内同步跑完 存储 → 解析(阿里云云端可达 900 秒,`settings.py:49`)→ 嵌入 → Milvus 入库。大体量 GB 标准必然超时;`compliance.py` 的 `/analyze` 为假异步(立即返回 mock);perception 爬取闭环无执行载体。PPT Slide 9 已将"大文件性能"列为关键挑战,对策正是"流式处理 + 异步队列 + 实时进度"。 + +### 3.2 关键前提:基建已就位 + +- `requirements.txt:28` 已含 `celery>=5.3.0` + `redis>=4.5.0` +- `docker-compose.yml:46` Redis 7 已配置;`settings.py:64` 已有 redis 连接配置 +- `PostgresDocumentProcessingStore` 已记录 run 状态/状态事件 —— 天然任务进度表 +- `app/workers/` 为空目录(唯一缺口) + +### 3.3 架构(遵循 AGENTS.md 的 `api → application → domain ports → infrastructure`) + +``` +api/routes/documents.py POST /upload + │ 1. 存二进制 + 建 Document 记录(快,同步) + │ 2. enqueue task → 立即返回 {doc_id, status:"queued", run_id} + ▼ +infrastructure/tasks/ ← 新增 + celery_app.py broker=redis, backend=redis + document_tasks.py @task process_document(doc_id) → DocumentCommandService + │ 复用现有 upload_and_process 的 parse→embed→index 段 + ▼ +application/documents/services.py(拆分:store 与 process 解耦) + │ 每阶段写 DocumentProcessingStore(已存在)→ 进度可查 + ▼ +api/routes/documents.py GET /status/{doc_id} ← 已存在,读 run 状态即可 +``` + +### 3.4 落地步骤(增量、不破坏现有同步路径) + +1. 新增 `infrastructure/tasks/celery_app.py` — Celery 实例,broker/backend 指向已配 Redis。 +2. 拆分 `upload_and_process` → `store_document`(同步快)+ `process_document`(可异步),复用现有逻辑,零重写解析/嵌入代码。 +3. 新增 `document_tasks.py` — `@celery_app.task` 包裹 `process_document`,失败用 `tenacity`(已在 deps)重试 + 死信。 +4. 改 `documents.py` 上传 — 默认入队(保留 `?sync=true` 同步回退便于演示);`GET /status/{doc_id}` 读 `DocumentProcessingStore` 返回阶段进度。 +5. 前端 `DocsPage.tsx` — 上传后轮询/SSE 进度条(架构图 Worker"心跳/状态上报"已是既定设计)。 +6. `dev.sh`/`dev.bat` 加 worker 启动:`celery -A app.infrastructure.tasks.celery_app worker`。 + +### 3.5 工作量与风险 + +- **M(中),3-5 天。** +- 最大风险:Celery worker 进程内 `PYTHONPATH=backend` 与 bootstrap `lru_cache` 单例需重新初始化 —— 可控,因 bootstrap 已是懒加载。 +- YAGNI 边界:本期仅异步化"文档处理"一条链;compliance/perception 复用同一 Celery 基建后续接入。 + +--- + +## 4. 深入设计 ② — 法规感知自动更新闭环 + +### 4.1 问题 + +感知闭环是愿景旗舰能力(`01_Architecture.html` L157-193、Slide 11)。现状:`PerceptionService` 喂 `MockEventStore`(`mock_event_store.py:7`,20 条手写死数据),`list_events`/`stats` 全静态,`source_url` 真实但从不访问。**LLM 影响分析与 RAG 关联检索是真的** —— 闭环缺的是前半段:真实采集 → 变更感知(Diff)→ 入库。 + +### 4.2 六步现状对照 + +| 步骤 | 愿景设计 | 现状 | 本期目标 | +|------|---------|------|---------| +| ① 法规源监控 | 定时爬国标网/MIIT/UN-ECE/EUR-Lex | ❌ 无 | ✅ 适配器+定时 | +| ② 智能变更感知 | NLP 比对新旧版本 Diff | ❌ 无 | ✅ 内容指纹+LLM Diff | +| ③ 自动解析入库 | MinerU→分块→BGE-M3→Milvus | ✅ 已有(复用设计①管线) | ✅ 接线 | +| ④ 知识图谱更新 | Neo4j 关系同步 | ❌ 无 | ⏭️ 本期不做(归 GraphRAG 专项) | +| ⑤ 差距分析&推送 | AI 比对+按角色推送 | 🟡 analyze_event 已有分析,无推送 | 🟡 分析复用,推送下期 | +| ⑥ 触发整改闭环 | 整改任务跟踪 | ❌ 无 | ⏭️ 下期 | + +本期聚焦 ①②③,复用设计①异步管线与已有解析/嵌入/检索/分析能力。 + +### 4.3 架构(端口与适配器) + +``` +domain/perception/ports.py ← 新增 + RegulationSource (Protocol) fetch_latest() → list[RawRegulation] + EventStore (Protocol) 抽象掉 MockEventStore(现有 mock 成为一个实现) + ChangeDetector (Protocol) diff(old, new) → ChangeSet + +infrastructure/perception/ + sources/ ← 新增,每法规源一个适配器 + gb_openstd_source.py 国标网 (openstd.samr.gov.cn) + miit_source.py 工信部 + base_html_source.py 通用 HTML 抓取基类(httpx 已在 deps) + postgres_event_store.py ← 替换 MockEventStore(真实持久化) + content_fingerprint_detector.py 哈希指纹 + LLM 语义 Diff + +application/perception/services.py(扩展现有) + ingest_cycle() ← 新增:①抓取 → ②Diff → ③入队解析(设计①的 task) + (list_events/analyze_event 保持不变,已是真实逻辑) + +infrastructure/tasks/perception_tasks.py ← 复用设计①的 Celery + @task perception_crawl_cycle() Celery Beat 定时触发 +``` + +### 4.4 关键设计决策 + +1. **接口契约零改动**:`PostgresEventStore` 输出与 `MockEventStore` 完全相同的 dict 结构(mock_event_store.py 的 20 字段),故 `perception.ts` 前端契约、`PerceptionPage.tsx`、`analyze_event` 全部不改。Mock 退化为种子数据/演示回退,通过 `perception_event_store=mock|postgres` 开关切换(对齐现有 `document_repository_backend` 模式)。 +2. **变更感知分两层**:廉价层(内容哈希指纹判断"是否变了")+ 智能层(变了才调 LLM 做"新增/修订/废止条款"结构化 Diff,复用 `get_llm_client`,prompt 风格照搬 `compliance/pipeline.py::_extract_json`)。 +3. **合规防滥用**:尊重 `robots.txt` + 限速 + `tenacity` 重试 + 抓取失败不污染已有数据;适配器隔离,单源故障不影响其它。 +4. **入库复用设计①**:抓到新法规 PDF → 丢进 `process_document` task → 自动走完解析/嵌入/索引。 + +### 4.5 落地步骤 + +1. 抽 `domain/perception/ports.py`,让现有 `MockEventStore` 实现 `EventStore` 协议(纯重构,行为不变)。 +2. `PostgresEventStore` + 建表(参照 `aliyun_parser/schema.sql` 风格)+ 20 条 mock 作 seed。 +3. 先做 1 个真实源适配器(建议国标网,结构最稳)跑通 ①→②→③,验证端到端。 +4. `content_fingerprint_detector` + LLM Diff。 +5. `perception_crawl_cycle` Celery Beat 定时(每日);新事件落 PostgresEventStore + 新法规入队解析。 +6. 前端 `PerceptionPage` 加"最近同步时间/本次新增 N 条"(stats 已有结构,加 2 字段)。 + +### 4.6 工作量与风险 + +- **L(大),5-8 天**,依赖设计①先落地(共用 Celery)。 +- 最大风险:外部源站不可控(改版/反爬)。缓解:适配器隔离 + mock 永久保留为回退 + 先攻 1 个源验证(对齐 Slide 13"选取 2-3 个场景 POC 验证")。 +- YAGNI 边界:④Neo4j 图谱、⑥整改闭环、多渠道推送本期不做,各自独立子项目。 + +--- + +## 5. 三阶段实施路线图 + +### 5.1 核心主线 + +项目不缺"能力点",缺的是**让能力点从同步脚本变成可运营的系统**。主线是**异步化基建**:既是文档处理性能解药(设计①),又是感知闭环执行载体(设计②),也是未来 EHS/推送的统一底座。路线图以它为"第 0 块地基",其余能力挂载其上。 + +### 5.2 与 PPT 三阶段映射(Slide 10) + +``` +PPT 规划 代码现状 本路线图补齐 +───────────────────────────────────────────────────── +一阶段 知识库+基础问答 ✅ 大体已实现 → 加固 (P0/P1) +二阶段 文档审查+API集成 🟡 审查真/API半 → 异步化+感知闭环 +三阶段 EHS+个性化+图谱 ❌ 基本缺失 → 子项目 (P2) +``` + +### 5.3 阶段一 · 生产地基(2-3 周)— "让它扛得住生产" + +| 顺序 | 事项 | 依据 | 估时 | +|------|------|------|------| +| 1 | 设计① 异步任务化 | celery/redis 已在 deps,workers/ 空 | M, 3-5d | +| 2 | 认证 + RBAC + 审计 + 收紧 CORS | 0 处 auth;Slide 12 矩阵 | M, 3-5d | +| 3 | 会话/任务持久化(内存 → Redis/PG) | InMemoryConversationStore 重启即丢 | M, 2-3d | +| 4 | 快赢:启用 Reranker | settings 默认关,管线已写好 | S, 0.5d | + +### 5.4 阶段二 · 招牌能力(2-3 周)— "让它有亮点" + +建议**感知闭环优先于图谱**(前者复用阶段一异步基建,ROI 更高)。 + +| 顺序 | 事项 | 依据 | 估时 | +|------|------|------|------| +| 5 | 设计② 法规感知闭环 ①②③ | MockEventStore → 真实采集 | L, 5-8d | +| 6 | Agentic 检索(查询改写/意图理解) | Slide 5"意图理解",代码是直检索 | M, 3-4d | +| 7 | 引文置信度评分 + 基础可观测性 | Slide 5 承诺;架构图右栏全缺 | S+M, 3-4d | + +### 5.5 阶段三 · 视野扩展(按需,各为独立子项目)— "让它成体系" + +每项单独 brainstorm → spec → 实施,本期不细化: + +- 知识图谱 / GraphRAG(Neo4j + LightRAG v1.5,接感知闭环第④步) +- EHS 隐患识别(SIF + 四维根因,Slide 7) +- 多渠道推送 + 订阅规则引擎(Slide 8)→ 闭环整改跟踪(第⑤⑥步) +- 持续加固:MinerU 3.1 升级、前端清 mock、legacy 收口 + +### 5.6 决策建议 + +1. 强烈建议按阶段顺序:地基 → 招牌 → 扩展。跳过地基直接做招牌,会在生产暴露超时/无鉴权/数据丢失。 +2. 阶段一第 4 项(Reranker)可立即做 —— 半天见效,与其它解耦,适合先尝甜头。 +3. 阶段二二选一先行:要 demo 冲击力选"感知闭环";要问答质量选"Agentic 检索"。 + +--- + +## 6. 最新 AI 技术调研(支撑选型) + +| 技术 | 版本/状态(2026) | 对应机会点 | +|------|------------------|-----------| +| LightRAG | v1.5.0(2026-06),EMNLP 2025;KG-RAG,原生支持 Neo4j + MinerU/Docling,含 Web UI 图谱可视化 | #7 知识图谱 | +| MinerU | v3.1.0(2026-04),协议转为 Apache 2.0 基础的开源协议,VLM 解析(MinerU2.5-Pro),109 语言 OCR | #15 本地解析兜底 | +| BGE Reranker | `bge-reranker-v2.5-gemma2-lightweight`(token 压缩 + 分层轻量化,生产推荐) | #5 Reranker 升级 | +| BGE-M3 | 100+ 语言,8192 上下文,dense+sparse+colbert 统一(现已在用) | 现有嵌入 | +| RAGFlow | 2026 支持 DeepSeek v4 / MCP / 跨语言查询;agentic RAG 参考实现 | #6 Agentic 检索参考 | + +--- + +## 7. 验收与边界 + +### 7.1 本文档明确不做(YAGNI) + +- 阶段三所有子项目(图谱、EHS、推送、整改闭环、企业集成)仅列方向,不在本期展开。 +- 移动端适配(AGENTS.md 明确 desktop-first)。 +- 感知闭环的第④⑤⑥步(图谱同步、推送、整改)。 + +### 7.2 架构约束(必须遵守) + +- 后端遵循 `api → application → domain ports → infrastructure`(`docs/architecture/backend-project-architecture.md` 为权威)。 +- 新业务逻辑不得落入 `services/*`、`workflows/*`(legacy 迁移区)。 +- `shared/bootstrap.py` 为依赖装配 composition root,新依赖在此接线。 +- 后端注释/docstring 全英文(AGENTS.md 规范)。 + +### 7.3 下一步 + +经用户审阅本 spec 后,对**阶段一**(异步任务化优先)调用 writing-plans 拆分为分步实施计划。 diff --git a/docs_dump.txt b/docs_dump.txt new file mode 100644 index 0000000..74dcff9 --- /dev/null +++ b/docs_dump.txt @@ -0,0 +1,108 @@ +===== PPTX ===== +--- Slide 1 ---: AI + 法律法规 | 合规智能中枢 | 面向车企与工厂的 AI 驱动合规解决方案 | 2026年4月 | EMS & EHS Compliance Intelligence Hub | AI Compliance +Intelligence Hub | Internal | AI 合规智能中枢 | 2026.04 +--- Slide 2 ---: 背景与挑战 | 车企和工厂面临的合规困境 | 法规来源复杂 | 国标GB · MIIT · UN-ECE +IATF 16949 · ISO 45001 +多轨并行,难以统管 | 更新频率高 | 新能源 · 数据安全 · 碳排放 +PIPL · NEV积分 · CCER +政策持续迭代 | 跨语言需求 | 中英文法规混存 +跨国工厂多语言 +合规场景并存 | 文档高度分散 | 分散于 Confluence +SharePoint · ERP · PLM +无法联通查询 | 隐患识别被动 | EHS 安全依赖人工 +隐患发现滞后 +缺乏预防性机制 | 覆盖核心法规域 | 🚗 车辆安全 GB 7258 · GB 18384 · UN-ECE R155/156 | 🔒 数据安全 PIPL · DSL · GB/T 35273 | 🏭 工厂EHS GB 6441 · AQ/T系列 · ISO 45001 | ♻️ 碳排放 NEV积分 · CCER · 欧盟碳边境税 | ✅ 质量管理 IATF 16949 · GB/T 19001 | AI 合规智能中枢 | 面向车企与工厂 | 2026.04 | 2 / 13 +--- Slide 3 ---: 产品定位与整体架构 | AI 驱动的全链路合规智能平台 | AI 合规智能中枢 | 📚 知识库构建 | 内外部法规 · 历史案例 +统一知识图谱 · 自动更新 | 💬 智能问答 | 混合检索 · 语义+关键词 +中英双语 · 引文溯源 | 📄 合规审查 | PDF/Word上传 +自动比对法规 · 风险标注 | 🔌 API集成 | 对接PLM · ERP · OA · MES | 🎯 个性化推荐 | 角色画像 · 上下文感知 | 📢 定制推送 | Email · Teams +飞书 · 钉钉 +法规变更 +实时通知 | 🦺 EHS 隐患识别 & 管理体系审计(C-SG专项) | 事故报告 NLP | SIF潜力识别 | 四维根因分析 | ISO 45001 要素扫描 | 自动生成审计报告 | 趋势分析仪表板 | AI 合规智能中枢 | 面向车企与工厂 | 2026.04 | 3 / 13 +--- Slide 4 ---: 功能一:合规知识库构建与动态更新 | 统一接入内外部法规,构建可检索的结构化知识库 | 📥 数据来源 | 内部文档 | Confluence · SharePoint +飞书 · 历史合规报告 · 审计记录 | ↓ | 外部法规 | 国标全文库 · 工标网 +MIIT政策 · UN-ECE · EUR-Lex | ↓ | 历史案例 | 处罚案例库 · 整改记录 +行业事故通报 | ⚙️ 处理流程 | 1 | ① 文档解析 | 版面感知OCR,扫描件 · PDF表格 · 多栏 · Word/Excel | ↓ | 2 | ② 智能分块 | 章节级 / 条款级双粒度切割,保留语义完整性 | ↓ | 3 | ③ 向量化存储 | 多语言嵌入(中英双语),向量库 + 关键词索引双轨 | ↓ | 4 | ④ 知识图谱 | 法规实体 → 条款 → 义务 → 适用范围关系图谱 | ↓ | 5 | ⑤ 自动更新 | 定时监控法规变更,触发增量重索引 + 版本管理 | ✨ 核心价值 | 数据不出厂 | 私有化本地部署 +满足PIPL/DSL数据主权 | 权限分级管理 | 研发/生产/采购/法务 +差异化访问控制 | 实时保鲜 | 法规修订自动触发重索引 +确保知识时效性 | 多格式支持 | 扫描件 · PDF · Word +Excel · 标准文件全覆盖 | AI 合规智能中枢 | 面向车企与工厂 | 2026.04 | 4 / 13 +--- Slide 5 ---: 功能二:混合检索智能问答引擎 | 语义检索 + 关键词检索 + 知识图谱,生成可溯源的合规决策建议 | 用户提问 | 中 / 英 / 混合 +自然语言输入 | ▶ | 意图理解 | 识别法规实体 +适用场景 · 地域 | ▶ | 混合检索 | BM25关键词 ++ 语义向量 +本地+网络双路 | ▶ | 重排序 | Cross-Encoder +精排召回结果 | ▶ | 生成回答 | 引文锚定输出 +置信度评分 +页码溯源 | 典型问答场景 | 法规解读 | "我们的纯电SUV需满足哪些GB强制认证要求?" | 政策查询 | "2025年NEV积分核算方式有哪些最新变化?" | 合规判断 | "供应商A的REACH声明是否满足我司采购合规要求?" | 多跳推理 | "ISO 45001变更管理要求,对应哪些内部流程需更新?" | 对比分析 | "GB 18384与欧盟ECE R100在电池安全上有哪些差异?" | 📎 引文溯源 | 答案标注原文出处 +页码精确定位 | 🌐 多语言支持 | 中英混合检索 +无需切换语言 | ⚖️ 决策辅助 | 结合内部制度 +输出综合建议 | 🔄 图谱增强 | 关联上下游条款 +多跳推理支持 | AI 合规智能中枢 | 面向车企与工厂 | 2026.04 | 5 / 13 +--- Slide 6 ---: 功能三:智能文档合规审查 | 上传 PDF/Word,自动比对法规库,标注风险并给出整改建议 | ⚙️ 审查流程 | 1 | ① 文件上传 | PDF · Word · Excel · 扫描件,支持批量 | ↓ | 2 | ② 文档解析 | 版面感知OCR,段落/条款级分块 | ↓ | 3 | ③ 法规域匹配 | 根据文档类型+内容自动识别适用法规域 | ↓ | 4 | ④ 合规比对 | 条款级语义对比,缺项检测 · 风险评分 | ↓ | 5 | ⑤ 报告输出 | 非合规位置标注,整改建议 · 风险等级 | 📋 报告输出内容 | 📍 | 非合规位置标注 | 页码 + 段落高亮,一键跳转原文 | ⚠️ | 风险等级分级 | 红(高危)/ 橙(中)/ 黄(低)三级 | 📖 | 法规条款引用 | 精确关联对应法规原文条款编号 | 🔧 | 整改建议 | 基于历史合规案例,给出可执行方案 | 📂 适用文档类型 | 供应商合规声明 | REACH/RoHS · 碳足迹申报 | 新产品EHS评估 | GB安全标准覆盖完整性核查 | 工厂安全作业规程 | AQ/T符合性 · 许可条款 | 劳动合同/协议 | 劳动法 · 工时 · 竞业条款 | 数据处理协议 | PIPL/GDPR 数据主体权利 | 供应链碳申报 | CCER/CBAM 核算方法验证 | AI 合规智能中枢 | 面向车企与工厂 | 2026.04 | 6 / 13 +--- Slide 7 ---: EHS 隐患识别 & 管理体系审计(C-SG专项) | AI驱动的主动安全预防:从被动响应到预测性干预 | 📥 数据输入 | 📝 事故/事件报告文本 · 巡检记录 · 安全观察卡 | 📊 设备运行数据 · 工伤统计 · 隐患整改台账 | 📷 现场照片(目标检测)· 视频(行为分析,可选) | 🤖 AI隐患识别引擎 | NLP文本分析 | 从叙述性文本中提取隐患实体 +触发因素 · 伤害类型 · 位置信息 | SIF风险评分 | 高严重性事件潜力预测 +优先处置最高风险隐患 | 四维根因分析 | 人因/设备/管理/环境 +系统性根因挖掘 | 法规自动关联 | 与GB 6441/AQ系列/ISO 45001 +自动映射对应条款 | 📋 体系审计功能 | ✓ ISO 45001要素覆盖度扫描(PDCA完整性) | ✓ 历史案例相似度匹配与经验复用 | ✓ 整改优先级排序(风险×紧迫×可行性) | ✓ 审计报告自动生成(条款级评分) | ⚠️ 典型隐患场景 | ▸ 高处坠落 | AQ/T 3049 | ▸ 有限空间 | AQ 3028 | ▸ 化学品管理 | GB 13690 | ▸ 设备点检 | IATF §8.5 | ▸ 应急演练 | ISO 45001 §8.2 | 📤 输出成果 | 隐患清单 | 位置 · 类型 +风险等级 +法规依据 +整改建议 | 体系审计 +报告 | 条款级符合 +性评分 +整改优先级 | 趋势分析 +仪表板 | 隐患热图 +月度趋势 +部门对比 | AI 合规智能中枢 | 面向车企与工厂 | 2026.04 | 7 / 13 +--- Slide 8 ---: 系统集成 · 个性化推荐 · 定制推送 | 合规能力 API 化,主动触达用户,融入业务流程 | 🔌 合规审查 API 化 | POST | /compliance/check | 大文本分片合规检查 | POST | /compliance/upload | PDF/Word文件上传审查 | GET | /compliance/query | 法规知识库问答 | POST | /compliance/subscribe | 法规变更Webhook订阅 | 🔗 企业系统集成 | PLM | 新产品立项/BOM变更 | → 自动触发法规适用性检查 | ERP | 供应商准入/合同签署 | → 供应商自动合规评分 | OA | 合同/协议提交审批 | → 高风险自动抄送法务 | MES | 生产工艺变更 | → 触发EHS合规影响评估 | 🎯 个性化推荐 | 👤 角色画像:EHS · 法务 · 采购 · 研发 | 💡 上下文感知:对话主题 → 关联法规推荐 | 🔔 到期提醒:认证到期 · 法规更新预警 | 📈 行为学习:历史查询 → 智能问题推荐 | 📢 定制化法规推送 | 📧 Email | HTML富文本,含变更对比 | 💬 Teams | 企业Bot,实时推送 | 📱 飞书/钉钉 | 企业机器人,移动端 | 🔔 站内消息 | 系统内通知中心 | ⚙️ 推送规则引擎 | ▸ 订阅维度: | 按法规域 / 业务场景 / 地域灵活订阅 | ▸ 优先级: | 🔴 强制 🟠 推荐 🔵 参考 三级分类 | ▸ 免打扰: | 工作时间推送 · 摘要合并 · 频率上限 | ▸ 内容生成: | LLM自动生成变更摘要 + 影响分析 + 行动项 | AI 合规智能中枢 | 面向车企与工厂 | 2026.04 | 8 / 13 +--- Slide 9 ---: 关键挑战与应对策略 | 确保合规建议的准确性、时效性与数据安全 | LLM幻觉风险 | 问题 | 合规建议失真 +可能导致法律责任 | 应对 | 引文锚定 + 输出验证 +高风险强制人工审核 | 数据主权 | 问题 | 敏感文件不能 +上传公有云 | 应对 | 全链路私有化部署 +数据不出厂 | 法规时效性 | 问题 | 知识库滞后 +导致错误建议 | 应对 | 自动更新机制 +时间戳标注 + 提醒 | 跨语言质量 | 问题 | 中英混合场景 +检索精度下降 | 应对 | 多语言嵌入模型 +语言标签过滤策略 | 大文件性能 | 问题 | GB标准数百页 +处理超时风险 | 应对 | 流式处理 + 分层索引 +异步队列实时进度 | 权限管控 | 问题 | 不同角色需 +不同密级访问 | 应对 | RBAC权限体系 +知识库分区 + 审计日志 | AI 合规智能中枢 | 面向车企与工厂 | 2026.04 | 9 / 13 +--- Slide 10 ---: 分阶段实施路线 | 从核心知识库到全链路合规智能,稳步落地 | 第一阶段 | 0 - 3 个月 | 知识库 + 基础问答 | 1 | 部署合规知识库平台,接入内部文档 | 2 | 接入 GB 标准 · AQ 系列 · IATF 16949 | 3 | 上线中英双语混合检索问答界面 | 4 | 完成权限分级与数据安全配置 | 第二阶段 | 3 - 6 个月 | 文档审查 + API 集成 | 1 | 构建文档合规审查引擎(PDF/Word) | 2 | 完成合规 API 封装,对接PLM/ERP/OA | 3 | 上线法规变更监控与推送服务 | 4 | 接入 Teams / 飞书 Bot 推送渠道 | 第三阶段 | 6 - 12 个月 | EHS隐患识别 + 个性化 | 1 | 构建 EHS 隐患识别与体系审计模块 | 2 | 引入知识图谱,支持多跳推理 | 3 | 上线个性化推荐引擎(角色画像) | 4 | 全链路合规智能体系正式上线 | ▶ | ▶ | AI 合规智能中枢 | 面向车企与工厂 | 2026.04 | 10 / 13 +--- Slide 11 ---: 三类合规闭环场景 | 从「发现问题」到「关闭归档」的完整业务闭环 | 📡 法规变更合规闭环 | 1 | ① | 法规监控 | ( | 扩展功能 | ) | 国内外法规数据库实时监控 +自动检测条款变更与新法发布 | ↓ | 2 | ② 知识库更新 | 变更内容自动解析入库 +版本管理 + 影响范围标注 | ↓ | 3 | ③ 精准推送 | 按角色/业务域推送变更摘要 +Email · Teams · 飞书多渠道 | ↓ | 4 | ④ 差距分析 | AI对比新旧法规差异 +识别企业现行制度缺口 | ↓ | 5 | ⑤ | 整改执行 | ( | 扩展功能 | ) | 生成整改任务清单 +关联责任人与完成时限 | ↓ | 6 | ⑥ | 闭环归档 | ( | 扩展功能 | ) | 整改完成后验收确认 | + | 合规证据归档留存 | 归档的文档放在 | share point | 同步更新知识库 | ↺ 持续监控 → 知识库保鲜 → 合规常态化 | 📄 文档审查合规闭环 | 1 | ① 文件上传 | PDF · Word · Excel · 扫描件 +支持批量上传与拖拽 | ↓ | 2 | ② AI解析 | 版面感知OCR,条款级分块 +自动识别文档类型与法规域 | ↓ | 3 | ③ 合规比对 | 条款级语义对比法规库 +缺项检测 · 风险评分 | ↓ | 4 | ④ 风险标注 | 页码+段落精确定位 +红/橙/黄三级风险可视化 | ↓ | 5 | ⑤ 整改建议 | AI生成具体整改方案 +关联历史合规最佳实践 | ↓ | 6 | ⑥ 复审归档 | 整改后重新提交复核 | + | 通过后合规证明自动归档 | 归档的文档放在 | share point | 同步更新知识库 | ↺ 上传即审查 → 整改即跟踪 → 归档即留证 | 🦺 | EHS安全管理闭环 | ( | 扩展功能 | ) | 1 | ① 隐患发现 | NLP解析巡检/事故报告文本 +图像识别 · 传感器数据接入 | ↓ | 2 | ② 风险评级 | SIF潜力评分 + 四维根因分析 +高/中/低三级优先级排序 | ↓ | 3 | ③ 任务派发 | 自动生成整改工单 +关联责任人 · 截止时间 · 法规依据 | ↓ | 4 | ④ 过程跟踪 | 整改进度实时可视化 +超期自动升级提醒 | ↓ | 5 | ⑤ 验收关闭 | 整改完成后现场复查 +AI辅助验收确认 | ↓ | 6 | ⑥ 体系优化 | 根因数据回流知识库 +优化隐患模型与预防策略 | ↺ 发现即评级 → 整改即跟踪 → 关闭即优化 | AI 合规智能中枢 | 面向车企与工厂 | 2026.04 | 11 / 13 +--- Slide 12 ---: 组织架构与 | RBAC | 权限体系 | 按角色分级授权,确保数据安全与合规责任落实到人 | 🏢 组织架构层级 | 集团 / 总部 | 合规委员会 · 法务部 · EHS总监 | ▼ | 事业部 / 工厂 | EHS部门 · 质量部 · 采购部 · 研发部 | ▼ | 业务线 / 车间 | 安全员 · 质检员 · 工艺工程师 | ▼ | 外部协作方 | 供应商 · 第三方审计 · 监管机构 | 🔐 角色权限矩阵(RBAC) | 知识库 +查询 | 文档 +审查 | EHS +审计 | 法规 +推送 | 系统 +管理 | 合规管理员 | ● | ● | ● | ★ | ★ | 法务专员 | ● | ● | ◑ | ◑ | ○ | EHS工程师 | ● | ◑ | ● | ◑ | ○ | 采购专员 | ◑ | ● | ○ | ◑ | ○ | 研发工程师 | ◑ | ◑ | ○ | ◑ | ○ | 工厂安全员 | ◑ | ○ | ● | ◑ | ○ | 供应商(外部) | ○ | ◑ | ○ | ○ | ○ | ● 完全权限 | ◑ 只读/有限 | ○ 无权限 | ★ 管理权限 | AI 合规智能中枢 | 面向车企与工厂 | 2026.04 | 12 / 13 +--- Slide 13 ---: 总结与下一步行动 | 构建面向车企与工厂的 AI 驱动全链路合规智能体系 | 📚 | 知识统一 | 内外部法规 + 历史案例 +一库统管,自动更新 | 💬 | 智能问答 | 混合检索 + 知识图谱 +可溯源的决策建议 | 📄 | 合规审查 | AI自动比对标注 +风险等级 + 整改建议 | 🦺 | EHS防控 | SIF预测 + 体系审计 +被动响应到主动预防 | 🔌 | 无缝集成 | API化能力嵌入 +PLM · ERP · OA · MES | 建议下一步行动 | 01 | 需求确认 | 与EHS · 法务 · 采购 +核心用户开展访谈 | 02 | POC验证 | 选取2-3个场景快速 +搭建原型验证可行性 | 03 | 数据准备 | 梳理内部文档,确认 +数据分级与权限策略 | 04 | 架构评审 | 与IT安全团队确认 +私有化部署与集成规范 | AI 合规智能中枢 | 面向车企与工厂 | 2026.04 | 13 / 13 + +===== DOCX ===== diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index c261608..d6f8d06 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,11 +1,13 @@ import './styles/globals.css'; -import { ThemeProvider } from './contexts'; +import { ThemeProvider, AuthProvider } from './contexts'; import { AppRouter } from './router/AppRouter'; function App() { return ( - + + + ); } diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts new file mode 100644 index 0000000..3391d09 --- /dev/null +++ b/frontend/src/api/auth.ts @@ -0,0 +1,43 @@ +const AUTH_API_BASE = '/api/v1'; + +export interface TokenResponse { + access_token: string; + token_type: string; + expires_in: number; +} + +export interface MeResponse { + user_id: string; + username: string; + role: string; +} + +export async function loginRequest( + username: string, + password: string, +): Promise { + const body = new URLSearchParams(); + body.set('username', username); + body.set('password', password); + + const res = await fetch(`${AUTH_API_BASE}/auth/token`, { + method: 'POST', + headers: { 'Content-Type': 'application/x-www-form-urlencoded' }, + body: body.toString(), + }); + + if (!res.ok) { + const payload = await res.json().catch(() => ({})) as { detail?: string }; + throw new Error(payload.detail ?? `Login failed (${res.status})`); + } + + return res.json() as Promise; +} + +export async function getMeRequest(token: string): Promise { + const res = await fetch(`${AUTH_API_BASE}/auth/me`, { + headers: { Authorization: `Bearer ${token}` }, + }); + if (!res.ok) throw new Error(`Unauthorised (${res.status})`); + return res.json() as Promise; +} diff --git a/frontend/src/api/docs.ts b/frontend/src/api/docs.ts index 85a3219..296328d 100644 --- a/frontend/src/api/docs.ts +++ b/frontend/src/api/docs.ts @@ -1,6 +1,12 @@ import type { DocInfo, DocListResponse, DocUploadResponse } from './index'; import { API_BASE_URL } from './index'; +const TOKEN_KEY = 'auth_token'; +function authHeaders(extra?: Record): Record { + const token = localStorage.getItem(TOKEN_KEY); + return token ? { Authorization: `Bearer ${token}`, ...extra } : { ...extra }; +} + interface BackendDocumentItem { doc_id: string; doc_name: string; @@ -76,6 +82,7 @@ export async function uploadDocument( const response = await fetch(`${API_BASE_URL}/documents/upload`, { method: 'POST', + headers: authHeaders(), body: formData, }); @@ -95,7 +102,9 @@ export async function uploadDocument( } export async function getDocumentList(): Promise { - const response = await fetch(`${API_BASE_URL}/documents/management-list`); + const response = await fetch(`${API_BASE_URL}/documents/management-list`, { + headers: authHeaders(), + }); if (!response.ok) { throw new Error(`List failed: ${response.status}`); } @@ -107,7 +116,9 @@ export async function getDocumentList(): Promise { } export async function getDocumentStatus(docId: string): Promise { - const response = await fetch(`${API_BASE_URL}/documents/status/${docId}`); + const response = await fetch(`${API_BASE_URL}/documents/status/${docId}`, { + headers: authHeaders(), + }); if (!response.ok) { throw new Error(`Status check failed: ${response.status}`); } @@ -115,14 +126,20 @@ export async function getDocumentStatus(docId: string): Promise { - const response = await fetch(`${API_BASE_URL}/documents/${docId}`, { method: 'DELETE' }); + const response = await fetch(`${API_BASE_URL}/documents/${docId}`, { + method: 'DELETE', + headers: authHeaders(), + }); if (!response.ok) { throw new Error(`Delete failed: ${response.status}`); } } export async function retryDocument(docId: string): Promise { - const response = await fetch(`${API_BASE_URL}/documents/${docId}/retry`, { method: 'POST' }); + const response = await fetch(`${API_BASE_URL}/documents/${docId}/retry`, { + method: 'POST', + headers: authHeaders(), + }); if (!response.ok) { throw new Error(`Retry failed: ${response.status}`); } @@ -132,10 +149,10 @@ export async function retryDocument(docId: string): Promise { export async function searchRegulations(query: string, topK: number = 8): Promise { const response = await fetch(`${API_BASE_URL}/knowledge/retrieval`, { method: 'POST', - headers: { + headers: authHeaders({ Accept: 'application/json', 'Content-Type': 'application/json', - }, + }), body: JSON.stringify({ query, top_k: topK }), }); diff --git a/frontend/src/api/index.ts b/frontend/src/api/index.ts index e07c536..d7bdfd3 100644 --- a/frontend/src/api/index.ts +++ b/frontend/src/api/index.ts @@ -1,5 +1,12 @@ const API_BASE_URL = '/api/v1'; +const TOKEN_KEY = 'auth_token'; + +/** Read the stored JWT without importing AuthContext (avoids circular deps). */ +function getStoredToken(): string | null { + return localStorage.getItem(TOKEN_KEY); +} + interface ApiErrorPayload { detail?: string; message?: string; @@ -19,8 +26,24 @@ async function readErrorMessage(response: Response): Promise { } } +/** Inject Authorization header when a token is available. */ +function withAuth(headers: Headers): Headers { + const token = getStoredToken(); + if (token && !headers.has('Authorization')) { + headers.set('Authorization', `Bearer ${token}`); + } + return headers; +} + +/** Handle 401 by clearing the stored token so the app redirects to login. */ +function handle401() { + localStorage.removeItem(TOKEN_KEY); + // Emit a custom event so AuthContext / router can react without a direct import. + window.dispatchEvent(new CustomEvent('auth:unauthorized')); +} + export async function fetchAPI(endpoint: string, options?: RequestInit): Promise { - const headers = new Headers(options?.headers); + const headers = withAuth(new Headers(options?.headers)); if (!headers.has('Accept')) { headers.set('Accept', 'application/json'); } @@ -33,6 +56,11 @@ export async function fetchAPI(endpoint: string, options?: RequestInit): Prom headers, }); + if (response.status === 401) { + handle401(); + throw new Error('Session expired, please log in again.'); + } + if (!response.ok) { throw new Error(`API Error: ${await readErrorMessage(response)}`); } @@ -54,15 +82,25 @@ export async function streamSSE( onError?: (error: Error) => void, onComplete?: () => void ): Promise { + const headers: Record = { + Accept: 'text/event-stream', + 'Content-Type': 'application/json', + }; + const token = getStoredToken(); + if (token) headers['Authorization'] = `Bearer ${token}`; + const response = await fetch(buildUrl(endpoint), { method: 'POST', - headers: { - Accept: 'text/event-stream', - 'Content-Type': 'application/json', - }, + headers, body: JSON.stringify(body), }); + if (response.status === 401) { + handle401(); + onError?.(new Error('Session expired, please log in again.')); + return; + } + if (!response.ok) { onError?.(new Error(`HTTP error! status: ${await readErrorMessage(response)}`)); return; diff --git a/frontend/src/api/perception.ts b/frontend/src/api/perception.ts index 800f43f..ac550bd 100644 --- a/frontend/src/api/perception.ts +++ b/frontend/src/api/perception.ts @@ -1,4 +1,9 @@ const PERCEPTION_API_BASE = '/api/v1'; +const TOKEN_KEY = 'auth_token'; +function authHeader(): Record { + const t = localStorage.getItem(TOKEN_KEY); + return t ? { Authorization: `Bearer ${t}` } : {}; +} export type ImpactLevel = 'high' | 'medium' | 'low'; export type EventStatus = 'enacted' | 'draft' | 'consultation'; @@ -48,7 +53,7 @@ export interface AnalysisSSEMessage { } export async function getPerceptionStats(): Promise { - const res = await fetch(`${PERCEPTION_API_BASE}/perception/stats`); + const res = await fetch(`${PERCEPTION_API_BASE}/perception/stats`, { headers: authHeader() }); if (!res.ok) throw new Error(`stats failed: ${res.status}`); return res.json() as Promise; } @@ -62,7 +67,7 @@ export async function listEvents(params?: { if (params?.source) query.set('source', params.source); if (params?.impact_level) query.set('impact_level', params.impact_level); if (params?.limit) query.set('limit', String(params.limit)); - const res = await fetch(`${PERCEPTION_API_BASE}/perception/events?${query.toString()}`); + const res = await fetch(`${PERCEPTION_API_BASE}/perception/events?${query.toString()}`, { headers: authHeader() }); if (!res.ok) throw new Error(`list events failed: ${res.status}`); return res.json() as Promise; } @@ -76,7 +81,7 @@ export async function analyzeEvent( try { const res = await fetch(`${PERCEPTION_API_BASE}/perception/events/${eventId}/analyze`, { method: 'POST', - headers: { Accept: 'text/event-stream' }, + headers: { Accept: 'text/event-stream', ...authHeader() }, signal, }); if (!res.ok || !res.body) throw new Error(`analyze failed: ${res.status}`); diff --git a/frontend/src/api/rag.ts b/frontend/src/api/rag.ts index 3413d4e..c182268 100644 --- a/frontend/src/api/rag.ts +++ b/frontend/src/api/rag.ts @@ -1,6 +1,8 @@ import type { QuickQuestionsResponse, SSEMessage } from './index'; const AGENT_API_BASE = '/api/v1'; +const TOKEN_KEY = 'auth_token'; +function getToken(): string | null { return localStorage.getItem(TOKEN_KEY); } const _FALLBACK_QUESTIONS = [ { id: '1', question: '请总结最新入库法规对电池安全的核心要求', category: '法规解读' }, @@ -100,6 +102,7 @@ export async function ragChat( headers: { 'Content-Type': 'application/json', Accept: 'text/event-stream', + ...(getToken() ? { Authorization: `Bearer ${getToken()}` } : {}), }, body: JSON.stringify({ query, diff --git a/frontend/src/components/layout/Sidebar.tsx b/frontend/src/components/layout/Sidebar.tsx index c2389ad..e0e5c7f 100644 --- a/frontend/src/components/layout/Sidebar.tsx +++ b/frontend/src/components/layout/Sidebar.tsx @@ -1,9 +1,10 @@ import { NavLink } from 'react-router-dom'; import { LayoutDashboard, Radio, Monitor, FileText, - Shield, MessageSquare, Sun, Moon + Shield, MessageSquare, Sun, Moon, LogOut } from 'lucide-react'; import { useTheme } from '../../contexts/ThemeContext'; +import { useAuth } from '../../contexts/AuthContext'; interface NavItem { to: string; @@ -49,8 +50,17 @@ function NavGroup({ title, items }: { title: string; items: NavItem[] }) { ); } +/** Avatar initials from username (up to 2 chars). */ +function initials(name: string): string { + const parts = name.trim().split(/[\s_-]+/); + if (parts.length >= 2) return (parts[0][0] + parts[1][0]).toUpperCase(); + return name.slice(0, 2).toUpperCase(); +} + export function Sidebar() { const { theme, toggleTheme } = useTheme(); + const { user, logout } = useAuth(); + return ( ); diff --git a/frontend/src/contexts/AuthContext.tsx b/frontend/src/contexts/AuthContext.tsx new file mode 100644 index 0000000..45818d9 --- /dev/null +++ b/frontend/src/contexts/AuthContext.tsx @@ -0,0 +1,72 @@ +import React, { createContext, useCallback, useContext, useEffect, useState } from 'react'; +import { loginRequest, getMeRequest } from '../api/auth'; + +const TOKEN_KEY = 'auth_token'; + +export interface AuthUser { + user_id: string; + username: string; + role: string; +} + +interface AuthContextValue { + token: string | null; + user: AuthUser | null; + loading: boolean; + login: (username: string, password: string) => Promise; + logout: () => void; +} + +const AuthContext = createContext({ + token: null, + user: null, + loading: true, + login: async () => {}, + logout: () => {}, +}); + +export function AuthProvider({ children }: { children: React.ReactNode }) { + const [token, setToken] = useState(() => localStorage.getItem(TOKEN_KEY)); + const [user, setUser] = useState(null); + const [loading, setLoading] = useState(true); + + // Validate the stored token on mount by calling /auth/me. + useEffect(() => { + if (!token) { + setLoading(false); + return; + } + getMeRequest(token) + .then(setUser) + .catch(() => { + // Token is expired or invalid — force re-login. + localStorage.removeItem(TOKEN_KEY); + setToken(null); + }) + .finally(() => setLoading(false)); + }, []); // eslint-disable-line react-hooks/exhaustive-deps + + const login = useCallback(async (username: string, password: string) => { + const resp = await loginRequest(username, password); + const me = await getMeRequest(resp.access_token); + localStorage.setItem(TOKEN_KEY, resp.access_token); + setToken(resp.access_token); + setUser(me); + }, []); + + const logout = useCallback(() => { + localStorage.removeItem(TOKEN_KEY); + setToken(null); + setUser(null); + }, []); + + return ( + + {children} + + ); +} + +export function useAuth() { + return useContext(AuthContext); +} diff --git a/frontend/src/contexts/index.ts b/frontend/src/contexts/index.ts index d18ca7e..5267afb 100644 --- a/frontend/src/contexts/index.ts +++ b/frontend/src/contexts/index.ts @@ -1 +1,3 @@ export { ThemeProvider, useTheme } from './ThemeContext'; +export { AuthProvider, useAuth } from './AuthContext'; +export type { AuthUser } from './AuthContext'; diff --git a/frontend/src/pages/Compliance/CompliancePage.tsx b/frontend/src/pages/Compliance/CompliancePage.tsx index 976f2e1..6042790 100644 --- a/frontend/src/pages/Compliance/CompliancePage.tsx +++ b/frontend/src/pages/Compliance/CompliancePage.tsx @@ -5,6 +5,12 @@ import { NewAnalysisModal } from './NewAnalysisModal'; import { useComplianceAnalysis } from './useComplianceAnalysis'; import type { FindingEvent, SourceEvent, AnalysisMeta } from './useComplianceAnalysis'; +const TOKEN_KEY = 'auth_token'; +function authHeader(): Record { + const t = localStorage.getItem(TOKEN_KEY); + return t ? { Authorization: `Bearer ${t}` } : {}; +} + const STATUS_LABEL: Record = { ok: 'Covered', warn: 'Gap', risk: 'Critical', info: 'Info' }; const SOURCE_TYPE_LABEL: Record = { text: 'Pasted Text', doc: 'Indexed Document', upload: 'Uploaded File' }; @@ -71,7 +77,7 @@ function useFindingChat() { try { const res = await fetch(`/api/v1/compliance/chat/${findingIdx ?? 0}`, { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: { 'Content-Type': 'application/json', ...authHeader() }, body: JSON.stringify({ query: q, segment_context: segmentContext }), signal: ctrl.signal, }); diff --git a/frontend/src/pages/Compliance/NewAnalysisModal.tsx b/frontend/src/pages/Compliance/NewAnalysisModal.tsx index ad97e7b..b0e7bd1 100644 --- a/frontend/src/pages/Compliance/NewAnalysisModal.tsx +++ b/frontend/src/pages/Compliance/NewAnalysisModal.tsx @@ -1,6 +1,12 @@ import { useState, useRef, useEffect } from 'react'; import { X, Upload, FileText, Database } from 'lucide-react'; +const TOKEN_KEY = 'auth_token'; +function authHeader(): Record { + const t = localStorage.getItem(TOKEN_KEY); + return t ? { Authorization: `Bearer ${t}` } : {}; +} + interface DocOption { id: string; name: string; @@ -30,7 +36,7 @@ export function NewAnalysisModal({ onClose, onSubmit }: Props) { // Fetch indexed docs for "From Document" tab useEffect(() => { - fetch('/api/v1/documents/management-list') + fetch('/api/v1/documents/management-list', { headers: authHeader() }) .then(r => r.json()) .then(d => { const list: DocOption[] = (d?.documents ?? d ?? []).map((item: Record) => ({ diff --git a/frontend/src/pages/Compliance/useComplianceAnalysis.ts b/frontend/src/pages/Compliance/useComplianceAnalysis.ts index 1c48caa..312c43e 100644 --- a/frontend/src/pages/Compliance/useComplianceAnalysis.ts +++ b/frontend/src/pages/Compliance/useComplianceAnalysis.ts @@ -1,5 +1,11 @@ import { useState, useCallback, useRef } from 'react'; +const TOKEN_KEY = 'auth_token'; +function authHeader(): Record { + const t = localStorage.getItem(TOKEN_KEY); + return t ? { Authorization: `Bearer ${t}` } : {}; +} + export type AnalysisStatus = 'idle' | 'streaming' | 'done' | 'error'; export interface SourceEvent { @@ -78,6 +84,7 @@ export function useComplianceAnalysis() { try { const res = await fetch('/api/v1/compliance/analyze-stream', { method: 'POST', + headers: authHeader(), body: formData, signal: ctrl.signal, }); diff --git a/frontend/src/pages/Docs/DocsPage.tsx b/frontend/src/pages/Docs/DocsPage.tsx index 66fe374..36d4a1e 100644 --- a/frontend/src/pages/Docs/DocsPage.tsx +++ b/frontend/src/pages/Docs/DocsPage.tsx @@ -3,6 +3,12 @@ import { Topbar } from '../../components/layout/Topbar'; import { Upload, Search, Download, Trash2, RefreshCw, AlertTriangle } from 'lucide-react'; import { UploadModal } from './UploadModal'; +const TOKEN_KEY = 'auth_token'; +function authHeader(): Record { + const t = localStorage.getItem(TOKEN_KEY); + return t ? { Authorization: `Bearer ${t}` } : {}; +} + interface Doc { id: string; name: string; @@ -79,7 +85,7 @@ export function DocsPage() { const fetchDocs = useCallback(() => { setLoading(true); - fetch('/api/v1/documents/management-list') + fetch('/api/v1/documents/management-list', { headers: authHeader() }) .then(r => r.json()) .then(d => { if (!Array.isArray(d?.documents)) { setLoading(false); return; } @@ -132,7 +138,7 @@ export function DocsPage() { async function retryDoc(id: string) { setRetrying(r => new Set([...r, id])); try { - await fetch(`/api/v1/documents/${id}/retry`, { method: 'POST' }); + await fetch(`/api/v1/documents/${id}/retry`, { method: 'POST', headers: authHeader() }); setTimeout(() => { setRetrying(r => { const s = new Set(r); s.delete(id); return s; }); setRefreshKey(k => k + 1); @@ -155,7 +161,7 @@ export function DocsPage() { setDeleting(new Set(ids)); await Promise.allSettled( - ids.map(id => fetch(`/api/v1/documents/${id}`, { method: 'DELETE' })) + ids.map(id => fetch(`/api/v1/documents/${id}`, { method: 'DELETE', headers: authHeader() })) ); setDeleting(new Set()); diff --git a/frontend/src/pages/Docs/UploadModal.tsx b/frontend/src/pages/Docs/UploadModal.tsx index 18c5555..7d9159f 100644 --- a/frontend/src/pages/Docs/UploadModal.tsx +++ b/frontend/src/pages/Docs/UploadModal.tsx @@ -1,6 +1,12 @@ import { useState, useRef, useCallback } from 'react'; import { X, Upload } from 'lucide-react'; +const TOKEN_KEY = 'auth_token'; +function authHeader(): Record { + const t = localStorage.getItem(TOKEN_KEY); + return t ? { Authorization: `Bearer ${t}` } : {}; +} + interface Props { onClose: () => void; onComplete?: () => void; // called when all uploads finish (indexed) @@ -44,11 +50,6 @@ function docStatusToStages(status: DocStatus): StageState[] { } } -// Generate a short unique ID client-side (matches backend's 8-char uuid prefix pattern) -function genDocId(): string { - return Math.random().toString(36).slice(2, 10); -} - export function UploadModal({ onClose, onComplete }: Props) { const [files, setFiles] = useState([]); const [regType, setRegType] = useState(REG_TYPES[0]); @@ -93,13 +94,14 @@ export function UploadModal({ onClose, onComplete }: Props) { const pollStatus = useCallback((docId: string, resolve: () => void, reject: (msg: string) => void) => { let attempts = 0; - const MAX_ATTEMPTS = 120; // 4 minutes at 2s interval + const MAX_ATTEMPTS = 450; // 15 minutes at 2s interval — Aliyun DocMind can take several minutes stopPolling(); pollTimer.current = setInterval(async () => { attempts++; try { - const res = await fetch(`/api/v1/documents/status/${docId}`); + const res = await fetch(`/api/v1/documents/status/${docId}`, { headers: authHeader() }); if (!res.ok) { + // Transient HTTP error (e.g. 502 during restart) — keep polling until timeout. if (attempts > MAX_ATTEMPTS) { stopPolling(); reject('Polling timeout'); } return; } @@ -114,7 +116,7 @@ export function UploadModal({ onClose, onComplete }: Props) { reject(data.message ?? 'Processing failed'); } else if (attempts > MAX_ATTEMPTS) { stopPolling(); - reject('Processing timeout — check Document Management for status'); + reject('Processing timeout (15 min) — check Document Management for status'); } } catch { // network hiccup — keep polling @@ -127,37 +129,42 @@ export function UploadModal({ onClose, onComplete }: Props) { setCurrentFileIdx(idx); setDocStatus('idle'); - const docId = genDocId(); const form = new FormData(); form.append('file', file); - form.append('doc_id', docId); form.append('doc_name', file.name); form.append('regulation_type', regType); if (version) form.append('version', version); form.append('generate_summary', 'false'); - // Fire upload — this is a long-running synchronous call on the backend. - // We start polling immediately so the UI updates as the backend writes status transitions. - const uploadPromise = fetch('/api/v1/documents/upload', { method: 'POST', body: form }); + // Upload the file first — response contains the authoritative doc_id. + // Without waiting here we risk polling an ID the server has not yet created. + let docId: string; + const uploadRes = await fetch('/api/v1/documents/upload', { + method: 'POST', + headers: authHeader(), + body: form, + }); + if (!uploadRes.ok) { + const detail = await uploadRes.text().catch(() => uploadRes.statusText); + throw new Error(`${file.name}: ${uploadRes.status} ${detail}`); + } + const uploadData = await uploadRes.json() as { doc_id: string; status: string }; + docId = uploadData.doc_id; - // Start polling after a short delay so the backend has time to create the document record + // If backend processed synchronously (sync=true or status already 'indexed'), resolve immediately. + if (uploadData.status === 'indexed') { + setDocStatus('indexed'); + return; + } + if (uploadData.status === 'failed') { + setDocStatus('failed'); + throw new Error(`${file.name}: Processing failed on server`); + } + + // Otherwise start polling the authoritative doc_id returned by the server. + setDocStatus(uploadData.status as DocStatus); await new Promise((res, rej) => { - const reject = (msg: string) => rej(new Error(msg)); - // Begin polling immediately — backend creates the record synchronously before processing - setTimeout(() => pollStatus(docId, res, reject), 800); - - // Also handle the upload response (in case processing finishes before poll catches it) - uploadPromise.then(async httpRes => { - if (!httpRes.ok) { - const detail = await httpRes.text().catch(() => httpRes.statusText); - stopPolling(); - reject(`${file.name}: ${httpRes.status} ${detail}`); - } - // Upload succeeded — polling will catch the final status - }).catch(err => { - stopPolling(); - reject(err instanceof Error ? err.message : 'Upload error'); - }); + pollStatus(docId, res, (msg: string) => rej(new Error(msg))); }); } diff --git a/frontend/src/pages/Login/LoginPage.tsx b/frontend/src/pages/Login/LoginPage.tsx new file mode 100644 index 0000000..96287c8 --- /dev/null +++ b/frontend/src/pages/Login/LoginPage.tsx @@ -0,0 +1,80 @@ +import React, { FormEvent, useState } from 'react'; +import { useAuth } from '../../contexts'; + +export function LoginPage() { + const { login } = useAuth(); + const [username, setUsername] = useState(''); + const [password, setPassword] = useState(''); + const [error, setError] = useState(''); + const [loading, setLoading] = useState(false); + + async function handleSubmit(e: FormEvent) { + e.preventDefault(); + if (!username.trim() || !password.trim()) return; + setError(''); + setLoading(true); + try { + await login(username.trim(), password); + } catch (err) { + setError(err instanceof Error ? err.message : 'Login failed'); + } finally { + setLoading(false); + } + } + + return ( +
+
+
+ T-Systems +
+
T-Systems
+
AI Regulation Hub
+
+
+ +

Sign in

+ +
+
+ + setUsername(e.target.value)} + autoFocus + autoComplete="username" + disabled={loading} + placeholder="e.g. admin" + /> +
+ +
+ + setPassword(e.target.value)} + autoComplete="current-password" + disabled={loading} + /> +
+ + {error &&

{error}

} + + +
+ +

+ Demo accounts: admin / legal / ehs / readonly +

+
+
+ ); +} diff --git a/frontend/src/pages/Overview/OverviewPage.tsx b/frontend/src/pages/Overview/OverviewPage.tsx index 9dd9be9..d3e07f3 100644 --- a/frontend/src/pages/Overview/OverviewPage.tsx +++ b/frontend/src/pages/Overview/OverviewPage.tsx @@ -22,7 +22,8 @@ const STEPS = [ export function OverviewPage() { const navigate = useNavigate(); return ( -
+
+

T-Systems · AI Regulation Hub

AI Compliance,
Automated end-to-end

@@ -83,5 +84,6 @@ export function OverviewPage() {
+
); } diff --git a/frontend/src/pages/Perception/PerceptionPage.tsx b/frontend/src/pages/Perception/PerceptionPage.tsx index 7c3ab37..7a342b3 100644 --- a/frontend/src/pages/Perception/PerceptionPage.tsx +++ b/frontend/src/pages/Perception/PerceptionPage.tsx @@ -2,6 +2,12 @@ import { useState, useEffect, useRef } from 'react'; import { Topbar } from '../../components/layout/Topbar'; import { RefreshCw, Play, Square, ExternalLink } from 'lucide-react'; +const TOKEN_KEY = 'auth_token'; +function authHeader(): Record { + const t = localStorage.getItem(TOKEN_KEY); + return t ? { Authorization: `Bearer ${t}` } : {}; +} + interface Signal { id: string; source: string; @@ -101,14 +107,14 @@ export function PerceptionPage() { const abortRef = useRef(null); useEffect(() => { - fetch('/api/v1/perception/stats') + fetch('/api/v1/perception/stats', { headers: authHeader() }) .then(r => r.json()) .then(setStats) .catch(() => setStats({ total: 47, high_impact: 7, medium_impact: 18, last_90_days: 14 })); }, []); useEffect(() => { - fetch('/api/v1/perception/events?limit=100') + fetch('/api/v1/perception/events?limit=100', { headers: authHeader() }) .then(r => r.json()) .then(d => { if (Array.isArray(d?.events) && d.events.length > 0) { @@ -135,7 +141,7 @@ export function PerceptionPage() { const ctrl = new AbortController(); abortRef.current = ctrl; // Backend: POST /api/v1/perception/events/{id}/analyze → SSE stream - fetch(`/api/v1/perception/events/${selected.id}/analyze`, { method: 'POST', signal: ctrl.signal }) + fetch(`/api/v1/perception/events/${selected.id}/analyze`, { method: 'POST', headers: authHeader(), signal: ctrl.signal }) .then(async res => { if (!res.body) { setAiOutput('No stream available.'); setStreaming(false); return; } const reader = res.body.getReader(); diff --git a/frontend/src/pages/RagChat/RagChatPage.tsx b/frontend/src/pages/RagChat/RagChatPage.tsx index 0130f76..ad508d4 100644 --- a/frontend/src/pages/RagChat/RagChatPage.tsx +++ b/frontend/src/pages/RagChat/RagChatPage.tsx @@ -2,6 +2,12 @@ import { useState, useRef, useEffect, useCallback } from 'react'; import { Topbar } from '../../components/layout/Topbar'; import { Send, Download } from 'lucide-react'; +const TOKEN_KEY = 'auth_token'; +function authHeader(): Record { + const t = localStorage.getItem(TOKEN_KEY); + return t ? { Authorization: `Bearer ${t}` } : {}; +} + interface Message { id: string; role: 'user' | 'assistant'; @@ -87,7 +93,7 @@ export function RagChatPage() { // Fetch quick questions from backend on mount useEffect(() => { - fetch('/api/v1/rag/quick-questions') + fetch('/api/v1/rag/quick-questions', { headers: authHeader() }) .then(r => r.json()) .then(d => { if (Array.isArray(d?.questions) && d.questions.length > 0) { @@ -136,7 +142,7 @@ export function RagChatPage() { const res = await fetch('/api/v1/rag/chat', { method: 'POST', - headers: { 'Content-Type': 'application/json' }, + headers: { 'Content-Type': 'application/json', ...authHeader() }, body: JSON.stringify(body), signal: ctrl.signal, }); diff --git a/frontend/src/pages/Status/StatusPage.tsx b/frontend/src/pages/Status/StatusPage.tsx index 933b145..8aa1a5e 100644 --- a/frontend/src/pages/Status/StatusPage.tsx +++ b/frontend/src/pages/Status/StatusPage.tsx @@ -3,6 +3,12 @@ import { Topbar } from '../../components/layout/Topbar'; import { Search, Upload, Download, RefreshCw, CheckCircle, XCircle, AlertTriangle, Info } from 'lucide-react'; import { UploadModal } from '../Docs/UploadModal'; +const TOKEN_KEY = 'auth_token'; +function authHeader(): Record { + const t = localStorage.getItem(TOKEN_KEY); + return t ? { Authorization: `Bearer ${t}` } : {}; +} + // ── API types ────────────────────────────────────────────────────────────── interface Stats { documents_total: number; @@ -83,9 +89,9 @@ export function StatusPage() { // Fetch all three endpoints in parallel Promise.allSettled([ - fetch('/api/v1/status/stats').then(r => r.json()), - fetch('/api/v1/status/health').then(r => r.json()), - fetch('/api/v1/status/config').then(r => r.json()), + fetch('/api/v1/status/stats', { headers: authHeader() }).then(r => r.json()), + fetch('/api/v1/status/health', { headers: authHeader() }).then(r => r.json()), + fetch('/api/v1/status/config', { headers: authHeader() }).then(r => r.json()), ]).then(([statsRes, healthRes, configRes]) => { if (statsRes.status === 'fulfilled') setStats(statsRes.value); else setStats({ documents_total: 0, documents_indexed: 0, documents_failed: 0, chunks_total: 0 }); diff --git a/frontend/src/router/AppRouter.tsx b/frontend/src/router/AppRouter.tsx index c4dae25..b399da0 100644 --- a/frontend/src/router/AppRouter.tsx +++ b/frontend/src/router/AppRouter.tsx @@ -1,5 +1,8 @@ -import { BrowserRouter, Routes, Route } from 'react-router-dom'; +import { useEffect } from 'react'; +import { BrowserRouter, Navigate, Routes, Route } from 'react-router-dom'; +import { useAuth } from '../contexts'; import { AppShell } from '../components/layout/AppShell'; +import { LoginPage } from '../pages/Login/LoginPage'; import { OverviewPage } from '../pages/Overview/OverviewPage'; import { StatusPage } from '../pages/Status/StatusPage'; import { PerceptionPage } from '../pages/Perception/PerceptionPage'; @@ -7,11 +10,47 @@ import { DocsPage } from '../pages/Docs/DocsPage'; import { CompliancePage } from '../pages/Compliance/CompliancePage'; import { RagChatPage } from '../pages/RagChat/RagChatPage'; +/** Redirect to /login when not authenticated. */ +function RequireAuth({ children }: { children: React.ReactNode }) { + const { token, loading } = useAuth(); + if (loading) return null; // wait for localStorage token validation + if (!token) return ; + return <>{children}; +} + +/** Redirect to / when already authenticated. */ +function GuestOnly({ children }: { children: React.ReactNode }) { + const { token, loading } = useAuth(); + if (loading) return null; + if (token) return ; + return <>{children}; +} + export function AppRouter() { + const { logout } = useAuth(); + + // Listen for global 401 events emitted by the API layer. + useEffect(() => { + function onUnauthorized() { logout(); } + window.addEventListener('auth:unauthorized', onUnauthorized); + return () => window.removeEventListener('auth:unauthorized', onUnauthorized); + }, [logout]); + return ( - }> + {/* Public route */} + } /> + + {/* Protected routes */} + + + + } + > } /> } /> } /> @@ -19,6 +58,9 @@ export function AppRouter() { } /> } /> + + {/* Catch-all */} + } /> ); diff --git a/frontend/src/styles/globals.css b/frontend/src/styles/globals.css index 377a00f..0231aa5 100644 --- a/frontend/src/styles/globals.css +++ b/frontend/src/styles/globals.css @@ -150,10 +150,10 @@ body { display: inline-flex; align-items: center; gap: 6px; - height: 32px; - padding: 0 12px; + height: 34px; + padding: 0 14px; border-radius: var(--radius-sm); - font-size: 13px; + font-size: 14px; font-family: var(--font-body); font-weight: 500; cursor: pointer; @@ -170,7 +170,7 @@ body { border-color: var(--accent); } .btn.primary:hover { background: var(--accent-hover); border-color: var(--accent-hover); } -.btn.sm { height: 28px; font-size: 12px; padding: 0 10px; } +.btn.sm { height: 30px; font-size: 13px; padding: 0 11px; } .btn:disabled { opacity: 0.5; cursor: not-allowed; } /* ── App Shell ──────────────────────────────────── */ @@ -216,14 +216,14 @@ body { object-fit: contain; flex-shrink: 0; } -.brand-name { font-size: 13px; font-weight: 700; font-family: var(--font-display); color: var(--rail-fg); } -.brand-sub { font-size: 10px; color: var(--rail-muted); } +.brand-name { font-size: 15px; font-weight: 700; font-family: var(--font-display); color: var(--rail-fg); } +.brand-sub { font-size: 12px; color: var(--rail-muted); } .sidebar-nav { flex: 1; overflow-y: auto; padding: 12px 0; } .nav-group { margin-bottom: 4px; } .nav-group-label { - font-size: 10px; font-weight: 700; + font-size: 11px; font-weight: 700; text-transform: uppercase; letter-spacing: 0.06em; color: var(--rail-muted); padding: 8px 16px 4px; @@ -231,10 +231,10 @@ body { .nav-item { display: flex; align-items: center; gap: 9px; - height: 36px; padding: 0 14px 0 16px; + height: 38px; padding: 0 14px 0 16px; text-decoration: none; color: var(--rail-fg); - font-size: 13px; + font-size: 14px; border-left: 3px solid transparent; transition: background 0.12s, color 0.12s; position: relative; @@ -261,14 +261,14 @@ body { } .sidebar-user { display: flex; align-items: center; gap: 8px; flex: 1; min-width: 0; } .user-avatar { - width: 28px; height: 28px; + width: 30px; height: 30px; background: var(--accent-dim); color: var(--accent); border-radius: 50%; display: flex; align-items: center; justify-content: center; - font-size: 10px; font-weight: 700; flex-shrink: 0; + font-size: 11px; font-weight: 700; flex-shrink: 0; } -.user-name { font-size: 12px; font-weight: 600; color: var(--rail-fg); } -.user-role { font-size: 10px; color: var(--rail-muted); } +.user-name { font-size: 13px; font-weight: 600; color: var(--rail-fg); } +.user-role { font-size: 11px; color: var(--rail-muted); } .theme-btn { width: 28px; height: 28px; @@ -297,8 +297,8 @@ body { z-index: 10; } .topbar-left { display: flex; align-items: baseline; gap: 10px; min-width: 0; } -.topbar-title { font-size: 15px; font-weight: 700; font-family: var(--font-display); color: var(--fg); } -.topbar-sub { font-size: 11px; color: var(--muted); font-family: var(--font-mono); } +.topbar-title { font-size: 17px; font-weight: 700; font-family: var(--font-display); color: var(--fg); } +.topbar-sub { font-size: 12px; color: var(--muted); font-family: var(--font-mono); } .topbar-actions { display: flex; align-items: center; gap: 8px; } /* ── Page Content ───────────────────────────────── */ @@ -309,13 +309,13 @@ body { } /* ── Search Box ─────────────────────────────────── */ -.search-box { display: flex; align-items: center; gap: 6px; height: 32px; padding: 0 10px; border: 1px solid var(--border); border-radius: var(--radius-sm); background: var(--bg); font-size: 13px; color: var(--muted); } -.search-box input { border: none; background: transparent; outline: none; font-size: 13px; color: var(--fg); width: 160px; } +.search-box { display: flex; align-items: center; gap: 6px; height: 34px; padding: 0 10px; border: 1px solid var(--border); border-radius: var(--radius-sm); background: var(--bg); font-size: 14px; color: var(--muted); } +.search-box input { border: none; background: transparent; outline: none; font-size: 14px; color: var(--fg); width: 160px; } /* ── Filter Bar / Chips ─────────────────────────── */ .filter-bar { display: flex; align-items: center; gap: 12px; padding: 10px 22px; border-bottom: 1px solid var(--border); background: var(--surface); flex-shrink: 0; } .chip-group { display: flex; gap: 6px; flex-wrap: wrap; } -.chip { height: 26px; padding: 0 10px; border-radius: var(--radius-pill); border: 1px solid var(--border); background: transparent; font-size: 11px; font-weight: 500; cursor: pointer; color: var(--muted); transition: all 0.12s; } +.chip { height: 28px; padding: 0 11px; border-radius: var(--radius-pill); border: 1px solid var(--border); background: transparent; font-size: 12px; font-weight: 500; cursor: pointer; color: var(--muted); transition: all 0.12s; } .chip:hover { border-color: var(--accent); color: var(--accent); } .chip.active { background: var(--accent-dim); border-color: var(--accent); color: var(--accent); font-weight: 600; } .filter-sep { width: 1px; height: 20px; background: var(--border); flex-shrink: 0; } @@ -327,34 +327,47 @@ body { @keyframes spin { from { transform: rotate(0deg); } to { transform: rotate(360deg); } } /* ── Overview Page ──────────────────────────────── */ -.overview-page { padding: 32px; max-width: 900px; display: flex; flex-direction: column; gap: 32px; } +/* Outer container fills content-area and handles scrolling. */ +/* Inner .overview-page constrains width and holds the page content. */ +.overview-scroll-wrapper { + flex: 1; + overflow-y: auto; + height: 100%; +} +.overview-page { + padding: 32px; + max-width: 960px; + display: flex; + flex-direction: column; + gap: 32px; +} .overview-hero { display: flex; flex-direction: column; gap: 12px; } -.hero-eyebrow { font-size: 11px; font-weight: 700; text-transform: uppercase; letter-spacing: 0.08em; color: var(--accent); } -.hero-title { font-size: 32px; font-weight: 700; font-family: var(--font-display); line-height: 1.15; } -.hero-desc { font-size: 14px; color: var(--muted); max-width: 480px; line-height: 1.6; } +.hero-eyebrow { font-size: 12px; font-weight: 700; text-transform: uppercase; letter-spacing: 0.08em; color: var(--accent); } +.hero-title { font-size: 34px; font-weight: 700; font-family: var(--font-display); line-height: 1.15; } +.hero-desc { font-size: 15px; color: var(--muted); max-width: 520px; line-height: 1.6; } .hero-actions { display: flex; gap: 10px; padding-top: 4px; } .overview-summary { display: flex; align-items: center; gap: 0; } .summary-item { display: flex; flex-direction: column; align-items: center; gap: 2px; flex: 1; padding: 14px; } -.summary-num { font-size: 22px; font-weight: 700; font-family: var(--font-display); color: var(--accent); } -.summary-label { font-size: 11px; color: var(--muted); } +.summary-num { font-size: 24px; font-weight: 700; font-family: var(--font-display); color: var(--accent); } +.summary-label { font-size: 12px; color: var(--muted); } .summary-divider { width: 1px; height: 40px; background: var(--border); flex-shrink: 0; } -.section-title { font-size: 13px; font-weight: 700; font-family: var(--font-display); color: var(--fg); margin-bottom: 14px; } +.section-title { font-size: 14px; font-weight: 700; font-family: var(--font-display); color: var(--fg); margin-bottom: 14px; } .workflow-steps { display: grid; grid-template-columns: repeat(6, 1fr); gap: 12px; } .workflow-step { display: flex; flex-direction: column; gap: 4px; padding: 12px; background: var(--surface); border-radius: var(--radius-md); box-shadow: var(--shadow-card); } -.step-num { font-size: 10px; font-weight: 700; font-family: var(--font-mono); color: var(--accent); } -.step-label { font-size: 13px; font-weight: 700; font-family: var(--font-display); } -.step-desc { font-size: 11px; color: var(--muted); line-height: 1.4; } +.step-num { font-size: 11px; font-weight: 700; font-family: var(--font-mono); color: var(--accent); } +.step-label { font-size: 14px; font-weight: 700; font-family: var(--font-display); } +.step-desc { font-size: 12px; color: var(--muted); line-height: 1.4; } .screen-grid { display: grid; grid-template-columns: repeat(3, 1fr); gap: 14px; } .screen-card { text-align: left; cursor: pointer; border: none; transition: box-shadow 0.15s; display: flex; flex-direction: column; gap: 6px; } .screen-card:hover { box-shadow: 0 4px 12px rgba(0,0,0,.10), 0 0 0 2px var(--accent-dim); } .screen-icon { width: 36px; height: 36px; background: var(--accent-dim); color: var(--accent); border-radius: var(--radius-sm); display: flex; align-items: center; justify-content: center; margin-bottom: 4px; } -.screen-label { font-size: 13px; font-weight: 700; font-family: var(--font-display); } -.screen-desc { font-size: 12px; color: var(--muted); line-height: 1.4; } +.screen-label { font-size: 14px; font-weight: 700; font-family: var(--font-display); } +.screen-desc { font-size: 13px; color: var(--muted); line-height: 1.4; } /* ── Status Page ────────────────────────────────── */ .status-page { display: flex; flex-direction: column; height: 100%; } @@ -375,18 +388,18 @@ body { display: flex; flex-direction: column; gap: 4px; } .stat-cell:last-child { border-right: none; } -.stat-value { font-size: 26px; font-weight: 700; font-family: var(--font-display); } -.stat-label { font-size: 11px; color: var(--muted); } +.stat-value { font-size: 28px; font-weight: 700; font-family: var(--font-display); } +.stat-label { font-size: 12px; color: var(--muted); } .stat-cell.danger .stat-value { color: var(--danger); } .panel-grid { display: grid; grid-template-columns: 1.4fr 0.9fr; gap: 16px; } .panel-left, .panel-right { display: flex; flex-direction: column; gap: 16px; } -.card-header { font-size: 12px; font-weight: 700; font-family: var(--font-display); color: var(--muted); text-transform: uppercase; letter-spacing: 0.05em; margin-bottom: 12px; } +.card-header { font-size: 13px; font-weight: 700; font-family: var(--font-display); color: var(--muted); text-transform: uppercase; letter-spacing: 0.05em; margin-bottom: 12px; } .task-row { display: flex; align-items: center; gap: 12px; padding: 8px 0; border-top: 1px solid var(--border); } .task-info { flex: 1; min-width: 0; } -.task-name { font-size: 13px; font-weight: 500; margin-bottom: 4px; } +.task-name { font-size: 14px; font-weight: 500; margin-bottom: 4px; } .task-progress-bar { height: 3px; background: var(--border); border-radius: 2px; overflow: hidden; } .task-progress-fill { height: 100%; background: var(--accent); border-radius: 2px; } @@ -401,12 +414,12 @@ body { .kpi-value { font-size: 11px; font-weight: 700; font-family: var(--font-mono); color: var(--fg); width: 36px; text-align: right; flex-shrink: 0; } .service-row { display: flex; align-items: center; justify-content: space-between; padding: 8px 0; border-top: 1px solid var(--border); } -.service-name { font-size: 13px; } +.service-name { font-size: 14px; } .event-row { padding: 10px 0; border-top: 1px solid var(--border); display: flex; flex-direction: column; gap: 3px; } -.event-date { font-size: 10px; font-family: var(--font-mono); color: var(--muted); } -.event-title { font-size: 13px; font-weight: 600; } -.event-summary { font-size: 12px; color: var(--muted); display: -webkit-box; -webkit-line-clamp: 2; -webkit-box-orient: vertical; overflow: hidden; } +.event-date { font-size: 11px; font-family: var(--font-mono); color: var(--muted); } +.event-title { font-size: 14px; font-weight: 600; } +.event-summary { font-size: 13px; color: var(--muted); display: -webkit-box; -webkit-line-clamp: 2; -webkit-box-orient: vertical; overflow: hidden; } /* ── Perception Page ────────────────────────────── */ .perception-page { display: flex; flex-direction: column; height: 100%; } @@ -414,8 +427,8 @@ body { .stats-bar { display: flex; border-bottom: 1px solid var(--border); background: var(--surface); flex-shrink: 0; } .sbar-cell { flex: 1; padding: 14px 22px; border-right: 1px solid var(--border); display: flex; flex-direction: column; gap: 3px; } .sbar-cell:last-child { border-right: none; } -.sbar-val { font-size: 22px; font-weight: 700; font-family: var(--font-display); } -.sbar-lbl { font-size: 11px; color: var(--muted); } +.sbar-val { font-size: 24px; font-weight: 700; font-family: var(--font-display); } +.sbar-lbl { font-size: 12px; color: var(--muted); } .sbar-cell.danger .sbar-val { color: var(--danger); } .sbar-cell.warn .sbar-val { color: var(--warn); } .sbar-cell.accent .sbar-val { color: var(--accent); } @@ -427,14 +440,14 @@ body { .ev-card:hover { background: var(--bg); } .ev-card.selected { border-left-color: var(--accent); background: var(--accent-dim); box-shadow: inset 0 0 0 1px var(--accent-dim); } .ev-top { display: flex; align-items: center; gap: 7px; margin-bottom: 6px; } -.ev-std { font-size: 10px; font-family: var(--font-mono); color: var(--muted); } -.ev-title { font-size: 13px; font-weight: 600; line-height: 1.35; margin-bottom: 5px; } -.ev-summary { font-size: 12px; color: var(--muted); display: -webkit-box; -webkit-line-clamp: 2; -webkit-box-orient: vertical; overflow: hidden; margin-bottom: 8px; } +.ev-std { font-size: 11px; font-family: var(--font-mono); color: var(--muted); } +.ev-title { font-size: 14px; font-weight: 600; line-height: 1.35; margin-bottom: 5px; } +.ev-summary { font-size: 13px; color: var(--muted); display: -webkit-box; -webkit-line-clamp: 2; -webkit-box-orient: vertical; overflow: hidden; margin-bottom: 8px; } .ev-bottom { display: flex; align-items: center; gap: 8px; flex-wrap: wrap; } -.ev-date { font-size: 10px; font-family: var(--font-mono); color: var(--muted); } +.ev-date { font-size: 11px; font-family: var(--font-mono); color: var(--muted); } .ev-tags { display: flex; gap: 4px; flex-wrap: wrap; } -.ev-tag { font-size: 10px; padding: 1px 6px; background: var(--bg); border: 1px solid var(--border); border-radius: var(--radius-pill); color: var(--muted); } -.impact-dot { font-size: 10px; font-weight: 700; margin-left: auto; } +.ev-tag { font-size: 11px; padding: 1px 6px; background: var(--bg); border: 1px solid var(--border); border-radius: var(--radius-pill); color: var(--muted); } +.impact-dot { font-size: 11px; font-weight: 700; margin-left: auto; } .impact-high { color: var(--danger); } .impact-medium { color: var(--warn); } .impact-low { color: var(--success); } @@ -447,18 +460,18 @@ body { .detail-card { display: flex; flex-direction: column; gap: 10px; } .detail-header { display: flex; align-items: center; gap: 8px; } -.detail-title { font-size: 15px; font-weight: 700; font-family: var(--font-display); } -.detail-summary { font-size: 13px; color: var(--muted); line-height: 1.6; } +.detail-title { font-size: 16px; font-weight: 700; font-family: var(--font-display); } +.detail-summary { font-size: 14px; color: var(--muted); line-height: 1.6; } .detail-actions { display: flex; gap: 8px; padding-top: 4px; } .docs-card { display: flex; flex-direction: column; gap: 0; } .doc-row { display: flex; gap: 10px; padding: 10px 0; border-top: 1px solid var(--border); } .doc-score { font-size: 11px; font-weight: 700; font-family: var(--font-mono); color: var(--success); width: 34px; flex-shrink: 0; padding-top: 1px; } -.doc-name { font-size: 12px; font-weight: 600; margin-bottom: 3px; } -.doc-clause { font-size: 10px; font-family: var(--font-mono); color: var(--muted); margin-left: 5px; } -.doc-snippet { font-size: 11px; color: var(--muted); display: -webkit-box; -webkit-line-clamp: 2; -webkit-box-orient: vertical; overflow: hidden; } +.doc-name { font-size: 13px; font-weight: 600; margin-bottom: 3px; } +.doc-clause { font-size: 11px; font-family: var(--font-mono); color: var(--muted); margin-left: 5px; } +.doc-snippet { font-size: 12px; color: var(--muted); display: -webkit-box; -webkit-line-clamp: 2; -webkit-box-orient: vertical; overflow: hidden; } -.ai-card .ai-output { font-size: 13px; line-height: 1.7; white-space: pre-wrap; font-family: var(--font-mono); } +.ai-card .ai-output { font-size: 14px; line-height: 1.7; white-space: pre-wrap; font-family: var(--font-mono); } .blink-cursor { animation: blink 1s step-end infinite; } @keyframes blink { 0%,100% { opacity: 1; } 50% { opacity: 0; } } @@ -466,7 +479,7 @@ body { .docs-page { display: flex; flex-direction: column; height: 100%; } .docs-controls { display: flex; align-items: center; gap: 12px; margin-bottom: 14px; flex-wrap: wrap; } -.select-input { height: 28px; padding: 0 10px; border: 1px solid var(--border); border-radius: var(--radius-sm); background: var(--surface); font-size: 12px; color: var(--fg); outline: none; cursor: pointer; } +.select-input { height: 30px; padding: 0 10px; border: 1px solid var(--border); border-radius: var(--radius-sm); background: var(--surface); font-size: 13px; color: var(--fg); outline: none; cursor: pointer; } .batch-bar { display: flex; align-items: center; gap: 10px; padding: 8px 12px; background: var(--accent-dim); border: 1px solid var(--accent); border-radius: var(--radius-sm); margin-bottom: 12px; font-size: 13px; color: var(--accent); font-weight: 600; } .risk-btn { color: var(--danger); border-color: var(--danger-bg); } @@ -479,7 +492,7 @@ body { padding: 10px 14px; background: var(--bg); border-bottom: 1px solid var(--border); - font-size: 11px; font-weight: 700; text-transform: uppercase; letter-spacing: 0.05em; color: var(--muted); + font-size: 12px; font-weight: 700; text-transform: uppercase; letter-spacing: 0.05em; color: var(--muted); align-items: center; } .table-row { @@ -488,7 +501,7 @@ body { gap: 12px; padding: 11px 14px; border-bottom: 1px solid var(--border); - font-size: 13px; + font-size: 14px; align-items: center; transition: background 0.1s, opacity 0.3s; } @@ -497,10 +510,10 @@ body { .table-row.row-selected { background: var(--accent-dim); } .table-row.row-deleting { opacity: 0.4; pointer-events: none; } .doc-name-cell { font-weight: 500; } -.cell-mono { font-family: var(--font-mono); font-size: 11px; color: var(--muted); } -.cell-muted { font-size: 12px; color: var(--muted); } +.cell-mono { font-family: var(--font-mono); font-size: 12px; color: var(--muted); } +.cell-muted { font-size: 13px; color: var(--muted); } .row-actions { display: flex; gap: 10px; } -.text-link { background: none; border: none; font-size: 12px; color: var(--accent); cursor: pointer; font-weight: 500; padding: 0; } +.text-link { background: none; border: none; font-size: 13px; color: var(--accent); cursor: pointer; font-weight: 500; padding: 0; } .text-link:hover { text-decoration: underline; } .danger-link { color: var(--danger); } @@ -514,24 +527,24 @@ body { .compliance-workspace { display: grid; grid-template-columns: 0.95fr 1.25fr 0.9fr; gap: 14px; padding: 16px 24px 24px; flex: 1; overflow: hidden; min-height: 0; } .comp-col { display: flex; flex-direction: column; gap: 12px; overflow-y: auto; } -.col-header { font-size: 11px; font-weight: 700; text-transform: uppercase; letter-spacing: 0.05em; color: var(--muted); padding: 0 2px 4px; flex-shrink: 0; } +.col-header { font-size: 12px; font-weight: 700; text-transform: uppercase; letter-spacing: 0.05em; color: var(--muted); padding: 0 2px 4px; flex-shrink: 0; } .source-item { display: flex; flex-direction: column; gap: 6px; flex-shrink: 0; } .source-top { display: flex; align-items: center; justify-content: space-between; gap: 8px; } -.source-std { font-size: 13px; font-weight: 700; font-family: var(--font-display); } -.source-helper { font-size: 11px; color: var(--muted); } +.source-std { font-size: 14px; font-weight: 700; font-family: var(--font-display); } +.source-helper { font-size: 12px; color: var(--muted); } .source-scores { display: flex; gap: 5px; flex-wrap: wrap; } -.score-pill { font-size: 10px; font-family: var(--font-mono); padding: 2px 6px; background: var(--bg); border: 1px solid var(--border); border-radius: var(--radius-pill); color: var(--muted); } +.score-pill { font-size: 11px; font-family: var(--font-mono); padding: 2px 6px; background: var(--bg); border: 1px solid var(--border); border-radius: var(--radius-pill); color: var(--muted); } .para-card { flex-shrink: 0; } -.para-text { font-size: 13px; line-height: 1.7; color: var(--fg); } +.para-text { font-size: 14px; line-height: 1.7; color: var(--fg); } .para-text mark { background: rgba(226,0,116,.15); color: var(--accent); padding: 0 2px; border-radius: 2px; } .stages-card { flex-shrink: 0; } .stage-row { padding: 8px 0; border-top: 1px solid var(--border); display: flex; flex-direction: column; gap: 5px; } .stage-label-row { display: flex; justify-content: space-between; align-items: center; } -.stage-label { font-size: 12px; } -.stage-pct { font-size: 11px; font-family: var(--font-mono); color: var(--muted); } +.stage-label { font-size: 13px; } +.stage-pct { font-size: 12px; font-family: var(--font-mono); color: var(--muted); } .stage-bar { height: 4px; background: var(--border); border-radius: 2px; overflow: hidden; } .stage-fill { height: 100%; border-radius: 2px; } .stage-fill.stage-ok { background: var(--success); } @@ -540,13 +553,13 @@ body { .finding-item { display: flex; flex-direction: column; gap: 6px; flex-shrink: 0; } .finding-top { display: flex; align-items: flex-start; justify-content: space-between; gap: 8px; } -.finding-title { font-size: 13px; font-weight: 600; line-height: 1.3; } -.finding-desc { font-size: 12px; color: var(--muted); line-height: 1.5; } +.finding-title { font-size: 14px; font-weight: 600; line-height: 1.3; } +.finding-desc { font-size: 13px; color: var(--muted); line-height: 1.5; } .conclusion-box { display: flex; flex-direction: column; gap: 10px; flex-shrink: 0; } -.conclusion-text { font-size: 12px; line-height: 1.6; color: var(--fg); } +.conclusion-text { font-size: 13px; line-height: 1.6; color: var(--fg); } .action-items { display: flex; flex-direction: column; gap: 8px; padding-top: 8px; border-top: 1px solid var(--border); } -.action-item { display: flex; justify-content: space-between; align-items: center; gap: 8px; font-size: 12px; } +.action-item { display: flex; justify-content: space-between; align-items: center; gap: 8px; font-size: 13px; } .action-label { color: var(--muted); } .action-value { font-weight: 600; } .risk-text { color: var(--danger); } @@ -559,9 +572,9 @@ body { .history-header, .quick-header { font-size: 10px; font-weight: 700; text-transform: uppercase; letter-spacing: 0.06em; color: var(--muted); padding: 10px 16px 6px; } .history-item { padding: 8px 16px; cursor: pointer; border-radius: 0; transition: background 0.1s; } .history-item:hover { background: var(--bg); } -.history-title { font-size: 12px; font-weight: 500; display: -webkit-box; -webkit-line-clamp: 2; -webkit-box-orient: vertical; overflow: hidden; } -.history-date { font-size: 10px; font-family: var(--font-mono); color: var(--muted); margin-top: 2px; } -.quick-item { background: none; border: none; text-align: left; padding: 8px 16px; font-size: 12px; color: var(--muted); cursor: pointer; line-height: 1.4; transition: color 0.1s, background 0.1s; } +.history-title { font-size: 13px; font-weight: 500; display: -webkit-box; -webkit-line-clamp: 2; -webkit-box-orient: vertical; overflow: hidden; } +.history-date { font-size: 11px; font-family: var(--font-mono); color: var(--muted); margin-top: 2px; } +.quick-item { background: none; border: none; text-align: left; padding: 8px 16px; font-size: 13px; color: var(--muted); cursor: pointer; line-height: 1.4; transition: color 0.1s, background 0.1s; } .quick-item:hover { color: var(--accent); background: var(--accent-dim); } .chat-main { display: flex; flex-direction: column; overflow: hidden; } @@ -570,14 +583,14 @@ body { .msg-user { flex-direction: row-reverse; } .msg-avatar { width: 28px; height: 28px; border-radius: 50%; display: flex; align-items: center; justify-content: center; font-size: 10px; font-weight: 700; flex-shrink: 0; background: var(--accent-dim); color: var(--accent); } .msg-avatar.user-av { background: var(--bg); color: var(--muted); border: 1px solid var(--border); } -.msg-bubble { max-width: 72%; padding: 10px 14px; border-radius: 14px; font-size: 13px; line-height: 1.6; } +.msg-bubble { max-width: 72%; padding: 10px 14px; border-radius: 14px; font-size: 14px; line-height: 1.6; } .msg-assistant .msg-bubble { background: var(--surface); border: 1px solid var(--border); border-bottom-left-radius: 4px; } .msg-user .msg-bubble { background: var(--accent); color: #fff; border-bottom-right-radius: 4px; } .composer { padding: 12px 16px; border-top: 1px solid var(--border); background: var(--surface); display: flex; flex-direction: column; gap: 8px; flex-shrink: 0; } .quick-chips { display: flex; gap: 6px; flex-wrap: wrap; } .composer-row { display: flex; gap: 8px; align-items: flex-end; } -.composer-input { flex: 1; padding: 8px 12px; border: 1px solid var(--border); border-radius: var(--radius-sm); background: var(--bg); font-size: 13px; font-family: var(--font-body); color: var(--fg); resize: none; outline: none; line-height: 1.5; } +.composer-input { flex: 1; padding: 8px 12px; border: 1px solid var(--border); border-radius: var(--radius-sm); background: var(--bg); font-size: 14px; font-family: var(--font-body); color: var(--fg); resize: none; outline: none; line-height: 1.5; } .composer-input:focus { border-color: var(--accent); } .citation-rail { border-left: 1px solid var(--border); overflow-y: auto; padding: 14px 0; background: var(--surface); } @@ -586,9 +599,9 @@ body { .citation-item.highlighted { background: var(--accent-dim); border-left: 3px solid var(--accent); } .cit-score { font-size: 11px; font-weight: 700; font-family: var(--font-mono); color: var(--success); width: 34px; flex-shrink: 0; padding-top: 1px; } .cit-index { font-size: 10px; font-weight: 700; font-family: var(--font-mono); color: var(--muted); width: 18px; flex-shrink: 0; padding-top: 2px; } -.cit-name { font-size: 12px; font-weight: 600; margin-bottom: 3px; } -.cit-clause { font-size: 10px; font-family: var(--font-mono); color: var(--muted); margin-left: 5px; } -.cit-snippet { font-size: 11px; color: var(--muted); display: -webkit-box; -webkit-line-clamp: 3; -webkit-box-orient: vertical; overflow: hidden; line-height: 1.5; } +.cit-name { font-size: 13px; font-weight: 600; margin-bottom: 3px; } +.cit-clause { font-size: 11px; font-family: var(--font-mono); color: var(--muted); margin-left: 5px; } +.cit-snippet { font-size: 12px; color: var(--muted); display: -webkit-box; -webkit-line-clamp: 3; -webkit-box-orient: vertical; overflow: hidden; line-height: 1.5; } /* Inline citation badge [N] in message text */ .cite-ref { @@ -792,7 +805,7 @@ body { } .modal-tab { padding: 10px 18px; - font-size: 13px; + font-size: 14px; font-weight: 500; border: none; background: none; @@ -816,7 +829,7 @@ body { .domain-chip { padding: 4px 12px; border-radius: 20px; - font-size: 11px; + font-size: 12px; font-weight: 500; border: 1px solid var(--border); background: var(--surface); @@ -867,8 +880,8 @@ body { background: var(--accent); border-color: var(--accent); } -.doc-select-name { font-size: 13px; font-weight: 500; flex: 1; } -.doc-select-meta { font-size: 11px; color: var(--muted); } +.doc-select-name { font-size: 14px; font-weight: 500; flex: 1; } +.doc-select-meta { font-size: 12px; color: var(--muted); } /* Stage running animation */ @keyframes stage-pulse { @@ -945,8 +958,8 @@ body { color: var(--accent); opacity: 0.7; } -.analysis-empty h3 { font-size: 15px; font-weight: 600; color: var(--fg); margin: 0; } -.analysis-empty p { font-size: 13px; max-width: 280px; line-height: 1.6; margin: 0; } +.analysis-empty h3 { font-size: 16px; font-weight: 600; color: var(--fg); margin: 0; } +.analysis-empty p { font-size: 14px; max-width: 280px; line-height: 1.6; margin: 0; } /* Highlight terms in paragraph */ mark.comp-highlight { @@ -976,3 +989,122 @@ mark.comp-highlight { background: linear-gradient(to right, #22c55e 0%, #eab308 50%, #ef4444 100%); transition: width 0.6s ease; } + +/* ── Login Page ─────────────────────────────────── */ +.login-page { + min-height: 100vh; + display: flex; + align-items: center; + justify-content: center; + background: var(--bg); +} + +.login-card { + width: 100%; + max-width: 380px; + background: var(--surface); + border: 1px solid var(--border); + border-radius: var(--radius-md); + box-shadow: var(--shadow-card); + padding: 36px 32px 28px; +} + +.login-brand { + display: flex; + align-items: center; + gap: 10px; + margin-bottom: 28px; +} +.login-logo { width: 32px; height: 32px; object-fit: contain; } +.login-brand-name { font-size: 14px; font-weight: 700; font-family: var(--font-display); color: var(--fg); } +.login-brand-sub { font-size: 11px; color: var(--muted); } + +.login-title { + font-size: 20px; + font-weight: 700; + font-family: var(--font-display); + margin-bottom: 20px; + color: var(--fg); +} + +.login-form { display: flex; flex-direction: column; gap: 16px; } + +.login-field { display: flex; flex-direction: column; gap: 5px; } +.login-label { font-size: 12px; font-weight: 600; color: var(--fg); } + +.login-input { + height: 38px; + padding: 0 12px; + border: 1px solid var(--border-strong); + border-radius: var(--radius-sm); + background: var(--bg); + color: var(--fg); + font-size: 13px; + outline: none; + transition: border-color 0.15s; +} +.login-input:focus { border-color: var(--accent); } +.login-input:disabled { opacity: 0.6; } + +.login-error { + font-size: 12px; + color: var(--danger); + background: var(--danger-bg); + border: 1px solid rgba(220,38,38,.2); + border-radius: var(--radius-sm); + padding: 8px 12px; + margin: 0; +} + +.login-btn { + height: 40px; + background: var(--accent); + color: #fff; + border: none; + border-radius: var(--radius-sm); + font-size: 14px; + font-weight: 700; + cursor: pointer; + transition: background 0.15s; +} +.login-btn:hover:not(:disabled) { background: var(--accent-hover); } +.login-btn:disabled { opacity: 0.6; cursor: not-allowed; } + +.login-hint { + margin-top: 16px; + font-size: 11px; + color: var(--muted); + text-align: center; +} +.login-hint code { + font-family: var(--font-mono); + background: var(--bg); + padding: 1px 4px; + border-radius: 3px; + border: 1px solid var(--border); +} + +/* ── Sidebar user block (authenticated) ─────────── */ +.user-badge { + display: inline-block; + font-size: 10px; + font-weight: 700; + text-transform: uppercase; + padding: 1px 5px; + border-radius: 3px; + background: var(--accent-dim); + color: var(--accent); + letter-spacing: 0.04em; +} +.logout-btn { + background: none; + border: none; + cursor: pointer; + color: var(--muted); + padding: 4px; + border-radius: var(--radius-sm); + display: flex; + align-items: center; + transition: color 0.15s; +} +.logout-btn:hover { color: var(--danger); } diff --git a/pyproject.toml b/pyproject.toml index 7b69071..0dd1e5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,10 +31,13 @@ dependencies = [ "redis>=4.5.0", "minio>=7.1.0", "psycopg2-binary>=2.9.0", + "python-jose[cryptography]>=3.3.0", + "passlib[bcrypt]>=1.7.4", + "bcrypt>=3.2.0,<4.0.0", ] [dependency-groups] -dev = ["pytest>=7.0.0", "pytest-asyncio>=0.21.0", "isort>=8.0.1"] +dev = ["pytest>=7.0.0", "pytest-asyncio>=0.21.0", "isort>=8.0.1", "fakeredis>=2.0.0"] [build-system] requires = ["setuptools>=61.0"] diff --git a/requirements.txt b/requirements.txt index 4825645..60376a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,3 +42,7 @@ python-dotenv>=1.0.0 loguru>=0.7.0 tenacity>=8.2.0 httpx>=0.24.0 + +# Authentication +python-jose[cryptography]>=3.3.0 +passlib[bcrypt]>=1.7.4 diff --git a/scripts/seed_users.py b/scripts/seed_users.py new file mode 100644 index 0000000..d7053ed --- /dev/null +++ b/scripts/seed_users.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +"""Seed demo users into the PostgreSQL users table. + +Run from the repo root: + PYTHONPATH=backend python scripts/seed_users.py + +Creates four demo accounts (one per role). Safe to run multiple times — +existing usernames are left unchanged (ON CONFLICT DO NOTHING). +""" + +import os +import sys + +# Allow running from repo root without installing the package. +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend")) + +import psycopg2 +from passlib.context import CryptContext +from app.config.settings import settings + +_PWD_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto") + +DEMO_USERS = [ + {"username": "admin", "password": "Admin@2026!", "role": "admin"}, + {"username": "legal", "password": "Legal@2026!", "role": "legal"}, + {"username": "ehs", "password": "EHS@2026!", "role": "ehs"}, + {"username": "readonly", "password": "Read@2026!", "role": "readonly"}, +] + +_CREATE_TABLE_SQL = """ +CREATE TABLE IF NOT EXISTS users ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + username VARCHAR(100) UNIQUE NOT NULL, + hashed_pw TEXT NOT NULL, + role VARCHAR(50) NOT NULL DEFAULT 'readonly', + is_active BOOLEAN NOT NULL DEFAULT TRUE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +""" + + +def seed() -> None: + """Insert demo users. Skips any username that already exists.""" + conn = psycopg2.connect( + host=settings.postgres_host, + port=settings.postgres_port, + user=settings.postgres_user, + password=settings.postgres_password, + dbname=settings.postgres_db, + ) + conn.autocommit = True + + with conn.cursor() as cur: + # Enable pgcrypto for gen_random_uuid() if not already enabled, or use uuid-ossp fallback. + try: + cur.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;") + except Exception: + conn.rollback() + cur.execute(_CREATE_TABLE_SQL) + + inserted = 0 + with conn.cursor() as cur: + for user in DEMO_USERS: + hashed = _PWD_CTX.hash(user["password"]) + cur.execute( + """ + INSERT INTO users (username, hashed_pw, role) + VALUES (%s, %s, %s) + ON CONFLICT (username) DO NOTHING + """, + (user["username"], hashed, user["role"]), + ) + if cur.rowcount > 0: + print(f" Created user: {user['username']} (role={user['role']})") + inserted += 1 + else: + print(f" Skipped (already exists): {user['username']}") + + print(f"\nDone. {inserted} user(s) created.") + conn.close() + + +if __name__ == "__main__": + seed() diff --git a/tests/test_auth_routes.py b/tests/test_auth_routes.py new file mode 100644 index 0000000..30ae79c --- /dev/null +++ b/tests/test_auth_routes.py @@ -0,0 +1,58 @@ +"""Integration tests for the auth routes. + +Uses FastAPI TestClient with a mocked user store. +Does not require a running PostgreSQL. +""" + +import pytest +from fastapi.testclient import TestClient +from unittest.mock import MagicMock, patch + +from app.infrastructure.auth.user_store import UserRecord + + +@pytest.fixture +def client(): + """Return a TestClient with a mocked user store.""" + mock_store = MagicMock() + alice = UserRecord(id="uuid-1", username="alice", hashed_pw="hashed", role="admin", is_active=True) + mock_store.authenticate.side_effect = lambda u, p: alice if (u == "alice" and p == "correct") else None + + with patch("app.shared.bootstrap.get_user_store", return_value=mock_store): + # Import after patch so the mock is active when routes import bootstrap. + from app.api.main import app + with TestClient(app, raise_server_exceptions=False) as c: + yield c + + +def test_login_returns_token_for_valid_credentials(client): + """POST /auth/token must return an access_token for valid credentials.""" + resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "correct"}) + assert resp.status_code == 200 + body = resp.json() + assert "access_token" in body + assert body["token_type"] == "bearer" + + +def test_login_returns_401_for_wrong_password(client): + """POST /auth/token must return 401 for wrong password.""" + resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "wrong"}) + assert resp.status_code == 401 + + +def test_me_returns_user_when_authenticated(client): + """GET /auth/me must return user identity when a valid token is provided.""" + login_resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "correct"}) + assert login_resp.status_code == 200, login_resp.text + token = login_resp.json()["access_token"] + + me_resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"}) + assert me_resp.status_code == 200 + assert me_resp.json()["username"] == "alice" + assert me_resp.json()["role"] == "admin" + + +def test_me_returns_401_without_token(client): + """GET /auth/me must return 401 when no token is provided.""" + resp = client.get("/api/v1/auth/me") + assert resp.status_code == 401 diff --git a/tests/test_celery_tasks.py b/tests/test_celery_tasks.py new file mode 100644 index 0000000..9498939 --- /dev/null +++ b/tests/test_celery_tasks.py @@ -0,0 +1,45 @@ +"""Tests for Celery task infrastructure. + +Verifies Celery app configuration and task registration without +starting a real worker or connecting to Redis. +""" + + +def test_celery_app_uses_redis_broker(): + """Celery broker URL must be a Redis URL built from settings.""" + from app.infrastructure.tasks.celery_app import celery_app + assert celery_app.conf.broker_url.startswith("redis://") + + +def test_celery_app_uses_redis_backend(): + """Celery result backend must be Redis.""" + from app.infrastructure.tasks.celery_app import celery_app + assert celery_app.conf.result_backend.startswith("redis://") + + +def test_celery_app_has_json_serializer(): + """Task serializer must be JSON for portability.""" + from app.infrastructure.tasks.celery_app import celery_app + assert celery_app.conf.task_serializer == "json" + + +def test_process_document_task_is_registered(): + """process_document_task must be discoverable in the Celery task registry.""" + import app.infrastructure.tasks.document_tasks # noqa: F401 — triggers task registration + from app.infrastructure.tasks.celery_app import celery_app + registered = list(celery_app.tasks.keys()) + assert any("process_document_task" in name for name in registered), ( + f"process_document_task not found in {registered}" + ) + + +def test_document_command_service_has_process_document(): + """DocumentCommandService must expose _process_document method.""" + from app.application.documents.services import DocumentCommandService + assert hasattr(DocumentCommandService, "_process_document") + + +def test_document_command_service_has_store_document(): + """DocumentCommandService must expose store_document method.""" + from app.application.documents.services import DocumentCommandService + assert hasattr(DocumentCommandService, "store_document") diff --git a/tests/test_jwt_handler.py b/tests/test_jwt_handler.py new file mode 100644 index 0000000..669273e --- /dev/null +++ b/tests/test_jwt_handler.py @@ -0,0 +1,58 @@ +"""Tests for JWTHandler token creation and decoding. + +These tests do not require a running server or database. +""" + +import time +import pytest + +SECRET = "test-secret-key-minimum-32-characters-long" + + +@pytest.fixture +def handler(): + """Return a JWTHandler configured with a test secret.""" + from app.infrastructure.auth.jwt_handler import JWTHandler + return JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=30) + + +def test_create_token_returns_string(handler): + """create_access_token must return a non-empty string.""" + token = handler.create_access_token(user_id="u1", username="alice", role="admin") + assert isinstance(token, str) + assert len(token) > 20 + + +def test_decode_token_returns_correct_claims(handler): + """decode_token must return UserClaims matching the input.""" + token = handler.create_access_token(user_id="u1", username="alice", role="admin") + claims = handler.decode_token(token) + assert claims.user_id == "u1" + assert claims.username == "alice" + assert claims.role == "admin" + + +def test_decode_expired_token_raises(handler): + """decode_token must raise ValueError on an expired token.""" + from app.infrastructure.auth.jwt_handler import JWTHandler + short_handler = JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=0) + token = short_handler.create_access_token(user_id="u2", username="bob", role="readonly") + time.sleep(1) + with pytest.raises(ValueError, match="expired"): + short_handler.decode_token(token) + + +def test_decode_invalid_token_raises(handler): + """decode_token must raise ValueError for a tampered token.""" + with pytest.raises(ValueError): + handler.decode_token("not.a.valid.jwt.token") + + +def test_decode_wrong_secret_raises(): + """decode_token must raise ValueError when signed with a different secret.""" + from app.infrastructure.auth.jwt_handler import JWTHandler + creator = JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=60) + verifier = JWTHandler(secret_key="wrong-secret-key-also-minimum-32-chars", algorithm="HS256", expire_minutes=60) + token = creator.create_access_token(user_id="u3", username="carol", role="legal") + with pytest.raises(ValueError): + verifier.decode_token(token) diff --git a/tests/test_redis_conversation_store.py b/tests/test_redis_conversation_store.py new file mode 100644 index 0000000..5db2e46 --- /dev/null +++ b/tests/test_redis_conversation_store.py @@ -0,0 +1,89 @@ +"""Tests for RedisConversationStore. + +Uses fakeredis so no real Redis connection is required. +All tests follow the same ConversationStore contract as InMemoryConversationStore. +""" + +import pytest +import fakeredis + + +@pytest.fixture +def redis_client(): + """Return an in-process fake Redis client.""" + return fakeredis.FakeRedis() + + +@pytest.fixture +def store(redis_client): + """Return a RedisConversationStore backed by fake Redis.""" + from app.infrastructure.session.redis_conversation_store import RedisConversationStore + return RedisConversationStore(redis_client=redis_client, timeout_seconds=1800) + + +def test_create_session_returns_session_with_id(store): + """create_session() must return a ConversationSession with a non-empty session_id.""" + session = store.create_session() + assert session.session_id + assert len(session.session_id) > 0 + + +def test_get_session_returns_same_session(store): + """get_session() must return the previously created session.""" + session = store.create_session() + fetched = store.get_session(session.session_id) + assert fetched is not None + assert fetched.session_id == session.session_id + + +def test_get_session_returns_none_for_unknown_id(store): + """get_session() must return None when the session_id does not exist.""" + assert store.get_session("nonexistent-id") is None + + +def test_save_message_appends_to_session(store): + """save_message() must append a message and return the updated session.""" + session = store.create_session() + updated = store.save_message(session.session_id, role="user", content="Hello") + assert updated is not None + assert len(updated.messages) == 1 + assert updated.messages[0].role == "user" + assert updated.messages[0].content == "Hello" + + +def test_save_message_persists_across_lookups(store): + """Messages saved to a session must be visible in subsequent get_session calls.""" + session = store.create_session() + store.save_message(session.session_id, role="user", content="test") + fetched = store.get_session(session.session_id) + assert fetched is not None + assert len(fetched.messages) == 1 + + +def test_delete_session_removes_it(store): + """delete_session() must return True and remove the session.""" + session = store.create_session() + assert store.delete_session(session.session_id) is True + assert store.get_session(session.session_id) is None + + +def test_delete_session_returns_false_for_unknown(store): + """delete_session() must return False when the session does not exist.""" + assert store.delete_session("ghost-id") is False + + +def test_list_sessions_includes_created_session(store): + """list_sessions() must include all active sessions.""" + session = store.create_session() + ids = [s["session_id"] for s in store.list_sessions()] + assert session.session_id in ids + + +def test_session_expires_after_ttl(redis_client): + """Sessions must disappear after the TTL expires.""" + from app.infrastructure.session.redis_conversation_store import RedisConversationStore + store = RedisConversationStore(redis_client=redis_client, timeout_seconds=1) + session = store.create_session() + # Simulate TTL expiry by deleting the key directly (fakeredis expire(0) is a no-op). + redis_client.delete(f"session:{session.session_id}") + assert store.get_session(session.session_id) is None diff --git a/tests/test_reranker_bootstrap.py b/tests/test_reranker_bootstrap.py new file mode 100644 index 0000000..f9e58d4 --- /dev/null +++ b/tests/test_reranker_bootstrap.py @@ -0,0 +1,39 @@ +"""Verify that bootstrap correctly wires the reranker when the setting is enabled.""" + +from unittest.mock import patch + +import pytest + + +def test_get_reranker_returns_none_when_disabled(): + """get_reranker() must return None when reranker_enabled is False.""" + with patch("app.shared.bootstrap._build_binary_store"), \ + patch("app.shared.bootstrap._build_vector_index"): + from app.shared import bootstrap + bootstrap.get_reranker.cache_clear() + + with patch("app.shared.bootstrap.settings") as mock_settings: + mock_settings.reranker_enabled = False + mock_settings.reranker_base_url = "" + result = bootstrap.get_reranker() + + bootstrap.get_reranker.cache_clear() + assert result is None + + +def test_get_reranker_returns_instance_when_enabled(): + """get_reranker() must return an OpenAICompatibleReranker when enabled.""" + from app.shared import bootstrap + bootstrap.get_reranker.cache_clear() + + with patch("app.shared.bootstrap.settings") as mock_settings: + mock_settings.reranker_enabled = True + mock_settings.reranker_base_url = "http://localhost:8082" + mock_settings.reranker_model = "BAAI/bge-reranker-v2-m3" + mock_settings.reranker_api_key = "" + mock_settings.reranker_top_k = 5 + result = bootstrap.get_reranker() + + bootstrap.get_reranker.cache_clear() + from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker + assert isinstance(result, OpenAICompatibleReranker)