1. Add 登陆功能

2. 调整字体大小
3. 新增部分功能
This commit is contained in:
2026-06-05 18:00:31 +08:00
parent 06e0967128
commit 9fea9c6a53
58 changed files with 5028 additions and 322 deletions

View File

@@ -0,0 +1,5 @@
"""JWT token creation and validation infrastructure.
JWTHandler is the only component in this package. It is wired through
shared/bootstrap.py and injected into FastAPI dependencies.
"""

View File

@@ -0,0 +1,82 @@
"""JWT access token creation and decoding.
Uses python-jose for HS256 token signing. Token expiry is enforced at
decode time so expired tokens are rejected even if the signature is valid.
"""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from typing import Any
from jose import JWTError, jwt
from loguru import logger
from app.domain.auth.models import UserClaims, UserRole
class JWTHandler:
"""Create and validate HS256 JWT access tokens.
A single shared instance is wired by bootstrap.py. Use
get_jwt_handler() from shared.bootstrap for all token operations.
"""
def __init__(
self,
*,
secret_key: str,
algorithm: str = "HS256",
expire_minutes: int = 480,
) -> None:
"""Initialise the handler with signing credentials and token lifetime."""
self._secret = secret_key
self._algorithm = algorithm
self._expire_minutes = expire_minutes
def create_access_token(
self,
*,
user_id: str,
username: str,
role: str,
) -> str:
"""Return a signed JWT containing user identity and role claims."""
now = datetime.now(UTC)
payload: dict[str, Any] = {
"sub": user_id,
"username": username,
"role": role,
"iat": now,
"exp": now + timedelta(minutes=self._expire_minutes),
}
return jwt.encode(payload, self._secret, algorithm=self._algorithm)
def decode_token(self, token: str) -> UserClaims:
"""Decode and validate a JWT, returning UserClaims.
Raises ValueError with a descriptive message on expiry, tampering,
or any other validation failure so callers do not need to know jose.
"""
try:
payload = jwt.decode(token, self._secret, algorithms=[self._algorithm])
except JWTError as exc:
msg = str(exc).lower()
if "expired" in msg:
raise ValueError("Token expired") from exc
raise ValueError(f"Invalid token: {exc}") from exc
user_id = payload.get("sub")
username = payload.get("username", "")
role_str = payload.get("role", UserRole.READONLY.value)
if not user_id:
raise ValueError("Token missing subject claim")
try:
role = UserRole(role_str)
except ValueError:
logger.warning("Unknown role in token: {}, defaulting to readonly", role_str)
role = UserRole.READONLY
return UserClaims(user_id=user_id, username=username, role=role)

View File

@@ -0,0 +1,113 @@
"""PostgreSQL-backed user store for authentication.
Manages a `users` table with hashed passwords and roles.
Provides lookup by username for the login flow.
Table DDL is auto-applied on first connection.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import psycopg2
import psycopg2.extras
from loguru import logger
from passlib.context import CryptContext
from app.config.settings import settings
# bcrypt context — work factor 12 is a good production default.
_PWD_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto")
# DDL executed once to ensure the table exists.
_CREATE_TABLE_SQL = """
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
username VARCHAR(100) UNIQUE NOT NULL,
hashed_pw TEXT NOT NULL,
role VARCHAR(50) NOT NULL DEFAULT 'readonly',
is_active BOOLEAN NOT NULL DEFAULT TRUE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
@dataclass
class UserRecord:
"""A single row from the users table."""
id: str
username: str
hashed_pw: str
role: str
is_active: bool
class PostgresUserStore:
"""Read and verify users stored in the PostgreSQL users table.
The connection is opened on first use and shared for the lifetime
of the singleton instance wired by bootstrap.
"""
def __init__(self) -> None:
"""Initialise the store and ensure the users table exists."""
self._conn = psycopg2.connect(
host=settings.postgres_host,
port=settings.postgres_port,
user=settings.postgres_user,
password=settings.postgres_password,
dbname=settings.postgres_db,
cursor_factory=psycopg2.extras.RealDictCursor,
)
self._conn.autocommit = True
self._ensure_table()
def _ensure_table(self) -> None:
"""Create the users table if it does not already exist."""
with self._conn.cursor() as cur:
# Enable pgcrypto so gen_random_uuid() is available for UUID primary keys.
try:
cur.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;")
except Exception:
self._conn.rollback()
cur.execute(_CREATE_TABLE_SQL)
def get_by_username(self, username: str) -> Optional[UserRecord]:
"""Return a UserRecord for the given username, or None if not found."""
with self._conn.cursor() as cur:
cur.execute(
"SELECT id, username, hashed_pw, role, is_active "
"FROM users WHERE username = %s",
(username,),
)
row = cur.fetchone()
if row is None:
return None
return UserRecord(
id=str(row["id"]),
username=row["username"],
hashed_pw=row["hashed_pw"],
role=row["role"],
is_active=row["is_active"],
)
def verify_password(self, plain: str, hashed: str) -> bool:
"""Return True if `plain` matches the stored bcrypt hash."""
return _PWD_CTX.verify(plain, hashed)
def authenticate(self, username: str, password: str) -> Optional[UserRecord]:
"""Return the UserRecord if credentials are valid, else None."""
user = self.get_by_username(username)
if user is None or not user.is_active:
return None
if not self.verify_password(password, user.hashed_pw):
return None
return user
@staticmethod
def hash_password(plain: str) -> str:
"""Hash a plain-text password with bcrypt."""
return _PWD_CTX.hash(plain)

View File

@@ -0,0 +1,169 @@
"""Redis-backed conversation store for persistent chat sessions.
Sessions are stored as JSON strings under the key `session:{session_id}`.
The Redis TTL is refreshed on every write so active sessions stay alive.
On expiry, `get_session` returns None — callers should create a new session.
"""
from __future__ import annotations
import json
import time
import uuid
from typing import Any
from loguru import logger
from app.domain.conversation import ConversationMessage, ConversationSession, ConversationStore
class RedisConversationStore(ConversationStore):
"""Store conversation sessions in Redis with automatic TTL expiry.
Each session is serialised as a JSON object at key ``session:{session_id}``.
The TTL is reset on every write so sessions stay alive as long as they are active.
"""
# Prefix for all session keys to avoid collisions with other Redis consumers.
_PREFIX = "session:"
def __init__(self, *, redis_client: Any, timeout_seconds: int = 1800) -> None:
"""Initialise the store with an existing Redis client and a TTL in seconds."""
self._redis = redis_client
self._ttl = timeout_seconds
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _key(self, session_id: str) -> str:
"""Build the Redis key for a session."""
return f"{self._PREFIX}{session_id}"
def _serialise(self, session: ConversationSession) -> str:
"""Serialise a ConversationSession to a JSON string."""
return json.dumps(
{
"session_id": session.session_id,
"created_at": session.created_at,
"updated_at": session.updated_at,
"metadata": session.metadata,
"messages": [
{
"role": msg.role,
"content": msg.content,
"timestamp": msg.timestamp,
"sources": msg.sources,
}
for msg in session.messages
],
},
ensure_ascii=False,
)
def _deserialise(self, raw: bytes | str) -> ConversationSession:
"""Deserialise a JSON string back into a ConversationSession."""
data = json.loads(raw)
messages = [
ConversationMessage(
role=m["role"],
content=m["content"],
timestamp=m["timestamp"],
sources=m.get("sources", []),
)
for m in data.get("messages", [])
]
session = ConversationSession(
session_id=data["session_id"],
created_at=data.get("created_at", 0),
updated_at=data.get("updated_at", 0),
metadata=data.get("metadata", {}),
)
session.messages = messages
return session
def _save(self, session: ConversationSession) -> None:
"""Persist a session to Redis and refresh its TTL."""
self._redis.setex(self._key(session.session_id), self._ttl, self._serialise(session))
# ------------------------------------------------------------------
# ConversationStore protocol
# ------------------------------------------------------------------
def create_session(self, metadata: dict | None = None) -> ConversationSession:
"""Create a new empty session and persist it immediately."""
now = int(time.time())
session = ConversationSession(
session_id=str(uuid.uuid4())[:8],
created_at=now,
updated_at=now,
metadata=metadata or {},
)
self._save(session)
return session
def get_session(self, session_id: str) -> ConversationSession | None:
"""Return a session by ID, or None if it does not exist or has expired."""
raw = self._redis.get(self._key(session_id))
if raw is None:
return None
try:
return self._deserialise(raw)
except Exception:
logger.warning("Failed to deserialise session: {}", session_id)
return None
def save_message(
self,
session_id: str,
*,
role: str,
content: str,
sources: list[dict] | None = None,
) -> ConversationSession | None:
"""Append a message to a session and refresh its TTL."""
session = self.get_session(session_id)
if session is None:
return None
session.messages.append(
ConversationMessage(
role=role,
content=content,
timestamp=int(time.time()),
sources=sources or [],
)
)
session.updated_at = int(time.time())
self._save(session)
return session
def delete_session(self, session_id: str) -> bool:
"""Delete a session. Returns True if it existed, False otherwise."""
deleted = self._redis.delete(self._key(session_id))
return bool(deleted)
def list_sessions(self) -> list[dict]:
"""Return summary dicts for all live sessions visible in this Redis DB.
Note: KEYS is used for simplicity; replace with SCAN for large deployments.
"""
pattern = f"{self._PREFIX}*"
keys = self._redis.keys(pattern)
result = []
for key in keys:
raw = self._redis.get(key)
if raw is None:
continue
try:
data = json.loads(raw)
result.append(
{
"session_id": data["session_id"],
"message_count": len(data.get("messages", [])),
"created_at": data.get("created_at", 0),
"updated_at": data.get("updated_at", 0),
}
)
except Exception:
continue
return result

View File

@@ -0,0 +1,5 @@
"""Celery task definitions for background processing.
This package exposes the shared Celery application instance and all
registered task functions used by API routes to enqueue work.
"""

View File

@@ -0,0 +1,45 @@
"""Shared Celery application instance for background task processing.
All workers and enqueueing call sites import `celery_app` from this module
so the broker/backend configuration stays in one place.
"""
from __future__ import annotations
from celery import Celery
from app.config.settings import settings
def _redis_url() -> str:
"""Return a Redis connection URL from application settings."""
if settings.redis_password:
return (
f"redis://:{settings.redis_password}@"
f"{settings.redis_host}:{settings.redis_port}/{settings.redis_db}"
)
return f"redis://{settings.redis_host}:{settings.redis_port}/{settings.redis_db}"
_BROKER = _redis_url()
_BACKEND = _redis_url()
celery_app = Celery(
"compliance_hub",
broker=_BROKER,
backend=_BACKEND,
include=["app.infrastructure.tasks.document_tasks"],
)
celery_app.conf.update(
task_serializer="json",
result_serializer="json",
accept_content=["json"],
timezone="UTC",
enable_utc=True,
# Acknowledge task only after successful execution to avoid data loss.
task_acks_late=True,
task_reject_on_worker_lost=True,
# Keep results for 1 hour for status polling.
result_expires=3600,
)

View File

@@ -0,0 +1,73 @@
"""Celery tasks for document processing.
Each task is a thin wrapper that retrieves the already-stored document
binary and delegates to DocumentCommandService._process_document.
The task does not accept raw file bytes — it reads them from the binary
store using the doc_id, so the Celery message payload stays small.
"""
from __future__ import annotations
from loguru import logger
from app.infrastructure.tasks.celery_app import celery_app
@celery_app.task(
name="app.infrastructure.tasks.document_tasks.process_document_task",
bind=True,
max_retries=3,
default_retry_delay=30,
acks_late=True,
)
def process_document_task(
self,
doc_id: str,
file_name: str,
doc_name: str,
regulation_type: str,
version: str,
generate_summary: bool,
run_id: str | None = None,
) -> dict:
"""Parse, embed, and index a document that has already been stored.
The task reads the file binary from MinIO using doc_id so the Celery
message stays small. Retries up to 3 times with a 30-second delay on
transient infrastructure errors.
"""
# Import inside the task function to avoid pickling issues and to ensure
# that each worker process initialises its own bootstrap singletons.
from app.shared.bootstrap import get_document_command_service, get_document_query_service
logger.info("process_document_task started: doc_id={}", doc_id)
try:
svc = get_document_command_service()
doc = get_document_query_service().get(doc_id)
if not doc:
raise ValueError(f"Document record not found: {doc_id}")
# Read the stored binary from MinIO — avoids passing raw bytes in the task message.
content = svc.binary_store.read(doc.object_name)
result = svc._process_document(
doc_id=doc_id,
file_name=file_name,
final_doc_name=doc_name,
content=content,
regulation_type=regulation_type,
version=version,
generate_summary=generate_summary,
run_id=run_id,
)
logger.info(
"process_document_task completed: doc_id={} status={} chunks={}",
doc_id, result.status, result.num_chunks,
)
return {"doc_id": result.doc_id, "status": result.status, "num_chunks": result.num_chunks}
except Exception as exc:
logger.exception("process_document_task failed: doc_id={}", doc_id)
# Retry on transient errors; permanent errors (bad file, parse failure)
# will exhaust retries and leave the document in FAILED state.
raise self.retry(exc=exc)