Files
AIRegulation-DocAnalysis/backend/app/services/agent/session_manager.py

192 lines
6.5 KiB
Python
Raw Permalink Normal View History

"""Provide service-layer logic for session manager."""
2026-05-14 15:07:34 +08:00
import time
import uuid
from typing import Dict, List, Optional
from dataclasses import dataclass, field
from loguru import logger
@dataclass
class ChatMessage:
"""Represent the Chat Message type."""
2026-05-14 15:07:34 +08:00
role: str # "user" / "assistant" / "system"
content: str
timestamp: int
sources: List[Dict] = field(default_factory=list)
metadata: Dict = field(default_factory=dict)
@dataclass
class ChatSession:
"""Represent the Chat Session type."""
2026-05-14 15:07:34 +08:00
session_id: str
messages: List[ChatMessage] = field(default_factory=list)
created_at: int = field(default_factory=lambda: int(time.time()))
updated_at: int = field(default_factory=lambda: int(time.time()))
metadata: Dict = field(default_factory=dict)
def add_user_message(self, content: str) -> ChatMessage:
"""Handle add user message for the Chat Session instance."""
2026-05-14 15:07:34 +08:00
message = ChatMessage(
role="user",
content=content,
timestamp=int(time.time())
)
self.messages.append(message)
self.updated_at = int(time.time())
return message
def add_assistant_message(
self,
content: str,
sources: List[Dict] = None
) -> ChatMessage:
"""Handle add assistant message for the Chat Session instance."""
2026-05-14 15:07:34 +08:00
message = ChatMessage(
role="assistant",
content=content,
timestamp=int(time.time()),
sources=sources or []
)
self.messages.append(message)
self.updated_at = int(time.time())
return message
def get_history(self, max_turns: int = 5) -> List[Dict[str, str]]:
"""Return history for the Chat Session instance."""
2026-05-14 15:07:34 +08:00
history = []
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
recent_messages = self.messages[-(max_turns * 2):]
for msg in recent_messages:
history.append({
"role": msg.role,
"content": msg.content
})
return history
def clear_history(self):
"""Handle clear history for the Chat Session instance."""
2026-05-14 15:07:34 +08:00
self.messages = []
self.updated_at = int(time.time())
logger.info(f"会话历史已清空: {self.session_id}")
@property
def message_count(self) -> int:
"""Handle message count for the Chat Session instance."""
2026-05-14 15:07:34 +08:00
return len(self.messages)
@property
def is_empty(self) -> bool:
"""Return whether empty for the Chat Session instance."""
2026-05-14 15:07:34 +08:00
return len(self.messages) == 0
class SessionManager:
"""Represent the Session Manager type."""
2026-05-14 15:07:34 +08:00
def __init__(
self,
max_sessions: int = 100,
session_timeout_minutes: int = 30
):
"""Initialize the Session Manager instance."""
2026-05-14 15:07:34 +08:00
self.max_sessions = max_sessions
self.session_timeout = session_timeout_minutes * 60
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
self._sessions: Dict[str, ChatSession] = {}
logger.info(f"会话管理器初始化: max_sessions={max_sessions}, timeout={session_timeout_minutes}min")
def create_session(self, metadata: Dict = None) -> ChatSession:
"""Create session for the Session Manager instance."""
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
if len(self._sessions) >= self.max_sessions:
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
self._cleanup_expired_sessions()
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
if len(self._sessions) >= self.max_sessions:
oldest_id = min(
self._sessions.keys(),
key=lambda x: self._sessions[x].created_at
)
self.delete_session(oldest_id)
logger.warning(f"删除最老会话以腾出空间: {oldest_id}")
session_id = str(uuid.uuid4())[:8]
session = ChatSession(
session_id=session_id,
metadata=metadata or {}
)
self._sessions[session_id] = session
logger.info(f"创建新会话: {session_id}")
return session
def get_session(self, session_id: str) -> Optional[ChatSession]:
"""Return session for the Session Manager instance."""
2026-05-14 15:07:34 +08:00
session = self._sessions.get(session_id)
if session:
# Keep service responsibilities explicit so downstream behavior stays predictable.
2026-05-14 15:07:34 +08:00
if self._is_session_expired(session):
self.delete_session(session_id)
logger.info(f"会话已过期,已删除: {session_id}")
return None
return session
def delete_session(self, session_id: str) -> bool:
"""Delete session for the Session Manager instance."""
2026-05-14 15:07:34 +08:00
if session_id in self._sessions:
del self._sessions[session_id]
logger.info(f"删除会话: {session_id}")
return True
return False
def list_sessions(self) -> List[Dict]:
"""List sessions for the Session Manager instance."""
2026-05-14 15:07:34 +08:00
return [
{
"session_id": s.session_id,
"message_count": s.message_count,
"created_at": s.created_at,
"updated_at": s.updated_at
}
for s in self._sessions.values()
]
def _is_session_expired(self, session: ChatSession) -> bool:
"""Handle is session expired for this module for the Session Manager instance."""
2026-05-14 15:07:34 +08:00
current_time = int(time.time())
return (current_time - session.updated_at) > self.session_timeout
def _cleanup_expired_sessions(self) -> int:
"""Handle cleanup expired sessions for this module for the Session Manager instance."""
2026-05-14 15:07:34 +08:00
expired_ids = [
sid for sid, session in self._sessions.items()
if self._is_session_expired(session)
]
for sid in expired_ids:
self.delete_session(sid)
if expired_ids:
logger.info(f"清理过期会话: {len(expired_ids)}")
return len(expired_ids)
def get_session_count(self) -> int:
"""Return session count for the Session Manager instance."""
2026-05-14 15:07:34 +08:00
return len(self._sessions)
def clear_all_sessions(self):
"""Handle clear all sessions for the Session Manager instance."""
2026-05-14 15:07:34 +08:00
self._sessions.clear()
logger.info("所有会话已清空")