2026-05-14 15:07:34 +08:00
|
|
|
|
"""多轮对话会话管理"""
|
|
|
|
|
|
|
|
|
|
|
|
import time
|
|
|
|
|
|
import uuid
|
|
|
|
|
|
from typing import Dict, List, Optional
|
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
from loguru import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class ChatMessage:
|
|
|
|
|
|
"""对话消息"""
|
|
|
|
|
|
role: str # "user" / "assistant" / "system"
|
|
|
|
|
|
content: str
|
|
|
|
|
|
timestamp: int
|
|
|
|
|
|
sources: List[Dict] = field(default_factory=list)
|
|
|
|
|
|
metadata: Dict = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class ChatSession:
|
|
|
|
|
|
"""对话会话"""
|
|
|
|
|
|
session_id: str
|
|
|
|
|
|
messages: List[ChatMessage] = field(default_factory=list)
|
|
|
|
|
|
created_at: int = field(default_factory=lambda: int(time.time()))
|
|
|
|
|
|
updated_at: int = field(default_factory=lambda: int(time.time()))
|
|
|
|
|
|
metadata: Dict = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
def add_user_message(self, content: str) -> ChatMessage:
|
|
|
|
|
|
"""添加用户消息"""
|
|
|
|
|
|
message = ChatMessage(
|
|
|
|
|
|
role="user",
|
|
|
|
|
|
content=content,
|
|
|
|
|
|
timestamp=int(time.time())
|
|
|
|
|
|
)
|
|
|
|
|
|
self.messages.append(message)
|
|
|
|
|
|
self.updated_at = int(time.time())
|
|
|
|
|
|
return message
|
|
|
|
|
|
|
|
|
|
|
|
def add_assistant_message(
|
|
|
|
|
|
self,
|
|
|
|
|
|
content: str,
|
|
|
|
|
|
sources: List[Dict] = None
|
|
|
|
|
|
) -> ChatMessage:
|
|
|
|
|
|
"""添加助手消息"""
|
|
|
|
|
|
message = ChatMessage(
|
|
|
|
|
|
role="assistant",
|
|
|
|
|
|
content=content,
|
|
|
|
|
|
timestamp=int(time.time()),
|
|
|
|
|
|
sources=sources or []
|
|
|
|
|
|
)
|
|
|
|
|
|
self.messages.append(message)
|
|
|
|
|
|
self.updated_at = int(time.time())
|
|
|
|
|
|
return message
|
|
|
|
|
|
|
|
|
|
|
|
def get_history(self, max_turns: int = 5) -> List[Dict[str, str]]:
|
|
|
|
|
|
"""获取历史对话(用于LLM上下文)"""
|
|
|
|
|
|
history = []
|
|
|
|
|
|
# 获取最近N轮对话(每轮包含user + assistant)
|
|
|
|
|
|
recent_messages = self.messages[-(max_turns * 2):]
|
|
|
|
|
|
|
|
|
|
|
|
for msg in recent_messages:
|
|
|
|
|
|
history.append({
|
|
|
|
|
|
"role": msg.role,
|
|
|
|
|
|
"content": msg.content
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return history
|
|
|
|
|
|
|
|
|
|
|
|
def clear_history(self):
|
|
|
|
|
|
"""清空对话历史"""
|
|
|
|
|
|
self.messages = []
|
|
|
|
|
|
self.updated_at = int(time.time())
|
|
|
|
|
|
logger.info(f"会话历史已清空: {self.session_id}")
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
def message_count(self) -> int:
|
|
|
|
|
|
"""消息数量"""
|
|
|
|
|
|
return len(self.messages)
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
def is_empty(self) -> bool:
|
|
|
|
|
|
"""是否为空会话"""
|
|
|
|
|
|
return len(self.messages) == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SessionManager:
|
|
|
|
|
|
"""
|
|
|
|
|
|
会话管理器
|
|
|
|
|
|
|
|
|
|
|
|
功能:
|
|
|
|
|
|
- 创建/获取/删除会话
|
|
|
|
|
|
- 会话超时清理
|
|
|
|
|
|
- 会话历史记录管理
|
|
|
|
|
|
|
|
|
|
|
|
使用示例:
|
|
|
|
|
|
manager = SessionManager()
|
|
|
|
|
|
|
|
|
|
|
|
# 创建会话
|
|
|
|
|
|
session = manager.create_session()
|
|
|
|
|
|
|
|
|
|
|
|
# 添加消息
|
|
|
|
|
|
session.add_user_message("什么是机动车安全技术检验?")
|
|
|
|
|
|
session.add_assistant_message("根据GB 7258...", sources=[...])
|
|
|
|
|
|
|
|
|
|
|
|
# 获取历史(用于LLM多轮对话)
|
|
|
|
|
|
history = session.get_history(max_turns=3)
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
max_sessions: int = 100,
|
|
|
|
|
|
session_timeout_minutes: int = 30
|
|
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
初始化会话管理器
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
max_sessions: 最大会话数量
|
|
|
|
|
|
session_timeout_minutes: 会话超时时间(分钟)
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.max_sessions = max_sessions
|
|
|
|
|
|
self.session_timeout = session_timeout_minutes * 60
|
|
|
|
|
|
|
|
|
|
|
|
# 会话存储(内存)
|
|
|
|
|
|
self._sessions: Dict[str, ChatSession] = {}
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"会话管理器初始化: max_sessions={max_sessions}, timeout={session_timeout_minutes}min")
|
|
|
|
|
|
|
|
|
|
|
|
def create_session(self, metadata: Dict = None) -> ChatSession:
|
|
|
|
|
|
"""
|
|
|
|
|
|
创建新会话
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
metadata: 会话元数据(可选)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
ChatSession: 新创建的会话
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 检查会话数量限制
|
|
|
|
|
|
if len(self._sessions) >= self.max_sessions:
|
|
|
|
|
|
# 清理过期会话
|
|
|
|
|
|
self._cleanup_expired_sessions()
|
|
|
|
|
|
|
|
|
|
|
|
# 如果仍然超出限制,删除最老的会话
|
|
|
|
|
|
if len(self._sessions) >= self.max_sessions:
|
|
|
|
|
|
oldest_id = min(
|
|
|
|
|
|
self._sessions.keys(),
|
|
|
|
|
|
key=lambda x: self._sessions[x].created_at
|
|
|
|
|
|
)
|
|
|
|
|
|
self.delete_session(oldest_id)
|
|
|
|
|
|
logger.warning(f"删除最老会话以腾出空间: {oldest_id}")
|
|
|
|
|
|
|
|
|
|
|
|
session_id = str(uuid.uuid4())[:8]
|
|
|
|
|
|
session = ChatSession(
|
|
|
|
|
|
session_id=session_id,
|
|
|
|
|
|
metadata=metadata or {}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self._sessions[session_id] = session
|
|
|
|
|
|
logger.info(f"创建新会话: {session_id}")
|
|
|
|
|
|
|
|
|
|
|
|
return session
|
|
|
|
|
|
|
|
|
|
|
|
def get_session(self, session_id: str) -> Optional[ChatSession]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取会话
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
session_id: 会话ID
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
ChatSession: 会话对象(如不存在返回None)
|
|
|
|
|
|
"""
|
|
|
|
|
|
session = self._sessions.get(session_id)
|
|
|
|
|
|
|
|
|
|
|
|
if session:
|
|
|
|
|
|
# 检查是否过期
|
|
|
|
|
|
if self._is_session_expired(session):
|
|
|
|
|
|
self.delete_session(session_id)
|
|
|
|
|
|
logger.info(f"会话已过期,已删除: {session_id}")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
return session
|
|
|
|
|
|
|
|
|
|
|
|
def delete_session(self, session_id: str) -> bool:
|
|
|
|
|
|
"""
|
|
|
|
|
|
删除会话
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
session_id: 会话ID
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
bool: 是否成功删除
|
|
|
|
|
|
"""
|
|
|
|
|
|
if session_id in self._sessions:
|
|
|
|
|
|
del self._sessions[session_id]
|
|
|
|
|
|
logger.info(f"删除会话: {session_id}")
|
|
|
|
|
|
return True
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def list_sessions(self) -> List[Dict]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
列出所有会话
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
List[Dict]: 会话列表摘要
|
|
|
|
|
|
"""
|
|
|
|
|
|
return [
|
|
|
|
|
|
{
|
|
|
|
|
|
"session_id": s.session_id,
|
|
|
|
|
|
"message_count": s.message_count,
|
|
|
|
|
|
"created_at": s.created_at,
|
|
|
|
|
|
"updated_at": s.updated_at
|
|
|
|
|
|
}
|
|
|
|
|
|
for s in self._sessions.values()
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
def _is_session_expired(self, session: ChatSession) -> bool:
|
|
|
|
|
|
"""检查会话是否过期"""
|
|
|
|
|
|
current_time = int(time.time())
|
|
|
|
|
|
return (current_time - session.updated_at) > self.session_timeout
|
|
|
|
|
|
|
|
|
|
|
|
def _cleanup_expired_sessions(self) -> int:
|
|
|
|
|
|
"""清理过期会话"""
|
|
|
|
|
|
expired_ids = [
|
|
|
|
|
|
sid for sid, session in self._sessions.items()
|
|
|
|
|
|
if self._is_session_expired(session)
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
for sid in expired_ids:
|
|
|
|
|
|
self.delete_session(sid)
|
|
|
|
|
|
|
|
|
|
|
|
if expired_ids:
|
|
|
|
|
|
logger.info(f"清理过期会话: {len(expired_ids)}个")
|
|
|
|
|
|
|
|
|
|
|
|
return len(expired_ids)
|
|
|
|
|
|
|
|
|
|
|
|
def get_session_count(self) -> int:
|
|
|
|
|
|
"""获取当前会话数量"""
|
|
|
|
|
|
return len(self._sessions)
|
|
|
|
|
|
|
|
|
|
|
|
def clear_all_sessions(self):
|
|
|
|
|
|
"""清空所有会话"""
|
|
|
|
|
|
self._sessions.clear()
|
2026-05-14 18:09:15 +08:00
|
|
|
|
logger.info("所有会话已清空")
|