1. Add 登陆功能
2. 调整字体大小 3. 新增部分功能
This commit is contained in:
5
backend/app/infrastructure/auth/__init__.py
Normal file
5
backend/app/infrastructure/auth/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""JWT token creation and validation infrastructure.
|
||||
|
||||
JWTHandler is the only component in this package. It is wired through
|
||||
shared/bootstrap.py and injected into FastAPI dependencies.
|
||||
"""
|
||||
82
backend/app/infrastructure/auth/jwt_handler.py
Normal file
82
backend/app/infrastructure/auth/jwt_handler.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""JWT access token creation and decoding.
|
||||
|
||||
Uses python-jose for HS256 token signing. Token expiry is enforced at
|
||||
decode time so expired tokens are rejected even if the signature is valid.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from loguru import logger
|
||||
|
||||
from app.domain.auth.models import UserClaims, UserRole
|
||||
|
||||
|
||||
class JWTHandler:
|
||||
"""Create and validate HS256 JWT access tokens.
|
||||
|
||||
A single shared instance is wired by bootstrap.py. Use
|
||||
get_jwt_handler() from shared.bootstrap for all token operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
secret_key: str,
|
||||
algorithm: str = "HS256",
|
||||
expire_minutes: int = 480,
|
||||
) -> None:
|
||||
"""Initialise the handler with signing credentials and token lifetime."""
|
||||
self._secret = secret_key
|
||||
self._algorithm = algorithm
|
||||
self._expire_minutes = expire_minutes
|
||||
|
||||
def create_access_token(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
username: str,
|
||||
role: str,
|
||||
) -> str:
|
||||
"""Return a signed JWT containing user identity and role claims."""
|
||||
now = datetime.now(UTC)
|
||||
payload: dict[str, Any] = {
|
||||
"sub": user_id,
|
||||
"username": username,
|
||||
"role": role,
|
||||
"iat": now,
|
||||
"exp": now + timedelta(minutes=self._expire_minutes),
|
||||
}
|
||||
return jwt.encode(payload, self._secret, algorithm=self._algorithm)
|
||||
|
||||
def decode_token(self, token: str) -> UserClaims:
|
||||
"""Decode and validate a JWT, returning UserClaims.
|
||||
|
||||
Raises ValueError with a descriptive message on expiry, tampering,
|
||||
or any other validation failure so callers do not need to know jose.
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, self._secret, algorithms=[self._algorithm])
|
||||
except JWTError as exc:
|
||||
msg = str(exc).lower()
|
||||
if "expired" in msg:
|
||||
raise ValueError("Token expired") from exc
|
||||
raise ValueError(f"Invalid token: {exc}") from exc
|
||||
|
||||
user_id = payload.get("sub")
|
||||
username = payload.get("username", "")
|
||||
role_str = payload.get("role", UserRole.READONLY.value)
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("Token missing subject claim")
|
||||
|
||||
try:
|
||||
role = UserRole(role_str)
|
||||
except ValueError:
|
||||
logger.warning("Unknown role in token: {}, defaulting to readonly", role_str)
|
||||
role = UserRole.READONLY
|
||||
|
||||
return UserClaims(user_id=user_id, username=username, role=role)
|
||||
113
backend/app/infrastructure/auth/user_store.py
Normal file
113
backend/app/infrastructure/auth/user_store.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""PostgreSQL-backed user store for authentication.
|
||||
|
||||
Manages a `users` table with hashed passwords and roles.
|
||||
Provides lookup by username for the login flow.
|
||||
Table DDL is auto-applied on first connection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from loguru import logger
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
# bcrypt context — work factor 12 is a good production default.
|
||||
_PWD_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# DDL executed once to ensure the table exists.
|
||||
_CREATE_TABLE_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
username VARCHAR(100) UNIQUE NOT NULL,
|
||||
hashed_pw TEXT NOT NULL,
|
||||
role VARCHAR(50) NOT NULL DEFAULT 'readonly',
|
||||
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserRecord:
|
||||
"""A single row from the users table."""
|
||||
|
||||
id: str
|
||||
username: str
|
||||
hashed_pw: str
|
||||
role: str
|
||||
is_active: bool
|
||||
|
||||
|
||||
class PostgresUserStore:
|
||||
"""Read and verify users stored in the PostgreSQL users table.
|
||||
|
||||
The connection is opened on first use and shared for the lifetime
|
||||
of the singleton instance wired by bootstrap.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise the store and ensure the users table exists."""
|
||||
self._conn = psycopg2.connect(
|
||||
host=settings.postgres_host,
|
||||
port=settings.postgres_port,
|
||||
user=settings.postgres_user,
|
||||
password=settings.postgres_password,
|
||||
dbname=settings.postgres_db,
|
||||
cursor_factory=psycopg2.extras.RealDictCursor,
|
||||
)
|
||||
self._conn.autocommit = True
|
||||
self._ensure_table()
|
||||
|
||||
def _ensure_table(self) -> None:
|
||||
"""Create the users table if it does not already exist."""
|
||||
with self._conn.cursor() as cur:
|
||||
# Enable pgcrypto so gen_random_uuid() is available for UUID primary keys.
|
||||
try:
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;")
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
cur.execute(_CREATE_TABLE_SQL)
|
||||
|
||||
def get_by_username(self, username: str) -> Optional[UserRecord]:
|
||||
"""Return a UserRecord for the given username, or None if not found."""
|
||||
with self._conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT id, username, hashed_pw, role, is_active "
|
||||
"FROM users WHERE username = %s",
|
||||
(username,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return UserRecord(
|
||||
id=str(row["id"]),
|
||||
username=row["username"],
|
||||
hashed_pw=row["hashed_pw"],
|
||||
role=row["role"],
|
||||
is_active=row["is_active"],
|
||||
)
|
||||
|
||||
def verify_password(self, plain: str, hashed: str) -> bool:
|
||||
"""Return True if `plain` matches the stored bcrypt hash."""
|
||||
return _PWD_CTX.verify(plain, hashed)
|
||||
|
||||
def authenticate(self, username: str, password: str) -> Optional[UserRecord]:
|
||||
"""Return the UserRecord if credentials are valid, else None."""
|
||||
user = self.get_by_username(username)
|
||||
if user is None or not user.is_active:
|
||||
return None
|
||||
if not self.verify_password(password, user.hashed_pw):
|
||||
return None
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def hash_password(plain: str) -> str:
|
||||
"""Hash a plain-text password with bcrypt."""
|
||||
return _PWD_CTX.hash(plain)
|
||||
169
backend/app/infrastructure/session/redis_conversation_store.py
Normal file
169
backend/app/infrastructure/session/redis_conversation_store.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Redis-backed conversation store for persistent chat sessions.
|
||||
|
||||
Sessions are stored as JSON strings under the key `session:{session_id}`.
|
||||
The Redis TTL is refreshed on every write so active sessions stay alive.
|
||||
On expiry, `get_session` returns None — callers should create a new session.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.domain.conversation import ConversationMessage, ConversationSession, ConversationStore
|
||||
|
||||
|
||||
class RedisConversationStore(ConversationStore):
|
||||
"""Store conversation sessions in Redis with automatic TTL expiry.
|
||||
|
||||
Each session is serialised as a JSON object at key ``session:{session_id}``.
|
||||
The TTL is reset on every write so sessions stay alive as long as they are active.
|
||||
"""
|
||||
|
||||
# Prefix for all session keys to avoid collisions with other Redis consumers.
|
||||
_PREFIX = "session:"
|
||||
|
||||
def __init__(self, *, redis_client: Any, timeout_seconds: int = 1800) -> None:
|
||||
"""Initialise the store with an existing Redis client and a TTL in seconds."""
|
||||
self._redis = redis_client
|
||||
self._ttl = timeout_seconds
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _key(self, session_id: str) -> str:
|
||||
"""Build the Redis key for a session."""
|
||||
return f"{self._PREFIX}{session_id}"
|
||||
|
||||
def _serialise(self, session: ConversationSession) -> str:
|
||||
"""Serialise a ConversationSession to a JSON string."""
|
||||
return json.dumps(
|
||||
{
|
||||
"session_id": session.session_id,
|
||||
"created_at": session.created_at,
|
||||
"updated_at": session.updated_at,
|
||||
"metadata": session.metadata,
|
||||
"messages": [
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"timestamp": msg.timestamp,
|
||||
"sources": msg.sources,
|
||||
}
|
||||
for msg in session.messages
|
||||
],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
def _deserialise(self, raw: bytes | str) -> ConversationSession:
|
||||
"""Deserialise a JSON string back into a ConversationSession."""
|
||||
data = json.loads(raw)
|
||||
messages = [
|
||||
ConversationMessage(
|
||||
role=m["role"],
|
||||
content=m["content"],
|
||||
timestamp=m["timestamp"],
|
||||
sources=m.get("sources", []),
|
||||
)
|
||||
for m in data.get("messages", [])
|
||||
]
|
||||
session = ConversationSession(
|
||||
session_id=data["session_id"],
|
||||
created_at=data.get("created_at", 0),
|
||||
updated_at=data.get("updated_at", 0),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
session.messages = messages
|
||||
return session
|
||||
|
||||
def _save(self, session: ConversationSession) -> None:
|
||||
"""Persist a session to Redis and refresh its TTL."""
|
||||
self._redis.setex(self._key(session.session_id), self._ttl, self._serialise(session))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# ConversationStore protocol
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def create_session(self, metadata: dict | None = None) -> ConversationSession:
|
||||
"""Create a new empty session and persist it immediately."""
|
||||
now = int(time.time())
|
||||
session = ConversationSession(
|
||||
session_id=str(uuid.uuid4())[:8],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self._save(session)
|
||||
return session
|
||||
|
||||
def get_session(self, session_id: str) -> ConversationSession | None:
|
||||
"""Return a session by ID, or None if it does not exist or has expired."""
|
||||
raw = self._redis.get(self._key(session_id))
|
||||
if raw is None:
|
||||
return None
|
||||
try:
|
||||
return self._deserialise(raw)
|
||||
except Exception:
|
||||
logger.warning("Failed to deserialise session: {}", session_id)
|
||||
return None
|
||||
|
||||
def save_message(
|
||||
self,
|
||||
session_id: str,
|
||||
*,
|
||||
role: str,
|
||||
content: str,
|
||||
sources: list[dict] | None = None,
|
||||
) -> ConversationSession | None:
|
||||
"""Append a message to a session and refresh its TTL."""
|
||||
session = self.get_session(session_id)
|
||||
if session is None:
|
||||
return None
|
||||
session.messages.append(
|
||||
ConversationMessage(
|
||||
role=role,
|
||||
content=content,
|
||||
timestamp=int(time.time()),
|
||||
sources=sources or [],
|
||||
)
|
||||
)
|
||||
session.updated_at = int(time.time())
|
||||
self._save(session)
|
||||
return session
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""Delete a session. Returns True if it existed, False otherwise."""
|
||||
deleted = self._redis.delete(self._key(session_id))
|
||||
return bool(deleted)
|
||||
|
||||
def list_sessions(self) -> list[dict]:
|
||||
"""Return summary dicts for all live sessions visible in this Redis DB.
|
||||
|
||||
Note: KEYS is used for simplicity; replace with SCAN for large deployments.
|
||||
"""
|
||||
pattern = f"{self._PREFIX}*"
|
||||
keys = self._redis.keys(pattern)
|
||||
result = []
|
||||
for key in keys:
|
||||
raw = self._redis.get(key)
|
||||
if raw is None:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
result.append(
|
||||
{
|
||||
"session_id": data["session_id"],
|
||||
"message_count": len(data.get("messages", [])),
|
||||
"created_at": data.get("created_at", 0),
|
||||
"updated_at": data.get("updated_at", 0),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
return result
|
||||
5
backend/app/infrastructure/tasks/__init__.py
Normal file
5
backend/app/infrastructure/tasks/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Celery task definitions for background processing.
|
||||
|
||||
This package exposes the shared Celery application instance and all
|
||||
registered task functions used by API routes to enqueue work.
|
||||
"""
|
||||
45
backend/app/infrastructure/tasks/celery_app.py
Normal file
45
backend/app/infrastructure/tasks/celery_app.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Shared Celery application instance for background task processing.
|
||||
|
||||
All workers and enqueueing call sites import `celery_app` from this module
|
||||
so the broker/backend configuration stays in one place.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
def _redis_url() -> str:
|
||||
"""Return a Redis connection URL from application settings."""
|
||||
if settings.redis_password:
|
||||
return (
|
||||
f"redis://:{settings.redis_password}@"
|
||||
f"{settings.redis_host}:{settings.redis_port}/{settings.redis_db}"
|
||||
)
|
||||
return f"redis://{settings.redis_host}:{settings.redis_port}/{settings.redis_db}"
|
||||
|
||||
|
||||
_BROKER = _redis_url()
|
||||
_BACKEND = _redis_url()
|
||||
|
||||
celery_app = Celery(
|
||||
"compliance_hub",
|
||||
broker=_BROKER,
|
||||
backend=_BACKEND,
|
||||
include=["app.infrastructure.tasks.document_tasks"],
|
||||
)
|
||||
|
||||
celery_app.conf.update(
|
||||
task_serializer="json",
|
||||
result_serializer="json",
|
||||
accept_content=["json"],
|
||||
timezone="UTC",
|
||||
enable_utc=True,
|
||||
# Acknowledge task only after successful execution to avoid data loss.
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
# Keep results for 1 hour for status polling.
|
||||
result_expires=3600,
|
||||
)
|
||||
73
backend/app/infrastructure/tasks/document_tasks.py
Normal file
73
backend/app/infrastructure/tasks/document_tasks.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Celery tasks for document processing.
|
||||
|
||||
Each task is a thin wrapper that retrieves the already-stored document
|
||||
binary and delegates to DocumentCommandService._process_document.
|
||||
The task does not accept raw file bytes — it reads them from the binary
|
||||
store using the doc_id, so the Celery message payload stays small.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.infrastructure.tasks.celery_app import celery_app
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="app.infrastructure.tasks.document_tasks.process_document_task",
|
||||
bind=True,
|
||||
max_retries=3,
|
||||
default_retry_delay=30,
|
||||
acks_late=True,
|
||||
)
|
||||
def process_document_task(
|
||||
self,
|
||||
doc_id: str,
|
||||
file_name: str,
|
||||
doc_name: str,
|
||||
regulation_type: str,
|
||||
version: str,
|
||||
generate_summary: bool,
|
||||
run_id: str | None = None,
|
||||
) -> dict:
|
||||
"""Parse, embed, and index a document that has already been stored.
|
||||
|
||||
The task reads the file binary from MinIO using doc_id so the Celery
|
||||
message stays small. Retries up to 3 times with a 30-second delay on
|
||||
transient infrastructure errors.
|
||||
"""
|
||||
# Import inside the task function to avoid pickling issues and to ensure
|
||||
# that each worker process initialises its own bootstrap singletons.
|
||||
from app.shared.bootstrap import get_document_command_service, get_document_query_service
|
||||
|
||||
logger.info("process_document_task started: doc_id={}", doc_id)
|
||||
try:
|
||||
svc = get_document_command_service()
|
||||
doc = get_document_query_service().get(doc_id)
|
||||
if not doc:
|
||||
raise ValueError(f"Document record not found: {doc_id}")
|
||||
|
||||
# Read the stored binary from MinIO — avoids passing raw bytes in the task message.
|
||||
content = svc.binary_store.read(doc.object_name)
|
||||
|
||||
result = svc._process_document(
|
||||
doc_id=doc_id,
|
||||
file_name=file_name,
|
||||
final_doc_name=doc_name,
|
||||
content=content,
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
generate_summary=generate_summary,
|
||||
run_id=run_id,
|
||||
)
|
||||
logger.info(
|
||||
"process_document_task completed: doc_id={} status={} chunks={}",
|
||||
doc_id, result.status, result.num_chunks,
|
||||
)
|
||||
return {"doc_id": result.doc_id, "status": result.status, "num_chunks": result.num_chunks}
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("process_document_task failed: doc_id={}", doc_id)
|
||||
# Retry on transient errors; permanent errors (bad file, parse failure)
|
||||
# will exhaust retries and leave the document in FAILED state.
|
||||
raise self.retry(exc=exc)
|
||||
Reference in New Issue
Block a user