Files
AIRegulation-DocAnalysis/docs/superpowers/plans/2026-06-05-phase1-production-foundation.md
wangwei 9fea9c6a53 1. Add 登陆功能
2. 调整字体大小
3. 新增部分功能
2026-06-05 18:00:31 +08:00

78 KiB
Raw Permalink Blame History

Phase 1: Production Foundation Implementation Plan

For agentic workers: REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (- [ ]) syntax for tracking.

Goal: Transform the existing synchronous RAG system into a production-ready backend with async task processing, persistent sessions, and JWT authentication — turning "can run" into "can operate."

Architecture: Four independent workstreams executed in priority order: (A) enable the already-wired reranker, (B) introduce Celery workers backed by the already-configured Redis to make document processing non-blocking, (C) replace the in-memory conversation store with a Redis-backed implementation, and (D) add JWT authentication with role claims and audit logging across all API routes.

Tech Stack: FastAPI · Celery 5 · Redis 7 · PostgreSQL 15 · python-jose[cryptography] · passlib[bcrypt] · fakeredis (test) — all running on the existing docker-compose.yml infrastructure.


Scope Note

These four workstreams are largely independent and can be worked in parallel by separate developers. Auth (Group D) should be integrated last because it modifies every route. The recommended solo sequence: A → B → C → D.

Not in scope (YAGNI): Fine-grained per-endpoint RBAC enforcement (the role claim is added to JWT as a future enforcement point), user self-registration, perception crawling, knowledge graph, EHS module.


File Map

Group A — Reranker Quick Win

Action File
Modify .env.example — document reranker env vars
Create tests/test_reranker_bootstrap.py

Group B — Async Task-ification

Action File
Create backend/app/infrastructure/tasks/__init__.py
Create backend/app/infrastructure/tasks/celery_app.py
Create backend/app/infrastructure/tasks/document_tasks.py
Modify backend/app/application/documents/services.py — extract _process_document
Modify backend/app/api/routes/documents.py — enqueue vs inline
Modify backend/app/shared/bootstrap.py — expose get_celery_app()
Modify dev.sh — add worker start target
Create tests/test_celery_tasks.py

Group C — Session Persistence

Action File
Create backend/app/infrastructure/session/redis_conversation_store.py
Modify backend/app/config/settings.py — add session_backend field
Modify backend/app/shared/bootstrap.py — switch store on setting
Modify pyproject.toml — add fakeredis dev dep
Create tests/test_redis_conversation_store.py

Group D — JWT Auth + RBAC + Audit

Action File
Modify requirements.txt — add python-jose, passlib[bcrypt]
Modify pyproject.toml — same
Create backend/app/domain/auth/__init__.py
Create backend/app/domain/auth/models.py
Create backend/app/infrastructure/auth/__init__.py
Create backend/app/infrastructure/auth/jwt_handler.py
Create backend/app/infrastructure/auth/user_store.py
Create backend/app/api/dependencies/__init__.py
Create backend/app/api/dependencies/auth.py
Create backend/app/api/middleware/__init__.py
Create backend/app/api/middleware/audit.py
Create backend/app/api/routes/auth.py
Modify backend/app/api/routes/__init__.py — include auth router
Modify backend/app/config/settings.py — add auth fields
Modify backend/app/shared/bootstrap.py — wire auth components
Modify backend/app/api/main.py — restrict CORS, add audit middleware
Modify backend/app/api/routes/documents.py — require auth
Modify backend/app/api/routes/rag.py — require auth
Modify backend/app/api/routes/compliance.py — require auth
Create scripts/seed_users.py
Create tests/test_jwt_handler.py
Create tests/test_auth_routes.py

Group A — Reranker Quick Win

Estimated time: 30 minutes. No new code, just configuration and a verification test.

Why: settings.reranker_enabled defaults to False in backend/app/config/settings.py:113. The entire OpenAICompatibleReranker infrastructure at backend/app/infrastructure/vectorstore/cross_encoder_reranker.py and the wiring in backend/app/shared/bootstrap.py:199-203 is already complete. Enabling it requires setting two env vars and verifying the wiring works.

Task 1: Document reranker env vars and write a bootstrap verification test

Files:

  • Modify: .env.example

  • Create: tests/test_reranker_bootstrap.py

  • Step 1: Add reranker vars to .env.example

Open .env.example. Locate the existing embedding/LLM block and append:

# ── Reranker (Cross-Encoder) ──────────────────────────────────────────────────
# Set RERANKER_ENABLED=true and point to a TEI or Cohere-compatible rerank API.
# Recommended model: BAAI/bge-reranker-v2.5-gemma2-lightweight (lighter) or
#                    BAAI/bge-reranker-v2-m3 (heavier, higher quality).
# The endpoint must expose POST /rerank (TEI style) or POST /v1/rerank (Cohere style).
RERANKER_ENABLED=true
RERANKER_BASE_URL=http://6.86.80.4:30080/v1
RERANKER_MODEL=BAAI/bge-reranker-v2-m3
RERANKER_API_KEY=
RERANKER_TOP_K=5
  • Step 2: Write the failing test

Create tests/test_reranker_bootstrap.py:

"""Verify that bootstrap correctly wires the reranker when the setting is enabled."""

from unittest.mock import patch

import pytest


def test_get_reranker_returns_none_when_disabled():
    """get_reranker() must return None when reranker_enabled is False."""
    with patch("app.shared.bootstrap._build_binary_store"), \
         patch("app.shared.bootstrap._build_vector_index"):
        # Import fresh by clearing lru_cache between tests
        from app.shared import bootstrap
        bootstrap.get_reranker.cache_clear()

        with patch("app.config.settings.settings") as mock_settings:
            mock_settings.reranker_enabled = False
            mock_settings.reranker_base_url = ""
            result = bootstrap.get_reranker()

        bootstrap.get_reranker.cache_clear()
        assert result is None


def test_get_reranker_returns_instance_when_enabled():
    """get_reranker() must return an OpenAICompatibleReranker when enabled."""
    from app.shared import bootstrap
    bootstrap.get_reranker.cache_clear()

    with patch("app.config.settings.settings") as mock_settings:
        mock_settings.reranker_enabled = True
        mock_settings.reranker_base_url = "http://localhost:8082"
        mock_settings.reranker_model = "BAAI/bge-reranker-v2-m3"
        mock_settings.reranker_api_key = ""
        mock_settings.reranker_top_k = 5
        result = bootstrap.get_reranker()

    bootstrap.get_reranker.cache_clear()
    from app.infrastructure.vectorstore.cross_encoder_reranker import OpenAICompatibleReranker
    assert isinstance(result, OpenAICompatibleReranker)
  • Step 3: Run test to verify it fails (before fix)
PYTHONPATH=backend pytest tests/test_reranker_bootstrap.py -v

Expected: Both tests PASS immediately (the bootstrap logic already exists). If they pass, the wiring is confirmed correct — proceed.

  • Step 4: Enable reranker in your local .env

Add these two lines to your local .env (not .env.example, which is already updated):

RERANKER_ENABLED=true
RERANKER_BASE_URL=http://6.86.80.4:30080/v1
  • Step 5: Verify GET /api/v1/status/health reports reranker enabled

Start the backend and call:

curl http://localhost:8000/api/v1/status/health | python -m json.tool

Expected fragment in response:

"reranker": {
  "enabled": true,
  "model": "BAAI/bge-reranker-v2-m3"
}

Group A complete. The reranker is now active for all KnowledgeRetrievalService.retrieve() calls.


Group B — Async Task-ification

Estimated time: 3-5 days.

Why: upload_document at backend/app/api/routes/documents.py:34 runs the full parse→embed→index pipeline synchronously in the HTTP request. For large GB standards this can take 15+ minutes (Aliyun timeout is 900 s per settings.py:49). Celery 5 and Redis are already in requirements.txt and docker-compose.yml — only the wiring is missing.

Strategy: Split DocumentCommandService.upload_and_process into a fast store_document (sync, returns immediately) and a background _process_document (called from a Celery task). The existing DocumentProcessingStore already records per-stage status events — it becomes the progress tracker. The HTTP route returns {doc_id, status: "queued"} immediately; the client polls GET /documents/status/{doc_id} (already exists) for progress.

Task 2: Create the Celery application

Files:

  • Create: backend/app/infrastructure/tasks/__init__.py

  • Create: backend/app/infrastructure/tasks/celery_app.py

  • Create: tests/test_celery_tasks.py (first test only)

  • Step 1: Write a failing test for Celery app configuration

Create tests/test_celery_tasks.py:

"""Tests for Celery task infrastructure.

These tests verify task registration and Celery app configuration without
starting a real worker or connecting to Redis.
"""

import pytest


def test_celery_app_uses_redis_broker():
    """Celery broker URL must be a Redis URL built from settings."""
    from app.infrastructure.tasks.celery_app import celery_app
    assert celery_app.conf.broker_url.startswith("redis://")


def test_celery_app_uses_redis_backend():
    """Celery result backend must be Redis."""
    from app.infrastructure.tasks.celery_app import celery_app
    assert celery_app.conf.result_backend.startswith("redis://")


def test_celery_app_has_json_serializer():
    """Task serializer must be JSON for portability."""
    from app.infrastructure.tasks.celery_app import celery_app
    assert celery_app.conf.task_serializer == "json"
  • Step 2: Run test to confirm it fails
PYTHONPATH=backend pytest tests/test_celery_tasks.py::test_celery_app_uses_redis_broker -v

Expected: ModuleNotFoundError: No module named 'app.infrastructure.tasks'

  • Step 3: Create the tasks package

Create backend/app/infrastructure/tasks/__init__.py:

"""Celery task definitions for background processing.

This package exposes the shared Celery application instance and all
registered task functions used by API routes to enqueue work.
"""
  • Step 4: Create the Celery app factory

Create backend/app/infrastructure/tasks/celery_app.py:

"""Shared Celery application instance for background task processing.

All workers and enqueueing call sites import `celery_app` from this module
so the broker/backend configuration stays in one place.
"""

from __future__ import annotations

from celery import Celery

from app.config.settings import settings

# Build Redis URL from the same settings used by the rest of the backend.
# The password segment is omitted when redis_password is empty so that
# local development against a password-free Redis instance works out of the box.
def _redis_url() -> str:
    """Return a Redis connection URL from application settings."""
    if settings.redis_password:
        return (
            f"redis://:{settings.redis_password}@"
            f"{settings.redis_host}:{settings.redis_port}/{settings.redis_db}"
        )
    return f"redis://{settings.redis_host}:{settings.redis_port}/{settings.redis_db}"


_BROKER = _redis_url()
_BACKEND = _redis_url()

celery_app = Celery(
    "compliance_hub",
    broker=_BROKER,
    backend=_BACKEND,
    include=["app.infrastructure.tasks.document_tasks"],
)

celery_app.conf.update(
    task_serializer="json",
    result_serializer="json",
    accept_content=["json"],
    timezone="UTC",
    enable_utc=True,
    # Retry failed tasks up to 3 times with exponential backoff.
    task_acks_late=True,
    task_reject_on_worker_lost=True,
    # Keep results for 1 hour for status polling.
    result_expires=3600,
)
  • Step 5: Run the Celery app tests
PYTHONPATH=backend pytest tests/test_celery_tasks.py -v

Expected: test_celery_app_uses_redis_broker PASS, test_celery_app_uses_redis_backend PASS, test_celery_app_has_json_serializer PASS.

Task 3: Split DocumentCommandService and create the document processing task

Files:

  • Modify: backend/app/application/documents/services.py

  • Create: backend/app/infrastructure/tasks/document_tasks.py

  • Step 1: Add task registration test to tests/test_celery_tasks.py

Append to tests/test_celery_tasks.py:

def test_process_document_task_is_registered():
    """process_document_task must be discoverable in the Celery task registry."""
    # Importing document_tasks triggers task registration via the @task decorator.
    import app.infrastructure.tasks.document_tasks  # noqa: F401
    from app.infrastructure.tasks.celery_app import celery_app

    registered = list(celery_app.tasks.keys())
    assert any("process_document_task" in name for name in registered), (
        f"process_document_task not found in {registered}"
    )
  • Step 2: Run test to confirm it fails
PYTHONPATH=backend pytest tests/test_celery_tasks.py::test_process_document_task_is_registered -v

Expected: ModuleNotFoundError: No module named 'app.infrastructure.tasks.document_tasks'

  • Step 3: Extract _process_document from DocumentCommandService

In backend/app/application/documents/services.py, locate upload_and_process (line 233). Add a new method _process_document that contains the parse→embed→index logic, and refactor upload_and_process to call it after the store step.

Find the end of the _safe_mark_run_stored call block (around line 299) and extract everything from the suffix = os.path.splitext(...) line to the end of the try block into a new method:

    def _process_document(
        self,
        *,
        doc_id: str,
        file_name: str,
        final_doc_name: str,
        content: bytes,
        regulation_type: str,
        version: str,
        generate_summary: bool,
        run_id: str | None = None,
    ) -> DocumentProcessResult:
        """Run parse → chunk → embed → index for a document that is already stored.

        This method is called both synchronously (from upload_and_process) and
        asynchronously (from the Celery document_tasks worker). All side-effects
        write through DocumentProcessingStore so callers can poll progress.
        """
        current_status = DocumentStatus.STORED
        current_stage = "parse"
        temp_path = ""
        try:
            suffix = os.path.splitext(file_name)[1]
            with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
                temp_file.write(content)
                temp_path = temp_file.name

            parsed_document = self.parser.parse(
                file_path=temp_path,
                doc_id=doc_id,
                doc_name=final_doc_name,
            )
            self._safe_mark_run_parsed(doc_id=doc_id, run_id=run_id, parsed_document=parsed_document)

            artifact_keys: dict[str, str] = {}
            try:
                artifact_keys = self._save_parse_artifacts(doc_id=doc_id, parsed_document=parsed_document)
            except Exception:
                logger.warning("Parse artifact binary persistence failed for doc_id={}", doc_id)

            self.document_repository.update_status(
                doc_id,
                DocumentStatus.PARSED,
                parser_name=parsed_document.parser_name,
                metadata={
                    "parser_backend": parsed_document.parser_name,
                    "parse_task_id": parsed_document.metadata.get("task_id", ""),
                    "layout_count": parsed_document.metadata.get("layout_count", len(parsed_document.raw_layouts)),
                    "structure_node_count": len(parsed_document.structure_nodes),
                    "semantic_block_count": len(parsed_document.semantic_blocks),
                    "vector_chunk_count": len(parsed_document.vector_chunks),
                    "artifact_keys": artifact_keys,
                    "processing_stage": "parsed",
                },
            )
            current_status = DocumentStatus.PARSED
            current_stage = "embed"
            self._safe_replace_processing_artifacts(doc_id=doc_id, run_id=run_id, artifact_keys=artifact_keys)
            self._safe_append_status_event(
                doc_id=doc_id, run_id=run_id,
                from_status=DocumentStatus.STORED.value, to_status=DocumentStatus.PARSED.value,
                stage="parse", message="Document parsed", metadata={"artifact_count": len(artifact_keys)},
            )
            if self.parse_artifact_store:
                try:
                    self.parse_artifact_store.save(doc_id, parsed_document.structure_nodes, parsed_document.semantic_blocks)
                except Exception:
                    logger.warning("ParseArtifactStore.save failed for doc_id={}", doc_id)

            chunks = self.chunk_builder.build(
                parsed_document=parsed_document,
                regulation_type=regulation_type,
                version=version,
            )
            if not chunks:
                raise ValueError("解析完成但没有生成可入库的 chunks")

            vectors = self.embedding_provider.embed_texts([chunk.embedding_text for chunk in chunks])
            current_stage = "index"
            inserted = self.vector_index.upsert(chunks, vectors)
            if inserted != len(chunks):
                logger.warning("Milvus upsert count mismatched: inserted={}, chunks={}", inserted, len(chunks))

            health = self.vector_index.health()
            index_name = health.get("collection_name", "")
            self.document_repository.update_status(
                doc_id, DocumentStatus.INDEXED,
                chunk_count=len(chunks), summary="", summary_latency_ms=0,
                index_name=index_name,
                metadata={"index_collection": index_name, "processing_stage": "indexed"},
            )
            self._safe_mark_run_indexed(doc_id=doc_id, run_id=run_id, chunk_count=len(chunks), index_name=index_name)
            self._safe_append_status_event(
                doc_id=doc_id, run_id=run_id,
                from_status=DocumentStatus.PARSED.value, to_status=DocumentStatus.INDEXED.value,
                stage="index", message="Document indexed",
                metadata={"chunk_count": len(chunks), "index_name": index_name},
            )
            stored = self.document_repository.get(doc_id)
            return DocumentProcessResult(
                doc_id=doc_id, doc_name=final_doc_name,
                status=(stored.status.value if stored else DocumentStatus.INDEXED.value),
                message="处理成功", num_chunks=len(chunks),
                summary=stored.summary if stored else "",
                summary_latency_ms=stored.summary_latency_ms if stored else 0,
            )
        except Exception as exc:
            logger.exception("文档处理失败: doc_id={}", doc_id)
            self.document_repository.update_status(
                doc_id, DocumentStatus.FAILED, error_message=str(exc),
                metadata={"failure_reason": str(exc), "processing_stage": "failed", "failure_stage": current_stage},
            )
            self._safe_mark_run_failed(doc_id=doc_id, run_id=run_id, failure_stage=current_stage, error_message=str(exc))
            self._safe_append_status_event(
                doc_id=doc_id, run_id=run_id,
                from_status=current_status.value, to_status=DocumentStatus.FAILED.value,
                stage=current_stage, message=str(exc),
            )
            return DocumentProcessResult(
                doc_id=doc_id, doc_name=final_doc_name,
                status=DocumentStatus.FAILED.value, message=f"文档处理失败: {exc}",
            )
        finally:
            if temp_path and os.path.exists(temp_path):
                try:
                    os.remove(temp_path)
                except OSError:
                    logger.warning("临时文件清理失败: {}", temp_path)

Then replace the parse-through-index section of upload_and_process with a call to _process_document:

In upload_and_process, replace everything from suffix = os.path.splitext(file_name)[1] through to the end of the try block (before the except) with:

            # Delegate parse → embed → index to the shared processing method.
            # This same method is invoked by the Celery worker for async processing.
            return self._process_document(
                doc_id=doc_id,
                file_name=file_name,
                final_doc_name=final_doc_name,
                content=content,
                regulation_type=regulation_type,
                version=version,
                generate_summary=generate_summary,
                run_id=run_id,
            )

Also add a new method store_document for the fast sync-only path:

    def store_document(
        self,
        *,
        doc_id: str | None = None,
        file_name: str,
        content: bytes,
        content_type: str,
        doc_name: str | None,
        regulation_type: str,
        version: str,
        generate_summary: bool,
    ) -> tuple[str, str | None]:
        """Store the binary file and create the Document record.

        Returns (doc_id, run_id). Does NOT parse, embed, or index.
        This is the fast synchronous first step; processing is enqueued separately.
        """
        doc_id = doc_id or str(uuid.uuid4())[:8]
        final_doc_name = doc_name or file_name
        object_name = f"{doc_id}/{file_name}"

        document = Document(
            doc_id=doc_id,
            doc_name=final_doc_name,
            file_name=file_name,
            object_name=object_name,
            content_type=content_type,
            size_bytes=len(content),
            regulation_type=regulation_type,
            version=version,
            metadata={"generate_summary": generate_summary},
        )
        self.document_repository.create(document)
        run_id = self._safe_create_processing_run(
            doc_id=doc_id, trigger_type="upload", generate_summary=generate_summary
        )
        self.binary_store.save(
            object_name=object_name, data=content,
            content_type=content_type, metadata={"doc_id": doc_id},
        )
        self.document_repository.update_status(doc_id, DocumentStatus.STORED)
        self._safe_mark_run_stored(doc_id=doc_id, run_id=run_id)
        self._safe_append_status_event(
            doc_id=doc_id, run_id=run_id,
            from_status=DocumentStatus.PENDING.value, to_status=DocumentStatus.STORED.value,
            stage="store", message="Source file stored",
        )
        return doc_id, run_id
  • Step 4: Create the Celery document task

Create backend/app/infrastructure/tasks/document_tasks.py:

"""Celery tasks for document processing.

Each task is a thin wrapper that retrieves the already-stored document
binary and delegates to DocumentCommandService._process_document.
The task does not accept raw file bytes — it reads them from the binary
store using the doc_id, so the payload stays small.
"""

from __future__ import annotations

from loguru import logger

from app.infrastructure.tasks.celery_app import celery_app


@celery_app.task(
    name="app.infrastructure.tasks.document_tasks.process_document_task",
    bind=True,
    max_retries=3,
    default_retry_delay=30,
    acks_late=True,
)
def process_document_task(
    self,
    doc_id: str,
    file_name: str,
    doc_name: str,
    regulation_type: str,
    version: str,
    generate_summary: bool,
    run_id: str | None = None,
) -> dict:
    """Parse, embed, and index a document that has already been stored.

    The task reads the file binary from MinIO using doc_id so the Celery
    message stays small. Retries up to 3 times with a 30-second delay on
    transient infrastructure errors.
    """
    # Import inside the task function to avoid pickling issues and to ensure
    # that each worker process initialises its own bootstrap singletons.
    from app.shared.bootstrap import get_document_command_service, get_document_query_service

    logger.info("process_document_task started: doc_id={}", doc_id)
    try:
        # Read the stored binary from MinIO.
        svc = get_document_command_service()
        doc = get_document_query_service().get(doc_id)
        if not doc:
            raise ValueError(f"Document record not found: {doc_id}")

        content = svc.binary_store.read(doc.object_name)

        result = svc._process_document(
            doc_id=doc_id,
            file_name=file_name,
            final_doc_name=doc_name,
            content=content,
            regulation_type=regulation_type,
            version=version,
            generate_summary=generate_summary,
            run_id=run_id,
        )
        logger.info(
            "process_document_task completed: doc_id={} status={} chunks={}",
            doc_id, result.status, result.num_chunks,
        )
        return {"doc_id": result.doc_id, "status": result.status, "num_chunks": result.num_chunks}

    except Exception as exc:
        logger.exception("process_document_task failed: doc_id={}", doc_id)
        # Retry on transient errors; permanent errors (bad file, parse failure)
        # will exhaust retries and leave the document in FAILED state.
        raise self.retry(exc=exc)
  • Step 5: Run all Celery task tests
PYTHONPATH=backend pytest tests/test_celery_tasks.py -v

Expected: All 4 tests PASS.

Task 4: Update the upload route to support async enqueueing

Files:

  • Modify: backend/app/api/routes/documents.py

  • Modify: backend/app/shared/bootstrap.py

  • Step 1: Add get_celery_app to bootstrap

In backend/app/shared/bootstrap.py, add at the bottom (before preload_runtime_dependencies):

@lru_cache
def get_celery_app():
    """Return the shared Celery application instance."""
    # Import here to avoid importing Celery at module load time when workers
    # are not configured (e.g., during tests that mock bootstrap).
    from app.infrastructure.tasks.celery_app import celery_app
    return celery_app
  • Step 2: Update the upload route

Replace the upload_document handler in backend/app/api/routes/documents.py with:

@router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document(
    file: UploadFile = File(..., description="上传的文档文件"),
    doc_id: str | None = Form(None, description="客户端预分配的文档ID不传则自动生成"),
    doc_name: str | None = Form(None, description="文档名称"),
    regulation_type: str | None = Form(None, description="法规类型"),
    version: str | None = Form(None, description="文档版本"),
    generate_summary: bool = Form(False, description="是否生成摘要"),
    sync: bool = Form(False, description="同步处理(演示/测试用,默认异步)"),
):
    """Upload a document and enqueue it for background processing.

    By default the route stores the binary immediately and enqueues parse→embed→index
    as a Celery task, returning {doc_id, status:'stored'} within seconds.
    Pass sync=true to process inline (useful for demos when no worker is running).
    Poll GET /documents/status/{doc_id} for processing progress.
    """
    content = await file.read()
    if not file.filename:
        raise HTTPException(status_code=400, detail="文件名不能为空")
    if not content:
        raise HTTPException(status_code=400, detail="上传文件为空")

    try:
        svc = get_document_command_service()

        if sync:
            # Synchronous fallback: full inline processing (demo / no-worker mode).
            result = svc.upload_and_process(
                doc_id=doc_id,
                file_name=file.filename,
                content=content,
                content_type=file.content_type or "application/octet-stream",
                doc_name=doc_name,
                regulation_type=regulation_type or "",
                version=version or "",
                generate_summary=generate_summary,
            )
        else:
            # Async default: store binary, enqueue processing task.
            stored_doc_id, run_id = svc.store_document(
                doc_id=doc_id,
                file_name=file.filename,
                content=content,
                content_type=file.content_type or "application/octet-stream",
                doc_name=doc_name,
                regulation_type=regulation_type or "",
                version=version or "",
                generate_summary=generate_summary,
            )
            from app.infrastructure.tasks.document_tasks import process_document_task
            process_document_task.delay(
                doc_id=stored_doc_id,
                file_name=file.filename,
                doc_name=doc_name or file.filename,
                regulation_type=regulation_type or "",
                version=version or "",
                generate_summary=generate_summary,
                run_id=run_id,
            )
            result = DocumentProcessResult(
                doc_id=stored_doc_id,
                doc_name=doc_name or file.filename,
                status="stored",
                message="文件已存储,正在后台处理。请轮询 GET /documents/status/{doc_id} 查看进度。",
            )

        if result.status == "failed":
            raise HTTPException(status_code=500, detail=result.message)
        return _document_response(result)

    except HTTPException:
        raise
    except Exception as exc:
        logger.exception("文档上传失败")
        raise HTTPException(status_code=500, detail=str(exc))

Also add the import at the top of documents.py:

from app.application.documents import DocumentProcessResult
  • Step 3: Verify the app still starts
PYTHONPATH=backend python -c "from app.api.main import app; print('OK')"

Expected: OK with no import errors.

Task 5: Add worker startup to dev scripts

Files:

  • Modify: dev.sh

  • Step 1: Add worker target to dev.sh

Open dev.sh. Find the start command section and add a worker subcommand. Locate the section that handles start api and add after it:

    start_worker() {
        # Start the Celery document processing worker.
        # Reads PYTHONPATH and env from the same root .env the API uses.
        export PYTHONPATH=backend
        celery -A app.infrastructure.tasks.celery_app worker \
            --loglevel=info \
            --concurrency=2 \
            --queues=celery \
            "$@"
    }

    start_beat() {
        # Start the Celery Beat scheduler for periodic tasks (future: perception crawl).
        export PYTHONPATH=backend
        celery -A app.infrastructure.tasks.celery_app beat \
            --loglevel=info \
            "$@"
    }

And in the argument dispatch block, add:

    "worker")
        start_worker "${@:3}"
        ;;
    "beat")
        start_beat "${@:3}"
        ;;
  • Step 2: Verify worker starts (requires Redis running)
./dev.sh start worker --dry-run 2>/dev/null || PYTHONPATH=backend celery -A app.infrastructure.tasks.celery_app inspect ping 2>&1 | head -5

Expected: Either dry-run succeeds or Celery shows it can find the app.

Group B complete. Document uploads now return in <1 s. Workers process in the background. GET /documents/status/{doc_id} reports live progress via DocumentProcessingStore.


Group C — Session Persistence

Estimated time: 2-3 days.

Why: InMemoryConversationStore at backend/app/infrastructure/session/in_memory_conversation_store.py stores all chat sessions in a plain Python dict. Every backend restart loses all active conversations. Redis 7 is already running; persisting sessions there requires only a new adapter and a settings switch.

Task 6: Add fakeredis dev dependency and create the Redis store

Files:

  • Modify: pyproject.toml

  • Create: backend/app/infrastructure/session/redis_conversation_store.py

  • Create: tests/test_redis_conversation_store.py

  • Step 1: Add fakeredis to dev dependencies

In pyproject.toml, find the [dependency-groups] section and update:

[dependency-groups]
dev = ["pytest>=7.0.0", "pytest-asyncio>=0.21.0", "isort>=8.0.1", "fakeredis>=2.0.0"]

Install it:

uv sync
  • Step 2: Write failing tests for RedisConversationStore

Create tests/test_redis_conversation_store.py:

"""Tests for RedisConversationStore.

Uses fakeredis so no real Redis connection is required.
All tests follow the same ConversationStore contract as InMemoryConversationStore.
"""

import json
import pytest
import fakeredis


@pytest.fixture
def redis_client():
    """Return an in-process fake Redis client."""
    return fakeredis.FakeRedis()


@pytest.fixture
def store(redis_client):
    """Return a RedisConversationStore backed by fake Redis."""
    from app.infrastructure.session.redis_conversation_store import RedisConversationStore
    return RedisConversationStore(redis_client=redis_client, timeout_seconds=1800)


def test_create_session_returns_session_with_id(store):
    """create_session() must return a ConversationSession with a non-empty session_id."""
    session = store.create_session()
    assert session.session_id
    assert len(session.session_id) > 0


def test_get_session_returns_same_session(store):
    """get_session() must return the previously created session."""
    session = store.create_session()
    fetched = store.get_session(session.session_id)
    assert fetched is not None
    assert fetched.session_id == session.session_id


def test_get_session_returns_none_for_unknown_id(store):
    """get_session() must return None when the session_id does not exist."""
    result = store.get_session("nonexistent-id")
    assert result is None


def test_save_message_appends_to_session(store):
    """save_message() must append a message and return the updated session."""
    session = store.create_session()
    updated = store.save_message(session.session_id, role="user", content="Hello")
    assert updated is not None
    assert len(updated.messages) == 1
    assert updated.messages[0].role == "user"
    assert updated.messages[0].content == "Hello"


def test_save_message_persists_across_lookups(store):
    """Messages saved to a session must be visible in subsequent get_session calls."""
    session = store.create_session()
    store.save_message(session.session_id, role="user", content="test")
    fetched = store.get_session(session.session_id)
    assert fetched is not None
    assert len(fetched.messages) == 1


def test_delete_session_removes_it(store):
    """delete_session() must return True and remove the session."""
    session = store.create_session()
    result = store.delete_session(session.session_id)
    assert result is True
    assert store.get_session(session.session_id) is None


def test_delete_session_returns_false_for_unknown(store):
    """delete_session() must return False when the session does not exist."""
    result = store.delete_session("ghost-id")
    assert result is False


def test_list_sessions_includes_created_session(store):
    """list_sessions() must include all active sessions."""
    session = store.create_session()
    sessions = store.list_sessions()
    ids = [s["session_id"] for s in sessions]
    assert session.session_id in ids


def test_session_expires_after_ttl(redis_client):
    """Sessions must disappear after the TTL expires."""
    from app.infrastructure.session.redis_conversation_store import RedisConversationStore
    # Use a 1-second TTL for this test.
    store = RedisConversationStore(redis_client=redis_client, timeout_seconds=1)
    session = store.create_session()
    # Manually expire the key to simulate TTL expiry.
    redis_client.expire(f"session:{session.session_id}", 0)
    assert store.get_session(session.session_id) is None
  • Step 3: Run tests to confirm they fail
PYTHONPATH=backend pytest tests/test_redis_conversation_store.py -v

Expected: ModuleNotFoundError: No module named 'app.infrastructure.session.redis_conversation_store'

  • Step 4: Implement RedisConversationStore

Create backend/app/infrastructure/session/redis_conversation_store.py:

"""Redis-backed conversation store for persistent chat sessions.

Sessions are stored as JSON strings under the key `session:{session_id}`.
The Redis TTL is refreshed on every write so active sessions stay alive.
On expiry, `get_session` returns None — callers should create a new session.
"""

from __future__ import annotations

import json
import time
import uuid
from typing import Any

from loguru import logger

from app.domain.conversation import ConversationMessage, ConversationSession, ConversationStore


class RedisConversationStore(ConversationStore):
    """Store conversation sessions in Redis with automatic TTL expiry.

    Each session is serialised as a JSON object at key ``session:{session_id}``.
    The TTL is reset on every write so sessions stay alive as long as they are active.
    """

    # Prefix for all session keys to avoid collisions with other Redis consumers.
    _PREFIX = "session:"

    def __init__(self, *, redis_client: Any, timeout_seconds: int = 1800) -> None:
        """Initialise the store with an existing Redis client and a TTL in seconds."""
        self._redis = redis_client
        self._ttl = timeout_seconds

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _key(self, session_id: str) -> str:
        """Build the Redis key for a session."""
        return f"{self._PREFIX}{session_id}"

    def _serialise(self, session: ConversationSession) -> str:
        """Serialise a ConversationSession to a JSON string."""
        return json.dumps(
            {
                "session_id": session.session_id,
                "created_at": session.created_at,
                "updated_at": session.updated_at,
                "metadata": session.metadata,
                "messages": [
                    {
                        "role": msg.role,
                        "content": msg.content,
                        "timestamp": msg.timestamp,
                        "sources": msg.sources,
                    }
                    for msg in session.messages
                ],
            },
            ensure_ascii=False,
        )

    def _deserialise(self, raw: bytes | str) -> ConversationSession:
        """Deserialise a JSON string back into a ConversationSession."""
        data = json.loads(raw)
        messages = [
            ConversationMessage(
                role=m["role"],
                content=m["content"],
                timestamp=m["timestamp"],
                sources=m.get("sources", []),
            )
            for m in data.get("messages", [])
        ]
        session = ConversationSession(
            session_id=data["session_id"],
            created_at=data.get("created_at", 0),
            updated_at=data.get("updated_at", 0),
            metadata=data.get("metadata", {}),
        )
        session.messages = messages
        return session

    def _save(self, session: ConversationSession) -> None:
        """Persist a session to Redis and refresh its TTL."""
        self._redis.setex(self._key(session.session_id), self._ttl, self._serialise(session))

    # ------------------------------------------------------------------
    # ConversationStore protocol
    # ------------------------------------------------------------------

    def create_session(self, metadata: dict | None = None) -> ConversationSession:
        """Create a new empty session and persist it immediately."""
        now = int(time.time())
        session = ConversationSession(
            session_id=str(uuid.uuid4())[:8],
            created_at=now,
            updated_at=now,
            metadata=metadata or {},
        )
        self._save(session)
        return session

    def get_session(self, session_id: str) -> ConversationSession | None:
        """Return a session by ID, or None if it does not exist or has expired."""
        raw = self._redis.get(self._key(session_id))
        if raw is None:
            return None
        try:
            return self._deserialise(raw)
        except Exception:
            logger.warning("Failed to deserialise session: {}", session_id)
            return None

    def save_message(
        self,
        session_id: str,
        *,
        role: str,
        content: str,
        sources: list[dict] | None = None,
    ) -> ConversationSession | None:
        """Append a message to a session and refresh its TTL."""
        session = self.get_session(session_id)
        if session is None:
            return None
        session.messages.append(
            ConversationMessage(
                role=role,
                content=content,
                timestamp=int(time.time()),
                sources=sources or [],
            )
        )
        session.updated_at = int(time.time())
        self._save(session)
        return session

    def delete_session(self, session_id: str) -> bool:
        """Delete a session. Returns True if it existed, False otherwise."""
        deleted = self._redis.delete(self._key(session_id))
        return bool(deleted)

    def list_sessions(self) -> list[dict]:
        """Return summary dicts for all live sessions visible in this Redis DB.

        Note: KEYS is used here for simplicity; in a large production deployment
        replace with SCAN for non-blocking iteration.
        """
        pattern = f"{self._PREFIX}*"
        keys = self._redis.keys(pattern)
        result = []
        for key in keys:
            raw = self._redis.get(key)
            if raw is None:
                continue
            try:
                data = json.loads(raw)
                result.append(
                    {
                        "session_id": data["session_id"],
                        "message_count": len(data.get("messages", [])),
                        "created_at": data.get("created_at", 0),
                        "updated_at": data.get("updated_at", 0),
                    }
                )
            except Exception:
                continue
        return result
  • Step 5: Run tests
PYTHONPATH=backend pytest tests/test_redis_conversation_store.py -v

Expected: All 9 tests PASS.

Task 7: Add session_backend setting and switch bootstrap

Files:

  • Modify: backend/app/config/settings.py

  • Modify: backend/app/shared/bootstrap.py

  • Step 1: Add session_backend setting

In backend/app/config/settings.py, find the session block (around line 125) and add:

    session_backend: str = Field(
        default="memory",
        description="会话存储后端 (memory | redis)。redis 需要 Redis 可用。",
    )
  • Step 2: Update get_conversation_store in bootstrap

Replace the existing get_conversation_store function in backend/app/shared/bootstrap.py:

@lru_cache
def get_conversation_store() -> ConversationStore:
    """Return the active conversation store based on settings.

    When session_backend='redis', sessions survive backend restarts.
    When session_backend='memory' (default), sessions are process-local.
    """
    if settings.session_backend == "redis":
        import redis as redis_lib
        from app.infrastructure.session.redis_conversation_store import RedisConversationStore

        # Build the Redis client from the same connection settings as Celery.
        kwargs: dict = {
            "host": settings.redis_host,
            "port": settings.redis_port,
            "db": settings.redis_db,
            "decode_responses": False,
        }
        if settings.redis_password:
            kwargs["password"] = settings.redis_password

        redis_client = redis_lib.Redis(**kwargs)
        return RedisConversationStore(
            redis_client=redis_client,
            timeout_seconds=settings.session_timeout_minutes * 60,
        )
    return InMemoryConversationStore(
        max_sessions=settings.session_max_sessions,
        timeout_minutes=settings.session_timeout_minutes,
    )

Also add the ConversationStore type hint to the import block at the top of bootstrap.py:

from app.domain.conversation import ConversationStore
  • Step 3: Add to .env.example

Append to .env.example:

# ── Session backend ───────────────────────────────────────────────────────────
# Set SESSION_BACKEND=redis for persistent sessions (requires Redis).
# Use SESSION_BACKEND=memory for development without Redis.
SESSION_BACKEND=redis
  • Step 4: Verify app starts with the new setting
PYTHONPATH=backend python -c "
from app.config.settings import settings
from app.shared.bootstrap import get_conversation_store
get_conversation_store.cache_clear()
print('session_backend:', settings.session_backend)
print('OK')
"

Expected: prints session_backend: memory (or redis if set in .env) then OK.

Group C complete. Set SESSION_BACKEND=redis in .env to persist sessions across restarts.


Group D — JWT Authentication + RBAC + Audit Logging

Estimated time: 3-5 days.

Why: backend/app/api/main.py:49 has allow_origins=["*"]. No route requires any credential. PPT Slide 12 defines a four-role permission matrix; Slide 9 lists "RBAC" as a key challenge response.

Approach for MVP: JWT tokens carry a role claim (admin/legal/ehs/readonly). All API routes require a valid token. Fine-grained per-route RBAC enforcement is a follow-up; the role field is the enforcement hook. An audit middleware logs every authenticated API call. Users are stored in PostgreSQL.

Task 8: Add auth dependencies

Files:

  • Modify: requirements.txt

  • Modify: pyproject.toml

  • Step 1: Add python-jose and passlib to requirements.txt

In requirements.txt, add after the existing tool libs block:

# Authentication
python-jose[cryptography]>=3.3.0
passlib[bcrypt]>=1.7.4
  • Step 2: Add to pyproject.toml dependencies list

In pyproject.toml under dependencies, add:

    "python-jose[cryptography]>=3.3.0",
    "passlib[bcrypt]>=1.7.4",
  • Step 3: Install
uv sync

Expected: No errors.

Task 9: Create auth domain models

Files:

  • Create: backend/app/domain/auth/__init__.py

  • Create: backend/app/domain/auth/models.py

  • Step 1: Write failing test

Create tests/test_jwt_handler.py:

"""Tests for JWTHandler — token creation and decoding.

These tests do not require a running server or database.
"""

import time
import pytest

SECRET = "test-secret-key-minimum-32-characters-long"


@pytest.fixture
def handler():
    """Return a JWTHandler configured with a test secret."""
    from app.infrastructure.auth.jwt_handler import JWTHandler
    return JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=30)


def test_create_token_returns_string(handler):
    """create_access_token must return a non-empty string."""
    token = handler.create_access_token(user_id="u1", username="alice", role="admin")
    assert isinstance(token, str)
    assert len(token) > 20


def test_decode_token_returns_correct_claims(handler):
    """decode_token must return UserClaims matching the input."""
    token = handler.create_access_token(user_id="u1", username="alice", role="admin")
    claims = handler.decode_token(token)
    assert claims.user_id == "u1"
    assert claims.username == "alice"
    assert claims.role == "admin"


def test_decode_expired_token_raises(handler):
    """decode_token must raise ValueError on an expired token."""
    from app.infrastructure.auth.jwt_handler import JWTHandler
    # Create a token that expires in 0 minutes (immediately expired).
    short_handler = JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=0)
    token = short_handler.create_access_token(user_id="u2", username="bob", role="readonly")
    time.sleep(1)
    with pytest.raises(ValueError, match="expired"):
        short_handler.decode_token(token)


def test_decode_invalid_token_raises(handler):
    """decode_token must raise ValueError for a tampered token."""
    with pytest.raises(ValueError):
        handler.decode_token("not.a.valid.jwt.token")


def test_decode_wrong_secret_raises():
    """decode_token must raise ValueError when signed with a different secret."""
    from app.infrastructure.auth.jwt_handler import JWTHandler
    creator = JWTHandler(secret_key=SECRET, algorithm="HS256", expire_minutes=60)
    verifier = JWTHandler(secret_key="wrong-secret-key-also-minimum-32-chars", algorithm="HS256", expire_minutes=60)
    token = creator.create_access_token(user_id="u3", username="carol", role="legal")
    with pytest.raises(ValueError):
        verifier.decode_token(token)
  • Step 2: Run test to confirm failure
PYTHONPATH=backend pytest tests/test_jwt_handler.py -v

Expected: ModuleNotFoundError: No module named 'app.domain.auth'

  • Step 3: Create domain auth package

Create backend/app/domain/auth/__init__.py:

"""Auth domain: role definitions and token claim models.

The domain layer defines what a user identity looks like (UserClaims) and
what roles exist (UserRole). Infrastructure details (JWT, bcrypt, PostgreSQL)
live under infrastructure/auth and never leak into this package.
"""

from .models import UserClaims, UserRole

__all__ = ["UserClaims", "UserRole"]
  • Step 4: Create auth domain models

Create backend/app/domain/auth/models.py:

"""Auth domain models: roles and token claims.

UserRole defines the four roles from PPT Slide 12.
UserClaims is what the JWT decodes to — it is the identity object passed
through FastAPI dependency injection to route handlers.
"""

from __future__ import annotations

import enum
from dataclasses import dataclass


class UserRole(str, enum.Enum):
    """Access roles mirroring the four-role RBAC matrix from the product spec.

    ADMIN      — full platform access including system management.
    LEGAL      — knowledge query, document review, compliance checks.
    EHS        — knowledge query, perception/regulatory signals.
    READONLY   — knowledge query only.
    """

    ADMIN = "admin"
    LEGAL = "legal"
    EHS = "ehs"
    READONLY = "readonly"


@dataclass
class UserClaims:
    """Decoded JWT payload representing an authenticated user.

    Instances are created by JWTHandler.decode_token() and injected into
    route handlers via the get_current_user FastAPI dependency.
    """

    # Unique user identifier (UUID string stored in PostgreSQL users table).
    user_id: str
    # Display name used for audit log entries.
    username: str
    # Role determines which resources the user may access.
    role: UserRole

Task 10: Create JWT handler infrastructure

Files:

  • Create: backend/app/infrastructure/auth/__init__.py

  • Create: backend/app/infrastructure/auth/jwt_handler.py

  • Step 1: Create infrastructure auth package

Create backend/app/infrastructure/auth/__init__.py:

"""JWT token creation and validation infrastructure.

JWTHandler is the only component in this package. It is wired through
shared/bootstrap.py and injected into FastAPI dependencies.
"""
  • Step 2: Implement JWTHandler

Create backend/app/infrastructure/auth/jwt_handler.py:

"""JWT access token creation and decoding.

Uses python-jose for HS256 token signing. Token expiry is enforced at
decode time so expired tokens are rejected even if the signature is valid.
"""

from __future__ import annotations

from datetime import UTC, datetime, timedelta
from typing import Any

from jose import JWTError, jwt
from loguru import logger

from app.domain.auth.models import UserClaims, UserRole


class JWTHandler:
    """Create and validate HS256 JWT access tokens.

    A single shared instance is wired by bootstrap.py. Use
    get_jwt_handler() from shared.bootstrap for all token operations.
    """

    def __init__(
        self,
        *,
        secret_key: str,
        algorithm: str = "HS256",
        expire_minutes: int = 480,
    ) -> None:
        """Initialise the handler with signing credentials and token lifetime."""
        self._secret = secret_key
        self._algorithm = algorithm
        self._expire_minutes = expire_minutes

    def create_access_token(
        self,
        *,
        user_id: str,
        username: str,
        role: str,
    ) -> str:
        """Return a signed JWT containing user identity and role claims."""
        now = datetime.now(UTC)
        payload: dict[str, Any] = {
            "sub": user_id,
            "username": username,
            "role": role,
            "iat": now,
            "exp": now + timedelta(minutes=self._expire_minutes),
        }
        return jwt.encode(payload, self._secret, algorithm=self._algorithm)

    def decode_token(self, token: str) -> UserClaims:
        """Decode and validate a JWT, returning UserClaims.

        Raises ValueError with a descriptive message on expiry, tampering,
        or any other validation failure so callers do not need to know jose.
        """
        try:
            payload = jwt.decode(token, self._secret, algorithms=[self._algorithm])
        except JWTError as exc:
            msg = str(exc).lower()
            if "expired" in msg:
                raise ValueError("Token expired") from exc
            raise ValueError(f"Invalid token: {exc}") from exc

        user_id = payload.get("sub")
        username = payload.get("username", "")
        role_str = payload.get("role", UserRole.READONLY.value)

        if not user_id:
            raise ValueError("Token missing subject claim")

        try:
            role = UserRole(role_str)
        except ValueError:
            logger.warning("Unknown role in token: {}, defaulting to readonly", role_str)
            role = UserRole.READONLY

        return UserClaims(user_id=user_id, username=username, role=role)
  • Step 3: Run JWT tests
PYTHONPATH=backend pytest tests/test_jwt_handler.py -v

Expected: All 5 tests PASS.

Task 11: Create PostgreSQL user store and seed script

Files:

  • Create: backend/app/infrastructure/auth/user_store.py

  • Create: scripts/seed_users.py

  • Step 1: Implement PostgreSQL user store

Create backend/app/infrastructure/auth/user_store.py:

"""PostgreSQL-backed user store for authentication.

Manages a `users` table with hashed passwords and roles.
Provides lookup by username for the login flow.
Table DDL is auto-applied on first connection.
"""

from __future__ import annotations

import uuid
from dataclasses import dataclass
from typing import Optional

import psycopg2
import psycopg2.extras
from loguru import logger
from passlib.context import CryptContext

from app.config.settings import settings


# bcrypt context — work factor 12 is a good production default.
_PWD_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto")

# DDL executed once to ensure the table exists.
_CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS users (
    id          UUID        PRIMARY KEY DEFAULT gen_random_uuid(),
    username    VARCHAR(100) UNIQUE NOT NULL,
    hashed_pw   TEXT         NOT NULL,
    role        VARCHAR(50)  NOT NULL DEFAULT 'readonly',
    is_active   BOOLEAN      NOT NULL DEFAULT TRUE,
    created_at  TIMESTAMPTZ  NOT NULL DEFAULT NOW()
);
"""


@dataclass
class UserRecord:
    """A single row from the users table."""

    id: str
    username: str
    hashed_pw: str
    role: str
    is_active: bool


class PostgresUserStore:
    """Read and verify users stored in the PostgreSQL users table.

    The connection is opened on first use and shared for the lifetime
    of the singleton instance wired by bootstrap.
    """

    def __init__(self) -> None:
        """Initialise the store and ensure the users table exists."""
        self._conn = psycopg2.connect(
            host=settings.postgres_host,
            port=settings.postgres_port,
            user=settings.postgres_user,
            password=settings.postgres_password,
            dbname=settings.postgres_db,
            cursor_factory=psycopg2.extras.RealDictCursor,
        )
        self._conn.autocommit = True
        self._ensure_table()

    def _ensure_table(self) -> None:
        """Create the users table if it does not already exist."""
        with self._conn.cursor() as cur:
            cur.execute(_CREATE_TABLE_SQL)

    def get_by_username(self, username: str) -> Optional[UserRecord]:
        """Return a UserRecord for the given username, or None if not found."""
        with self._conn.cursor() as cur:
            cur.execute(
                "SELECT id, username, hashed_pw, role, is_active "
                "FROM users WHERE username = %s",
                (username,),
            )
            row = cur.fetchone()
        if row is None:
            return None
        return UserRecord(
            id=str(row["id"]),
            username=row["username"],
            hashed_pw=row["hashed_pw"],
            role=row["role"],
            is_active=row["is_active"],
        )

    def verify_password(self, plain: str, hashed: str) -> bool:
        """Return True if `plain` matches the stored bcrypt hash."""
        return _PWD_CTX.verify(plain, hashed)

    def authenticate(self, username: str, password: str) -> Optional[UserRecord]:
        """Return the UserRecord if credentials are valid, else None."""
        user = self.get_by_username(username)
        if user is None or not user.is_active:
            return None
        if not self.verify_password(password, user.hashed_pw):
            return None
        return user

    @staticmethod
    def hash_password(plain: str) -> str:
        """Hash a plain-text password with bcrypt."""
        return _PWD_CTX.hash(plain)
  • Step 2: Create the seed script

Create scripts/seed_users.py:

#!/usr/bin/env python
"""Seed demo users into the PostgreSQL users table.

Run from the repo root:
    PYTHONPATH=backend python scripts/seed_users.py

Creates four demo accounts (one per role). Safe to run multiple times
— existing usernames are left unchanged (ON CONFLICT DO NOTHING).
"""

import sys
import os

# Allow running from repo root without installing the package.
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend"))

import psycopg2
from passlib.context import CryptContext
from app.config.settings import settings

_PWD_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto")

DEMO_USERS = [
    {"username": "admin",    "password": "Admin@2026!",   "role": "admin"},
    {"username": "legal",    "password": "Legal@2026!",   "role": "legal"},
    {"username": "ehs",      "password": "EHS@2026!",     "role": "ehs"},
    {"username": "readonly", "password": "Read@2026!",    "role": "readonly"},
]


def seed() -> None:
    """Insert demo users. Skips any username that already exists."""
    conn = psycopg2.connect(
        host=settings.postgres_host,
        port=settings.postgres_port,
        user=settings.postgres_user,
        password=settings.postgres_password,
        dbname=settings.postgres_db,
    )
    conn.autocommit = True

    # Ensure the table exists (same DDL as user_store.py).
    with conn.cursor() as cur:
        cur.execute("""
            CREATE TABLE IF NOT EXISTS users (
                id          UUID        PRIMARY KEY DEFAULT gen_random_uuid(),
                username    VARCHAR(100) UNIQUE NOT NULL,
                hashed_pw   TEXT         NOT NULL,
                role        VARCHAR(50)  NOT NULL DEFAULT 'readonly',
                is_active   BOOLEAN      NOT NULL DEFAULT TRUE,
                created_at  TIMESTAMPTZ  NOT NULL DEFAULT NOW()
            );
        """)

    inserted = 0
    with conn.cursor() as cur:
        for user in DEMO_USERS:
            hashed = _PWD_CTX.hash(user["password"])
            cur.execute(
                """
                INSERT INTO users (username, hashed_pw, role)
                VALUES (%s, %s, %s)
                ON CONFLICT (username) DO NOTHING
                """,
                (user["username"], hashed, user["role"]),
            )
            if cur.rowcount > 0:
                print(f"  Created user: {user['username']} (role={user['role']})")
                inserted += 1
            else:
                print(f"  Skipped (already exists): {user['username']}")

    print(f"\nDone. {inserted} user(s) created.")
    conn.close()


if __name__ == "__main__":
    seed()
  • Step 3: Seed demo users (requires PostgreSQL running)
PYTHONPATH=backend python scripts/seed_users.py

Expected:

  Created user: admin (role=admin)
  Created user: legal (role=legal)
  Created user: ehs (role=ehs)
  Created user: readonly (role=readonly)

Done. 4 user(s) created.

Task 12: Create FastAPI auth dependencies and the token endpoint

Files:

  • Create: backend/app/api/dependencies/__init__.py

  • Create: backend/app/api/dependencies/auth.py

  • Create: backend/app/api/routes/auth.py

  • Modify: backend/app/config/settings.py

  • Modify: backend/app/shared/bootstrap.py

  • Step 1: Add auth settings

In backend/app/config/settings.py, add to the Settings class:

    # ── Auth ─────────────────────────────────────────────────────────────
    # Change auth_secret_key to a long random string in production.
    # Generate one with: python -c "import secrets; print(secrets.token_hex(32))"
    auth_secret_key: str = Field(
        default="change-me-in-production-must-be-32-or-more-characters-long",
        description="JWT signing secret. MUST be changed in production.",
    )
    auth_algorithm: str = Field(default="HS256", description="JWT signing algorithm.")
    auth_token_expire_minutes: int = Field(default=480, description="JWT TTL in minutes (default: 8 hours).")
    auth_enabled: bool = Field(default=True, description="Set False to bypass auth (development only).")
  • Step 2: Wire auth components in bootstrap

Add to backend/app/shared/bootstrap.py:

@lru_cache
def get_jwt_handler():
    """Return the shared JWTHandler instance."""
    from app.infrastructure.auth.jwt_handler import JWTHandler
    return JWTHandler(
        secret_key=settings.auth_secret_key,
        algorithm=settings.auth_algorithm,
        expire_minutes=settings.auth_token_expire_minutes,
    )


@lru_cache
def get_user_store():
    """Return the PostgreSQL user store (lazy, connects on first call)."""
    from app.infrastructure.auth.user_store import PostgresUserStore
    return PostgresUserStore()
  • Step 3: Create the dependencies package

Create backend/app/api/dependencies/__init__.py:

"""FastAPI dependency functions for authentication and authorisation.

Import `get_current_user` or `require_role` into route modules to protect
endpoints. Both use the shared JWTHandler wired through bootstrap.
"""
  • Step 4: Implement auth dependencies

Create backend/app/api/dependencies/auth.py:

"""FastAPI dependencies for JWT authentication.

Usage in a route:
    from app.api.dependencies.auth import get_current_user, require_role
    from app.domain.auth.models import UserRole

    @router.get("/protected")
    async def protected(user: UserClaims = Depends(get_current_user)):
        return {"user": user.username}

    @router.delete("/admin-only")
    async def admin_only(user: UserClaims = Depends(require_role(UserRole.ADMIN))):
        ...
"""

from __future__ import annotations

from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from app.config.settings import settings
from app.domain.auth.models import UserClaims, UserRole
from app.shared.bootstrap import get_jwt_handler

# Use Bearer token scheme — client sends `Authorization: Bearer <token>`.
_bearer = HTTPBearer(auto_error=False)


async def get_current_user(
    credentials: HTTPAuthorizationCredentials | None = Depends(_bearer),
) -> UserClaims:
    """Extract and validate the JWT from the Authorization header.

    Returns the decoded UserClaims on success.
    Raises HTTP 401 when the token is missing, expired, or invalid.
    When auth_enabled=False (development), returns a synthetic admin user.
    """
    if not settings.auth_enabled:
        # Development bypass — never enable this in production.
        return UserClaims(user_id="dev", username="dev-admin", role=UserRole.ADMIN)

    if credentials is None:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Missing authentication token",
            headers={"WWW-Authenticate": "Bearer"},
        )
    try:
        return get_jwt_handler().decode_token(credentials.credentials)
    except ValueError as exc:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=str(exc),
            headers={"WWW-Authenticate": "Bearer"},
        ) from exc


def require_role(*roles: UserRole):
    """Return a dependency that enforces one of the given roles.

    Example:
        Depends(require_role(UserRole.ADMIN, UserRole.LEGAL))
    """
    async def _check(user: UserClaims = Depends(get_current_user)) -> UserClaims:
        """Verify the user holds one of the required roles."""
        if user.role not in roles:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail=f"Role '{user.role}' is not permitted. Required: {[r.value for r in roles]}",
            )
        return user
    return _check
  • Step 5: Create the auth router

Create backend/app/api/routes/auth.py:

"""Authentication routes — token issuance only.

POST /auth/token  — exchange username + password for a JWT.
GET  /auth/me     — return the current user's identity (requires token).
"""

from __future__ import annotations

from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel

from app.api.dependencies.auth import get_current_user
from app.domain.auth.models import UserClaims
from app.shared.bootstrap import get_jwt_handler, get_user_store


router = APIRouter(prefix="/auth", tags=["认证"])


class TokenResponse(BaseModel):
    """JWT token response body."""

    access_token: str
    token_type: str = "bearer"
    expires_in: int


@router.post("/token", response_model=TokenResponse)
async def login(form: OAuth2PasswordRequestForm = Depends()):
    """Issue a JWT for valid username + password credentials.

    Uses standard OAuth2 password grant form fields so this endpoint
    is compatible with Swagger UI's Authorize button.
    """
    user_store = get_user_store()
    user = user_store.authenticate(form.username, form.password)
    if user is None:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Incorrect username or password",
            headers={"WWW-Authenticate": "Bearer"},
        )

    from app.config.settings import settings
    token = get_jwt_handler().create_access_token(
        user_id=user.id,
        username=user.username,
        role=user.role,
    )
    return TokenResponse(
        access_token=token,
        token_type="bearer",
        expires_in=settings.auth_token_expire_minutes * 60,
    )


@router.get("/me")
async def get_me(current_user: UserClaims = Depends(get_current_user)):
    """Return the identity of the currently authenticated user."""
    return {
        "user_id": current_user.user_id,
        "username": current_user.username,
        "role": current_user.role.value,
    }
  • Step 6: Register the auth router

In backend/app/api/routes/__init__.py, add:

from .auth import router as auth_router

And in the api_router.include_router(...) block, add:

api_router.include_router(auth_router)

Also add auth_router to __all__.

  • Step 7: Write auth route tests

Create tests/test_auth_routes.py:

"""Integration tests for the auth routes.

Uses FastAPI TestClient. Does not require a running PostgreSQL — patches
the user_store dependency.
"""

import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch

from app.api.main import app
from app.infrastructure.auth.user_store import UserRecord

client = TestClient(app, raise_server_exceptions=False)


@pytest.fixture(autouse=True)
def mock_user_store():
    """Replace the real PostgreSQL user store with a mock."""
    mock_store = MagicMock()
    alice = UserRecord(id="uuid-1", username="alice", hashed_pw="hashed", role="admin", is_active=True)
    mock_store.authenticate.side_effect = lambda u, p: alice if (u == "alice" and p == "correct") else None

    with patch("app.api.routes.auth.get_user_store", return_value=mock_store):
        yield mock_store


def test_login_returns_token_for_valid_credentials():
    """POST /auth/token must return an access_token for valid credentials."""
    resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "correct"})
    assert resp.status_code == 200
    body = resp.json()
    assert "access_token" in body
    assert body["token_type"] == "bearer"


def test_login_returns_401_for_wrong_password():
    """POST /auth/token must return 401 for wrong password."""
    resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "wrong"})
    assert resp.status_code == 401


def test_me_returns_user_when_authenticated():
    """GET /auth/me must return user identity when a valid token is provided."""
    # Get a real token first.
    login_resp = client.post("/api/v1/auth/token", data={"username": "alice", "password": "correct"})
    token = login_resp.json()["access_token"]

    me_resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
    assert me_resp.status_code == 200
    assert me_resp.json()["username"] == "alice"
    assert me_resp.json()["role"] == "admin"


def test_me_returns_401_without_token():
    """GET /auth/me must return 401 when no token is provided."""
    resp = client.get("/api/v1/auth/me")
    assert resp.status_code == 401
  • Step 8: Run auth tests
PYTHONPATH=backend pytest tests/test_auth_routes.py -v

Expected: All 4 tests PASS.

Task 13: Add audit middleware and apply auth to routes

Files:

  • Create: backend/app/api/middleware/__init__.py

  • Create: backend/app/api/middleware/audit.py

  • Modify: backend/app/api/main.py

  • Modify: backend/app/api/routes/documents.py

  • Modify: backend/app/api/routes/rag.py

  • Modify: backend/app/api/routes/compliance.py

  • Step 1: Create the audit middleware

Create backend/app/api/middleware/__init__.py:

"""HTTP middleware for cross-cutting concerns: audit logging."""

Create backend/app/api/middleware/audit.py:

"""Audit logging middleware.

Logs every API request with method, path, status code, response time,
and the authenticated user identity (extracted from the JWT when present).
Log lines are structured so they can be ingested by ELK / Loki.
"""

from __future__ import annotations

import time

from fastapi import Request, Response
from loguru import logger
from starlette.middleware.base import BaseHTTPMiddleware


class AuditMiddleware(BaseHTTPMiddleware):
    """Log all API calls. Skips /health and /docs to reduce noise."""

    # Paths that produce no audit log entry.
    _SKIP_PREFIXES = ("/health", "/docs", "/redoc", "/openapi.json", "/")

    async def dispatch(self, request: Request, call_next) -> Response:
        """Intercept the request, call the handler, and log the outcome."""
        path = request.url.path
        if any(path.startswith(p) for p in self._SKIP_PREFIXES):
            return await call_next(request)

        start = time.perf_counter()
        response = await call_next(request)
        elapsed_ms = int((time.perf_counter() - start) * 1000)

        # Try to extract user identity from a decoded JWT header if present.
        # We do not re-validate the token here — auth dependencies do that.
        user_id = "anonymous"
        username = "anonymous"
        auth_header = request.headers.get("authorization", "")
        if auth_header.startswith("Bearer "):
            try:
                from app.shared.bootstrap import get_jwt_handler
                claims = get_jwt_handler().decode_token(auth_header[7:])
                user_id = claims.user_id
                username = claims.username
            except Exception:
                pass

        logger.info(
            "AUDIT method={} path={} status={} elapsed_ms={} user_id={} username={}",
            request.method,
            path,
            response.status_code,
            elapsed_ms,
            user_id,
            username,
        )
        return response
  • Step 2: Register the middleware and tighten CORS in main.py

In backend/app/api/main.py, replace the CORSMiddleware block and add the audit middleware:

from app.api.middleware.audit import AuditMiddleware

# Tighten CORS — only allow the configured frontend origin.
# In production, set CORS_ALLOW_ORIGINS in .env to the actual frontend URL.
_ORIGINS = [o.strip() for o in settings.cors_allow_origins.split(",") if o.strip()]
if not _ORIGINS:
    _ORIGINS = ["http://localhost:5173"]  # Vite dev default

app.add_middleware(
    CORSMiddleware,
    allow_origins=_ORIGINS,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

app.add_middleware(AuditMiddleware)

Also add cors_allow_origins to settings.py:

    cors_allow_origins: str = Field(
        default="http://localhost:5173",
        description="Comma-separated allowed CORS origins. Use * only in development.",
    )

And add to .env.example:

# ── CORS ─────────────────────────────────────────────────────────────────────
# Comma-separated frontend origin(s). Never use * in production.
CORS_ALLOW_ORIGINS=http://localhost:5173
  • Step 3: Apply Depends(get_current_user) to key routes

In backend/app/api/routes/documents.py, add to imports:

from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from app.api.dependencies.auth import get_current_user
from app.domain.auth.models import UserClaims

Add current_user: UserClaims = Depends(get_current_user) as a parameter to upload_document, list_documents, get_document_management_list, and delete_document:

@router.post("/upload", response_model=DocumentUploadResponse)
async def upload_document(
    file: UploadFile = File(...),
    doc_id: str | None = Form(None),
    doc_name: str | None = Form(None),
    regulation_type: str | None = Form(None),
    version: str | None = Form(None),
    generate_summary: bool = Form(False),
    sync: bool = Form(False),
    current_user: UserClaims = Depends(get_current_user),  # ← add this
):

Apply the same pattern to:

  • list_documents in documents.py
  • delete_document in documents.py
  • rag_chat in rag.py
  • analyze_stream in compliance.py

In backend/app/api/routes/rag.py, add:

from fastapi import APIRouter, Depends
from app.api.dependencies.auth import get_current_user
from app.domain.auth.models import UserClaims
@router.post("/chat")
async def rag_chat(
    request: RagChatRequest,
    current_user: UserClaims = Depends(get_current_user),  # ← add this
):

In backend/app/api/routes/compliance.py, add:

from fastapi import APIRouter, Depends, File, Form, UploadFile
from app.api.dependencies.auth import get_current_user
from app.domain.auth.models import UserClaims
@router.post("/analyze-stream")
async def analyze_stream(
    text: Optional[str] = Form(None),
    doc_id: Optional[str] = Form(None),
    file: Optional[UploadFile] = File(None),
    domains: Optional[str] = Form(None),
    title: Optional[str] = Form(None),
    current_user: UserClaims = Depends(get_current_user),  # ← add this
):
  • Step 4: Verify app starts and auth is enforced
PYTHONPATH=backend python -c "from app.api.main import app; print('OK')"

Expected: OK

Test that an unauth'd request is rejected:

curl -s -o /dev/null -w "%{http_code}" http://localhost:8000/api/v1/rag/chat \
  -X POST -H "Content-Type: application/json" -d '{"query":"test"}'

Expected: 401

Test that an auth'd request passes:

TOKEN=$(curl -s -X POST http://localhost:8000/api/v1/auth/token \
  -d "username=admin&password=Admin@2026!" | python -c "import sys,json; print(json.load(sys.stdin)['access_token'])")

curl -s -o /dev/null -w "%{http_code}" http://localhost:8000/api/v1/rag/quick-questions \
  -H "Authorization: Bearer $TOKEN"

Expected: 200

  • Step 5: Run the full test suite
PYTHONPATH=backend pytest tests/test_reranker_bootstrap.py tests/test_celery_tasks.py tests/test_redis_conversation_store.py tests/test_jwt_handler.py tests/test_auth_routes.py -v

Expected: All tests PASS.


Self-Review Checklist

Spec coverage:

  • §3.1 Reranker quick win → Task 1
  • §3.2 Async task-ification (Celery + Redis) → Tasks 2-5
  • §3.3 Session persistence → Tasks 6-7
  • §3.4 Auth/RBAC/Audit → Tasks 8-13
  • AGENTS.md: api → application → domain ports → infrastructure respected throughout
  • No new logic in services/* or workflows/*
  • shared/bootstrap.py is the composition root for all new singletons
  • All comments/docstrings in new backend files are in English
  • Desktop-first frontend constraint: no frontend changes in this plan

Placeholder scan: None found — all code blocks are complete and runnable.

Type consistency:

  • UserClaims defined in Task 9, used in Tasks 12, 13
  • UserRole enum defined in Task 9, used in Task 12
  • process_document_task name matches registration in Task 3 and lookup in Task 3 test
  • store_document defined in Task 3, called in Task 4
  • _process_document defined in Task 3, called from Task 3 and Task 4
  • get_jwt_handler, get_user_store, get_celery_app added to bootstrap in Tasks 4, 12
  • RedisConversationStore constructor signature (redis_client, timeout_seconds) consistent across implementation, tests, and bootstrap