Files
AIRegulation-DocAnalysis/docs/superpowers/plans/2026-06-05-phase1-production-foundation.md
wangwei 9fea9c6a53 1. Add 登陆功能
2. 调整字体大小
3. 新增部分功能
2026-06-05 18:00:31 +08:00

2305 lines
78 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Phase 1: Production Foundation Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Transform the existing synchronous RAG system into a production-ready backend with async task processing, persistent sessions, and JWT authentication — turning "can run" into "can operate."
**Architecture:** Four independent workstreams executed in priority order: (A) enable the already-wired reranker, (B) introduce Celery workers backed by the already-configured Redis to make document processing non-blocking, (C) replace the in-memory conversation store with a Redis-backed implementation, and (D) add JWT authentication with role claims and audit logging across all API routes.
**Tech Stack:** FastAPI · Celery 5 · Redis 7 · PostgreSQL 15 · `python-jose[cryptography]` · `passlib[bcrypt]` · `fakeredis` (test) — all running on the existing `docker-compose.yml` infrastructure.
---
## Scope Note
These four workstreams are **largely independent** and can be worked in parallel by separate developers. Auth (Group D) should be integrated last because it modifies every route. The recommended solo sequence: A → B → C → D.
**Not in scope (YAGNI):** Fine-grained per-endpoint RBAC enforcement (the role claim is added to JWT as a future enforcement point), user self-registration, perception crawling, knowledge graph, EHS module.
---
## File Map
### Group A — Reranker Quick Win
| Action | File |
|--------|------|
| Modify | `.env.example` — document reranker env vars |
| Create | `tests/test_reranker_bootstrap.py` |
### Group B — Async Task-ification
| Action | File |
|--------|------|
| Create | `backend/app/infrastructure/tasks/__init__.py` |
| Create | `backend/app/infrastructure/tasks/celery_app.py` |
| Create | `backend/app/infrastructure/tasks/document_tasks.py` |
| Modify | `backend/app/application/documents/services.py` — extract `_process_document` |
| Modify | `backend/app/api/routes/documents.py` — enqueue vs inline |
| Modify | `backend/app/shared/bootstrap.py` — expose `get_celery_app()` |
| Modify | `dev.sh` — add worker start target |
| Create | `tests/test_celery_tasks.py` |
### Group C — Session Persistence
| Action | File |
|--------|------|
| Create | `backend/app/infrastructure/session/redis_conversation_store.py` |
| Modify | `backend/app/config/settings.py` — add `session_backend` field |
| Modify | `backend/app/shared/bootstrap.py` — switch store on setting |
| Modify | `pyproject.toml` — add `fakeredis` dev dep |
| Create | `tests/test_redis_conversation_store.py` |
### Group D — JWT Auth + RBAC + Audit
| Action | File |
|--------|------|
| Modify | `requirements.txt` — add `python-jose`, `passlib[bcrypt]` |
| Modify | `pyproject.toml` — same |
| Create | `backend/app/domain/auth/__init__.py` |
| Create | `backend/app/domain/auth/models.py` |
| Create | `backend/app/infrastructure/auth/__init__.py` |
| Create | `backend/app/infrastructure/auth/jwt_handler.py` |
| Create | `backend/app/infrastructure/auth/user_store.py` |
| Create | `backend/app/api/dependencies/__init__.py` |
| Create | `backend/app/api/dependencies/auth.py` |
| Create | `backend/app/api/middleware/__init__.py` |
| Create | `backend/app/api/middleware/audit.py` |
| Create | `backend/app/api/routes/auth.py` |
| Modify | `backend/app/api/routes/__init__.py` — include auth router |
| Modify | `backend/app/config/settings.py` — add auth fields |
| Modify | `backend/app/shared/bootstrap.py` — wire auth components |
| Modify | `backend/app/api/main.py` — restrict CORS, add audit middleware |
| Modify | `backend/app/api/routes/documents.py` — require auth |
| Modify | `backend/app/api/routes/rag.py` — require auth |
| Modify | `backend/app/api/routes/compliance.py` — require auth |
| Create | `scripts/seed_users.py` |
| Create | `tests/test_jwt_handler.py` |
| Create | `tests/test_auth_routes.py` |
---
## Group A — Reranker Quick Win
> Estimated time: 30 minutes. No new code, just configuration and a verification test.
>
> **Why:** `settings.reranker_enabled` defaults to `False` in `backend/app/config/settings.py:113`. The entire `OpenAICompatibleReranker` infrastructure at `backend/app/infrastructure/vectorstore/cross_encoder_reranker.py` and the wiring in `backend/app/shared/bootstrap.py:199-203` is already complete. Enabling it requires setting two env vars and verifying the wiring works.
### Task 1: Document reranker env vars and write a bootstrap verification test
**Files:**
- Modify: `.env.example`
- Create: `tests/test_reranker_bootstrap.py`
- [ ] **Step 1: Add reranker vars to `.env.example`**
Open `.env.example`. Locate the existing embedding/LLM block and append:
```ini
# ── Reranker (Cross-Encoder) ──────────────────────────────────────────────────
# Set RERANKER_ENABLED=true and point to a TEI or Cohere-compatible rerank API.
# Recommended model: BAAI/bge-reranker-v2.5-gemma2-lightweight (lighter) or
# BAAI/bge-reranker-v2-m3 (heavier, higher quality).
# The endpoint must expose POST /rerank (TEI style) or POST /v1/rerank (Cohere style).
RERANKER_ENABLED=true
RERANKER_BASE_URL=http://6.86.80.4:30080/v1
RERANKER_MODEL=BAAI/bge-reranker-v2-m3
RERANKER_API_KEY=
RERANKER_TOP_K=5
```
- [ ] **Step 2: Write the failing test**
Create `tests/test_reranker_bootstrap.py`:
```python
"""Verify that bootstrap correctly wires the reranker when the setting is enabled."""
from unittest.mock import patch
import pytest
def test_get_reranker_returns_none_when_disabled():
"""get_reranker() must return None when reranker_enabled is False."""
with patch("app.shared.bootstrap._build_binary_store"), \
patch("app.shared.bootstrap._build_vector_index"):
# Import fresh by clearing lru_cache between tests
from app.shared import bootstrap
bootstrap.get_reranker.cache_clear()
with patch("app.config.settings.settings") as mock_settings:
mock_settings.reranker_enabled = False
mock_settings.reranker_base_url = ""
result = bootstrap.get_reranker()
bootstrap.get_reranker.cache_clear()
assert result is None
def test_get_reranker_returns_instance_when_enabled():
"""get_reranker() must return an OpenAICompatibleReranker when enabled."""
from app.shared import bootstrap
bootstrap.get_reranker.cache_clear()
with patch("app.config.settings.settings") as mock_settings:
mock_settings.reranker_enabled = True
mock_settings.reranker_base_url = "http://localhost:8082"
mock_settings.reranker_model = "BAAI/bge-reranker-v2-m3"
mock_settings.reranker_api_key = ""
mock_settings.reranker_top_k = 5
result = bootstrap.get_reranker()
bootstrap.get_reranker.cache_clear()
from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker
assert isinstance(result, OpenAICompatibleReranker)
```
- [ ] **Step 3: Run test to verify it fails (before fix)**
```bash
PYTHONPATH=backend pytest tests/test_reranker_bootstrap.py -v
```
Expected: Both tests PASS immediately (the bootstrap logic already exists). If they pass, the wiring is confirmed correct — proceed.
- [ ] **Step 4: Enable reranker in your local `.env`**
Add these two lines to your local `.env` (not `.env.example`, which is already updated):
```ini
RERANKER_ENABLED=true
RERANKER_BASE_URL=http://6.86.80.4:30080/v1
```
- [ ] **Step 5: Verify `GET /api/v1/status/health` reports reranker enabled**
Start the backend and call:
```bash
curl http://localhost:8000/api/v1/status/health | python -m json.tool
```
Expected fragment in response:
```json
"reranker": {
"enabled": true,
"model": "BAAI/bge-reranker-v2-m3"
}
```
> ✅ **Group A complete.** The reranker is now active for all `KnowledgeRetrievalService.retrieve()` calls.
---
## Group B — Async Task-ification
> Estimated time: 3-5 days.
>
> **Why:** `upload_document` at `backend/app/api/routes/documents.py:34` runs the full parse→embed→index pipeline synchronously in the HTTP request. For large GB standards this can take 15+ minutes (Aliyun timeout is 900 s per `settings.py:49`). Celery 5 and Redis are already in `requirements.txt` and `docker-compose.yml` — only the wiring is missing.
>
> **Strategy:** Split `DocumentCommandService.upload_and_process` into a fast `store_document` (sync, returns immediately) and a background `_process_document` (called from a Celery task). The existing `DocumentProcessingStore` already records per-stage status events — it becomes the progress tracker. The HTTP route returns `{doc_id, status: "queued"}` immediately; the client polls `GET /documents/status/{doc_id}` (already exists) for progress.
### Task 2: Create the Celery application
**Files:**
- Create: `backend/app/infrastructure/tasks/__init__.py`
- Create: `backend/app/infrastructure/tasks/celery_app.py`
- Create: `tests/test_celery_tasks.py` (first test only)
- [ ] **Step 1: Write a failing test for Celery app configuration**
Create `tests/test_celery_tasks.py`:
```python
"""Tests for Celery task infrastructure.
These tests verify task registration and Celery app configuration without
starting a real worker or connecting to Redis.
"""
import pytest
def test_celery_app_uses_redis_broker():
"""Celery broker URL must be a Redis URL built from settings."""
from app.infrastructure.tasks.celery_app import celery_app
assert celery_app.conf.broker_url.startswith("redis://")
def test_celery_app_uses_redis_backend():
"""Celery result backend must be Redis."""
from app.infrastructure.tasks.celery_app import celery_app
assert celery_app.conf.result_backend.startswith("redis://")
def test_celery_app_has_json_serializer():
"""Task serializer must be JSON for portability."""
from app.infrastructure.tasks.celery_app import celery_app
assert celery_app.conf.task_serializer == "json"
```
- [ ] **Step 2: Run test to confirm it fails**
```bash
PYTHONPATH=backend pytest tests/test_celery_tasks.py::test_celery_app_uses_redis_broker -v
```
Expected: `ModuleNotFoundError: No module named 'app.infrastructure.tasks'`
- [ ] **Step 3: Create the tasks package**
Create `backend/app/infrastructure/tasks/__init__.py`:
```python
"""Celery task definitions for background processing.
This package exposes the shared Celery application instance and all
registered task functions used by API routes to enqueue work.
"""
```
- [ ] **Step 4: Create the Celery app factory**
Create `backend/app/infrastructure/tasks/celery_app.py`:
```python
"""Shared Celery application instance for background task processing.
All workers and enqueueing call sites import `celery_app` from this module
so the broker/backend configuration stays in one place.
"""
from __future__ import annotations
from celery import Celery
from app.config.settings import settings
# Build Redis URL from the same settings used by the rest of the backend.
# The password segment is omitted when redis_password is empty so that
# local development against a password-free Redis instance works out of the box.
def _redis_url() -> str:
"""Return a Redis connection URL from application settings."""
if settings.redis_password:
return (
f"redis://:{settings.redis_password}@"
f"{settings.redis_host}:{settings.redis_port}/{settings.redis_db}"
)
return f"redis://{settings.redis_host}:{settings.redis_port}/{settings.redis_db}"
_BROKER = _redis_url()
_BACKEND = _redis_url()
celery_app = Celery(
"compliance_hub",
broker=_BROKER,
backend=_BACKEND,
include=["app.infrastructure.tasks.document_tasks"],
)
celery_app.conf.update(
task_serializer="json",
result_serializer="json",
accept_content=["json"],
timezone="UTC",
enable_utc=True,
# Retry failed tasks up to 3 times with exponential backoff.
task_acks_late=True,
task_reject_on_worker_lost=True,
# Keep results for 1 hour for status polling.
result_expires=3600,
)
```
- [ ] **Step 5: Run the Celery app tests**
```bash
PYTHONPATH=backend pytest tests/test_celery_tasks.py -v
```
Expected: `test_celery_app_uses_redis_broker` PASS, `test_celery_app_uses_redis_backend` PASS, `test_celery_app_has_json_serializer` PASS.
### Task 3: Split `DocumentCommandService` and create the document processing task
**Files:**
- Modify: `backend/app/application/documents/services.py`
- Create: `backend/app/infrastructure/tasks/document_tasks.py`
- [ ] **Step 1: Add task registration test to `tests/test_celery_tasks.py`**
Append to `tests/test_celery_tasks.py`:
```python
def test_process_document_task_is_registered():
"""process_document_task must be discoverable in the Celery task registry."""
# Importing document_tasks triggers task registration via the @task decorator.
import app.infrastructure.tasks.document_tasks # noqa: F401
from app.infrastructure.tasks.celery_app import celery_app
registered = list(celery_app.tasks.keys())
assert any("process_document_task" in name for name in registered), (
f"process_document_task not found in {registered}"
)
```
- [ ] **Step 2: Run test to confirm it fails**
```bash
PYTHONPATH=backend pytest tests/test_celery_tasks.py::test_process_document_task_is_registered -v
```
Expected: `ModuleNotFoundError: No module named 'app.infrastructure.tasks.document_tasks'`
- [ ] **Step 3: Extract `_process_document` from `DocumentCommandService`**
In `backend/app/application/documents/services.py`, locate `upload_and_process` (line 233). Add a new method `_process_document` that contains the parse→embed→index logic, and refactor `upload_and_process` to call it after the store step.
Find the end of the `_safe_mark_run_stored` call block (around line 299) and extract everything from the `suffix = os.path.splitext(...)` line to the end of the try block into a new method:
```python
def _process_document(
self,
*,
doc_id: str,
file_name: str,
final_doc_name: str,
content: bytes,
regulation_type: str,
version: str,
generate_summary: bool,
run_id: str | None = None,
) -> DocumentProcessResult:
"""Run parse → chunk → embed → index for a document that is already stored.
This method is called both synchronously (from upload_and_process) and
asynchronously (from the Celery document_tasks worker). All side-effects
write through DocumentProcessingStore so callers can poll progress.
"""
current_status = DocumentStatus.STORED
current_stage = "parse"
temp_path = ""
try:
suffix = os.path.splitext(file_name)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
temp_file.write(content)
temp_path = temp_file.name
parsed_document = self.parser.parse(
file_path=temp_path,
doc_id=doc_id,
doc_name=final_doc_name,
)
self._safe_mark_run_parsed(doc_id=doc_id, run_id=run_id, parsed_document=parsed_document)
artifact_keys: dict[str, str] = {}
try:
artifact_keys = self._save_parse_artifacts(doc_id=doc_id, parsed_document=parsed_document)
except Exception:
logger.warning("Parse artifact binary persistence failed for doc_id={}", doc_id)
self.document_repository.update_status(
doc_id,
DocumentStatus.PARSED,
parser_name=parsed_document.parser_name,
metadata={
"parser_backend": parsed_document.parser_name,
"parse_task_id": parsed_document.metadata.get("task_id", ""),
"layout_count": parsed_document.metadata.get("layout_count", len(parsed_document.raw_layouts)),
"structure_node_count": len(parsed_document.structure_nodes),
"semantic_block_count": len(parsed_document.semantic_blocks),
"vector_chunk_count": len(parsed_document.vector_chunks),
"artifact_keys": artifact_keys,
"processing_stage": "parsed",
},
)
current_status = DocumentStatus.PARSED
current_stage = "embed"
self._safe_replace_processing_artifacts(doc_id=doc_id, run_id=run_id, artifact_keys=artifact_keys)
self._safe_append_status_event(
doc_id=doc_id, run_id=run_id,
from_status=DocumentStatus.STORED.value, to_status=DocumentStatus.PARSED.value,
stage="parse", message="Document parsed", metadata={"artifact_count": len(artifact_keys)},
)
if self.parse_artifact_store:
try:
self.parse_artifact_store.save(doc_id, parsed_document.structure_nodes, parsed_document.semantic_blocks)
except Exception:
logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id)
chunks = self.chunk_builder.build(
parsed_document=parsed_document,
regulation_type=regulation_type,
version=version,
)
if not chunks:
raise ValueError("解析完成但没有生成可入库的 chunks")
vectors = self.embedding_provider.embed_texts([chunk.embedding_text for chunk in chunks])
current_stage = "index"
inserted = self.vector_index.upsert(chunks, vectors)
if inserted != len(chunks):
logger.warning("Milvus upsert count mismatched: inserted={}, chunks={}", inserted, len(chunks))
health = self.vector_index.health()
index_name = health.get("collection_name", "")
self.document_repository.update_status(
doc_id, DocumentStatus.INDEXED,
chunk_count=len(chunks), summary="", summary_latency_ms=0,
index_name=index_name,
metadata={"index_collection": index_name, "processing_stage": "indexed"},
)
self._safe_mark_run_indexed(doc_id=doc_id, run_id=run_id, chunk_count=len(chunks), index_name=index_name)
self._safe_append_status_event(
doc_id=doc_id, run_id=run_id,
from_status=DocumentStatus.PARSED.value, to_status=DocumentStatus.INDEXED.value,
stage="index", message="Document indexed",
metadata={"chunk_count": len(chunks), "index_name": index_name},
)
stored = self.document_repository.get(doc_id)
return DocumentProcessResult(
doc_id=doc_id, doc_name=final_doc_name,
status=(stored.status.value if stored else DocumentStatus.INDEXED.value),
message="处理成功", num_chunks=len(chunks),
summary=stored.summary if stored else "",
summary_latency_ms=stored.summary_latency_ms if stored else 0,
)
except Exception as exc:
logger.exception("文档处理失败: doc_id={}", doc_id)
self.document_repository.update_status(
doc_id, DocumentStatus.FAILED, error_message=str(exc),
metadata={"failure_reason": str(exc), "processing_stage": "failed", "failure_stage": current_stage},
)
self._safe_mark_run_failed(doc_id=doc_id, run_id=run_id, failure_stage=current_stage, error_message=str(exc))
self._safe_append_status_event(
doc_id=doc_id, run_id=run_id,
from_status=current_status.value, to_status=DocumentStatus.FAILED.value,
stage=current_stage, message=str(exc),
)
return DocumentProcessResult(
doc_id=doc_id, doc_name=final_doc_name,
status=DocumentStatus.FAILED.value, message=f"文档处理失败: {exc}",
)
finally:
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
except OSError:
logger.warning("临时文件清理失败: {}", temp_path)
```
Then replace the parse-through-index section of `upload_and_process` with a call to `_process_document`:
In `upload_and_process`, replace everything from `suffix = os.path.splitext(file_name)[1]` through to the end of the try block (before the except) with:
```python
# Delegate parse → embed → index to the shared processing method.
# This same method is invoked by the Celery worker for async processing.
return self._process_document(
doc_id=doc_id,
file_name=file_name,
final_doc_name=final_doc_name,
content=content,
regulation_type=regulation_type,
version=version,
generate_summary=generate_summary,
run_id=run_id,
)
```
Also add a new method `store_document` for the fast sync-only path:
```python
def store_document(
self,
*,
doc_id: str | None = None,
file_name: str,
content: bytes,
content_type: str,
doc_name: str | None,
regulation_type: str,
version: str,
generate_summary: bool,
) -> tuple[str, str | None]:
"""Store the binary file and create the Document record.
Returns (doc_id, run_id). Does NOT parse, embed, or index.
This is the fast synchronous first step; processing is enqueued separately.
"""
doc_id = doc_id or str(uuid.uuid4())[:8]
final_doc_name = doc_name or file_name
object_name = f"{doc_id}/{file_name}"
document = Document(
doc_id=doc_id,
doc_name=final_doc_name,
file_name=file_name,
object_name=object_name,
content_type=content_type,
size_bytes=len(content),
regulation_type=regulation_type,
version=version,
metadata={"generate_summary": generate_summary},
)
self.document_repository.create(document)
run_id = self._safe_create_processing_run(
doc_id=doc_id, trigger_type="upload", generate_summary=generate_summary
)
self.binary_store.save(
object_name=object_name, data=content,
content_type=content_type, metadata={"doc_id": doc_id},
)
self.document_repository.update_status(doc_id, DocumentStatus.STORED)
self._safe_mark_run_stored(doc_id=doc_id, run_id=run_id)
self._safe_append_status_event(
doc_id=doc_id, run_id=run_id,
from_status=DocumentStatus.PENDING.value, to_status=DocumentStatus.STORED.value,
stage="store", message="Source file stored",
)
return doc_id, run_id
```
- [ ] **Step 4: Create the Celery document task**
Create `backend/app/infrastructure/tasks/document_tasks.py`:
```python
"""Celery tasks for document processing.
Each task is a thin wrapper that retrieves the already-stored document
binary and delegates to DocumentCommandService._process_document.
The task does not accept raw file bytes — it reads them from the binary
store using the doc_id, so the payload stays small.
"""
from __future__ import annotations
from loguru import logger
from app.infrastructure.tasks.celery_app import celery_app
@celery_app.task(
name="app.infrastructure.tasks.document_tasks.process_document_task",
bind=True,
max_retries=3,
default_retry_delay=30,
acks_late=True,
)
def process_document_task(
self,
doc_id: str,
file_name: str,
doc_name: str,
regulation_type: str,
version: str,
generate_summary: bool,
run_id: str | None = None,
) -> dict:
"""Parse, embed, and index a document that has already been stored.
The task reads the file binary from MinIO using doc_id so the Celery
message stays small. Retries up to 3 times with a 30-second delay on
transient infrastructure errors.
"""
# Import inside the task function to avoid pickling issues and to ensure
# that each worker process initialises its own bootstrap singletons.
from app.shared.bootstrap import get_document_command_service, get_document_query_service
logger.info("process_document_task started: doc_id={}", doc_id)
try:
# Read the stored binary from MinIO.
svc = get_document_command_service()
doc = get_document_query_service().get(doc_id)
if not doc:
raise ValueError(f"Document record not found: {doc_id}")
content = svc.binary_store.read(doc.object_name)
result = svc._process_document(
doc_id=doc_id,
file_name=file_name,
final_doc_name=doc_name,
content=content,
regulation_type=regulation_type,
version=version,
generate_summary=generate_summary,
run_id=run_id,
)
logger.info(
"process_document_task completed: doc_id={} status={} chunks={}",
doc_id, result.status, result.num_chunks,
)
return {"doc_id": result.doc_id, "status": result.status, "num_chunks": result.num_chunks}
except Exception as exc:
logger.exception("process_document_task failed: doc_id={}", doc_id)
# Retry on transient errors; permanent errors (bad file, parse failure)
# will exhaust retries and leave the document in FAILED state.
raise self.retry(exc=exc)
```
- [ ] **Step 5: Run all Celery task tests**
```bash
PYTHONPATH=backend pytest tests/test_celery_tasks.py -v
```
Expected: All 4 tests PASS.
### Task 4: Update the upload route to support async enqueueing
**Files:**
- Modify: `backend/app/api/routes/documents.py`
- Modify: `backend/app/shared/bootstrap.py`
- [ ] **Step 1: Add `get_celery_app` to bootstrap**
In `backend/app/shared/bootstrap.py`, add at the bottom (before `preload_runtime_dependencies`):
```python
@lru_cache
def get_celery_app():
"""Return the shared Celery application instance."""
# Import here to avoid importing Celery at module load time when workers
# are not configured (e.g., during tests that mock bootstrap).
from app.infrastructure.tasks.celery_app import celery_app
return celery_app
```
- [ ] **Step 2: Update the upload route**
Replace the `upload_document` handler in `backend/app/api/routes/documents.py` with:
```python
@router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document(
file: UploadFile = File(..., description="上传的文档文件"),
doc_id: str | None = Form(None, description="客户端预分配的文档ID不传则自动生成"),
doc_name: str | None = Form(None, description="文档名称"),
regulation_type: str | None = Form(None, description="法规类型"),
version: str | None = Form(None, description="文档版本"),
generate_summary: bool = Form(False, description="是否生成摘要"),
sync: bool = Form(False, description="同步处理(演示/测试用,默认异步)"),
):
"""Upload a document and enqueue it for background processing.
By default the route stores the binary immediately and enqueues parse→embed→index
as a Celery task, returning {doc_id, status:'stored'} within seconds.
Pass sync=true to process inline (useful for demos when no worker is running).
Poll GET /documents/status/{doc_id} for processing progress.
"""
content = await file.read()
if not file.filename:
raise HTTPException(status_code=400, detail="文件名不能为空")
if not content:
raise HTTPException(status_code=400, detail="上传文件为空")
try:
svc = get_document_command_service()
if sync:
# Synchronous fallback: full inline processing (demo / no-worker mode).
result = svc.upload_and_process(
doc_id=doc_id,
file_name=file.filename,
content=content,
content_type=file.content_type or "application/octet-stream",
doc_name=doc_name,
regulation_type=regulation_type or "",
version=version or "",
generate_summary=generate_summary,
)
else:
# Async default: store binary, enqueue processing task.
stored_doc_id, run_id = svc.store_document(
doc_id=doc_id,
file_name=file.filename,
content=content,
content_type=file.content_type or "application/octet-stream",
doc_name=doc_name,
regulation_type=regulation_type or "",
version=version or "",
generate_summary=generate_summary,
)
from app.infrastructure.tasks.document_tasks import process_document_task
process_document_task.delay(
doc_id=stored_doc_id,
file_name=file.filename,
doc_name=doc_name or file.filename,
regulation_type=regulation_type or "",
version=version or "",
generate_summary=generate_summary,
run_id=run_id,
)
result = DocumentProcessResult(
doc_id=stored_doc_id,
doc_name=doc_name or file.filename,
status="stored",
message="文件已存储,正在后台处理。请轮询 GET /documents/status/{doc_id} 查看进度。",
)
if result.status == "failed":
raise HTTPException(status_code=500, detail=result.message)
return _document_response(result)
except HTTPException:
raise
except Exception as exc:
logger.exception("文档上传失败")
raise HTTPException(status_code=500, detail=str(exc))
```
Also add the import at the top of `documents.py`:
```python
from app.application.documents import DocumentProcessResult
```
- [ ] **Step 3: Verify the app still starts**
```bash
PYTHONPATH=backend python -c "from app.api.main import app; print('OK')"
```
Expected: `OK` with no import errors.
### Task 5: Add worker startup to dev scripts
**Files:**
- Modify: `dev.sh`
- [ ] **Step 1: Add worker target to `dev.sh`**
Open `dev.sh`. Find the `start` command section and add a `worker` subcommand. Locate the section that handles `start api` and add after it:
```bash
start_worker() {
# Start the Celery document processing worker.
# Reads PYTHONPATH and env from the same root .env the API uses.
export PYTHONPATH=backend
celery -A app.infrastructure.tasks.celery_app worker \
--loglevel=info \
--concurrency=2 \
--queues=celery \
"$@"
}
start_beat() {
# Start the Celery Beat scheduler for periodic tasks (future: perception crawl).
export PYTHONPATH=backend
celery -A app.infrastructure.tasks.celery_app beat \
--loglevel=info \
"$@"
}
```
And in the argument dispatch block, add:
```bash
"worker")
start_worker "${@:3}"
;;
"beat")
start_beat "${@:3}"
;;
```
- [ ] **Step 2: Verify worker starts (requires Redis running)**
```bash
./dev.sh start worker --dry-run 2>/dev/null || PYTHONPATH=backend celery -A app.infrastructure.tasks.celery_app inspect ping 2>&1 | head -5
```
Expected: Either dry-run succeeds or Celery shows it can find the app.
> ✅ **Group B complete.** Document uploads now return in <1 s. Workers process in the background. `GET /documents/status/{doc_id}` reports live progress via `DocumentProcessingStore`.
---
## Group C — Session Persistence
> Estimated time: 2-3 days.
>
> **Why:** `InMemoryConversationStore` at `backend/app/infrastructure/session/in_memory_conversation_store.py` stores all chat sessions in a plain Python dict. Every backend restart loses all active conversations. Redis 7 is already running; persisting sessions there requires only a new adapter and a settings switch.
### Task 6: Add `fakeredis` dev dependency and create the Redis store
**Files:**
- Modify: `pyproject.toml`
- Create: `backend/app/infrastructure/session/redis_conversation_store.py`
- Create: `tests/test_redis_conversation_store.py`
- [ ] **Step 1: Add `fakeredis` to dev dependencies**
In `pyproject.toml`, find the `[dependency-groups]` section and update:
```toml
[dependency-groups]
dev = ["pytest>=7.0.0", "pytest-asyncio>=0.21.0", "isort>=8.0.1", "fakeredis>=2.0.0"]
```
Install it:
```bash
uv sync
```
- [ ] **Step 2: Write failing tests for RedisConversationStore**
Create `tests/test_redis_conversation_store.py`:
```python
"""Tests for RedisConversationStore.
Uses fakeredis so no real Redis connection is required.
All tests follow the same ConversationStore contract as InMemoryConversationStore.
"""
import json
import pytest
import fakeredis
@pytest.fixture
def redis_client():
"""Return an in-process fake Redis client."""
return fakeredis.FakeRedis()
@pytest.fixture
def store(redis_client):
"""Return a RedisConversationStore backed by fake Redis."""
from app.infrastructure.session.redis_conversation_store import RedisConversationStore
return RedisConversationStore(redis_client=redis_client, timeout_seconds=1800)
def test_create_session_returns_session_with_id(store):
"""create_session() must return a ConversationSession with a non-empty session_id."""
session = store.create_session()
assert session.session_id
assert len(session.session_id) > 0
def test_get_session_returns_same_session(store):
"""get_session() must return the previously created session."""
session = store.create_session()
fetched = store.get_session(session.session_id)
assert fetched is not None
assert fetched.session_id == session.session_id
def test_get_session_returns_none_for_unknown_id(store):
"""get_session() must return None when the session_id does not exist."""
result = store.get_session("nonexistent-id")
assert result is None
def test_save_message_appends_to_session(store):
"""save_message() must append a message and return the updated session."""
session = store.create_session()
updated = store.save_message(session.session_id, role="user", content="Hello")
assert updated is not None
assert len(updated.messages) == 1
assert updated.messages[0].role == "user"
assert updated.messages[0].content == "Hello"
def test_save_message_persists_across_lookups(store):
"""Messages saved to a session must be visible in subsequent get_session calls."""
session = store.create_session()
store.save_message(session.session_id, role="user", content="test")
fetched = store.get_session(session.session_id)
assert fetched is not None
assert len(fetched.messages) == 1
def test_delete_session_removes_it(store):
"""delete_session() must return True and remove the session."""
session = store.create_session()
result = store.delete_session(session.session_id)
assert result is True
assert store.get_session(session.session_id) is None
def test_delete_session_returns_false_for_unknown(store):
"""delete_session() must return False when the session does not exist."""
result = store.delete_session("ghost-id")
assert result is False
def test_list_sessions_includes_created_session(store):
"""list_sessions() must include all active sessions."""
session = store.create_session()
sessions = store.list_sessions()
ids = [s["session_id"] for s in sessions]
assert session.session_id in ids
def test_session_expires_after_ttl(redis_client):
"""Sessions must disappear after the TTL expires."""
from app.infrastructure.session.redis_conversation_store import RedisConversationStore
# Use a 1-second TTL for this test.
store = RedisConversationStore(redis_client=redis_client, timeout_seconds=1)
session = store.create_session()
# Manually expire the key to simulate TTL expiry.
redis_client.expire(f"session:{session.session_id}", 0)
assert store.get_session(session.session_id) is None
```
- [ ] **Step 3: Run tests to confirm they fail**
```bash
PYTHONPATH=backend pytest tests/test_redis_conversation_store.py -v
```
Expected: `ModuleNotFoundError: No module named 'app.infrastructure.session.redis_conversation_store'`
- [ ] **Step 4: Implement `RedisConversationStore`**
Create `backend/app/infrastructure/session/redis_conversation_store.py`:
```python
"""Redis-backed conversation store for persistent chat sessions.
Sessions are stored as JSON strings under the key `session:{session_id}`.
The Redis TTL is refreshed on every write so active sessions stay alive.
On expiry, `get_session` returns None — callers should create a new session.
"""
from __future__ import annotations
import json
import time
import uuid
from typing import Any
from loguru import logger
from app.domain.conversation import ConversationMessage, ConversationSession, ConversationStore
class RedisConversationStore(ConversationStore):
"""Store conversation sessions in Redis with automatic TTL expiry.
Each session is serialised as a JSON object at key ``session:{session_id}``.
The TTL is reset on every write so sessions stay alive as long as they are active.
"""
# Prefix for all session keys to avoid collisions with other Redis consumers.
_PREFIX = "session:"
def __init__(self, *, redis_client: Any, timeout_seconds: int = 1800) -> None:
"""Initialise the store with an existing Redis client and a TTL in seconds."""
self._redis = redis_client
self._ttl = timeout_seconds
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _key(self, session_id: str) -> str:
"""Build the Redis key for a session."""
return f"{self._PREFIX}{session_id}"
def _serialise(self, session: ConversationSession) -> str:
"""Serialise a ConversationSession to a JSON string."""
return json.dumps(
{
"session_id": session.session_id,
"created_at": session.created_at,
"updated_at": session.updated_at,
"metadata": session.metadata,
"messages": [
{
"role": msg.role,
"content": msg.content,
"timestamp": msg.timestamp,
"sources": msg.sources,
}
for msg in session.messages
],
},
ensure_ascii=False,
)
def _deserialise(self, raw: bytes | str) -> ConversationSession:
"""Deserialise a JSON string back into a ConversationSession."""
data = json.loads(raw)
messages = [
ConversationMessage(
role=m["role"],
content=m["content"],
timestamp=m["timestamp"],
sources=m.get("sources", []),
)
for m in data.get("messages", [])
]
session = ConversationSession(
session_id=data["session_id"],
created_at=data.get("created_at", 0),
updated_at=data.get("updated_at", 0),
metadata=data.get("metadata", {}),
)
session.messages = messages
return session
def _save(self, session: ConversationSession) -> None:
"""Persist a session to Redis and refresh its TTL."""
self._redis.setex(self._key(session.session_id), self._ttl, self._serialise(session))
# ------------------------------------------------------------------
# ConversationStore protocol
# ------------------------------------------------------------------
def create_session(self, metadata: dict | None = None) -> ConversationSession:
"""Create a new empty session and persist it immediately."""
now = int(time.time())
session = ConversationSession(
session_id=str(uuid.uuid4())[:8],
created_at=now,
updated_at=now,
metadata=metadata or {},
)
self._save(session)
return session
def get_session(self, session_id: str) -> ConversationSession | None:
"""Return a session by ID, or None if it does not exist or has expired."""
raw = self._redis.get(self._key(session_id))
if raw is None:
return None
try:
return self._deserialise(raw)
except Exception:
logger.warning("Failed to deserialise session: {}", session_id)
return None
def save_message(
self,
session_id: str,
*,
role: str,
content: str,
sources: list[dict] | None = None,
) -> ConversationSession | None:
"""Append a message to a session and refresh its TTL."""
session = self.get_session(session_id)
if session is None:
return None
session.messages.append(
ConversationMessage(
role=role,
content=content,
timestamp=int(time.time()),
sources=sources or [],
)
)
session.updated_at = int(time.time())
self._save(session)
return session
def delete_session(self, session_id: str) -> bool:
"""Delete a session. Returns True if it existed, False otherwise."""
deleted = self._redis.delete(self._key(session_id))
return bool(deleted)
def list_sessions(self) -> list[dict]:
"""Return summary dicts for all live sessions visible in this Redis DB.
Note: KEYS is used here for simplicity; in a large production deployment
replace with SCAN for non-blocking iteration.
"""
pattern = f"{self._PREFIX}*"
keys = self._redis.keys(pattern)
result = []
for key in keys:
raw = self._redis.get(key)
if raw is None:
continue
try:
data = json.loads(raw)
result.append(
{
"session_id": data["session_id"],
"message_count": len(data.get("messages", [])),
"created_at": data.get("created_at", 0),
"updated_at": data.get("updated_at", 0),
}
)
except Exception:
continue
return result
```
- [ ] **Step 5: Run tests**
```bash
PYTHONPATH=backend pytest tests/test_redis_conversation_store.py -v
```
Expected: All 9 tests PASS.
### Task 7: Add `session_backend` setting and switch bootstrap
**Files:**
- Modify: `backend/app/config/settings.py`
- Modify: `backend/app/shared/bootstrap.py`
- [ ] **Step 1: Add `session_backend` setting**
In `backend/app/config/settings.py`, find the session block (around line 125) and add:
```python
session_backend: str = Field(
default="memory",
description="会话存储后端 (memory | redis)。redis 需要 Redis 可用。",
)
```
- [ ] **Step 2: Update `get_conversation_store` in bootstrap**
Replace the existing `get_conversation_store` function in `backend/app/shared/bootstrap.py`:
```python
@lru_cache
def get_conversation_store() -> ConversationStore:
"""Return the active conversation store based on settings.
When session_backend='redis', sessions survive backend restarts.
When session_backend='memory' (default), sessions are process-local.
"""
if settings.session_backend == "redis":
import redis as redis_lib
from app.infrastructure.session.redis_conversation_store import RedisConversationStore
# Build the Redis client from the same connection settings as Celery.
kwargs: dict = {
"host": settings.redis_host,
"port": settings.redis_port,
"db": settings.redis_db,
"decode_responses": False,
}
if settings.redis_password:
kwargs["password"] = settings.redis_password
redis_client = redis_lib.Redis(**kwargs)
return RedisConversationStore(
redis_client=redis_client,
timeout_seconds=settings.session_timeout_minutes * 60,
)
return InMemoryConversationStore(
max_sessions=settings.session_max_sessions,
timeout_minutes=settings.session_timeout_minutes,
)
```
Also add the `ConversationStore` type hint to the import block at the top of `bootstrap.py`:
```python
from app.domain.conversation import ConversationStore
```
- [ ] **Step 3: Add to `.env.example`**
Append to `.env.example`:
```ini
# ── Session backend ───────────────────────────────────────────────────────────
# Set SESSION_BACKEND=redis for persistent sessions (requires Redis).
# Use SESSION_BACKEND=memory for development without Redis.
SESSION_BACKEND=redis
```
- [ ] **Step 4: Verify app starts with the new setting**
```bash
PYTHONPATH=backend python -c "
from app.config.settings import settings
from app.shared.bootstrap import get_conversation_store
get_conversation_store.cache_clear()
print('session_backend:', settings.session_backend)
print('OK')
"
```
Expected: prints `session_backend: memory` (or `redis` if set in .env) then `OK`.
> ✅ **Group C complete.** Set `SESSION_BACKEND=redis` in `.env` to persist sessions across restarts.
---
## Group D — JWT Authentication + RBAC + Audit Logging
> Estimated time: 3-5 days.
>
> **Why:** `backend/app/api/main.py:49` has `allow_origins=["*"]`. No route requires any credential. PPT Slide 12 defines a four-role permission matrix; Slide 9 lists "RBAC" as a key challenge response.
>
> **Approach for MVP:** JWT tokens carry a `role` claim (admin/legal/ehs/readonly). All API routes require a valid token. Fine-grained per-route RBAC enforcement is a follow-up; the role field is the enforcement hook. An audit middleware logs every authenticated API call. Users are stored in PostgreSQL.
### Task 8: Add auth dependencies
**Files:**
- Modify: `requirements.txt`
- Modify: `pyproject.toml`
- [ ] **Step 1: Add `python-jose` and `passlib` to `requirements.txt`**
In `requirements.txt`, add after the existing tool libs block:
```
# Authentication
python-jose[cryptography]>=3.3.0
passlib[bcrypt]>=1.7.4
```
- [ ] **Step 2: Add to `pyproject.toml` dependencies list**
In `pyproject.toml` under `dependencies`, add:
```toml
"python-jose[cryptography]>=3.3.0",
"passlib[bcrypt]>=1.7.4",
```
- [ ] **Step 3: Install**
```bash
uv sync
```
Expected: No errors.
### Task 9: Create auth domain models
**Files:**
- Create: `backend/app/domain/auth/__init__.py`
- Create: `backend/app/domain/auth/models.py`
- [ ] **Step 1: Write failing test**
Create `tests/test_jwt_handler.py`:
```python
"""Tests for JWTHandler — token creation and decoding.
These tests do not require a running server or database.
"""
import time
import pytest
SECRET = "test-secret-key-minimum-32-characters-long"
@pytest.fixture
def handler():
"""Return a JWTHandler configured with a test secret."""
from app.infrastructure.auth.jwt_handler import JWTHandler
return JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=30)
def test_create_token_returns_string(handler):
"""create_access_token must return a non-empty string."""
token = handler.create_access_token(user_id="u1", username="alice", role="admin")
assert isinstance(token, str)
assert len(token) > 20
def test_decode_token_returns_correct_claims(handler):
"""decode_token must return UserClaims matching the input."""
token = handler.create_access_token(user_id="u1", username="alice", role="admin")
claims = handler.decode_token(token)
assert claims.user_id == "u1"
assert claims.username == "alice"
assert claims.role == "admin"
def test_decode_expired_token_raises(handler):
"""decode_token must raise ValueError on an expired token."""
from app.infrastructure.auth.jwt_handler import JWTHandler
# Create a token that expires in 0 minutes (immediately expired).
short_handler = JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=0)
token = short_handler.create_access_token(user_id="u2", username="bob", role="readonly")
time.sleep(1)
with pytest.raises(ValueError, match="expired"):
short_handler.decode_token(token)
def test_decode_invalid_token_raises(handler):
"""decode_token must raise ValueError for a tampered token."""
with pytest.raises(ValueError):
handler.decode_token("not.a.valid.jwt.token")
def test_decode_wrong_secret_raises():
"""decode_token must raise ValueError when signed with a different secret."""
from app.infrastructure.auth.jwt_handler import JWTHandler
creator = JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=60)
verifier = JWTHandler(secret_key="wrong-secret-key-also-minimum-32-chars", algorithm="HS256", expire_minutes=60)
token = creator.create_access_token(user_id="u3", username="carol", role="legal")
with pytest.raises(ValueError):
verifier.decode_token(token)
```
- [ ] **Step 2: Run test to confirm failure**
```bash
PYTHONPATH=backend pytest tests/test_jwt_handler.py -v
```
Expected: `ModuleNotFoundError: No module named 'app.domain.auth'`
- [ ] **Step 3: Create domain auth package**
Create `backend/app/domain/auth/__init__.py`:
```python
"""Auth domain: role definitions and token claim models.
The domain layer defines what a user identity looks like (UserClaims) and
what roles exist (UserRole). Infrastructure details (JWT, bcrypt, PostgreSQL)
live under infrastructure/auth and never leak into this package.
"""
from .models import UserClaims, UserRole
__all__ = ["UserClaims", "UserRole"]
```
- [ ] **Step 4: Create auth domain models**
Create `backend/app/domain/auth/models.py`:
```python
"""Auth domain models: roles and token claims.
UserRole defines the four roles from PPT Slide 12.
UserClaims is what the JWT decodes to — it is the identity object passed
through FastAPI dependency injection to route handlers.
"""
from __future__ import annotations
import enum
from dataclasses import dataclass
class UserRole(str, enum.Enum):
"""Access roles mirroring the four-role RBAC matrix from the product spec.
ADMIN — full platform access including system management.
LEGAL — knowledge query, document review, compliance checks.
EHS — knowledge query, perception/regulatory signals.
READONLY — knowledge query only.
"""
ADMIN = "admin"
LEGAL = "legal"
EHS = "ehs"
READONLY = "readonly"
@dataclass
class UserClaims:
"""Decoded JWT payload representing an authenticated user.
Instances are created by JWTHandler.decode_token() and injected into
route handlers via the get_current_user FastAPI dependency.
"""
# Unique user identifier (UUID string stored in PostgreSQL users table).
user_id: str
# Display name used for audit log entries.
username: str
# Role determines which resources the user may access.
role: UserRole
```
### Task 10: Create JWT handler infrastructure
**Files:**
- Create: `backend/app/infrastructure/auth/__init__.py`
- Create: `backend/app/infrastructure/auth/jwt_handler.py`
- [ ] **Step 1: Create infrastructure auth package**
Create `backend/app/infrastructure/auth/__init__.py`:
```python
"""JWT token creation and validation infrastructure.
JWTHandler is the only component in this package. It is wired through
shared/bootstrap.py and injected into FastAPI dependencies.
"""
```
- [ ] **Step 2: Implement JWTHandler**
Create `backend/app/infrastructure/auth/jwt_handler.py`:
```python
"""JWT access token creation and decoding.
Uses python-jose for HS256 token signing. Token expiry is enforced at
decode time so expired tokens are rejected even if the signature is valid.
"""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from typing import Any
from jose import JWTError, jwt
from loguru import logger
from app.domain.auth.models import UserClaims, UserRole
class JWTHandler:
"""Create and validate HS256 JWT access tokens.
A single shared instance is wired by bootstrap.py. Use
get_jwt_handler() from shared.bootstrap for all token operations.
"""
def __init__(
self,
*,
secret_key: str,
algorithm: str = "HS256",
expire_minutes: int = 480,
) -> None:
"""Initialise the handler with signing credentials and token lifetime."""
self._secret = secret_key
self._algorithm = algorithm
self._expire_minutes = expire_minutes
def create_access_token(
self,
*,
user_id: str,
username: str,
role: str,
) -> str:
"""Return a signed JWT containing user identity and role claims."""
now = datetime.now(UTC)
payload: dict[str, Any] = {
"sub": user_id,
"username": username,
"role": role,
"iat": now,
"exp": now + timedelta(minutes=self._expire_minutes),
}
return jwt.encode(payload, self._secret, algorithm=self._algorithm)
def decode_token(self, token: str) -> UserClaims:
"""Decode and validate a JWT, returning UserClaims.
Raises ValueError with a descriptive message on expiry, tampering,
or any other validation failure so callers do not need to know jose.
"""
try:
payload = jwt.decode(token, self._secret, algorithms=[self._algorithm])
except JWTError as exc:
msg = str(exc).lower()
if "expired" in msg:
raise ValueError("Token expired") from exc
raise ValueError(f"Invalid token: {exc}") from exc
user_id = payload.get("sub")
username = payload.get("username", "")
role_str = payload.get("role", UserRole.READONLY.value)
if not user_id:
raise ValueError("Token missing subject claim")
try:
role = UserRole(role_str)
except ValueError:
logger.warning("Unknown role in token: {}, defaulting to readonly", role_str)
role = UserRole.READONLY
return UserClaims(user_id=user_id, username=username, role=role)
```
- [ ] **Step 3: Run JWT tests**
```bash
PYTHONPATH=backend pytest tests/test_jwt_handler.py -v
```
Expected: All 5 tests PASS.
### Task 11: Create PostgreSQL user store and seed script
**Files:**
- Create: `backend/app/infrastructure/auth/user_store.py`
- Create: `scripts/seed_users.py`
- [ ] **Step 1: Implement PostgreSQL user store**
Create `backend/app/infrastructure/auth/user_store.py`:
```python
"""PostgreSQL-backed user store for authentication.
Manages a `users` table with hashed passwords and roles.
Provides lookup by username for the login flow.
Table DDL is auto-applied on first connection.
"""
from __future__ import annotations
import uuid
from dataclasses import dataclass
from typing import Optional
import psycopg2
import psycopg2.extras
from loguru import logger
from passlib.context import CryptContext
from app.config.settings import settings
# bcrypt context — work factor 12 is a good production default.
_PWD_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto")
# DDL executed once to ensure the table exists.
_CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
username VARCHAR(100) UNIQUE NOT NULL,
hashed_pw TEXT NOT NULL,
role VARCHAR(50) NOT NULL DEFAULT 'readonly',
is_active BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
@dataclass
class UserRecord:
"""A single row from the users table."""
id: str
username: str
hashed_pw: str
role: str
is_active: bool
class PostgresUserStore:
"""Read and verify users stored in the PostgreSQL users table.
The connection is opened on first use and shared for the lifetime
of the singleton instance wired by bootstrap.
"""
def __init__(self) -> None:
"""Initialise the store and ensure the users table exists."""
self._conn = psycopg2.connect(
host=settings.postgres_host,
port=settings.postgres_port,
user=settings.postgres_user,
password=settings.postgres_password,
dbname=settings.postgres_db,
cursor_factory=psycopg2.extras.RealDictCursor,
)
self._conn.autocommit = True
self._ensure_table()
def _ensure_table(self) -> None:
"""Create the users table if it does not already exist."""
with self._conn.cursor() as cur:
cur.execute(_CREATE_TABLE_SQL)
def get_by_username(self, username: str) -> Optional[UserRecord]:
"""Return a UserRecord for the given username, or None if not found."""
with self._conn.cursor() as cur:
cur.execute(
"SELECT id, username, hashed_pw, role, is_active "
"FROM users WHERE username = %s",
(username,),
)
row = cur.fetchone()
if row is None:
return None
return UserRecord(
id=str(row["id"]),
username=row["username"],
hashed_pw=row["hashed_pw"],
role=row["role"],
is_active=row["is_active"],
)
def verify_password(self, plain: str, hashed: str) -> bool:
"""Return True if `plain` matches the stored bcrypt hash."""
return _PWD_CTX.verify(plain, hashed)
def authenticate(self, username: str, password: str) -> Optional[UserRecord]:
"""Return the UserRecord if credentials are valid, else None."""
user = self.get_by_username(username)
if user is None or not user.is_active:
return None
if not self.verify_password(password, user.hashed_pw):
return None
return user
@staticmethod
def hash_password(plain: str) -> str:
"""Hash a plain-text password with bcrypt."""
return _PWD_CTX.hash(plain)
```
- [ ] **Step 2: Create the seed script**
Create `scripts/seed_users.py`:
```python
#!/usr/bin/env python
"""Seed demo users into the PostgreSQL users table.
Run from the repo root:
PYTHONPATH=backend python scripts/seed_users.py
Creates four demo accounts (one per role). Safe to run multiple times
— existing usernames are left unchanged (ON CONFLICT DO NOTHING).
"""
import sys
import os
# Allow running from repo root without installing the package.
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend"))
import psycopg2
from passlib.context import CryptContext
from app.config.settings import settings
_PWD_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto")
DEMO_USERS = [
{"username": "admin", "password": "Admin@2026!", "role": "admin"},
{"username": "legal", "password": "Legal@2026!", "role": "legal"},
{"username": "ehs", "password": "EHS@2026!", "role": "ehs"},
{"username": "readonly", "password": "Read@2026!", "role": "readonly"},
]
def seed() -> None:
"""Insert demo users. Skips any username that already exists."""
conn = psycopg2.connect(
host=settings.postgres_host,
port=settings.postgres_port,
user=settings.postgres_user,
password=settings.postgres_password,
dbname=settings.postgres_db,
)
conn.autocommit = True
# Ensure the table exists (same DDL as user_store.py).
with conn.cursor() as cur:
cur.execute("""
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
username VARCHAR(100) UNIQUE NOT NULL,
hashed_pw TEXT NOT NULL,
role VARCHAR(50) NOT NULL DEFAULT 'readonly',
is_active BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
""")
inserted = 0
with conn.cursor() as cur:
for user in DEMO_USERS:
hashed = _PWD_CTX.hash(user["password"])
cur.execute(
"""
INSERT INTO users (username, hashed_pw, role)
VALUES (%s, %s, %s)
ON CONFLICT (username) DO NOTHING
""",
(user["username"], hashed, user["role"]),
)
if cur.rowcount > 0:
print(f" Created user: {user['username']} (role={user['role']})")
inserted += 1
else:
print(f" Skipped (already exists): {user['username']}")
print(f"\nDone. {inserted} user(s) created.")
conn.close()
if __name__ == "__main__":
seed()
```
- [ ] **Step 3: Seed demo users (requires PostgreSQL running)**
```bash
PYTHONPATH=backend python scripts/seed_users.py
```
Expected:
```
Created user: admin (role=admin)
Created user: legal (role=legal)
Created user: ehs (role=ehs)
Created user: readonly (role=readonly)
Done. 4 user(s) created.
```
### Task 12: Create FastAPI auth dependencies and the token endpoint
**Files:**
- Create: `backend/app/api/dependencies/__init__.py`
- Create: `backend/app/api/dependencies/auth.py`
- Create: `backend/app/api/routes/auth.py`
- Modify: `backend/app/config/settings.py`
- Modify: `backend/app/shared/bootstrap.py`
- [ ] **Step 1: Add auth settings**
In `backend/app/config/settings.py`, add to the `Settings` class:
```python
# ── Auth ─────────────────────────────────────────────────────────────
# Change auth_secret_key to a long random string in production.
# Generate one with: python -c "import secrets; print(secrets.token_hex(32))"
auth_secret_key: str = Field(
default="change-me-in-production-must-be-32-or-more-characters-long",
description="JWT signing secret. MUST be changed in production.",
)
auth_algorithm: str = Field(default="HS256", description="JWT signing algorithm.")
auth_token_expire_minutes: int = Field(default=480, description="JWT TTL in minutes (default: 8 hours).")
auth_enabled: bool = Field(default=True, description="Set False to bypass auth (development only).")
```
- [ ] **Step 2: Wire auth components in bootstrap**
Add to `backend/app/shared/bootstrap.py`:
```python
@lru_cache
def get_jwt_handler():
"""Return the shared JWTHandler instance."""
from app.infrastructure.auth.jwt_handler import JWTHandler
return JWTHandler(
secret_key=settings.auth_secret_key,
algorithm=settings.auth_algorithm,
expire_minutes=settings.auth_token_expire_minutes,
)
@lru_cache
def get_user_store():
"""Return the PostgreSQL user store (lazy, connects on first call)."""
from app.infrastructure.auth.user_store import PostgresUserStore
return PostgresUserStore()
```
- [ ] **Step 3: Create the dependencies package**
Create `backend/app/api/dependencies/__init__.py`:
```python
"""FastAPI dependency functions for authentication and authorisation.
Import `get_current_user` or `require_role` into route modules to protect
endpoints. Both use the shared JWTHandler wired through bootstrap.
"""
```
- [ ] **Step 4: Implement auth dependencies**
Create `backend/app/api/dependencies/auth.py`:
```python
"""FastAPI dependencies for JWT authentication.
Usage in a route:
from app.api.dependencies.auth import get_current_user, require_role
from app.domain.auth.models import UserRole
@router.get("/protected")
async def protected(user: UserClaims = Depends(get_current_user)):
return {"user": user.username}
@router.delete("/admin-only")
async def admin_only(user: UserClaims = Depends(require_role(UserRole.ADMIN))):
...
"""
from __future__ import annotations
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.config.settings import settings
from app.domain.auth.models import UserClaims, UserRole
from app.shared.bootstrap import get_jwt_handler
# Use Bearer token scheme — client sends `Authorization: Bearer <token>`.
_bearer = HTTPBearer(auto_error=False)
async def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(_bearer),
) -> UserClaims:
"""Extract and validate the JWT from the Authorization header.
Returns the decoded UserClaims on success.
Raises HTTP 401 when the token is missing, expired, or invalid.
When auth_enabled=False (development), returns a synthetic admin user.
"""
if not settings.auth_enabled:
# Development bypass — never enable this in production.
return UserClaims(user_id="dev", username="dev-admin", role=UserRole.ADMIN)
if credentials is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authentication token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
return get_jwt_handler().decode_token(credentials.credentials)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(exc),
headers={"WWW-Authenticate": "Bearer"},
) from exc
def require_role(*roles: UserRole):
"""Return a dependency that enforces one of the given roles.
Example:
Depends(require_role(UserRole.ADMIN, UserRole.LEGAL))
"""
async def _check(user: UserClaims = Depends(get_current_user)) -> UserClaims:
"""Verify the user holds one of the required roles."""
if user.role not in roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Role '{user.role}' is not permitted. Required: {[r.value for r in roles]}",
)
return user
return _check
```
- [ ] **Step 5: Create the auth router**
Create `backend/app/api/routes/auth.py`:
```python
"""Authentication routes — token issuance only.
POST /auth/token — exchange username + password for a JWT.
GET /auth/me — return the current user's identity (requires token).
"""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel
from app.api.dependencies.auth import get_current_user
from app.domain.auth.models import UserClaims
from app.shared.bootstrap import get_jwt_handler, get_user_store
router = APIRouter(prefix="/auth", tags=["认证"])
class TokenResponse(BaseModel):
"""JWT token response body."""
access_token: str
token_type: str = "bearer"
expires_in: int
@router.post("/token", response_model=TokenResponse)
async def login(form: OAuth2PasswordRequestForm = Depends()):
"""Issue a JWT for valid username + password credentials.
Uses standard OAuth2 password grant form fields so this endpoint
is compatible with Swagger UI's Authorize button.
"""
user_store = get_user_store()
user = user_store.authenticate(form.username, form.password)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
from app.config.settings import settings
token = get_jwt_handler().create_access_token(
user_id=user.id,
username=user.username,
role=user.role,
)
return TokenResponse(
access_token=token,
token_type="bearer",
expires_in=settings.auth_token_expire_minutes * 60,
)
@router.get("/me")
async def get_me(current_user: UserClaims = Depends(get_current_user)):
"""Return the identity of the currently authenticated user."""
return {
"user_id": current_user.user_id,
"username": current_user.username,
"role": current_user.role.value,
}
```
- [ ] **Step 6: Register the auth router**
In `backend/app/api/routes/__init__.py`, add:
```python
from .auth import router as auth_router
```
And in the `api_router.include_router(...)` block, add:
```python
api_router.include_router(auth_router)
```
Also add `auth_router` to `__all__`.
- [ ] **Step 7: Write auth route tests**
Create `tests/test_auth_routes.py`:
```python
"""Integration tests for the auth routes.
Uses FastAPI TestClient. Does not require a running PostgreSQL — patches
the user_store dependency.
"""
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
from app.api.main import app
from app.infrastructure.auth.user_store import UserRecord
client = TestClient(app, raise_server_exceptions=False)
@pytest.fixture(autouse=True)
def mock_user_store():
"""Replace the real PostgreSQL user store with a mock."""
mock_store = MagicMock()
alice = UserRecord(id="uuid-1", username="alice", hashed_pw="hashed", role="admin", is_active=True)
mock_store.authenticate.side_effect = lambda u, p: alice if (u == "alice" and p == "correct") else None
with patch("app.api.routes.auth.get_user_store", return_value=mock_store):
yield mock_store
def test_login_returns_token_for_valid_credentials():
"""POST /auth/token must return an access_token for valid credentials."""
resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "correct"})
assert resp.status_code == 200
body = resp.json()
assert "access_token" in body
assert body["token_type"] == "bearer"
def test_login_returns_401_for_wrong_password():
"""POST /auth/token must return 401 for wrong password."""
resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "wrong"})
assert resp.status_code == 401
def test_me_returns_user_when_authenticated():
"""GET /auth/me must return user identity when a valid token is provided."""
# Get a real token first.
login_resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "correct"})
token = login_resp.json()["access_token"]
me_resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
assert me_resp.status_code == 200
assert me_resp.json()["username"] == "alice"
assert me_resp.json()["role"] == "admin"
def test_me_returns_401_without_token():
"""GET /auth/me must return 401 when no token is provided."""
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
```
- [ ] **Step 8: Run auth tests**
```bash
PYTHONPATH=backend pytest tests/test_auth_routes.py -v
```
Expected: All 4 tests PASS.
### Task 13: Add audit middleware and apply auth to routes
**Files:**
- Create: `backend/app/api/middleware/__init__.py`
- Create: `backend/app/api/middleware/audit.py`
- Modify: `backend/app/api/main.py`
- Modify: `backend/app/api/routes/documents.py`
- Modify: `backend/app/api/routes/rag.py`
- Modify: `backend/app/api/routes/compliance.py`
- [ ] **Step 1: Create the audit middleware**
Create `backend/app/api/middleware/__init__.py`:
```python
"""HTTP middleware for cross-cutting concerns: audit logging."""
```
Create `backend/app/api/middleware/audit.py`:
```python
"""Audit logging middleware.
Logs every API request with method, path, status code, response time,
and the authenticated user identity (extracted from the JWT when present).
Log lines are structured so they can be ingested by ELK / Loki.
"""
from __future__ import annotations
import time
from fastapi import Request, Response
from loguru import logger
from starlette.middleware.base import BaseHTTPMiddleware
class AuditMiddleware(BaseHTTPMiddleware):
"""Log all API calls. Skips /health and /docs to reduce noise."""
# Paths that produce no audit log entry.
_SKIP_PREFIXES = ("/health", "/docs", "/redoc", "/openapi.json", "/")
async def dispatch(self, request: Request, call_next) -> Response:
"""Intercept the request, call the handler, and log the outcome."""
path = request.url.path
if any(path.startswith(p) for p in self._SKIP_PREFIXES):
return await call_next(request)
start = time.perf_counter()
response = await call_next(request)
elapsed_ms = int((time.perf_counter() - start) * 1000)
# Try to extract user identity from a decoded JWT header if present.
# We do not re-validate the token here — auth dependencies do that.
user_id = "anonymous"
username = "anonymous"
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
try:
from app.shared.bootstrap import get_jwt_handler
claims = get_jwt_handler().decode_token(auth_header[7:])
user_id = claims.user_id
username = claims.username
except Exception:
pass
logger.info(
"AUDIT method={} path={} status={} elapsed_ms={} user_id={} username={}",
request.method,
path,
response.status_code,
elapsed_ms,
user_id,
username,
)
return response
```
- [ ] **Step 2: Register the middleware and tighten CORS in `main.py`**
In `backend/app/api/main.py`, replace the `CORSMiddleware` block and add the audit middleware:
```python
from app.api.middleware.audit import AuditMiddleware
# Tighten CORS — only allow the configured frontend origin.
# In production, set CORS_ALLOW_ORIGINS in .env to the actual frontend URL.
_ORIGINS = [o.strip() for o in settings.cors_allow_origins.split(",") if o.strip()]
if not _ORIGINS:
_ORIGINS = ["http://localhost:5173"] # Vite dev default
app.add_middleware(
CORSMiddleware,
allow_origins=_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(AuditMiddleware)
```
Also add `cors_allow_origins` to `settings.py`:
```python
cors_allow_origins: str = Field(
default="http://localhost:5173",
description="Comma-separated allowed CORS origins. Use * only in development.",
)
```
And add to `.env.example`:
```ini
# ── CORS ─────────────────────────────────────────────────────────────────────
# Comma-separated frontend origin(s). Never use * in production.
CORS_ALLOW_ORIGINS=http://localhost:5173
```
- [ ] **Step 3: Apply `Depends(get_current_user)` to key routes**
In `backend/app/api/routes/documents.py`, add to imports:
```python
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from app.api.dependencies.auth import get_current_user
from app.domain.auth.models import UserClaims
```
Add `current_user: UserClaims = Depends(get_current_user)` as a parameter to `upload_document`, `list_documents`, `get_document_management_list`, and `delete_document`:
```python
@router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document(
file: UploadFile = File(...),
doc_id: str | None = Form(None),
doc_name: str | None = Form(None),
regulation_type: str | None = Form(None),
version: str | None = Form(None),
generate_summary: bool = Form(False),
sync: bool = Form(False),
current_user: UserClaims = Depends(get_current_user), # ← add this
):
```
Apply the same pattern to:
- `list_documents` in `documents.py`
- `delete_document` in `documents.py`
- `rag_chat` in `rag.py`
- `analyze_stream` in `compliance.py`
In `backend/app/api/routes/rag.py`, add:
```python
from fastapi import APIRouter, Depends
from app.api.dependencies.auth import get_current_user
from app.domain.auth.models import UserClaims
```
```python
@router.post("/chat")
async def rag_chat(
request: RagChatRequest,
current_user: UserClaims = Depends(get_current_user), # ← add this
):
```
In `backend/app/api/routes/compliance.py`, add:
```python
from fastapi import APIRouter, Depends, File, Form, UploadFile
from app.api.dependencies.auth import get_current_user
from app.domain.auth.models import UserClaims
```
```python
@router.post("/analyze-stream")
async def analyze_stream(
text: Optional[str] = Form(None),
doc_id: Optional[str] = Form(None),
file: Optional[UploadFile] = File(None),
domains: Optional[str] = Form(None),
title: Optional[str] = Form(None),
current_user: UserClaims = Depends(get_current_user), # ← add this
):
```
- [ ] **Step 4: Verify app starts and auth is enforced**
```bash
PYTHONPATH=backend python -c "from app.api.main import app; print('OK')"
```
Expected: `OK`
Test that an unauth'd request is rejected:
```bash
curl -s -o /dev/null -w "%{http_code}" http://localhost:8000/api/v1/rag/chat \
-X POST -H "Content-Type: application/json" -d '{"query":"test"}'
```
Expected: `401`
Test that an auth'd request passes:
```bash
TOKEN=$(curl -s -X POST http://localhost:8000/api/v1/auth/token \
-d "username=admin&password=Admin@2026!" | python -c "import sys,json; print(json.load(sys.stdin)['access_token'])")
curl -s -o /dev/null -w "%{http_code}" http://localhost:8000/api/v1/rag/quick-questions \
-H "Authorization: Bearer $TOKEN"
```
Expected: `200`
- [ ] **Step 5: Run the full test suite**
```bash
PYTHONPATH=backend pytest tests/test_reranker_bootstrap.py tests/test_celery_tasks.py tests/test_redis_conversation_store.py tests/test_jwt_handler.py tests/test_auth_routes.py -v
```
Expected: All tests PASS.
---
## Self-Review Checklist
**Spec coverage:**
- ✅ §3.1 Reranker quick win → Task 1
- ✅ §3.2 Async task-ification (Celery + Redis) → Tasks 2-5
- ✅ §3.3 Session persistence → Tasks 6-7
- ✅ §3.4 Auth/RBAC/Audit → Tasks 8-13
- ✅ AGENTS.md: `api → application → domain ports → infrastructure` respected throughout
- ✅ No new logic in `services/*` or `workflows/*`
-`shared/bootstrap.py` is the composition root for all new singletons
- ✅ All comments/docstrings in new backend files are in English
- ✅ Desktop-first frontend constraint: no frontend changes in this plan
**Placeholder scan:** None found — all code blocks are complete and runnable.
**Type consistency:**
- `UserClaims` defined in Task 9, used in Tasks 12, 13 ✅
- `UserRole` enum defined in Task 9, used in Task 12 ✅
- `process_document_task` name matches registration in Task 3 and lookup in Task 3 test ✅
- `store_document` defined in Task 3, called in Task 4 ✅
- `_process_document` defined in Task 3, called from Task 3 and Task 4 ✅
- `get_jwt_handler`, `get_user_store`, `get_celery_app` added to bootstrap in Tasks 4, 12 ✅
- `RedisConversationStore` constructor signature `(redis_client, timeout_seconds)` consistent across implementation, tests, and bootstrap ✅