Files
AIRegulation-DocAnalysis/backend/app/services/agent/session_manager.py
2026-05-14 15:07:34 +08:00

247 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# src/services/agent/session_manager.py
"""多轮对话会话管理"""
import time
import uuid
from typing import Dict, List, Optional
from dataclasses import dataclass, field
from loguru import logger
@dataclass
class ChatMessage:
"""对话消息"""
role: str # "user" / "assistant" / "system"
content: str
timestamp: int
sources: List[Dict] = field(default_factory=list)
metadata: Dict = field(default_factory=dict)
@dataclass
class ChatSession:
"""对话会话"""
session_id: str
messages: List[ChatMessage] = field(default_factory=list)
created_at: int = field(default_factory=lambda: int(time.time()))
updated_at: int = field(default_factory=lambda: int(time.time()))
metadata: Dict = field(default_factory=dict)
def add_user_message(self, content: str) -> ChatMessage:
"""添加用户消息"""
message = ChatMessage(
role="user",
content=content,
timestamp=int(time.time())
)
self.messages.append(message)
self.updated_at = int(time.time())
return message
def add_assistant_message(
self,
content: str,
sources: List[Dict] = None
) -> ChatMessage:
"""添加助手消息"""
message = ChatMessage(
role="assistant",
content=content,
timestamp=int(time.time()),
sources=sources or []
)
self.messages.append(message)
self.updated_at = int(time.time())
return message
def get_history(self, max_turns: int = 5) -> List[Dict[str, str]]:
"""获取历史对话用于LLM上下文"""
history = []
# 获取最近N轮对话每轮包含user + assistant
recent_messages = self.messages[-(max_turns * 2):]
for msg in recent_messages:
history.append({
"role": msg.role,
"content": msg.content
})
return history
def clear_history(self):
"""清空对话历史"""
self.messages = []
self.updated_at = int(time.time())
logger.info(f"会话历史已清空: {self.session_id}")
@property
def message_count(self) -> int:
"""消息数量"""
return len(self.messages)
@property
def is_empty(self) -> bool:
"""是否为空会话"""
return len(self.messages) == 0
class SessionManager:
"""
会话管理器
功能:
- 创建/获取/删除会话
- 会话超时清理
- 会话历史记录管理
使用示例:
manager = SessionManager()
# 创建会话
session = manager.create_session()
# 添加消息
session.add_user_message("什么是机动车安全技术检验?")
session.add_assistant_message("根据GB 7258...", sources=[...])
# 获取历史用于LLM多轮对话
history = session.get_history(max_turns=3)
"""
def __init__(
self,
max_sessions: int = 100,
session_timeout_minutes: int = 30
):
"""
初始化会话管理器
Args:
max_sessions: 最大会话数量
session_timeout_minutes: 会话超时时间(分钟)
"""
self.max_sessions = max_sessions
self.session_timeout = session_timeout_minutes * 60
# 会话存储(内存)
self._sessions: Dict[str, ChatSession] = {}
logger.info(f"会话管理器初始化: max_sessions={max_sessions}, timeout={session_timeout_minutes}min")
def create_session(self, metadata: Dict = None) -> ChatSession:
"""
创建新会话
Args:
metadata: 会话元数据(可选)
Returns:
ChatSession: 新创建的会话
"""
# 检查会话数量限制
if len(self._sessions) >= self.max_sessions:
# 清理过期会话
self._cleanup_expired_sessions()
# 如果仍然超出限制,删除最老的会话
if len(self._sessions) >= self.max_sessions:
oldest_id = min(
self._sessions.keys(),
key=lambda x: self._sessions[x].created_at
)
self.delete_session(oldest_id)
logger.warning(f"删除最老会话以腾出空间: {oldest_id}")
session_id = str(uuid.uuid4())[:8]
session = ChatSession(
session_id=session_id,
metadata=metadata or {}
)
self._sessions[session_id] = session
logger.info(f"创建新会话: {session_id}")
return session
def get_session(self, session_id: str) -> Optional[ChatSession]:
"""
获取会话
Args:
session_id: 会话ID
Returns:
ChatSession: 会话对象如不存在返回None
"""
session = self._sessions.get(session_id)
if session:
# 检查是否过期
if self._is_session_expired(session):
self.delete_session(session_id)
logger.info(f"会话已过期,已删除: {session_id}")
return None
return session
def delete_session(self, session_id: str) -> bool:
"""
删除会话
Args:
session_id: 会话ID
Returns:
bool: 是否成功删除
"""
if session_id in self._sessions:
del self._sessions[session_id]
logger.info(f"删除会话: {session_id}")
return True
return False
def list_sessions(self) -> List[Dict]:
"""
列出所有会话
Returns:
List[Dict]: 会话列表摘要
"""
return [
{
"session_id": s.session_id,
"message_count": s.message_count,
"created_at": s.created_at,
"updated_at": s.updated_at
}
for s in self._sessions.values()
]
def _is_session_expired(self, session: ChatSession) -> bool:
"""检查会话是否过期"""
current_time = int(time.time())
return (current_time - session.updated_at) > self.session_timeout
def _cleanup_expired_sessions(self) -> int:
"""清理过期会话"""
expired_ids = [
sid for sid, session in self._sessions.items()
if self._is_session_expired(session)
]
for sid in expired_ids:
self.delete_session(sid)
if expired_ids:
logger.info(f"清理过期会话: {len(expired_ids)}")
return len(expired_ids)
def get_session_count(self) -> int:
"""获取当前会话数量"""
return len(self._sessions)
def clear_all_sessions(self):
"""清空所有会话"""
self._sessions.clear()
logger.info("所有会话已清空")