2305 lines
78 KiB
Markdown
2305 lines
78 KiB
Markdown
|
|
# Phase 1: Production Foundation Implementation Plan
|
|||
|
|
|
|||
|
|
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
|||
|
|
|
|||
|
|
**Goal:** Transform the existing synchronous RAG system into a production-ready backend with async task processing, persistent sessions, and JWT authentication — turning "can run" into "can operate."
|
|||
|
|
|
|||
|
|
**Architecture:** Four independent workstreams executed in priority order: (A) enable the already-wired reranker, (B) introduce Celery workers backed by the already-configured Redis to make document processing non-blocking, (C) replace the in-memory conversation store with a Redis-backed implementation, and (D) add JWT authentication with role claims and audit logging across all API routes.
|
|||
|
|
|
|||
|
|
**Tech Stack:** FastAPI · Celery 5 · Redis 7 · PostgreSQL 15 · `python-jose[cryptography]` · `passlib[bcrypt]` · `fakeredis` (test) — all running on the existing `docker-compose.yml` infrastructure.
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Scope Note
|
|||
|
|
|
|||
|
|
These four workstreams are **largely independent** and can be worked in parallel by separate developers. Auth (Group D) should be integrated last because it modifies every route. The recommended solo sequence: A → B → C → D.
|
|||
|
|
|
|||
|
|
**Not in scope (YAGNI):** Fine-grained per-endpoint RBAC enforcement (the role claim is added to JWT as a future enforcement point), user self-registration, perception crawling, knowledge graph, EHS module.
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## File Map
|
|||
|
|
|
|||
|
|
### Group A — Reranker Quick Win
|
|||
|
|
| Action | File |
|
|||
|
|
|--------|------|
|
|||
|
|
| Modify | `.env.example` — document reranker env vars |
|
|||
|
|
| Create | `tests/test_reranker_bootstrap.py` |
|
|||
|
|
|
|||
|
|
### Group B — Async Task-ification
|
|||
|
|
| Action | File |
|
|||
|
|
|--------|------|
|
|||
|
|
| Create | `backend/app/infrastructure/tasks/__init__.py` |
|
|||
|
|
| Create | `backend/app/infrastructure/tasks/celery_app.py` |
|
|||
|
|
| Create | `backend/app/infrastructure/tasks/document_tasks.py` |
|
|||
|
|
| Modify | `backend/app/application/documents/services.py` — extract `_process_document` |
|
|||
|
|
| Modify | `backend/app/api/routes/documents.py` — enqueue vs inline |
|
|||
|
|
| Modify | `backend/app/shared/bootstrap.py` — expose `get_celery_app()` |
|
|||
|
|
| Modify | `dev.sh` — add worker start target |
|
|||
|
|
| Create | `tests/test_celery_tasks.py` |
|
|||
|
|
|
|||
|
|
### Group C — Session Persistence
|
|||
|
|
| Action | File |
|
|||
|
|
|--------|------|
|
|||
|
|
| Create | `backend/app/infrastructure/session/redis_conversation_store.py` |
|
|||
|
|
| Modify | `backend/app/config/settings.py` — add `session_backend` field |
|
|||
|
|
| Modify | `backend/app/shared/bootstrap.py` — switch store on setting |
|
|||
|
|
| Modify | `pyproject.toml` — add `fakeredis` dev dep |
|
|||
|
|
| Create | `tests/test_redis_conversation_store.py` |
|
|||
|
|
|
|||
|
|
### Group D — JWT Auth + RBAC + Audit
|
|||
|
|
| Action | File |
|
|||
|
|
|--------|------|
|
|||
|
|
| Modify | `requirements.txt` — add `python-jose`, `passlib[bcrypt]` |
|
|||
|
|
| Modify | `pyproject.toml` — same |
|
|||
|
|
| Create | `backend/app/domain/auth/__init__.py` |
|
|||
|
|
| Create | `backend/app/domain/auth/models.py` |
|
|||
|
|
| Create | `backend/app/infrastructure/auth/__init__.py` |
|
|||
|
|
| Create | `backend/app/infrastructure/auth/jwt_handler.py` |
|
|||
|
|
| Create | `backend/app/infrastructure/auth/user_store.py` |
|
|||
|
|
| Create | `backend/app/api/dependencies/__init__.py` |
|
|||
|
|
| Create | `backend/app/api/dependencies/auth.py` |
|
|||
|
|
| Create | `backend/app/api/middleware/__init__.py` |
|
|||
|
|
| Create | `backend/app/api/middleware/audit.py` |
|
|||
|
|
| Create | `backend/app/api/routes/auth.py` |
|
|||
|
|
| Modify | `backend/app/api/routes/__init__.py` — include auth router |
|
|||
|
|
| Modify | `backend/app/config/settings.py` — add auth fields |
|
|||
|
|
| Modify | `backend/app/shared/bootstrap.py` — wire auth components |
|
|||
|
|
| Modify | `backend/app/api/main.py` — restrict CORS, add audit middleware |
|
|||
|
|
| Modify | `backend/app/api/routes/documents.py` — require auth |
|
|||
|
|
| Modify | `backend/app/api/routes/rag.py` — require auth |
|
|||
|
|
| Modify | `backend/app/api/routes/compliance.py` — require auth |
|
|||
|
|
| Create | `scripts/seed_users.py` |
|
|||
|
|
| Create | `tests/test_jwt_handler.py` |
|
|||
|
|
| Create | `tests/test_auth_routes.py` |
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Group A — Reranker Quick Win
|
|||
|
|
|
|||
|
|
> Estimated time: 30 minutes. No new code, just configuration and a verification test.
|
|||
|
|
>
|
|||
|
|
> **Why:** `settings.reranker_enabled` defaults to `False` in `backend/app/config/settings.py:113`. The entire `OpenAICompatibleReranker` infrastructure at `backend/app/infrastructure/vectorstore/cross_encoder_reranker.py` and the wiring in `backend/app/shared/bootstrap.py:199-203` is already complete. Enabling it requires setting two env vars and verifying the wiring works.
|
|||
|
|
|
|||
|
|
### Task 1: Document reranker env vars and write a bootstrap verification test
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Modify: `.env.example`
|
|||
|
|
- Create: `tests/test_reranker_bootstrap.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: Add reranker vars to `.env.example`**
|
|||
|
|
|
|||
|
|
Open `.env.example`. Locate the existing embedding/LLM block and append:
|
|||
|
|
|
|||
|
|
```ini
|
|||
|
|
# ── Reranker (Cross-Encoder) ──────────────────────────────────────────────────
|
|||
|
|
# Set RERANKER_ENABLED=true and point to a TEI or Cohere-compatible rerank API.
|
|||
|
|
# Recommended model: BAAI/bge-reranker-v2.5-gemma2-lightweight (lighter) or
|
|||
|
|
# BAAI/bge-reranker-v2-m3 (heavier, higher quality).
|
|||
|
|
# The endpoint must expose POST /rerank (TEI style) or POST /v1/rerank (Cohere style).
|
|||
|
|
RERANKER_ENABLED=true
|
|||
|
|
RERANKER_BASE_URL=http://6.86.80.4:30080/v1
|
|||
|
|
RERANKER_MODEL=BAAI/bge-reranker-v2-m3
|
|||
|
|
RERANKER_API_KEY=
|
|||
|
|
RERANKER_TOP_K=5
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: Write the failing test**
|
|||
|
|
|
|||
|
|
Create `tests/test_reranker_bootstrap.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""Verify that bootstrap correctly wires the reranker when the setting is enabled."""
|
|||
|
|
|
|||
|
|
from unittest.mock import patch
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_get_reranker_returns_none_when_disabled():
|
|||
|
|
"""get_reranker() must return None when reranker_enabled is False."""
|
|||
|
|
with patch("app.shared.bootstrap._build_binary_store"), \
|
|||
|
|
patch("app.shared.bootstrap._build_vector_index"):
|
|||
|
|
# Import fresh by clearing lru_cache between tests
|
|||
|
|
from app.shared import bootstrap
|
|||
|
|
bootstrap.get_reranker.cache_clear()
|
|||
|
|
|
|||
|
|
with patch("app.config.settings.settings") as mock_settings:
|
|||
|
|
mock_settings.reranker_enabled = False
|
|||
|
|
mock_settings.reranker_base_url = ""
|
|||
|
|
result = bootstrap.get_reranker()
|
|||
|
|
|
|||
|
|
bootstrap.get_reranker.cache_clear()
|
|||
|
|
assert result is None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_get_reranker_returns_instance_when_enabled():
|
|||
|
|
"""get_reranker() must return an OpenAICompatibleReranker when enabled."""
|
|||
|
|
from app.shared import bootstrap
|
|||
|
|
bootstrap.get_reranker.cache_clear()
|
|||
|
|
|
|||
|
|
with patch("app.config.settings.settings") as mock_settings:
|
|||
|
|
mock_settings.reranker_enabled = True
|
|||
|
|
mock_settings.reranker_base_url = "http://localhost:8082"
|
|||
|
|
mock_settings.reranker_model = "BAAI/bge-reranker-v2-m3"
|
|||
|
|
mock_settings.reranker_api_key = ""
|
|||
|
|
mock_settings.reranker_top_k = 5
|
|||
|
|
result = bootstrap.get_reranker()
|
|||
|
|
|
|||
|
|
bootstrap.get_reranker.cache_clear()
|
|||
|
|
from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker
|
|||
|
|
assert isinstance(result, OpenAICompatibleReranker)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: Run test to verify it fails (before fix)**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend pytest tests/test_reranker_bootstrap.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: Both tests PASS immediately (the bootstrap logic already exists). If they pass, the wiring is confirmed correct — proceed.
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: Enable reranker in your local `.env`**
|
|||
|
|
|
|||
|
|
Add these two lines to your local `.env` (not `.env.example`, which is already updated):
|
|||
|
|
|
|||
|
|
```ini
|
|||
|
|
RERANKER_ENABLED=true
|
|||
|
|
RERANKER_BASE_URL=http://6.86.80.4:30080/v1
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: Verify `GET /api/v1/status/health` reports reranker enabled**
|
|||
|
|
|
|||
|
|
Start the backend and call:
|
|||
|
|
```bash
|
|||
|
|
curl http://localhost:8000/api/v1/status/health | python -m json.tool
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected fragment in response:
|
|||
|
|
```json
|
|||
|
|
"reranker": {
|
|||
|
|
"enabled": true,
|
|||
|
|
"model": "BAAI/bge-reranker-v2-m3"
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
> ✅ **Group A complete.** The reranker is now active for all `KnowledgeRetrievalService.retrieve()` calls.
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Group B — Async Task-ification
|
|||
|
|
|
|||
|
|
> Estimated time: 3-5 days.
|
|||
|
|
>
|
|||
|
|
> **Why:** `upload_document` at `backend/app/api/routes/documents.py:34` runs the full parse→embed→index pipeline synchronously in the HTTP request. For large GB standards this can take 15+ minutes (Aliyun timeout is 900 s per `settings.py:49`). Celery 5 and Redis are already in `requirements.txt` and `docker-compose.yml` — only the wiring is missing.
|
|||
|
|
>
|
|||
|
|
> **Strategy:** Split `DocumentCommandService.upload_and_process` into a fast `store_document` (sync, returns immediately) and a background `_process_document` (called from a Celery task). The existing `DocumentProcessingStore` already records per-stage status events — it becomes the progress tracker. The HTTP route returns `{doc_id, status: "queued"}` immediately; the client polls `GET /documents/status/{doc_id}` (already exists) for progress.
|
|||
|
|
|
|||
|
|
### Task 2: Create the Celery application
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `backend/app/infrastructure/tasks/__init__.py`
|
|||
|
|
- Create: `backend/app/infrastructure/tasks/celery_app.py`
|
|||
|
|
- Create: `tests/test_celery_tasks.py` (first test only)
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: Write a failing test for Celery app configuration**
|
|||
|
|
|
|||
|
|
Create `tests/test_celery_tasks.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""Tests for Celery task infrastructure.
|
|||
|
|
|
|||
|
|
These tests verify task registration and Celery app configuration without
|
|||
|
|
starting a real worker or connecting to Redis.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_celery_app_uses_redis_broker():
|
|||
|
|
"""Celery broker URL must be a Redis URL built from settings."""
|
|||
|
|
from app.infrastructure.tasks.celery_app import celery_app
|
|||
|
|
assert celery_app.conf.broker_url.startswith("redis://")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_celery_app_uses_redis_backend():
|
|||
|
|
"""Celery result backend must be Redis."""
|
|||
|
|
from app.infrastructure.tasks.celery_app import celery_app
|
|||
|
|
assert celery_app.conf.result_backend.startswith("redis://")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_celery_app_has_json_serializer():
|
|||
|
|
"""Task serializer must be JSON for portability."""
|
|||
|
|
from app.infrastructure.tasks.celery_app import celery_app
|
|||
|
|
assert celery_app.conf.task_serializer == "json"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: Run test to confirm it fails**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend pytest tests/test_celery_tasks.py::test_celery_app_uses_redis_broker -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `ModuleNotFoundError: No module named 'app.infrastructure.tasks'`
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: Create the tasks package**
|
|||
|
|
|
|||
|
|
Create `backend/app/infrastructure/tasks/__init__.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""Celery task definitions for background processing.
|
|||
|
|
|
|||
|
|
This package exposes the shared Celery application instance and all
|
|||
|
|
registered task functions used by API routes to enqueue work.
|
|||
|
|
"""
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: Create the Celery app factory**
|
|||
|
|
|
|||
|
|
Create `backend/app/infrastructure/tasks/celery_app.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""Shared Celery application instance for background task processing.
|
|||
|
|
|
|||
|
|
All workers and enqueueing call sites import `celery_app` from this module
|
|||
|
|
so the broker/backend configuration stays in one place.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from celery import Celery
|
|||
|
|
|
|||
|
|
from app.config.settings import settings
|
|||
|
|
|
|||
|
|
# Build Redis URL from the same settings used by the rest of the backend.
|
|||
|
|
# The password segment is omitted when redis_password is empty so that
|
|||
|
|
# local development against a password-free Redis instance works out of the box.
|
|||
|
|
def _redis_url() -> str:
|
|||
|
|
"""Return a Redis connection URL from application settings."""
|
|||
|
|
if settings.redis_password:
|
|||
|
|
return (
|
|||
|
|
f"redis://:{settings.redis_password}@"
|
|||
|
|
f"{settings.redis_host}:{settings.redis_port}/{settings.redis_db}"
|
|||
|
|
)
|
|||
|
|
return f"redis://{settings.redis_host}:{settings.redis_port}/{settings.redis_db}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
_BROKER = _redis_url()
|
|||
|
|
_BACKEND = _redis_url()
|
|||
|
|
|
|||
|
|
celery_app = Celery(
|
|||
|
|
"compliance_hub",
|
|||
|
|
broker=_BROKER,
|
|||
|
|
backend=_BACKEND,
|
|||
|
|
include=["app.infrastructure.tasks.document_tasks"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
celery_app.conf.update(
|
|||
|
|
task_serializer="json",
|
|||
|
|
result_serializer="json",
|
|||
|
|
accept_content=["json"],
|
|||
|
|
timezone="UTC",
|
|||
|
|
enable_utc=True,
|
|||
|
|
# Retry failed tasks up to 3 times with exponential backoff.
|
|||
|
|
task_acks_late=True,
|
|||
|
|
task_reject_on_worker_lost=True,
|
|||
|
|
# Keep results for 1 hour for status polling.
|
|||
|
|
result_expires=3600,
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: Run the Celery app tests**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend pytest tests/test_celery_tasks.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `test_celery_app_uses_redis_broker` PASS, `test_celery_app_uses_redis_backend` PASS, `test_celery_app_has_json_serializer` PASS.
|
|||
|
|
|
|||
|
|
### Task 3: Split `DocumentCommandService` and create the document processing task
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Modify: `backend/app/application/documents/services.py`
|
|||
|
|
- Create: `backend/app/infrastructure/tasks/document_tasks.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: Add task registration test to `tests/test_celery_tasks.py`**
|
|||
|
|
|
|||
|
|
Append to `tests/test_celery_tasks.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
def test_process_document_task_is_registered():
|
|||
|
|
"""process_document_task must be discoverable in the Celery task registry."""
|
|||
|
|
# Importing document_tasks triggers task registration via the @task decorator.
|
|||
|
|
import app.infrastructure.tasks.document_tasks # noqa: F401
|
|||
|
|
from app.infrastructure.tasks.celery_app import celery_app
|
|||
|
|
|
|||
|
|
registered = list(celery_app.tasks.keys())
|
|||
|
|
assert any("process_document_task" in name for name in registered), (
|
|||
|
|
f"process_document_task not found in {registered}"
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: Run test to confirm it fails**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend pytest tests/test_celery_tasks.py::test_process_document_task_is_registered -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `ModuleNotFoundError: No module named 'app.infrastructure.tasks.document_tasks'`
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: Extract `_process_document` from `DocumentCommandService`**
|
|||
|
|
|
|||
|
|
In `backend/app/application/documents/services.py`, locate `upload_and_process` (line 233). Add a new method `_process_document` that contains the parse→embed→index logic, and refactor `upload_and_process` to call it after the store step.
|
|||
|
|
|
|||
|
|
Find the end of the `_safe_mark_run_stored` call block (around line 299) and extract everything from the `suffix = os.path.splitext(...)` line to the end of the try block into a new method:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
def _process_document(
|
|||
|
|
self,
|
|||
|
|
*,
|
|||
|
|
doc_id: str,
|
|||
|
|
file_name: str,
|
|||
|
|
final_doc_name: str,
|
|||
|
|
content: bytes,
|
|||
|
|
regulation_type: str,
|
|||
|
|
version: str,
|
|||
|
|
generate_summary: bool,
|
|||
|
|
run_id: str | None = None,
|
|||
|
|
) -> DocumentProcessResult:
|
|||
|
|
"""Run parse → chunk → embed → index for a document that is already stored.
|
|||
|
|
|
|||
|
|
This method is called both synchronously (from upload_and_process) and
|
|||
|
|
asynchronously (from the Celery document_tasks worker). All side-effects
|
|||
|
|
write through DocumentProcessingStore so callers can poll progress.
|
|||
|
|
"""
|
|||
|
|
current_status = DocumentStatus.STORED
|
|||
|
|
current_stage = "parse"
|
|||
|
|
temp_path = ""
|
|||
|
|
try:
|
|||
|
|
suffix = os.path.splitext(file_name)[1]
|
|||
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
|
|||
|
|
temp_file.write(content)
|
|||
|
|
temp_path = temp_file.name
|
|||
|
|
|
|||
|
|
parsed_document = self.parser.parse(
|
|||
|
|
file_path=temp_path,
|
|||
|
|
doc_id=doc_id,
|
|||
|
|
doc_name=final_doc_name,
|
|||
|
|
)
|
|||
|
|
self._safe_mark_run_parsed(doc_id=doc_id, run_id=run_id, parsed_document=parsed_document)
|
|||
|
|
|
|||
|
|
artifact_keys: dict[str, str] = {}
|
|||
|
|
try:
|
|||
|
|
artifact_keys = self._save_parse_artifacts(doc_id=doc_id, parsed_document=parsed_document)
|
|||
|
|
except Exception:
|
|||
|
|
logger.warning("Parse artifact binary persistence failed for doc_id={}", doc_id)
|
|||
|
|
|
|||
|
|
self.document_repository.update_status(
|
|||
|
|
doc_id,
|
|||
|
|
DocumentStatus.PARSED,
|
|||
|
|
parser_name=parsed_document.parser_name,
|
|||
|
|
metadata={
|
|||
|
|
"parser_backend": parsed_document.parser_name,
|
|||
|
|
"parse_task_id": parsed_document.metadata.get("task_id", ""),
|
|||
|
|
"layout_count": parsed_document.metadata.get("layout_count", len(parsed_document.raw_layouts)),
|
|||
|
|
"structure_node_count": len(parsed_document.structure_nodes),
|
|||
|
|
"semantic_block_count": len(parsed_document.semantic_blocks),
|
|||
|
|
"vector_chunk_count": len(parsed_document.vector_chunks),
|
|||
|
|
"artifact_keys": artifact_keys,
|
|||
|
|
"processing_stage": "parsed",
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
current_status = DocumentStatus.PARSED
|
|||
|
|
current_stage = "embed"
|
|||
|
|
self._safe_replace_processing_artifacts(doc_id=doc_id, run_id=run_id, artifact_keys=artifact_keys)
|
|||
|
|
self._safe_append_status_event(
|
|||
|
|
doc_id=doc_id, run_id=run_id,
|
|||
|
|
from_status=DocumentStatus.STORED.value, to_status=DocumentStatus.PARSED.value,
|
|||
|
|
stage="parse", message="Document parsed", metadata={"artifact_count": len(artifact_keys)},
|
|||
|
|
)
|
|||
|
|
if self.parse_artifact_store:
|
|||
|
|
try:
|
|||
|
|
self.parse_artifact_store.save(doc_id, parsed_document.structure_nodes, parsed_document.semantic_blocks)
|
|||
|
|
except Exception:
|
|||
|
|
logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id)
|
|||
|
|
|
|||
|
|
chunks = self.chunk_builder.build(
|
|||
|
|
parsed_document=parsed_document,
|
|||
|
|
regulation_type=regulation_type,
|
|||
|
|
version=version,
|
|||
|
|
)
|
|||
|
|
if not chunks:
|
|||
|
|
raise ValueError("解析完成但没有生成可入库的 chunks")
|
|||
|
|
|
|||
|
|
vectors = self.embedding_provider.embed_texts([chunk.embedding_text for chunk in chunks])
|
|||
|
|
current_stage = "index"
|
|||
|
|
inserted = self.vector_index.upsert(chunks, vectors)
|
|||
|
|
if inserted != len(chunks):
|
|||
|
|
logger.warning("Milvus upsert count mismatched: inserted={}, chunks={}", inserted, len(chunks))
|
|||
|
|
|
|||
|
|
health = self.vector_index.health()
|
|||
|
|
index_name = health.get("collection_name", "")
|
|||
|
|
self.document_repository.update_status(
|
|||
|
|
doc_id, DocumentStatus.INDEXED,
|
|||
|
|
chunk_count=len(chunks), summary="", summary_latency_ms=0,
|
|||
|
|
index_name=index_name,
|
|||
|
|
metadata={"index_collection": index_name, "processing_stage": "indexed"},
|
|||
|
|
)
|
|||
|
|
self._safe_mark_run_indexed(doc_id=doc_id, run_id=run_id, chunk_count=len(chunks), index_name=index_name)
|
|||
|
|
self._safe_append_status_event(
|
|||
|
|
doc_id=doc_id, run_id=run_id,
|
|||
|
|
from_status=DocumentStatus.PARSED.value, to_status=DocumentStatus.INDEXED.value,
|
|||
|
|
stage="index", message="Document indexed",
|
|||
|
|
metadata={"chunk_count": len(chunks), "index_name": index_name},
|
|||
|
|
)
|
|||
|
|
stored = self.document_repository.get(doc_id)
|
|||
|
|
return DocumentProcessResult(
|
|||
|
|
doc_id=doc_id, doc_name=final_doc_name,
|
|||
|
|
status=(stored.status.value if stored else DocumentStatus.INDEXED.value),
|
|||
|
|
message="处理成功", num_chunks=len(chunks),
|
|||
|
|
summary=stored.summary if stored else "",
|
|||
|
|
summary_latency_ms=stored.summary_latency_ms if stored else 0,
|
|||
|
|
)
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.exception("文档处理失败: doc_id={}", doc_id)
|
|||
|
|
self.document_repository.update_status(
|
|||
|
|
doc_id, DocumentStatus.FAILED, error_message=str(exc),
|
|||
|
|
metadata={"failure_reason": str(exc), "processing_stage": "failed", "failure_stage": current_stage},
|
|||
|
|
)
|
|||
|
|
self._safe_mark_run_failed(doc_id=doc_id, run_id=run_id, failure_stage=current_stage, error_message=str(exc))
|
|||
|
|
self._safe_append_status_event(
|
|||
|
|
doc_id=doc_id, run_id=run_id,
|
|||
|
|
from_status=current_status.value, to_status=DocumentStatus.FAILED.value,
|
|||
|
|
stage=current_stage, message=str(exc),
|
|||
|
|
)
|
|||
|
|
return DocumentProcessResult(
|
|||
|
|
doc_id=doc_id, doc_name=final_doc_name,
|
|||
|
|
status=DocumentStatus.FAILED.value, message=f"文档处理失败: {exc}",
|
|||
|
|
)
|
|||
|
|
finally:
|
|||
|
|
if temp_path and os.path.exists(temp_path):
|
|||
|
|
try:
|
|||
|
|
os.remove(temp_path)
|
|||
|
|
except OSError:
|
|||
|
|
logger.warning("临时文件清理失败: {}", temp_path)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Then replace the parse-through-index section of `upload_and_process` with a call to `_process_document`:
|
|||
|
|
|
|||
|
|
In `upload_and_process`, replace everything from `suffix = os.path.splitext(file_name)[1]` through to the end of the try block (before the except) with:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
# Delegate parse → embed → index to the shared processing method.
|
|||
|
|
# This same method is invoked by the Celery worker for async processing.
|
|||
|
|
return self._process_document(
|
|||
|
|
doc_id=doc_id,
|
|||
|
|
file_name=file_name,
|
|||
|
|
final_doc_name=final_doc_name,
|
|||
|
|
content=content,
|
|||
|
|
regulation_type=regulation_type,
|
|||
|
|
version=version,
|
|||
|
|
generate_summary=generate_summary,
|
|||
|
|
run_id=run_id,
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Also add a new method `store_document` for the fast sync-only path:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
def store_document(
|
|||
|
|
self,
|
|||
|
|
*,
|
|||
|
|
doc_id: str | None = None,
|
|||
|
|
file_name: str,
|
|||
|
|
content: bytes,
|
|||
|
|
content_type: str,
|
|||
|
|
doc_name: str | None,
|
|||
|
|
regulation_type: str,
|
|||
|
|
version: str,
|
|||
|
|
generate_summary: bool,
|
|||
|
|
) -> tuple[str, str | None]:
|
|||
|
|
"""Store the binary file and create the Document record.
|
|||
|
|
|
|||
|
|
Returns (doc_id, run_id). Does NOT parse, embed, or index.
|
|||
|
|
This is the fast synchronous first step; processing is enqueued separately.
|
|||
|
|
"""
|
|||
|
|
doc_id = doc_id or str(uuid.uuid4())[:8]
|
|||
|
|
final_doc_name = doc_name or file_name
|
|||
|
|
object_name = f"{doc_id}/{file_name}"
|
|||
|
|
|
|||
|
|
document = Document(
|
|||
|
|
doc_id=doc_id,
|
|||
|
|
doc_name=final_doc_name,
|
|||
|
|
file_name=file_name,
|
|||
|
|
object_name=object_name,
|
|||
|
|
content_type=content_type,
|
|||
|
|
size_bytes=len(content),
|
|||
|
|
regulation_type=regulation_type,
|
|||
|
|
version=version,
|
|||
|
|
metadata={"generate_summary": generate_summary},
|
|||
|
|
)
|
|||
|
|
self.document_repository.create(document)
|
|||
|
|
run_id = self._safe_create_processing_run(
|
|||
|
|
doc_id=doc_id, trigger_type="upload", generate_summary=generate_summary
|
|||
|
|
)
|
|||
|
|
self.binary_store.save(
|
|||
|
|
object_name=object_name, data=content,
|
|||
|
|
content_type=content_type, metadata={"doc_id": doc_id},
|
|||
|
|
)
|
|||
|
|
self.document_repository.update_status(doc_id, DocumentStatus.STORED)
|
|||
|
|
self._safe_mark_run_stored(doc_id=doc_id, run_id=run_id)
|
|||
|
|
self._safe_append_status_event(
|
|||
|
|
doc_id=doc_id, run_id=run_id,
|
|||
|
|
from_status=DocumentStatus.PENDING.value, to_status=DocumentStatus.STORED.value,
|
|||
|
|
stage="store", message="Source file stored",
|
|||
|
|
)
|
|||
|
|
return doc_id, run_id
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: Create the Celery document task**
|
|||
|
|
|
|||
|
|
Create `backend/app/infrastructure/tasks/document_tasks.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""Celery tasks for document processing.
|
|||
|
|
|
|||
|
|
Each task is a thin wrapper that retrieves the already-stored document
|
|||
|
|
binary and delegates to DocumentCommandService._process_document.
|
|||
|
|
The task does not accept raw file bytes — it reads them from the binary
|
|||
|
|
store using the doc_id, so the payload stays small.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
from app.infrastructure.tasks.celery_app import celery_app
|
|||
|
|
|
|||
|
|
|
|||
|
|
@celery_app.task(
|
|||
|
|
name="app.infrastructure.tasks.document_tasks.process_document_task",
|
|||
|
|
bind=True,
|
|||
|
|
max_retries=3,
|
|||
|
|
default_retry_delay=30,
|
|||
|
|
acks_late=True,
|
|||
|
|
)
|
|||
|
|
def process_document_task(
|
|||
|
|
self,
|
|||
|
|
doc_id: str,
|
|||
|
|
file_name: str,
|
|||
|
|
doc_name: str,
|
|||
|
|
regulation_type: str,
|
|||
|
|
version: str,
|
|||
|
|
generate_summary: bool,
|
|||
|
|
run_id: str | None = None,
|
|||
|
|
) -> dict:
|
|||
|
|
"""Parse, embed, and index a document that has already been stored.
|
|||
|
|
|
|||
|
|
The task reads the file binary from MinIO using doc_id so the Celery
|
|||
|
|
message stays small. Retries up to 3 times with a 30-second delay on
|
|||
|
|
transient infrastructure errors.
|
|||
|
|
"""
|
|||
|
|
# Import inside the task function to avoid pickling issues and to ensure
|
|||
|
|
# that each worker process initialises its own bootstrap singletons.
|
|||
|
|
from app.shared.bootstrap import get_document_command_service, get_document_query_service
|
|||
|
|
|
|||
|
|
logger.info("process_document_task started: doc_id={}", doc_id)
|
|||
|
|
try:
|
|||
|
|
# Read the stored binary from MinIO.
|
|||
|
|
svc = get_document_command_service()
|
|||
|
|
doc = get_document_query_service().get(doc_id)
|
|||
|
|
if not doc:
|
|||
|
|
raise ValueError(f"Document record not found: {doc_id}")
|
|||
|
|
|
|||
|
|
content = svc.binary_store.read(doc.object_name)
|
|||
|
|
|
|||
|
|
result = svc._process_document(
|
|||
|
|
doc_id=doc_id,
|
|||
|
|
file_name=file_name,
|
|||
|
|
final_doc_name=doc_name,
|
|||
|
|
content=content,
|
|||
|
|
regulation_type=regulation_type,
|
|||
|
|
version=version,
|
|||
|
|
generate_summary=generate_summary,
|
|||
|
|
run_id=run_id,
|
|||
|
|
)
|
|||
|
|
logger.info(
|
|||
|
|
"process_document_task completed: doc_id={} status={} chunks={}",
|
|||
|
|
doc_id, result.status, result.num_chunks,
|
|||
|
|
)
|
|||
|
|
return {"doc_id": result.doc_id, "status": result.status, "num_chunks": result.num_chunks}
|
|||
|
|
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.exception("process_document_task failed: doc_id={}", doc_id)
|
|||
|
|
# Retry on transient errors; permanent errors (bad file, parse failure)
|
|||
|
|
# will exhaust retries and leave the document in FAILED state.
|
|||
|
|
raise self.retry(exc=exc)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: Run all Celery task tests**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend pytest tests/test_celery_tasks.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: All 4 tests PASS.
|
|||
|
|
|
|||
|
|
### Task 4: Update the upload route to support async enqueueing
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Modify: `backend/app/api/routes/documents.py`
|
|||
|
|
- Modify: `backend/app/shared/bootstrap.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: Add `get_celery_app` to bootstrap**
|
|||
|
|
|
|||
|
|
In `backend/app/shared/bootstrap.py`, add at the bottom (before `preload_runtime_dependencies`):
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
@lru_cache
|
|||
|
|
def get_celery_app():
|
|||
|
|
"""Return the shared Celery application instance."""
|
|||
|
|
# Import here to avoid importing Celery at module load time when workers
|
|||
|
|
# are not configured (e.g., during tests that mock bootstrap).
|
|||
|
|
from app.infrastructure.tasks.celery_app import celery_app
|
|||
|
|
return celery_app
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: Update the upload route**
|
|||
|
|
|
|||
|
|
Replace the `upload_document` handler in `backend/app/api/routes/documents.py` with:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
@router.post("/upload", response_model=DocumentUploadResponse)
|
|||
|
|
async def upload_document(
|
|||
|
|
file: UploadFile = File(..., description="上传的文档文件"),
|
|||
|
|
doc_id: str | None = Form(None, description="客户端预分配的文档ID,不传则自动生成"),
|
|||
|
|
doc_name: str | None = Form(None, description="文档名称"),
|
|||
|
|
regulation_type: str | None = Form(None, description="法规类型"),
|
|||
|
|
version: str | None = Form(None, description="文档版本"),
|
|||
|
|
generate_summary: bool = Form(False, description="是否生成摘要"),
|
|||
|
|
sync: bool = Form(False, description="同步处理(演示/测试用,默认异步)"),
|
|||
|
|
):
|
|||
|
|
"""Upload a document and enqueue it for background processing.
|
|||
|
|
|
|||
|
|
By default the route stores the binary immediately and enqueues parse→embed→index
|
|||
|
|
as a Celery task, returning {doc_id, status:'stored'} within seconds.
|
|||
|
|
Pass sync=true to process inline (useful for demos when no worker is running).
|
|||
|
|
Poll GET /documents/status/{doc_id} for processing progress.
|
|||
|
|
"""
|
|||
|
|
content = await file.read()
|
|||
|
|
if not file.filename:
|
|||
|
|
raise HTTPException(status_code=400, detail="文件名不能为空")
|
|||
|
|
if not content:
|
|||
|
|
raise HTTPException(status_code=400, detail="上传文件为空")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
svc = get_document_command_service()
|
|||
|
|
|
|||
|
|
if sync:
|
|||
|
|
# Synchronous fallback: full inline processing (demo / no-worker mode).
|
|||
|
|
result = svc.upload_and_process(
|
|||
|
|
doc_id=doc_id,
|
|||
|
|
file_name=file.filename,
|
|||
|
|
content=content,
|
|||
|
|
content_type=file.content_type or "application/octet-stream",
|
|||
|
|
doc_name=doc_name,
|
|||
|
|
regulation_type=regulation_type or "",
|
|||
|
|
version=version or "",
|
|||
|
|
generate_summary=generate_summary,
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
# Async default: store binary, enqueue processing task.
|
|||
|
|
stored_doc_id, run_id = svc.store_document(
|
|||
|
|
doc_id=doc_id,
|
|||
|
|
file_name=file.filename,
|
|||
|
|
content=content,
|
|||
|
|
content_type=file.content_type or "application/octet-stream",
|
|||
|
|
doc_name=doc_name,
|
|||
|
|
regulation_type=regulation_type or "",
|
|||
|
|
version=version or "",
|
|||
|
|
generate_summary=generate_summary,
|
|||
|
|
)
|
|||
|
|
from app.infrastructure.tasks.document_tasks import process_document_task
|
|||
|
|
process_document_task.delay(
|
|||
|
|
doc_id=stored_doc_id,
|
|||
|
|
file_name=file.filename,
|
|||
|
|
doc_name=doc_name or file.filename,
|
|||
|
|
regulation_type=regulation_type or "",
|
|||
|
|
version=version or "",
|
|||
|
|
generate_summary=generate_summary,
|
|||
|
|
run_id=run_id,
|
|||
|
|
)
|
|||
|
|
result = DocumentProcessResult(
|
|||
|
|
doc_id=stored_doc_id,
|
|||
|
|
doc_name=doc_name or file.filename,
|
|||
|
|
status="stored",
|
|||
|
|
message="文件已存储,正在后台处理。请轮询 GET /documents/status/{doc_id} 查看进度。",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if result.status == "failed":
|
|||
|
|
raise HTTPException(status_code=500, detail=result.message)
|
|||
|
|
return _document_response(result)
|
|||
|
|
|
|||
|
|
except HTTPException:
|
|||
|
|
raise
|
|||
|
|
except Exception as exc:
|
|||
|
|
logger.exception("文档上传失败")
|
|||
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Also add the import at the top of `documents.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from app.application.documents import DocumentProcessResult
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: Verify the app still starts**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend python -c "from app.api.main import app; print('OK')"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `OK` with no import errors.
|
|||
|
|
|
|||
|
|
### Task 5: Add worker startup to dev scripts
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Modify: `dev.sh`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: Add worker target to `dev.sh`**
|
|||
|
|
|
|||
|
|
Open `dev.sh`. Find the `start` command section and add a `worker` subcommand. Locate the section that handles `start api` and add after it:
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
start_worker() {
|
|||
|
|
# Start the Celery document processing worker.
|
|||
|
|
# Reads PYTHONPATH and env from the same root .env the API uses.
|
|||
|
|
export PYTHONPATH=backend
|
|||
|
|
celery -A app.infrastructure.tasks.celery_app worker \
|
|||
|
|
--loglevel=info \
|
|||
|
|
--concurrency=2 \
|
|||
|
|
--queues=celery \
|
|||
|
|
"$@"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
start_beat() {
|
|||
|
|
# Start the Celery Beat scheduler for periodic tasks (future: perception crawl).
|
|||
|
|
export PYTHONPATH=backend
|
|||
|
|
celery -A app.infrastructure.tasks.celery_app beat \
|
|||
|
|
--loglevel=info \
|
|||
|
|
"$@"
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
And in the argument dispatch block, add:
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
"worker")
|
|||
|
|
start_worker "${@:3}"
|
|||
|
|
;;
|
|||
|
|
"beat")
|
|||
|
|
start_beat "${@:3}"
|
|||
|
|
;;
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: Verify worker starts (requires Redis running)**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
./dev.sh start worker --dry-run 2>/dev/null || PYTHONPATH=backend celery -A app.infrastructure.tasks.celery_app inspect ping 2>&1 | head -5
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: Either dry-run succeeds or Celery shows it can find the app.
|
|||
|
|
|
|||
|
|
> ✅ **Group B complete.** Document uploads now return in <1 s. Workers process in the background. `GET /documents/status/{doc_id}` reports live progress via `DocumentProcessingStore`.
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Group C — Session Persistence
|
|||
|
|
|
|||
|
|
> Estimated time: 2-3 days.
|
|||
|
|
>
|
|||
|
|
> **Why:** `InMemoryConversationStore` at `backend/app/infrastructure/session/in_memory_conversation_store.py` stores all chat sessions in a plain Python dict. Every backend restart loses all active conversations. Redis 7 is already running; persisting sessions there requires only a new adapter and a settings switch.
|
|||
|
|
|
|||
|
|
### Task 6: Add `fakeredis` dev dependency and create the Redis store
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Modify: `pyproject.toml`
|
|||
|
|
- Create: `backend/app/infrastructure/session/redis_conversation_store.py`
|
|||
|
|
- Create: `tests/test_redis_conversation_store.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: Add `fakeredis` to dev dependencies**
|
|||
|
|
|
|||
|
|
In `pyproject.toml`, find the `[dependency-groups]` section and update:
|
|||
|
|
|
|||
|
|
```toml
|
|||
|
|
[dependency-groups]
|
|||
|
|
dev = ["pytest>=7.0.0", "pytest-asyncio>=0.21.0", "isort>=8.0.1", "fakeredis>=2.0.0"]
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Install it:
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
uv sync
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: Write failing tests for RedisConversationStore**
|
|||
|
|
|
|||
|
|
Create `tests/test_redis_conversation_store.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""Tests for RedisConversationStore.
|
|||
|
|
|
|||
|
|
Uses fakeredis so no real Redis connection is required.
|
|||
|
|
All tests follow the same ConversationStore contract as InMemoryConversationStore.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import pytest
|
|||
|
|
import fakeredis
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def redis_client():
|
|||
|
|
"""Return an in-process fake Redis client."""
|
|||
|
|
return fakeredis.FakeRedis()
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def store(redis_client):
|
|||
|
|
"""Return a RedisConversationStore backed by fake Redis."""
|
|||
|
|
from app.infrastructure.session.redis_conversation_store import RedisConversationStore
|
|||
|
|
return RedisConversationStore(redis_client=redis_client, timeout_seconds=1800)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_create_session_returns_session_with_id(store):
|
|||
|
|
"""create_session() must return a ConversationSession with a non-empty session_id."""
|
|||
|
|
session = store.create_session()
|
|||
|
|
assert session.session_id
|
|||
|
|
assert len(session.session_id) > 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_get_session_returns_same_session(store):
|
|||
|
|
"""get_session() must return the previously created session."""
|
|||
|
|
session = store.create_session()
|
|||
|
|
fetched = store.get_session(session.session_id)
|
|||
|
|
assert fetched is not None
|
|||
|
|
assert fetched.session_id == session.session_id
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_get_session_returns_none_for_unknown_id(store):
|
|||
|
|
"""get_session() must return None when the session_id does not exist."""
|
|||
|
|
result = store.get_session("nonexistent-id")
|
|||
|
|
assert result is None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_save_message_appends_to_session(store):
|
|||
|
|
"""save_message() must append a message and return the updated session."""
|
|||
|
|
session = store.create_session()
|
|||
|
|
updated = store.save_message(session.session_id, role="user", content="Hello")
|
|||
|
|
assert updated is not None
|
|||
|
|
assert len(updated.messages) == 1
|
|||
|
|
assert updated.messages[0].role == "user"
|
|||
|
|
assert updated.messages[0].content == "Hello"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_save_message_persists_across_lookups(store):
|
|||
|
|
"""Messages saved to a session must be visible in subsequent get_session calls."""
|
|||
|
|
session = store.create_session()
|
|||
|
|
store.save_message(session.session_id, role="user", content="test")
|
|||
|
|
fetched = store.get_session(session.session_id)
|
|||
|
|
assert fetched is not None
|
|||
|
|
assert len(fetched.messages) == 1
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_delete_session_removes_it(store):
|
|||
|
|
"""delete_session() must return True and remove the session."""
|
|||
|
|
session = store.create_session()
|
|||
|
|
result = store.delete_session(session.session_id)
|
|||
|
|
assert result is True
|
|||
|
|
assert store.get_session(session.session_id) is None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_delete_session_returns_false_for_unknown(store):
|
|||
|
|
"""delete_session() must return False when the session does not exist."""
|
|||
|
|
result = store.delete_session("ghost-id")
|
|||
|
|
assert result is False
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_list_sessions_includes_created_session(store):
|
|||
|
|
"""list_sessions() must include all active sessions."""
|
|||
|
|
session = store.create_session()
|
|||
|
|
sessions = store.list_sessions()
|
|||
|
|
ids = [s["session_id"] for s in sessions]
|
|||
|
|
assert session.session_id in ids
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_session_expires_after_ttl(redis_client):
|
|||
|
|
"""Sessions must disappear after the TTL expires."""
|
|||
|
|
from app.infrastructure.session.redis_conversation_store import RedisConversationStore
|
|||
|
|
# Use a 1-second TTL for this test.
|
|||
|
|
store = RedisConversationStore(redis_client=redis_client, timeout_seconds=1)
|
|||
|
|
session = store.create_session()
|
|||
|
|
# Manually expire the key to simulate TTL expiry.
|
|||
|
|
redis_client.expire(f"session:{session.session_id}", 0)
|
|||
|
|
assert store.get_session(session.session_id) is None
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: Run tests to confirm they fail**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend pytest tests/test_redis_conversation_store.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `ModuleNotFoundError: No module named 'app.infrastructure.session.redis_conversation_store'`
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: Implement `RedisConversationStore`**
|
|||
|
|
|
|||
|
|
Create `backend/app/infrastructure/session/redis_conversation_store.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""Redis-backed conversation store for persistent chat sessions.
|
|||
|
|
|
|||
|
|
Sessions are stored as JSON strings under the key `session:{session_id}`.
|
|||
|
|
The Redis TTL is refreshed on every write so active sessions stay alive.
|
|||
|
|
On expiry, `get_session` returns None — callers should create a new session.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import json
|
|||
|
|
import time
|
|||
|
|
import uuid
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
from app.domain.conversation import ConversationMessage, ConversationSession, ConversationStore
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RedisConversationStore(ConversationStore):
|
|||
|
|
"""Store conversation sessions in Redis with automatic TTL expiry.
|
|||
|
|
|
|||
|
|
Each session is serialised as a JSON object at key ``session:{session_id}``.
|
|||
|
|
The TTL is reset on every write so sessions stay alive as long as they are active.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# Prefix for all session keys to avoid collisions with other Redis consumers.
|
|||
|
|
_PREFIX = "session:"
|
|||
|
|
|
|||
|
|
def __init__(self, *, redis_client: Any, timeout_seconds: int = 1800) -> None:
|
|||
|
|
"""Initialise the store with an existing Redis client and a TTL in seconds."""
|
|||
|
|
self._redis = redis_client
|
|||
|
|
self._ttl = timeout_seconds
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# Internal helpers
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def _key(self, session_id: str) -> str:
|
|||
|
|
"""Build the Redis key for a session."""
|
|||
|
|
return f"{self._PREFIX}{session_id}"
|
|||
|
|
|
|||
|
|
def _serialise(self, session: ConversationSession) -> str:
|
|||
|
|
"""Serialise a ConversationSession to a JSON string."""
|
|||
|
|
return json.dumps(
|
|||
|
|
{
|
|||
|
|
"session_id": session.session_id,
|
|||
|
|
"created_at": session.created_at,
|
|||
|
|
"updated_at": session.updated_at,
|
|||
|
|
"metadata": session.metadata,
|
|||
|
|
"messages": [
|
|||
|
|
{
|
|||
|
|
"role": msg.role,
|
|||
|
|
"content": msg.content,
|
|||
|
|
"timestamp": msg.timestamp,
|
|||
|
|
"sources": msg.sources,
|
|||
|
|
}
|
|||
|
|
for msg in session.messages
|
|||
|
|
],
|
|||
|
|
},
|
|||
|
|
ensure_ascii=False,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _deserialise(self, raw: bytes | str) -> ConversationSession:
|
|||
|
|
"""Deserialise a JSON string back into a ConversationSession."""
|
|||
|
|
data = json.loads(raw)
|
|||
|
|
messages = [
|
|||
|
|
ConversationMessage(
|
|||
|
|
role=m["role"],
|
|||
|
|
content=m["content"],
|
|||
|
|
timestamp=m["timestamp"],
|
|||
|
|
sources=m.get("sources", []),
|
|||
|
|
)
|
|||
|
|
for m in data.get("messages", [])
|
|||
|
|
]
|
|||
|
|
session = ConversationSession(
|
|||
|
|
session_id=data["session_id"],
|
|||
|
|
created_at=data.get("created_at", 0),
|
|||
|
|
updated_at=data.get("updated_at", 0),
|
|||
|
|
metadata=data.get("metadata", {}),
|
|||
|
|
)
|
|||
|
|
session.messages = messages
|
|||
|
|
return session
|
|||
|
|
|
|||
|
|
def _save(self, session: ConversationSession) -> None:
|
|||
|
|
"""Persist a session to Redis and refresh its TTL."""
|
|||
|
|
self._redis.setex(self._key(session.session_id), self._ttl, self._serialise(session))
|
|||
|
|
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
# ConversationStore protocol
|
|||
|
|
# ------------------------------------------------------------------
|
|||
|
|
|
|||
|
|
def create_session(self, metadata: dict | None = None) -> ConversationSession:
|
|||
|
|
"""Create a new empty session and persist it immediately."""
|
|||
|
|
now = int(time.time())
|
|||
|
|
session = ConversationSession(
|
|||
|
|
session_id=str(uuid.uuid4())[:8],
|
|||
|
|
created_at=now,
|
|||
|
|
updated_at=now,
|
|||
|
|
metadata=metadata or {},
|
|||
|
|
)
|
|||
|
|
self._save(session)
|
|||
|
|
return session
|
|||
|
|
|
|||
|
|
def get_session(self, session_id: str) -> ConversationSession | None:
|
|||
|
|
"""Return a session by ID, or None if it does not exist or has expired."""
|
|||
|
|
raw = self._redis.get(self._key(session_id))
|
|||
|
|
if raw is None:
|
|||
|
|
return None
|
|||
|
|
try:
|
|||
|
|
return self._deserialise(raw)
|
|||
|
|
except Exception:
|
|||
|
|
logger.warning("Failed to deserialise session: {}", session_id)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def save_message(
|
|||
|
|
self,
|
|||
|
|
session_id: str,
|
|||
|
|
*,
|
|||
|
|
role: str,
|
|||
|
|
content: str,
|
|||
|
|
sources: list[dict] | None = None,
|
|||
|
|
) -> ConversationSession | None:
|
|||
|
|
"""Append a message to a session and refresh its TTL."""
|
|||
|
|
session = self.get_session(session_id)
|
|||
|
|
if session is None:
|
|||
|
|
return None
|
|||
|
|
session.messages.append(
|
|||
|
|
ConversationMessage(
|
|||
|
|
role=role,
|
|||
|
|
content=content,
|
|||
|
|
timestamp=int(time.time()),
|
|||
|
|
sources=sources or [],
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
session.updated_at = int(time.time())
|
|||
|
|
self._save(session)
|
|||
|
|
return session
|
|||
|
|
|
|||
|
|
def delete_session(self, session_id: str) -> bool:
|
|||
|
|
"""Delete a session. Returns True if it existed, False otherwise."""
|
|||
|
|
deleted = self._redis.delete(self._key(session_id))
|
|||
|
|
return bool(deleted)
|
|||
|
|
|
|||
|
|
def list_sessions(self) -> list[dict]:
|
|||
|
|
"""Return summary dicts for all live sessions visible in this Redis DB.
|
|||
|
|
|
|||
|
|
Note: KEYS is used here for simplicity; in a large production deployment
|
|||
|
|
replace with SCAN for non-blocking iteration.
|
|||
|
|
"""
|
|||
|
|
pattern = f"{self._PREFIX}*"
|
|||
|
|
keys = self._redis.keys(pattern)
|
|||
|
|
result = []
|
|||
|
|
for key in keys:
|
|||
|
|
raw = self._redis.get(key)
|
|||
|
|
if raw is None:
|
|||
|
|
continue
|
|||
|
|
try:
|
|||
|
|
data = json.loads(raw)
|
|||
|
|
result.append(
|
|||
|
|
{
|
|||
|
|
"session_id": data["session_id"],
|
|||
|
|
"message_count": len(data.get("messages", [])),
|
|||
|
|
"created_at": data.get("created_at", 0),
|
|||
|
|
"updated_at": data.get("updated_at", 0),
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
except Exception:
|
|||
|
|
continue
|
|||
|
|
return result
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: Run tests**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend pytest tests/test_redis_conversation_store.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: All 9 tests PASS.
|
|||
|
|
|
|||
|
|
### Task 7: Add `session_backend` setting and switch bootstrap
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Modify: `backend/app/config/settings.py`
|
|||
|
|
- Modify: `backend/app/shared/bootstrap.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: Add `session_backend` setting**
|
|||
|
|
|
|||
|
|
In `backend/app/config/settings.py`, find the session block (around line 125) and add:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
session_backend: str = Field(
|
|||
|
|
default="memory",
|
|||
|
|
description="会话存储后端 (memory | redis)。redis 需要 Redis 可用。",
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: Update `get_conversation_store` in bootstrap**
|
|||
|
|
|
|||
|
|
Replace the existing `get_conversation_store` function in `backend/app/shared/bootstrap.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
@lru_cache
|
|||
|
|
def get_conversation_store() -> ConversationStore:
|
|||
|
|
"""Return the active conversation store based on settings.
|
|||
|
|
|
|||
|
|
When session_backend='redis', sessions survive backend restarts.
|
|||
|
|
When session_backend='memory' (default), sessions are process-local.
|
|||
|
|
"""
|
|||
|
|
if settings.session_backend == "redis":
|
|||
|
|
import redis as redis_lib
|
|||
|
|
from app.infrastructure.session.redis_conversation_store import RedisConversationStore
|
|||
|
|
|
|||
|
|
# Build the Redis client from the same connection settings as Celery.
|
|||
|
|
kwargs: dict = {
|
|||
|
|
"host": settings.redis_host,
|
|||
|
|
"port": settings.redis_port,
|
|||
|
|
"db": settings.redis_db,
|
|||
|
|
"decode_responses": False,
|
|||
|
|
}
|
|||
|
|
if settings.redis_password:
|
|||
|
|
kwargs["password"] = settings.redis_password
|
|||
|
|
|
|||
|
|
redis_client = redis_lib.Redis(**kwargs)
|
|||
|
|
return RedisConversationStore(
|
|||
|
|
redis_client=redis_client,
|
|||
|
|
timeout_seconds=settings.session_timeout_minutes * 60,
|
|||
|
|
)
|
|||
|
|
return InMemoryConversationStore(
|
|||
|
|
max_sessions=settings.session_max_sessions,
|
|||
|
|
timeout_minutes=settings.session_timeout_minutes,
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Also add the `ConversationStore` type hint to the import block at the top of `bootstrap.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from app.domain.conversation import ConversationStore
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: Add to `.env.example`**
|
|||
|
|
|
|||
|
|
Append to `.env.example`:
|
|||
|
|
|
|||
|
|
```ini
|
|||
|
|
# ── Session backend ───────────────────────────────────────────────────────────
|
|||
|
|
# Set SESSION_BACKEND=redis for persistent sessions (requires Redis).
|
|||
|
|
# Use SESSION_BACKEND=memory for development without Redis.
|
|||
|
|
SESSION_BACKEND=redis
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: Verify app starts with the new setting**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend python -c "
|
|||
|
|
from app.config.settings import settings
|
|||
|
|
from app.shared.bootstrap import get_conversation_store
|
|||
|
|
get_conversation_store.cache_clear()
|
|||
|
|
print('session_backend:', settings.session_backend)
|
|||
|
|
print('OK')
|
|||
|
|
"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: prints `session_backend: memory` (or `redis` if set in .env) then `OK`.
|
|||
|
|
|
|||
|
|
> ✅ **Group C complete.** Set `SESSION_BACKEND=redis` in `.env` to persist sessions across restarts.
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Group D — JWT Authentication + RBAC + Audit Logging
|
|||
|
|
|
|||
|
|
> Estimated time: 3-5 days.
|
|||
|
|
>
|
|||
|
|
> **Why:** `backend/app/api/main.py:49` has `allow_origins=["*"]`. No route requires any credential. PPT Slide 12 defines a four-role permission matrix; Slide 9 lists "RBAC" as a key challenge response.
|
|||
|
|
>
|
|||
|
|
> **Approach for MVP:** JWT tokens carry a `role` claim (admin/legal/ehs/readonly). All API routes require a valid token. Fine-grained per-route RBAC enforcement is a follow-up; the role field is the enforcement hook. An audit middleware logs every authenticated API call. Users are stored in PostgreSQL.
|
|||
|
|
|
|||
|
|
### Task 8: Add auth dependencies
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Modify: `requirements.txt`
|
|||
|
|
- Modify: `pyproject.toml`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: Add `python-jose` and `passlib` to `requirements.txt`**
|
|||
|
|
|
|||
|
|
In `requirements.txt`, add after the existing tool libs block:
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
# Authentication
|
|||
|
|
python-jose[cryptography]>=3.3.0
|
|||
|
|
passlib[bcrypt]>=1.7.4
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: Add to `pyproject.toml` dependencies list**
|
|||
|
|
|
|||
|
|
In `pyproject.toml` under `dependencies`, add:
|
|||
|
|
|
|||
|
|
```toml
|
|||
|
|
"python-jose[cryptography]>=3.3.0",
|
|||
|
|
"passlib[bcrypt]>=1.7.4",
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: Install**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
uv sync
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: No errors.
|
|||
|
|
|
|||
|
|
### Task 9: Create auth domain models
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `backend/app/domain/auth/__init__.py`
|
|||
|
|
- Create: `backend/app/domain/auth/models.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: Write failing test**
|
|||
|
|
|
|||
|
|
Create `tests/test_jwt_handler.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""Tests for JWTHandler — token creation and decoding.
|
|||
|
|
|
|||
|
|
These tests do not require a running server or database.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import time
|
|||
|
|
import pytest
|
|||
|
|
|
|||
|
|
SECRET = "test-secret-key-minimum-32-characters-long"
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture
|
|||
|
|
def handler():
|
|||
|
|
"""Return a JWTHandler configured with a test secret."""
|
|||
|
|
from app.infrastructure.auth.jwt_handler import JWTHandler
|
|||
|
|
return JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=30)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_create_token_returns_string(handler):
|
|||
|
|
"""create_access_token must return a non-empty string."""
|
|||
|
|
token = handler.create_access_token(user_id="u1", username="alice", role="admin")
|
|||
|
|
assert isinstance(token, str)
|
|||
|
|
assert len(token) > 20
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_decode_token_returns_correct_claims(handler):
|
|||
|
|
"""decode_token must return UserClaims matching the input."""
|
|||
|
|
token = handler.create_access_token(user_id="u1", username="alice", role="admin")
|
|||
|
|
claims = handler.decode_token(token)
|
|||
|
|
assert claims.user_id == "u1"
|
|||
|
|
assert claims.username == "alice"
|
|||
|
|
assert claims.role == "admin"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_decode_expired_token_raises(handler):
|
|||
|
|
"""decode_token must raise ValueError on an expired token."""
|
|||
|
|
from app.infrastructure.auth.jwt_handler import JWTHandler
|
|||
|
|
# Create a token that expires in 0 minutes (immediately expired).
|
|||
|
|
short_handler = JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=0)
|
|||
|
|
token = short_handler.create_access_token(user_id="u2", username="bob", role="readonly")
|
|||
|
|
time.sleep(1)
|
|||
|
|
with pytest.raises(ValueError, match="expired"):
|
|||
|
|
short_handler.decode_token(token)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_decode_invalid_token_raises(handler):
|
|||
|
|
"""decode_token must raise ValueError for a tampered token."""
|
|||
|
|
with pytest.raises(ValueError):
|
|||
|
|
handler.decode_token("not.a.valid.jwt.token")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_decode_wrong_secret_raises():
|
|||
|
|
"""decode_token must raise ValueError when signed with a different secret."""
|
|||
|
|
from app.infrastructure.auth.jwt_handler import JWTHandler
|
|||
|
|
creator = JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=60)
|
|||
|
|
verifier = JWTHandler(secret_key="wrong-secret-key-also-minimum-32-chars", algorithm="HS256", expire_minutes=60)
|
|||
|
|
token = creator.create_access_token(user_id="u3", username="carol", role="legal")
|
|||
|
|
with pytest.raises(ValueError):
|
|||
|
|
verifier.decode_token(token)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: Run test to confirm failure**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend pytest tests/test_jwt_handler.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `ModuleNotFoundError: No module named 'app.domain.auth'`
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: Create domain auth package**
|
|||
|
|
|
|||
|
|
Create `backend/app/domain/auth/__init__.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""Auth domain: role definitions and token claim models.
|
|||
|
|
|
|||
|
|
The domain layer defines what a user identity looks like (UserClaims) and
|
|||
|
|
what roles exist (UserRole). Infrastructure details (JWT, bcrypt, PostgreSQL)
|
|||
|
|
live under infrastructure/auth and never leak into this package.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from .models import UserClaims, UserRole
|
|||
|
|
|
|||
|
|
__all__ = ["UserClaims", "UserRole"]
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: Create auth domain models**
|
|||
|
|
|
|||
|
|
Create `backend/app/domain/auth/models.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""Auth domain models: roles and token claims.
|
|||
|
|
|
|||
|
|
UserRole defines the four roles from PPT Slide 12.
|
|||
|
|
UserClaims is what the JWT decodes to — it is the identity object passed
|
|||
|
|
through FastAPI dependency injection to route handlers.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import enum
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
|
|||
|
|
|
|||
|
|
class UserRole(str, enum.Enum):
|
|||
|
|
"""Access roles mirroring the four-role RBAC matrix from the product spec.
|
|||
|
|
|
|||
|
|
ADMIN — full platform access including system management.
|
|||
|
|
LEGAL — knowledge query, document review, compliance checks.
|
|||
|
|
EHS — knowledge query, perception/regulatory signals.
|
|||
|
|
READONLY — knowledge query only.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
ADMIN = "admin"
|
|||
|
|
LEGAL = "legal"
|
|||
|
|
EHS = "ehs"
|
|||
|
|
READONLY = "readonly"
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class UserClaims:
|
|||
|
|
"""Decoded JWT payload representing an authenticated user.
|
|||
|
|
|
|||
|
|
Instances are created by JWTHandler.decode_token() and injected into
|
|||
|
|
route handlers via the get_current_user FastAPI dependency.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# Unique user identifier (UUID string stored in PostgreSQL users table).
|
|||
|
|
user_id: str
|
|||
|
|
# Display name used for audit log entries.
|
|||
|
|
username: str
|
|||
|
|
# Role determines which resources the user may access.
|
|||
|
|
role: UserRole
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### Task 10: Create JWT handler infrastructure
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `backend/app/infrastructure/auth/__init__.py`
|
|||
|
|
- Create: `backend/app/infrastructure/auth/jwt_handler.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: Create infrastructure auth package**
|
|||
|
|
|
|||
|
|
Create `backend/app/infrastructure/auth/__init__.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""JWT token creation and validation infrastructure.
|
|||
|
|
|
|||
|
|
JWTHandler is the only component in this package. It is wired through
|
|||
|
|
shared/bootstrap.py and injected into FastAPI dependencies.
|
|||
|
|
"""
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: Implement JWTHandler**
|
|||
|
|
|
|||
|
|
Create `backend/app/infrastructure/auth/jwt_handler.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""JWT access token creation and decoding.
|
|||
|
|
|
|||
|
|
Uses python-jose for HS256 token signing. Token expiry is enforced at
|
|||
|
|
decode time so expired tokens are rejected even if the signature is valid.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from datetime import UTC, datetime, timedelta
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from jose import JWTError, jwt
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
from app.domain.auth.models import UserClaims, UserRole
|
|||
|
|
|
|||
|
|
|
|||
|
|
class JWTHandler:
|
|||
|
|
"""Create and validate HS256 JWT access tokens.
|
|||
|
|
|
|||
|
|
A single shared instance is wired by bootstrap.py. Use
|
|||
|
|
get_jwt_handler() from shared.bootstrap for all token operations.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
*,
|
|||
|
|
secret_key: str,
|
|||
|
|
algorithm: str = "HS256",
|
|||
|
|
expire_minutes: int = 480,
|
|||
|
|
) -> None:
|
|||
|
|
"""Initialise the handler with signing credentials and token lifetime."""
|
|||
|
|
self._secret = secret_key
|
|||
|
|
self._algorithm = algorithm
|
|||
|
|
self._expire_minutes = expire_minutes
|
|||
|
|
|
|||
|
|
def create_access_token(
|
|||
|
|
self,
|
|||
|
|
*,
|
|||
|
|
user_id: str,
|
|||
|
|
username: str,
|
|||
|
|
role: str,
|
|||
|
|
) -> str:
|
|||
|
|
"""Return a signed JWT containing user identity and role claims."""
|
|||
|
|
now = datetime.now(UTC)
|
|||
|
|
payload: dict[str, Any] = {
|
|||
|
|
"sub": user_id,
|
|||
|
|
"username": username,
|
|||
|
|
"role": role,
|
|||
|
|
"iat": now,
|
|||
|
|
"exp": now + timedelta(minutes=self._expire_minutes),
|
|||
|
|
}
|
|||
|
|
return jwt.encode(payload, self._secret, algorithm=self._algorithm)
|
|||
|
|
|
|||
|
|
def decode_token(self, token: str) -> UserClaims:
|
|||
|
|
"""Decode and validate a JWT, returning UserClaims.
|
|||
|
|
|
|||
|
|
Raises ValueError with a descriptive message on expiry, tampering,
|
|||
|
|
or any other validation failure so callers do not need to know jose.
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
payload = jwt.decode(token, self._secret, algorithms=[self._algorithm])
|
|||
|
|
except JWTError as exc:
|
|||
|
|
msg = str(exc).lower()
|
|||
|
|
if "expired" in msg:
|
|||
|
|
raise ValueError("Token expired") from exc
|
|||
|
|
raise ValueError(f"Invalid token: {exc}") from exc
|
|||
|
|
|
|||
|
|
user_id = payload.get("sub")
|
|||
|
|
username = payload.get("username", "")
|
|||
|
|
role_str = payload.get("role", UserRole.READONLY.value)
|
|||
|
|
|
|||
|
|
if not user_id:
|
|||
|
|
raise ValueError("Token missing subject claim")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
role = UserRole(role_str)
|
|||
|
|
except ValueError:
|
|||
|
|
logger.warning("Unknown role in token: {}, defaulting to readonly", role_str)
|
|||
|
|
role = UserRole.READONLY
|
|||
|
|
|
|||
|
|
return UserClaims(user_id=user_id, username=username, role=role)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: Run JWT tests**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend pytest tests/test_jwt_handler.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: All 5 tests PASS.
|
|||
|
|
|
|||
|
|
### Task 11: Create PostgreSQL user store and seed script
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `backend/app/infrastructure/auth/user_store.py`
|
|||
|
|
- Create: `scripts/seed_users.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: Implement PostgreSQL user store**
|
|||
|
|
|
|||
|
|
Create `backend/app/infrastructure/auth/user_store.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""PostgreSQL-backed user store for authentication.
|
|||
|
|
|
|||
|
|
Manages a `users` table with hashed passwords and roles.
|
|||
|
|
Provides lookup by username for the login flow.
|
|||
|
|
Table DDL is auto-applied on first connection.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import uuid
|
|||
|
|
from dataclasses import dataclass
|
|||
|
|
from typing import Optional
|
|||
|
|
|
|||
|
|
import psycopg2
|
|||
|
|
import psycopg2.extras
|
|||
|
|
from loguru import logger
|
|||
|
|
from passlib.context import CryptContext
|
|||
|
|
|
|||
|
|
from app.config.settings import settings
|
|||
|
|
|
|||
|
|
|
|||
|
|
# bcrypt context — work factor 12 is a good production default.
|
|||
|
|
_PWD_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|||
|
|
|
|||
|
|
# DDL executed once to ensure the table exists.
|
|||
|
|
_CREATE_TABLE_SQL = """
|
|||
|
|
CREATE TABLE IF NOT EXISTS users (
|
|||
|
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
|||
|
|
username VARCHAR(100) UNIQUE NOT NULL,
|
|||
|
|
hashed_pw TEXT NOT NULL,
|
|||
|
|
role VARCHAR(50) NOT NULL DEFAULT 'readonly',
|
|||
|
|
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
|||
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|||
|
|
);
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class UserRecord:
|
|||
|
|
"""A single row from the users table."""
|
|||
|
|
|
|||
|
|
id: str
|
|||
|
|
username: str
|
|||
|
|
hashed_pw: str
|
|||
|
|
role: str
|
|||
|
|
is_active: bool
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PostgresUserStore:
|
|||
|
|
"""Read and verify users stored in the PostgreSQL users table.
|
|||
|
|
|
|||
|
|
The connection is opened on first use and shared for the lifetime
|
|||
|
|
of the singleton instance wired by bootstrap.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self) -> None:
|
|||
|
|
"""Initialise the store and ensure the users table exists."""
|
|||
|
|
self._conn = psycopg2.connect(
|
|||
|
|
host=settings.postgres_host,
|
|||
|
|
port=settings.postgres_port,
|
|||
|
|
user=settings.postgres_user,
|
|||
|
|
password=settings.postgres_password,
|
|||
|
|
dbname=settings.postgres_db,
|
|||
|
|
cursor_factory=psycopg2.extras.RealDictCursor,
|
|||
|
|
)
|
|||
|
|
self._conn.autocommit = True
|
|||
|
|
self._ensure_table()
|
|||
|
|
|
|||
|
|
def _ensure_table(self) -> None:
|
|||
|
|
"""Create the users table if it does not already exist."""
|
|||
|
|
with self._conn.cursor() as cur:
|
|||
|
|
cur.execute(_CREATE_TABLE_SQL)
|
|||
|
|
|
|||
|
|
def get_by_username(self, username: str) -> Optional[UserRecord]:
|
|||
|
|
"""Return a UserRecord for the given username, or None if not found."""
|
|||
|
|
with self._conn.cursor() as cur:
|
|||
|
|
cur.execute(
|
|||
|
|
"SELECT id, username, hashed_pw, role, is_active "
|
|||
|
|
"FROM users WHERE username = %s",
|
|||
|
|
(username,),
|
|||
|
|
)
|
|||
|
|
row = cur.fetchone()
|
|||
|
|
if row is None:
|
|||
|
|
return None
|
|||
|
|
return UserRecord(
|
|||
|
|
id=str(row["id"]),
|
|||
|
|
username=row["username"],
|
|||
|
|
hashed_pw=row["hashed_pw"],
|
|||
|
|
role=row["role"],
|
|||
|
|
is_active=row["is_active"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def verify_password(self, plain: str, hashed: str) -> bool:
|
|||
|
|
"""Return True if `plain` matches the stored bcrypt hash."""
|
|||
|
|
return _PWD_CTX.verify(plain, hashed)
|
|||
|
|
|
|||
|
|
def authenticate(self, username: str, password: str) -> Optional[UserRecord]:
|
|||
|
|
"""Return the UserRecord if credentials are valid, else None."""
|
|||
|
|
user = self.get_by_username(username)
|
|||
|
|
if user is None or not user.is_active:
|
|||
|
|
return None
|
|||
|
|
if not self.verify_password(password, user.hashed_pw):
|
|||
|
|
return None
|
|||
|
|
return user
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def hash_password(plain: str) -> str:
|
|||
|
|
"""Hash a plain-text password with bcrypt."""
|
|||
|
|
return _PWD_CTX.hash(plain)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: Create the seed script**
|
|||
|
|
|
|||
|
|
Create `scripts/seed_users.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
#!/usr/bin/env python
|
|||
|
|
"""Seed demo users into the PostgreSQL users table.
|
|||
|
|
|
|||
|
|
Run from the repo root:
|
|||
|
|
PYTHONPATH=backend python scripts/seed_users.py
|
|||
|
|
|
|||
|
|
Creates four demo accounts (one per role). Safe to run multiple times
|
|||
|
|
— existing usernames are left unchanged (ON CONFLICT DO NOTHING).
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
# Allow running from repo root without installing the package.
|
|||
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend"))
|
|||
|
|
|
|||
|
|
import psycopg2
|
|||
|
|
from passlib.context import CryptContext
|
|||
|
|
from app.config.settings import settings
|
|||
|
|
|
|||
|
|
_PWD_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|||
|
|
|
|||
|
|
DEMO_USERS = [
|
|||
|
|
{"username": "admin", "password": "Admin@2026!", "role": "admin"},
|
|||
|
|
{"username": "legal", "password": "Legal@2026!", "role": "legal"},
|
|||
|
|
{"username": "ehs", "password": "EHS@2026!", "role": "ehs"},
|
|||
|
|
{"username": "readonly", "password": "Read@2026!", "role": "readonly"},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
|
|||
|
|
def seed() -> None:
|
|||
|
|
"""Insert demo users. Skips any username that already exists."""
|
|||
|
|
conn = psycopg2.connect(
|
|||
|
|
host=settings.postgres_host,
|
|||
|
|
port=settings.postgres_port,
|
|||
|
|
user=settings.postgres_user,
|
|||
|
|
password=settings.postgres_password,
|
|||
|
|
dbname=settings.postgres_db,
|
|||
|
|
)
|
|||
|
|
conn.autocommit = True
|
|||
|
|
|
|||
|
|
# Ensure the table exists (same DDL as user_store.py).
|
|||
|
|
with conn.cursor() as cur:
|
|||
|
|
cur.execute("""
|
|||
|
|
CREATE TABLE IF NOT EXISTS users (
|
|||
|
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
|||
|
|
username VARCHAR(100) UNIQUE NOT NULL,
|
|||
|
|
hashed_pw TEXT NOT NULL,
|
|||
|
|
role VARCHAR(50) NOT NULL DEFAULT 'readonly',
|
|||
|
|
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
|||
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|||
|
|
);
|
|||
|
|
""")
|
|||
|
|
|
|||
|
|
inserted = 0
|
|||
|
|
with conn.cursor() as cur:
|
|||
|
|
for user in DEMO_USERS:
|
|||
|
|
hashed = _PWD_CTX.hash(user["password"])
|
|||
|
|
cur.execute(
|
|||
|
|
"""
|
|||
|
|
INSERT INTO users (username, hashed_pw, role)
|
|||
|
|
VALUES (%s, %s, %s)
|
|||
|
|
ON CONFLICT (username) DO NOTHING
|
|||
|
|
""",
|
|||
|
|
(user["username"], hashed, user["role"]),
|
|||
|
|
)
|
|||
|
|
if cur.rowcount > 0:
|
|||
|
|
print(f" Created user: {user['username']} (role={user['role']})")
|
|||
|
|
inserted += 1
|
|||
|
|
else:
|
|||
|
|
print(f" Skipped (already exists): {user['username']}")
|
|||
|
|
|
|||
|
|
print(f"\nDone. {inserted} user(s) created.")
|
|||
|
|
conn.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
seed()
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: Seed demo users (requires PostgreSQL running)**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend python scripts/seed_users.py
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected:
|
|||
|
|
```
|
|||
|
|
Created user: admin (role=admin)
|
|||
|
|
Created user: legal (role=legal)
|
|||
|
|
Created user: ehs (role=ehs)
|
|||
|
|
Created user: readonly (role=readonly)
|
|||
|
|
|
|||
|
|
Done. 4 user(s) created.
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
### Task 12: Create FastAPI auth dependencies and the token endpoint
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `backend/app/api/dependencies/__init__.py`
|
|||
|
|
- Create: `backend/app/api/dependencies/auth.py`
|
|||
|
|
- Create: `backend/app/api/routes/auth.py`
|
|||
|
|
- Modify: `backend/app/config/settings.py`
|
|||
|
|
- Modify: `backend/app/shared/bootstrap.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: Add auth settings**
|
|||
|
|
|
|||
|
|
In `backend/app/config/settings.py`, add to the `Settings` class:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
# ── Auth ─────────────────────────────────────────────────────────────
|
|||
|
|
# Change auth_secret_key to a long random string in production.
|
|||
|
|
# Generate one with: python -c "import secrets; print(secrets.token_hex(32))"
|
|||
|
|
auth_secret_key: str = Field(
|
|||
|
|
default="change-me-in-production-must-be-32-or-more-characters-long",
|
|||
|
|
description="JWT signing secret. MUST be changed in production.",
|
|||
|
|
)
|
|||
|
|
auth_algorithm: str = Field(default="HS256", description="JWT signing algorithm.")
|
|||
|
|
auth_token_expire_minutes: int = Field(default=480, description="JWT TTL in minutes (default: 8 hours).")
|
|||
|
|
auth_enabled: bool = Field(default=True, description="Set False to bypass auth (development only).")
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: Wire auth components in bootstrap**
|
|||
|
|
|
|||
|
|
Add to `backend/app/shared/bootstrap.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
@lru_cache
|
|||
|
|
def get_jwt_handler():
|
|||
|
|
"""Return the shared JWTHandler instance."""
|
|||
|
|
from app.infrastructure.auth.jwt_handler import JWTHandler
|
|||
|
|
return JWTHandler(
|
|||
|
|
secret_key=settings.auth_secret_key,
|
|||
|
|
algorithm=settings.auth_algorithm,
|
|||
|
|
expire_minutes=settings.auth_token_expire_minutes,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@lru_cache
|
|||
|
|
def get_user_store():
|
|||
|
|
"""Return the PostgreSQL user store (lazy, connects on first call)."""
|
|||
|
|
from app.infrastructure.auth.user_store import PostgresUserStore
|
|||
|
|
return PostgresUserStore()
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: Create the dependencies package**
|
|||
|
|
|
|||
|
|
Create `backend/app/api/dependencies/__init__.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""FastAPI dependency functions for authentication and authorisation.
|
|||
|
|
|
|||
|
|
Import `get_current_user` or `require_role` into route modules to protect
|
|||
|
|
endpoints. Both use the shared JWTHandler wired through bootstrap.
|
|||
|
|
"""
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: Implement auth dependencies**
|
|||
|
|
|
|||
|
|
Create `backend/app/api/dependencies/auth.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""FastAPI dependencies for JWT authentication.
|
|||
|
|
|
|||
|
|
Usage in a route:
|
|||
|
|
from app.api.dependencies.auth import get_current_user, require_role
|
|||
|
|
from app.domain.auth.models import UserRole
|
|||
|
|
|
|||
|
|
@router.get("/protected")
|
|||
|
|
async def protected(user: UserClaims = Depends(get_current_user)):
|
|||
|
|
return {"user": user.username}
|
|||
|
|
|
|||
|
|
@router.delete("/admin-only")
|
|||
|
|
async def admin_only(user: UserClaims = Depends(require_role(UserRole.ADMIN))):
|
|||
|
|
...
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from fastapi import Depends, HTTPException, status
|
|||
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|||
|
|
|
|||
|
|
from app.config.settings import settings
|
|||
|
|
from app.domain.auth.models import UserClaims, UserRole
|
|||
|
|
from app.shared.bootstrap import get_jwt_handler
|
|||
|
|
|
|||
|
|
# Use Bearer token scheme — client sends `Authorization: Bearer <token>`.
|
|||
|
|
_bearer = HTTPBearer(auto_error=False)
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def get_current_user(
|
|||
|
|
credentials: HTTPAuthorizationCredentials | None = Depends(_bearer),
|
|||
|
|
) -> UserClaims:
|
|||
|
|
"""Extract and validate the JWT from the Authorization header.
|
|||
|
|
|
|||
|
|
Returns the decoded UserClaims on success.
|
|||
|
|
Raises HTTP 401 when the token is missing, expired, or invalid.
|
|||
|
|
When auth_enabled=False (development), returns a synthetic admin user.
|
|||
|
|
"""
|
|||
|
|
if not settings.auth_enabled:
|
|||
|
|
# Development bypass — never enable this in production.
|
|||
|
|
return UserClaims(user_id="dev", username="dev-admin", role=UserRole.ADMIN)
|
|||
|
|
|
|||
|
|
if credentials is None:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="Missing authentication token",
|
|||
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|||
|
|
)
|
|||
|
|
try:
|
|||
|
|
return get_jwt_handler().decode_token(credentials.credentials)
|
|||
|
|
except ValueError as exc:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail=str(exc),
|
|||
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|||
|
|
) from exc
|
|||
|
|
|
|||
|
|
|
|||
|
|
def require_role(*roles: UserRole):
|
|||
|
|
"""Return a dependency that enforces one of the given roles.
|
|||
|
|
|
|||
|
|
Example:
|
|||
|
|
Depends(require_role(UserRole.ADMIN, UserRole.LEGAL))
|
|||
|
|
"""
|
|||
|
|
async def _check(user: UserClaims = Depends(get_current_user)) -> UserClaims:
|
|||
|
|
"""Verify the user holds one of the required roles."""
|
|||
|
|
if user.role not in roles:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|||
|
|
detail=f"Role '{user.role}' is not permitted. Required: {[r.value for r in roles]}",
|
|||
|
|
)
|
|||
|
|
return user
|
|||
|
|
return _check
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: Create the auth router**
|
|||
|
|
|
|||
|
|
Create `backend/app/api/routes/auth.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""Authentication routes — token issuance only.
|
|||
|
|
|
|||
|
|
POST /auth/token — exchange username + password for a JWT.
|
|||
|
|
GET /auth/me — return the current user's identity (requires token).
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|||
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
|
|||
|
|
from app.api.dependencies.auth import get_current_user
|
|||
|
|
from app.domain.auth.models import UserClaims
|
|||
|
|
from app.shared.bootstrap import get_jwt_handler, get_user_store
|
|||
|
|
|
|||
|
|
|
|||
|
|
router = APIRouter(prefix="/auth", tags=["认证"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TokenResponse(BaseModel):
|
|||
|
|
"""JWT token response body."""
|
|||
|
|
|
|||
|
|
access_token: str
|
|||
|
|
token_type: str = "bearer"
|
|||
|
|
expires_in: int
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/token", response_model=TokenResponse)
|
|||
|
|
async def login(form: OAuth2PasswordRequestForm = Depends()):
|
|||
|
|
"""Issue a JWT for valid username + password credentials.
|
|||
|
|
|
|||
|
|
Uses standard OAuth2 password grant form fields so this endpoint
|
|||
|
|
is compatible with Swagger UI's Authorize button.
|
|||
|
|
"""
|
|||
|
|
user_store = get_user_store()
|
|||
|
|
user = user_store.authenticate(form.username, form.password)
|
|||
|
|
if user is None:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|||
|
|
detail="Incorrect username or password",
|
|||
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
from app.config.settings import settings
|
|||
|
|
token = get_jwt_handler().create_access_token(
|
|||
|
|
user_id=user.id,
|
|||
|
|
username=user.username,
|
|||
|
|
role=user.role,
|
|||
|
|
)
|
|||
|
|
return TokenResponse(
|
|||
|
|
access_token=token,
|
|||
|
|
token_type="bearer",
|
|||
|
|
expires_in=settings.auth_token_expire_minutes * 60,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/me")
|
|||
|
|
async def get_me(current_user: UserClaims = Depends(get_current_user)):
|
|||
|
|
"""Return the identity of the currently authenticated user."""
|
|||
|
|
return {
|
|||
|
|
"user_id": current_user.user_id,
|
|||
|
|
"username": current_user.username,
|
|||
|
|
"role": current_user.role.value,
|
|||
|
|
}
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 6: Register the auth router**
|
|||
|
|
|
|||
|
|
In `backend/app/api/routes/__init__.py`, add:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from .auth import router as auth_router
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
And in the `api_router.include_router(...)` block, add:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
api_router.include_router(auth_router)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Also add `auth_router` to `__all__`.
|
|||
|
|
|
|||
|
|
- [ ] **Step 7: Write auth route tests**
|
|||
|
|
|
|||
|
|
Create `tests/test_auth_routes.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""Integration tests for the auth routes.
|
|||
|
|
|
|||
|
|
Uses FastAPI TestClient. Does not require a running PostgreSQL — patches
|
|||
|
|
the user_store dependency.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
from fastapi.testclient import TestClient
|
|||
|
|
from unittest.mock import MagicMock, patch
|
|||
|
|
|
|||
|
|
from app.api.main import app
|
|||
|
|
from app.infrastructure.auth.user_store import UserRecord
|
|||
|
|
|
|||
|
|
client = TestClient(app, raise_server_exceptions=False)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.fixture(autouse=True)
|
|||
|
|
def mock_user_store():
|
|||
|
|
"""Replace the real PostgreSQL user store with a mock."""
|
|||
|
|
mock_store = MagicMock()
|
|||
|
|
alice = UserRecord(id="uuid-1", username="alice", hashed_pw="hashed", role="admin", is_active=True)
|
|||
|
|
mock_store.authenticate.side_effect = lambda u, p: alice if (u == "alice" and p == "correct") else None
|
|||
|
|
|
|||
|
|
with patch("app.api.routes.auth.get_user_store", return_value=mock_store):
|
|||
|
|
yield mock_store
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_login_returns_token_for_valid_credentials():
|
|||
|
|
"""POST /auth/token must return an access_token for valid credentials."""
|
|||
|
|
resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "correct"})
|
|||
|
|
assert resp.status_code == 200
|
|||
|
|
body = resp.json()
|
|||
|
|
assert "access_token" in body
|
|||
|
|
assert body["token_type"] == "bearer"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_login_returns_401_for_wrong_password():
|
|||
|
|
"""POST /auth/token must return 401 for wrong password."""
|
|||
|
|
resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "wrong"})
|
|||
|
|
assert resp.status_code == 401
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_me_returns_user_when_authenticated():
|
|||
|
|
"""GET /auth/me must return user identity when a valid token is provided."""
|
|||
|
|
# Get a real token first.
|
|||
|
|
login_resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "correct"})
|
|||
|
|
token = login_resp.json()["access_token"]
|
|||
|
|
|
|||
|
|
me_resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
|
|||
|
|
assert me_resp.status_code == 200
|
|||
|
|
assert me_resp.json()["username"] == "alice"
|
|||
|
|
assert me_resp.json()["role"] == "admin"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_me_returns_401_without_token():
|
|||
|
|
"""GET /auth/me must return 401 when no token is provided."""
|
|||
|
|
resp = client.get("/api/v1/auth/me")
|
|||
|
|
assert resp.status_code == 401
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 8: Run auth tests**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend pytest tests/test_auth_routes.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: All 4 tests PASS.
|
|||
|
|
|
|||
|
|
### Task 13: Add audit middleware and apply auth to routes
|
|||
|
|
|
|||
|
|
**Files:**
|
|||
|
|
- Create: `backend/app/api/middleware/__init__.py`
|
|||
|
|
- Create: `backend/app/api/middleware/audit.py`
|
|||
|
|
- Modify: `backend/app/api/main.py`
|
|||
|
|
- Modify: `backend/app/api/routes/documents.py`
|
|||
|
|
- Modify: `backend/app/api/routes/rag.py`
|
|||
|
|
- Modify: `backend/app/api/routes/compliance.py`
|
|||
|
|
|
|||
|
|
- [ ] **Step 1: Create the audit middleware**
|
|||
|
|
|
|||
|
|
Create `backend/app/api/middleware/__init__.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""HTTP middleware for cross-cutting concerns: audit logging."""
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Create `backend/app/api/middleware/audit.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
"""Audit logging middleware.
|
|||
|
|
|
|||
|
|
Logs every API request with method, path, status code, response time,
|
|||
|
|
and the authenticated user identity (extracted from the JWT when present).
|
|||
|
|
Log lines are structured so they can be ingested by ELK / Loki.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import time
|
|||
|
|
|
|||
|
|
from fastapi import Request, Response
|
|||
|
|
from loguru import logger
|
|||
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AuditMiddleware(BaseHTTPMiddleware):
|
|||
|
|
"""Log all API calls. Skips /health and /docs to reduce noise."""
|
|||
|
|
|
|||
|
|
# Paths that produce no audit log entry.
|
|||
|
|
_SKIP_PREFIXES = ("/health", "/docs", "/redoc", "/openapi.json", "/")
|
|||
|
|
|
|||
|
|
async def dispatch(self, request: Request, call_next) -> Response:
|
|||
|
|
"""Intercept the request, call the handler, and log the outcome."""
|
|||
|
|
path = request.url.path
|
|||
|
|
if any(path.startswith(p) for p in self._SKIP_PREFIXES):
|
|||
|
|
return await call_next(request)
|
|||
|
|
|
|||
|
|
start = time.perf_counter()
|
|||
|
|
response = await call_next(request)
|
|||
|
|
elapsed_ms = int((time.perf_counter() - start) * 1000)
|
|||
|
|
|
|||
|
|
# Try to extract user identity from a decoded JWT header if present.
|
|||
|
|
# We do not re-validate the token here — auth dependencies do that.
|
|||
|
|
user_id = "anonymous"
|
|||
|
|
username = "anonymous"
|
|||
|
|
auth_header = request.headers.get("authorization", "")
|
|||
|
|
if auth_header.startswith("Bearer "):
|
|||
|
|
try:
|
|||
|
|
from app.shared.bootstrap import get_jwt_handler
|
|||
|
|
claims = get_jwt_handler().decode_token(auth_header[7:])
|
|||
|
|
user_id = claims.user_id
|
|||
|
|
username = claims.username
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
logger.info(
|
|||
|
|
"AUDIT method={} path={} status={} elapsed_ms={} user_id={} username={}",
|
|||
|
|
request.method,
|
|||
|
|
path,
|
|||
|
|
response.status_code,
|
|||
|
|
elapsed_ms,
|
|||
|
|
user_id,
|
|||
|
|
username,
|
|||
|
|
)
|
|||
|
|
return response
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 2: Register the middleware and tighten CORS in `main.py`**
|
|||
|
|
|
|||
|
|
In `backend/app/api/main.py`, replace the `CORSMiddleware` block and add the audit middleware:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from app.api.middleware.audit import AuditMiddleware
|
|||
|
|
|
|||
|
|
# Tighten CORS — only allow the configured frontend origin.
|
|||
|
|
# In production, set CORS_ALLOW_ORIGINS in .env to the actual frontend URL.
|
|||
|
|
_ORIGINS = [o.strip() for o in settings.cors_allow_origins.split(",") if o.strip()]
|
|||
|
|
if not _ORIGINS:
|
|||
|
|
_ORIGINS = ["http://localhost:5173"] # Vite dev default
|
|||
|
|
|
|||
|
|
app.add_middleware(
|
|||
|
|
CORSMiddleware,
|
|||
|
|
allow_origins=_ORIGINS,
|
|||
|
|
allow_credentials=True,
|
|||
|
|
allow_methods=["*"],
|
|||
|
|
allow_headers=["*"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
app.add_middleware(AuditMiddleware)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Also add `cors_allow_origins` to `settings.py`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
cors_allow_origins: str = Field(
|
|||
|
|
default="http://localhost:5173",
|
|||
|
|
description="Comma-separated allowed CORS origins. Use * only in development.",
|
|||
|
|
)
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
And add to `.env.example`:
|
|||
|
|
|
|||
|
|
```ini
|
|||
|
|
# ── CORS ─────────────────────────────────────────────────────────────────────
|
|||
|
|
# Comma-separated frontend origin(s). Never use * in production.
|
|||
|
|
CORS_ALLOW_ORIGINS=http://localhost:5173
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 3: Apply `Depends(get_current_user)` to key routes**
|
|||
|
|
|
|||
|
|
In `backend/app/api/routes/documents.py`, add to imports:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
|||
|
|
from app.api.dependencies.auth import get_current_user
|
|||
|
|
from app.domain.auth.models import UserClaims
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Add `current_user: UserClaims = Depends(get_current_user)` as a parameter to `upload_document`, `list_documents`, `get_document_management_list`, and `delete_document`:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
@router.post("/upload", response_model=DocumentUploadResponse)
|
|||
|
|
async def upload_document(
|
|||
|
|
file: UploadFile = File(...),
|
|||
|
|
doc_id: str | None = Form(None),
|
|||
|
|
doc_name: str | None = Form(None),
|
|||
|
|
regulation_type: str | None = Form(None),
|
|||
|
|
version: str | None = Form(None),
|
|||
|
|
generate_summary: bool = Form(False),
|
|||
|
|
sync: bool = Form(False),
|
|||
|
|
current_user: UserClaims = Depends(get_current_user), # ← add this
|
|||
|
|
):
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Apply the same pattern to:
|
|||
|
|
- `list_documents` in `documents.py`
|
|||
|
|
- `delete_document` in `documents.py`
|
|||
|
|
- `rag_chat` in `rag.py`
|
|||
|
|
- `analyze_stream` in `compliance.py`
|
|||
|
|
|
|||
|
|
In `backend/app/api/routes/rag.py`, add:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from fastapi import APIRouter, Depends
|
|||
|
|
from app.api.dependencies.auth import get_current_user
|
|||
|
|
from app.domain.auth.models import UserClaims
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
@router.post("/chat")
|
|||
|
|
async def rag_chat(
|
|||
|
|
request: RagChatRequest,
|
|||
|
|
current_user: UserClaims = Depends(get_current_user), # ← add this
|
|||
|
|
):
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
In `backend/app/api/routes/compliance.py`, add:
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
from fastapi import APIRouter, Depends, File, Form, UploadFile
|
|||
|
|
from app.api.dependencies.auth import get_current_user
|
|||
|
|
from app.domain.auth.models import UserClaims
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
```python
|
|||
|
|
@router.post("/analyze-stream")
|
|||
|
|
async def analyze_stream(
|
|||
|
|
text: Optional[str] = Form(None),
|
|||
|
|
doc_id: Optional[str] = Form(None),
|
|||
|
|
file: Optional[UploadFile] = File(None),
|
|||
|
|
domains: Optional[str] = Form(None),
|
|||
|
|
title: Optional[str] = Form(None),
|
|||
|
|
current_user: UserClaims = Depends(get_current_user), # ← add this
|
|||
|
|
):
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
- [ ] **Step 4: Verify app starts and auth is enforced**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend python -c "from app.api.main import app; print('OK')"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `OK`
|
|||
|
|
|
|||
|
|
Test that an unauth'd request is rejected:
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
curl -s -o /dev/null -w "%{http_code}" http://localhost:8000/api/v1/rag/chat \
|
|||
|
|
-X POST -H "Content-Type: application/json" -d '{"query":"test"}'
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `401`
|
|||
|
|
|
|||
|
|
Test that an auth'd request passes:
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
TOKEN=$(curl -s -X POST http://localhost:8000/api/v1/auth/token \
|
|||
|
|
-d "username=admin&password=Admin@2026!" | python -c "import sys,json; print(json.load(sys.stdin)['access_token'])")
|
|||
|
|
|
|||
|
|
curl -s -o /dev/null -w "%{http_code}" http://localhost:8000/api/v1/rag/quick-questions \
|
|||
|
|
-H "Authorization: Bearer $TOKEN"
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: `200`
|
|||
|
|
|
|||
|
|
- [ ] **Step 5: Run the full test suite**
|
|||
|
|
|
|||
|
|
```bash
|
|||
|
|
PYTHONPATH=backend pytest tests/test_reranker_bootstrap.py tests/test_celery_tasks.py tests/test_redis_conversation_store.py tests/test_jwt_handler.py tests/test_auth_routes.py -v
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Expected: All tests PASS.
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
## Self-Review Checklist
|
|||
|
|
|
|||
|
|
**Spec coverage:**
|
|||
|
|
- ✅ §3.1 Reranker quick win → Task 1
|
|||
|
|
- ✅ §3.2 Async task-ification (Celery + Redis) → Tasks 2-5
|
|||
|
|
- ✅ §3.3 Session persistence → Tasks 6-7
|
|||
|
|
- ✅ §3.4 Auth/RBAC/Audit → Tasks 8-13
|
|||
|
|
- ✅ AGENTS.md: `api → application → domain ports → infrastructure` respected throughout
|
|||
|
|
- ✅ No new logic in `services/*` or `workflows/*`
|
|||
|
|
- ✅ `shared/bootstrap.py` is the composition root for all new singletons
|
|||
|
|
- ✅ All comments/docstrings in new backend files are in English
|
|||
|
|
- ✅ Desktop-first frontend constraint: no frontend changes in this plan
|
|||
|
|
|
|||
|
|
**Placeholder scan:** None found — all code blocks are complete and runnable.
|
|||
|
|
|
|||
|
|
**Type consistency:**
|
|||
|
|
- `UserClaims` defined in Task 9, used in Tasks 12, 13 ✅
|
|||
|
|
- `UserRole` enum defined in Task 9, used in Task 12 ✅
|
|||
|
|
- `process_document_task` name matches registration in Task 3 and lookup in Task 3 test ✅
|
|||
|
|
- `store_document` defined in Task 3, called in Task 4 ✅
|
|||
|
|
- `_process_document` defined in Task 3, called from Task 3 and Task 4 ✅
|
|||
|
|
- `get_jwt_handler`, `get_user_store`, `get_celery_app` added to bootstrap in Tasks 4, 12 ✅
|
|||
|
|
- `RedisConversationStore` constructor signature `(redis_client, timeout_seconds)` consistent across implementation, tests, and bootstrap ✅
|