# src/services/agent/session_manager.py """多轮对话会话管理""" import time import uuid from typing import Dict, List, Optional from dataclasses import dataclass, field from loguru import logger @dataclass class ChatMessage: """对话消息""" role: str # "user" / "assistant" / "system" content: str timestamp: int sources: List[Dict] = field(default_factory=list) metadata: Dict = field(default_factory=dict) @dataclass class ChatSession: """对话会话""" session_id: str messages: List[ChatMessage] = field(default_factory=list) created_at: int = field(default_factory=lambda: int(time.time())) updated_at: int = field(default_factory=lambda: int(time.time())) metadata: Dict = field(default_factory=dict) def add_user_message(self, content: str) -> ChatMessage: """添加用户消息""" message = ChatMessage( role="user", content=content, timestamp=int(time.time()) ) self.messages.append(message) self.updated_at = int(time.time()) return message def add_assistant_message( self, content: str, sources: List[Dict] = None ) -> ChatMessage: """添加助手消息""" message = ChatMessage( role="assistant", content=content, timestamp=int(time.time()), sources=sources or [] ) self.messages.append(message) self.updated_at = int(time.time()) return message def get_history(self, max_turns: int = 5) -> List[Dict[str, str]]: """获取历史对话(用于LLM上下文)""" history = [] # 获取最近N轮对话(每轮包含user + assistant) recent_messages = self.messages[-(max_turns * 2):] for msg in recent_messages: history.append({ "role": msg.role, "content": msg.content }) return history def clear_history(self): """清空对话历史""" self.messages = [] self.updated_at = int(time.time()) logger.info(f"会话历史已清空: {self.session_id}") @property def message_count(self) -> int: """消息数量""" return len(self.messages) @property def is_empty(self) -> bool: """是否为空会话""" return len(self.messages) == 0 class SessionManager: """ 会话管理器 功能: - 创建/获取/删除会话 - 会话超时清理 - 会话历史记录管理 使用示例: manager = SessionManager() # 创建会话 session = manager.create_session() # 添加消息 session.add_user_message("什么是机动车安全技术检验?") session.add_assistant_message("根据GB 7258...", sources=[...]) # 获取历史(用于LLM多轮对话) history = session.get_history(max_turns=3) """ def __init__( self, max_sessions: int = 100, session_timeout_minutes: int = 30 ): """ 初始化会话管理器 Args: max_sessions: 最大会话数量 session_timeout_minutes: 会话超时时间(分钟) """ self.max_sessions = max_sessions self.session_timeout = session_timeout_minutes * 60 # 会话存储(内存) self._sessions: Dict[str, ChatSession] = {} logger.info(f"会话管理器初始化: max_sessions={max_sessions}, timeout={session_timeout_minutes}min") def create_session(self, metadata: Dict = None) -> ChatSession: """ 创建新会话 Args: metadata: 会话元数据(可选) Returns: ChatSession: 新创建的会话 """ # 检查会话数量限制 if len(self._sessions) >= self.max_sessions: # 清理过期会话 self._cleanup_expired_sessions() # 如果仍然超出限制,删除最老的会话 if len(self._sessions) >= self.max_sessions: oldest_id = min( self._sessions.keys(), key=lambda x: self._sessions[x].created_at ) self.delete_session(oldest_id) logger.warning(f"删除最老会话以腾出空间: {oldest_id}") session_id = str(uuid.uuid4())[:8] session = ChatSession( session_id=session_id, metadata=metadata or {} ) self._sessions[session_id] = session logger.info(f"创建新会话: {session_id}") return session def get_session(self, session_id: str) -> Optional[ChatSession]: """ 获取会话 Args: session_id: 会话ID Returns: ChatSession: 会话对象(如不存在返回None) """ session = self._sessions.get(session_id) if session: # 检查是否过期 if self._is_session_expired(session): self.delete_session(session_id) logger.info(f"会话已过期,已删除: {session_id}") return None return session def delete_session(self, session_id: str) -> bool: """ 删除会话 Args: session_id: 会话ID Returns: bool: 是否成功删除 """ if session_id in self._sessions: del self._sessions[session_id] logger.info(f"删除会话: {session_id}") return True return False def list_sessions(self) -> List[Dict]: """ 列出所有会话 Returns: List[Dict]: 会话列表摘要 """ return [ { "session_id": s.session_id, "message_count": s.message_count, "created_at": s.created_at, "updated_at": s.updated_at } for s in self._sessions.values() ] def _is_session_expired(self, session: ChatSession) -> bool: """检查会话是否过期""" current_time = int(time.time()) return (current_time - session.updated_at) > self.session_timeout def _cleanup_expired_sessions(self) -> int: """清理过期会话""" expired_ids = [ sid for sid, session in self._sessions.items() if self._is_session_expired(session) ] for sid in expired_ids: self.delete_session(sid) if expired_ids: logger.info(f"清理过期会话: {len(expired_ids)}个") return len(expired_ids) def get_session_count(self) -> int: """获取当前会话数量""" return len(self._sessions) def clear_all_sessions(self): """清空所有会话""" self._sessions.clear() logger.info("所有会话已清空")