Fix SSE route dependency and align architecture docs
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
"""Backend service package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Agent服务模块"""
|
||||
"""Initialize the app.services.agent package."""
|
||||
|
||||
from .qa_agent import QAAgent, ask_compliance_question
|
||||
from .session_manager import SessionManager, ChatSession
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = ["QAAgent", "ask_compliance_question", "SessionManager", "ChatSession"]
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
"""RAG问答Agent - 合规智能问答核心实现"""
|
||||
"""Provide service-layer logic for qa agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import List, Dict, Optional, Any, Generator
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
from typing import Dict, Generator, List, Optional
|
||||
|
||||
from app.services.llm import get_llm_client, BaseLLMClient, LLMResponse
|
||||
from app.services.llm.llm_factory import LLMFactory
|
||||
from app.services.rag.retriever import Retriever, RetrievedDocument
|
||||
from app.services.rag.context_builder import ContextBuilder, RAGContext
|
||||
from app.services.rag.prompt_templates import get_prompt_template, PromptTemplate
|
||||
from app.config.settings import settings
|
||||
from app.shared.bootstrap import get_agent_conversation_service
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
"""Agent响应结果"""
|
||||
"""Represent the Agent Response type."""
|
||||
answer: str
|
||||
sources: List[Dict] = field(default_factory=list)
|
||||
model: str = ""
|
||||
@@ -27,385 +25,73 @@ class AgentResponse:
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
"""Return whether success for the Agent Response instance."""
|
||||
return self.error is None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""Agent配置"""
|
||||
llm_provider: str = "deepseek"
|
||||
llm_model: str = "deepseek-v4-flash"
|
||||
top_k: int = 5
|
||||
min_score: float = 0.3
|
||||
max_context_tokens: int = 2000
|
||||
temperature: float = 0.7
|
||||
"""Define configuration for agent config."""
|
||||
llm_provider: str = settings.llm_provider
|
||||
llm_model: str = settings.llm_model
|
||||
top_k: int = settings.rag_top_k
|
||||
min_score: float = 0.0
|
||||
max_context_tokens: int = settings.rag_max_context_tokens
|
||||
temperature: float = settings.llm_temperature
|
||||
prompt_template: str = "compliance_qa"
|
||||
include_metadata: bool = True
|
||||
|
||||
|
||||
class QAAgent:
|
||||
"""
|
||||
合规问答Agent
|
||||
|
||||
核心流程:
|
||||
1. 接收用户问题
|
||||
2. Milvus混合检索相关法规条款
|
||||
3. 构建RAG上下文
|
||||
4. 调用LLM生成回答
|
||||
5. 返回答案和引用来源
|
||||
|
||||
使用示例:
|
||||
agent = QAAgent()
|
||||
response = agent.ask("机动车安全技术检验有哪些要求?")
|
||||
print(response.answer)
|
||||
for source in response.sources:
|
||||
print(f"引用: {source['doc_name']} - {source['clause_number']}")
|
||||
"""
|
||||
|
||||
"""Represent the Q A Agent type."""
|
||||
def __init__(self, config: Optional[AgentConfig] = None):
|
||||
"""
|
||||
初始化问答Agent
|
||||
|
||||
Args:
|
||||
config: Agent配置(可选,使用默认配置)
|
||||
"""
|
||||
self.config = config or AgentConfig(
|
||||
llm_provider=settings.llm_provider,
|
||||
llm_model=settings.llm_model,
|
||||
top_k=settings.rag_top_k,
|
||||
max_context_tokens=settings.rag_max_context_tokens
|
||||
)
|
||||
|
||||
# 初始化组件(延迟加载)
|
||||
self.llm: Optional[BaseLLMClient] = None
|
||||
self.retriever: Optional[Retriever] = None
|
||||
self.context_builder: Optional[ContextBuilder] = None
|
||||
|
||||
logger.info(f"问答Agent初始化: provider={self.config.llm_provider}, model={self.config.llm_model}")
|
||||
|
||||
def _init_llm(self):
|
||||
"""延迟初始化LLM客户端(优先使用全局缓存)"""
|
||||
if self.llm is None:
|
||||
# 尝试先获取全局缓存的客户端
|
||||
cached = LLMFactory.get_global_client(self.config.llm_provider, self.config.llm_model)
|
||||
if cached:
|
||||
self.llm = cached
|
||||
logger.debug(f"使用全局缓存的LLM客户端: {self.config.llm_provider} - {self.config.llm_model}")
|
||||
else:
|
||||
logger.info("创建新的LLM客户端...")
|
||||
self.llm = get_llm_client(
|
||||
provider=self.config.llm_provider,
|
||||
model=self.config.llm_model,
|
||||
temperature=self.config.temperature
|
||||
)
|
||||
|
||||
def _init_retriever(self):
|
||||
"""延迟初始化检索器"""
|
||||
if self.retriever is None:
|
||||
logger.info("初始化检索器...")
|
||||
self.retriever = Retriever(
|
||||
top_k=self.config.top_k,
|
||||
min_score=self.config.min_score
|
||||
)
|
||||
|
||||
def _init_context_builder(self):
|
||||
"""延迟初始化上下文构建器"""
|
||||
if self.context_builder is None:
|
||||
logger.info("初始化上下文构建器...")
|
||||
self.context_builder = ContextBuilder(
|
||||
max_context_tokens=self.config.max_context_tokens,
|
||||
include_metadata=self.config.include_metadata
|
||||
)
|
||||
"""Initialize the Q A Agent instance."""
|
||||
self.config = config or AgentConfig()
|
||||
|
||||
def ask(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[str] = None,
|
||||
prompt_template: Optional[str] = None
|
||||
prompt_template: Optional[str] = None,
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
回答用户问题
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
filters: 检索过滤条件(如 "regulation_type=='车辆安全'")
|
||||
prompt_template: Prompt模板名称(可选,覆盖默认配置)
|
||||
|
||||
Returns:
|
||||
AgentResponse: 包含答案和引用来源的响应对象
|
||||
"""
|
||||
start_time = time.time()
|
||||
logger.info(f"收到问题: {query}")
|
||||
|
||||
try:
|
||||
# Step 1: 检索相关法规
|
||||
self._init_retriever()
|
||||
documents = self.retriever.retrieve(query, filters)
|
||||
retrieved_count = len(documents)
|
||||
|
||||
if retrieved_count == 0:
|
||||
return AgentResponse(
|
||||
answer="抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问,或提供更具体的法规名称。",
|
||||
retrieved_count=0,
|
||||
error="no_retrieved_documents"
|
||||
)
|
||||
|
||||
# Step 2: 构建RAG上下文
|
||||
self._init_context_builder()
|
||||
template_name = prompt_template or self.config.prompt_template
|
||||
template = get_prompt_template(template_name)
|
||||
|
||||
context = self.context_builder.build(
|
||||
query=query,
|
||||
documents=documents,
|
||||
system_prompt=template.system_prompt
|
||||
)
|
||||
|
||||
# Step 3: 构建LLM输入消息
|
||||
messages = self._build_messages(template, context)
|
||||
|
||||
# Step 4: 调用LLM生成回答
|
||||
self._init_llm()
|
||||
llm_response = self.llm.chat(
|
||||
messages=messages,
|
||||
temperature=self.config.temperature
|
||||
)
|
||||
|
||||
if not llm_response.is_success:
|
||||
return AgentResponse(
|
||||
answer="",
|
||||
retrieved_count=retrieved_count,
|
||||
error=llm_response.error
|
||||
)
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Step 5: 返回结果
|
||||
logger.success(f"问答完成: {latency_ms}ms, {retrieved_count}条引用")
|
||||
|
||||
return AgentResponse(
|
||||
answer=llm_response.content,
|
||||
sources=context.sources,
|
||||
model=llm_response.model,
|
||||
latency_ms=latency_ms,
|
||||
retrieved_count=retrieved_count,
|
||||
context_tokens=context.total_tokens,
|
||||
truncated=context.truncated
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"问答失败: {e}")
|
||||
return AgentResponse(
|
||||
answer="",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
def ask_with_context(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[RetrievedDocument],
|
||||
prompt_template: Optional[str] = None
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
使用提供的文档回答问题(不执行检索)
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
documents: 已检索的文档列表
|
||||
prompt_template: Prompt模板名称
|
||||
|
||||
Returns:
|
||||
AgentResponse: 响应结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self._init_context_builder()
|
||||
self._init_llm()
|
||||
|
||||
template_name = prompt_template or self.config.prompt_template
|
||||
template = get_prompt_template(template_name)
|
||||
|
||||
context = self.context_builder.build(
|
||||
query=query,
|
||||
documents=documents,
|
||||
system_prompt=template.system_prompt
|
||||
)
|
||||
|
||||
messages = self._build_messages(template, context)
|
||||
|
||||
llm_response = self.llm.chat(messages)
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return AgentResponse(
|
||||
answer=llm_response.content,
|
||||
sources=context.sources,
|
||||
model=llm_response.model,
|
||||
latency_ms=latency_ms,
|
||||
retrieved_count=len(documents),
|
||||
context_tokens=context.total_tokens,
|
||||
truncated=context.truncated
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"问答失败: {e}")
|
||||
return AgentResponse(answer="", error=str(e))
|
||||
|
||||
def _build_messages(
|
||||
self,
|
||||
template: PromptTemplate,
|
||||
context: RAGContext
|
||||
) -> List[Dict[str, str]]:
|
||||
"""构建LLM输入消息"""
|
||||
user_content = template.user_template.format(
|
||||
context=context.context_text,
|
||||
query=context.user_query
|
||||
"""Handle ask for the Q A Agent instance."""
|
||||
_, result = get_agent_conversation_service().ask(
|
||||
query=query,
|
||||
filters=filters,
|
||||
provider=self.config.llm_provider,
|
||||
model=self.config.llm_model,
|
||||
top_k=self.config.top_k,
|
||||
prompt_template=prompt_template or self.config.prompt_template,
|
||||
)
|
||||
return AgentResponse(
|
||||
answer=result.answer,
|
||||
sources=[source.__dict__ for source in result.sources],
|
||||
model=result.model,
|
||||
latency_ms=result.latency_ms,
|
||||
retrieved_count=result.retrieved_count,
|
||||
context_tokens=result.context_tokens,
|
||||
truncated=result.truncated,
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
return [
|
||||
{"role": "system", "content": template.system_prompt},
|
||||
{"role": "user", "content": user_content}
|
||||
]
|
||||
|
||||
def ask_stream(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[str] = None,
|
||||
prompt_template: Optional[str] = None
|
||||
) -> Generator[Dict[str, Any], None, None]:
|
||||
"""
|
||||
流式回答用户问题(SSE模式)
|
||||
|
||||
返回事件类型:
|
||||
- {"event": "status", "data": "正在检索..."} - 状态更新
|
||||
- {"event": "sources", "data": [...]} - 引用来源
|
||||
- {"event": "content", "data": "文本片段"} - 回答内容
|
||||
- {"event": "done", "data": {"latency_ms": ..., "model": ...}} - 完成
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
filters: 检索过滤条件
|
||||
prompt_template: Prompt模板名称
|
||||
|
||||
Yields:
|
||||
Dict: SSE事件数据
|
||||
"""
|
||||
start_time = time.time()
|
||||
logger.info(f"收到流式问题: {query}")
|
||||
|
||||
try:
|
||||
# Step 1: 检索相关法规
|
||||
yield {"event": "status", "data": "正在检索相关法规..."}
|
||||
self._init_retriever()
|
||||
documents = self.retriever.retrieve(query, filters)
|
||||
retrieved_count = len(documents)
|
||||
|
||||
if retrieved_count == 0:
|
||||
yield {"event": "status", "data": "未找到相关法规"}
|
||||
yield {"event": "content", "data": "抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问。"}
|
||||
yield {"event": "done", "data": {"latency_ms": 0, "retrieved_count": 0}}
|
||||
return
|
||||
|
||||
# Step 2: 发送检索结果
|
||||
yield {"event": "status", "data": f"找到{retrieved_count}条相关法规,正在生成回答..."}
|
||||
sources = [
|
||||
{
|
||||
"doc_name": doc.doc_name,
|
||||
"doc_id": doc.doc_id,
|
||||
"clause_number": doc.clause_number,
|
||||
"score": doc.score
|
||||
}
|
||||
for doc in documents[:5] # 只返回前5条引用
|
||||
]
|
||||
yield {"event": "sources", "data": sources}
|
||||
|
||||
# Step 3: 构建RAG上下文
|
||||
self._init_context_builder()
|
||||
template_name = prompt_template or self.config.prompt_template
|
||||
template = get_prompt_template(template_name)
|
||||
context = self.context_builder.build(
|
||||
query=query,
|
||||
documents=documents,
|
||||
system_prompt=template.system_prompt
|
||||
)
|
||||
|
||||
# Step 4: 构建LLM输入消息
|
||||
messages = self._build_messages(template, context)
|
||||
|
||||
# Step 5: 流式调用LLM生成回答
|
||||
self._init_llm()
|
||||
full_answer = ""
|
||||
|
||||
# 检查LLM是否支持流式输出
|
||||
if hasattr(self.llm, 'stream_chat'):
|
||||
yield {"event": "status", "data": "思考中..."}
|
||||
for chunk in self.llm.stream_chat(
|
||||
messages=messages,
|
||||
temperature=self.config.temperature
|
||||
):
|
||||
full_answer += chunk
|
||||
yield {"event": "content", "data": chunk}
|
||||
else:
|
||||
# 如果不支持流式,回退到普通调用
|
||||
yield {"event": "status", "data": "生成回答中..."}
|
||||
llm_response = self.llm.chat(
|
||||
messages=messages,
|
||||
temperature=self.config.temperature
|
||||
)
|
||||
if llm_response.is_success:
|
||||
full_answer = llm_response.content
|
||||
yield {"event": "content", "data": full_answer}
|
||||
|
||||
# Step 6: 发送完成事件
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
logger.success(f"流式问答完成: {latency_ms}ms, {retrieved_count}条引用")
|
||||
|
||||
yield {
|
||||
"event": "done",
|
||||
"data": {
|
||||
"latency_ms": latency_ms,
|
||||
"model": self.config.llm_model,
|
||||
"retrieved_count": retrieved_count,
|
||||
"context_tokens": context.total_tokens
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式问答失败: {e}")
|
||||
yield {"event": "error", "data": str(e)}
|
||||
def ask_stream(self, query: str, filters: Optional[str] = None) -> Generator[dict, None, None]:
|
||||
"""Handle ask stream for the Q A Agent instance."""
|
||||
_, stream = get_agent_conversation_service().stream_chat(
|
||||
query=query,
|
||||
filters=filters,
|
||||
provider=self.config.llm_provider,
|
||||
model=self.config.llm_model,
|
||||
top_k=self.config.top_k,
|
||||
prompt_template=self.config.prompt_template,
|
||||
)
|
||||
for event in stream:
|
||||
yield event
|
||||
|
||||
def close(self):
|
||||
"""关闭Agent资源(不关闭LLM客户端,因为它全局缓存)"""
|
||||
if self.retriever:
|
||||
self.retriever.close()
|
||||
logger.info("问答Agent已关闭")
|
||||
"""Release the resources held by this component."""
|
||||
return None
|
||||
|
||||
|
||||
def ask_compliance_question(
|
||||
query: str,
|
||||
provider: str = "deepseek",
|
||||
model: str = "deepseek-v4-flash",
|
||||
top_k: int = 10
|
||||
) -> AgentResponse:
|
||||
"""
|
||||
便捷函数:问答合规问题
|
||||
|
||||
Args:
|
||||
query: 用户问题
|
||||
provider: LLM提供商
|
||||
model: LLM模型
|
||||
top_k: 检索数量
|
||||
|
||||
Returns:
|
||||
AgentResponse: 响应结果
|
||||
"""
|
||||
config = AgentConfig(
|
||||
llm_provider=provider,
|
||||
llm_model=model,
|
||||
top_k=top_k
|
||||
)
|
||||
agent = QAAgent(config)
|
||||
response = agent.ask(query)
|
||||
agent.close()
|
||||
return response
|
||||
def ask_compliance_question(query: str, top_k: int = 5) -> AgentResponse:
|
||||
"""Handle ask compliance question."""
|
||||
return QAAgent(AgentConfig(top_k=top_k)).ask(query)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""多轮对话会话管理"""
|
||||
"""Provide service-layer logic for session manager."""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
@@ -9,7 +9,7 @@ from loguru import logger
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""对话消息"""
|
||||
"""Represent the Chat Message type."""
|
||||
role: str # "user" / "assistant" / "system"
|
||||
content: str
|
||||
timestamp: int
|
||||
@@ -19,7 +19,7 @@ class ChatMessage:
|
||||
|
||||
@dataclass
|
||||
class ChatSession:
|
||||
"""对话会话"""
|
||||
"""Represent the Chat Session type."""
|
||||
session_id: str
|
||||
messages: List[ChatMessage] = field(default_factory=list)
|
||||
created_at: int = field(default_factory=lambda: int(time.time()))
|
||||
@@ -27,7 +27,7 @@ class ChatSession:
|
||||
metadata: Dict = field(default_factory=dict)
|
||||
|
||||
def add_user_message(self, content: str) -> ChatMessage:
|
||||
"""添加用户消息"""
|
||||
"""Handle add user message for the Chat Session instance."""
|
||||
message = ChatMessage(
|
||||
role="user",
|
||||
content=content,
|
||||
@@ -42,7 +42,7 @@ class ChatSession:
|
||||
content: str,
|
||||
sources: List[Dict] = None
|
||||
) -> ChatMessage:
|
||||
"""添加助手消息"""
|
||||
"""Handle add assistant message for the Chat Session instance."""
|
||||
message = ChatMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
@@ -54,9 +54,9 @@ class ChatSession:
|
||||
return message
|
||||
|
||||
def get_history(self, max_turns: int = 5) -> List[Dict[str, str]]:
|
||||
"""获取历史对话(用于LLM上下文)"""
|
||||
"""Return history for the Chat Session instance."""
|
||||
history = []
|
||||
# 获取最近N轮对话(每轮包含user + assistant)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
recent_messages = self.messages[-(max_turns * 2):]
|
||||
|
||||
for msg in recent_messages:
|
||||
@@ -68,81 +68,47 @@ class ChatSession:
|
||||
return history
|
||||
|
||||
def clear_history(self):
|
||||
"""清空对话历史"""
|
||||
"""Handle clear history for the Chat Session instance."""
|
||||
self.messages = []
|
||||
self.updated_at = int(time.time())
|
||||
logger.info(f"会话历史已清空: {self.session_id}")
|
||||
|
||||
@property
|
||||
def message_count(self) -> int:
|
||||
"""消息数量"""
|
||||
"""Handle message count for the Chat Session instance."""
|
||||
return len(self.messages)
|
||||
|
||||
@property
|
||||
def is_empty(self) -> bool:
|
||||
"""是否为空会话"""
|
||||
"""Return whether empty for the Chat Session instance."""
|
||||
return len(self.messages) == 0
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
会话管理器
|
||||
|
||||
功能:
|
||||
- 创建/获取/删除会话
|
||||
- 会话超时清理
|
||||
- 会话历史记录管理
|
||||
|
||||
使用示例:
|
||||
manager = SessionManager()
|
||||
|
||||
# 创建会话
|
||||
session = manager.create_session()
|
||||
|
||||
# 添加消息
|
||||
session.add_user_message("什么是机动车安全技术检验?")
|
||||
session.add_assistant_message("根据GB 7258...", sources=[...])
|
||||
|
||||
# 获取历史(用于LLM多轮对话)
|
||||
history = session.get_history(max_turns=3)
|
||||
"""
|
||||
"""Represent the Session Manager type."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_sessions: int = 100,
|
||||
session_timeout_minutes: int = 30
|
||||
):
|
||||
"""
|
||||
初始化会话管理器
|
||||
|
||||
Args:
|
||||
max_sessions: 最大会话数量
|
||||
session_timeout_minutes: 会话超时时间(分钟)
|
||||
"""
|
||||
"""Initialize the Session Manager instance."""
|
||||
self.max_sessions = max_sessions
|
||||
self.session_timeout = session_timeout_minutes * 60
|
||||
|
||||
# 会话存储(内存)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
self._sessions: Dict[str, ChatSession] = {}
|
||||
|
||||
logger.info(f"会话管理器初始化: max_sessions={max_sessions}, timeout={session_timeout_minutes}min")
|
||||
|
||||
def create_session(self, metadata: Dict = None) -> ChatSession:
|
||||
"""
|
||||
创建新会话
|
||||
|
||||
Args:
|
||||
metadata: 会话元数据(可选)
|
||||
|
||||
Returns:
|
||||
ChatSession: 新创建的会话
|
||||
"""
|
||||
# 检查会话数量限制
|
||||
"""Create session for the Session Manager instance."""
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if len(self._sessions) >= self.max_sessions:
|
||||
# 清理过期会话
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
self._cleanup_expired_sessions()
|
||||
|
||||
# 如果仍然超出限制,删除最老的会话
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if len(self._sessions) >= self.max_sessions:
|
||||
oldest_id = min(
|
||||
self._sessions.keys(),
|
||||
@@ -163,19 +129,11 @@ class SessionManager:
|
||||
return session
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[ChatSession]:
|
||||
"""
|
||||
获取会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
ChatSession: 会话对象(如不存在返回None)
|
||||
"""
|
||||
"""Return session for the Session Manager instance."""
|
||||
session = self._sessions.get(session_id)
|
||||
|
||||
if session:
|
||||
# 检查是否过期
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if self._is_session_expired(session):
|
||||
self.delete_session(session_id)
|
||||
logger.info(f"会话已过期,已删除: {session_id}")
|
||||
@@ -184,15 +142,7 @@ class SessionManager:
|
||||
return session
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""
|
||||
删除会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功删除
|
||||
"""
|
||||
"""Delete session for the Session Manager instance."""
|
||||
if session_id in self._sessions:
|
||||
del self._sessions[session_id]
|
||||
logger.info(f"删除会话: {session_id}")
|
||||
@@ -200,12 +150,7 @@ class SessionManager:
|
||||
return False
|
||||
|
||||
def list_sessions(self) -> List[Dict]:
|
||||
"""
|
||||
列出所有会话
|
||||
|
||||
Returns:
|
||||
List[Dict]: 会话列表摘要
|
||||
"""
|
||||
"""List sessions for the Session Manager instance."""
|
||||
return [
|
||||
{
|
||||
"session_id": s.session_id,
|
||||
@@ -217,12 +162,12 @@ class SessionManager:
|
||||
]
|
||||
|
||||
def _is_session_expired(self, session: ChatSession) -> bool:
|
||||
"""检查会话是否过期"""
|
||||
"""Handle is session expired for this module for the Session Manager instance."""
|
||||
current_time = int(time.time())
|
||||
return (current_time - session.updated_at) > self.session_timeout
|
||||
|
||||
def _cleanup_expired_sessions(self) -> int:
|
||||
"""清理过期会话"""
|
||||
"""Handle cleanup expired sessions for this module for the Session Manager instance."""
|
||||
expired_ids = [
|
||||
sid for sid, session in self._sessions.items()
|
||||
if self._is_session_expired(session)
|
||||
@@ -237,10 +182,10 @@ class SessionManager:
|
||||
return len(expired_ids)
|
||||
|
||||
def get_session_count(self) -> int:
|
||||
"""获取当前会话数量"""
|
||||
"""Return session count for the Session Manager instance."""
|
||||
return len(self._sessions)
|
||||
|
||||
def clear_all_sessions(self):
|
||||
"""清空所有会话"""
|
||||
"""Handle clear all sessions for the Session Manager instance."""
|
||||
self._sessions.clear()
|
||||
logger.info("所有会话已清空")
|
||||
|
||||
@@ -1,24 +1,19 @@
|
||||
"""文档处理主流程 - 解析→摘要→分块→嵌入→入库"""
|
||||
"""Provide service-layer logic for document processor."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from app.shared.bootstrap import get_document_command_service, get_retrieval_service
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
from .parser.pdf_parser import PDFParser
|
||||
from .parser.docx_parser import DocxParser
|
||||
from .parser.mineru_parser import ParserOrchestrator
|
||||
from .embedding.text_chunker import RegulationChunker, TextChunk
|
||||
from .embedding.bge_m3_embedder import BGEM3Embedder, EmbeddingResult
|
||||
from .storage.milvus_client import MilvusClient
|
||||
from .llm.document_summarizer import DocumentSummarizer, DocumentSummary
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingResult:
|
||||
"""文档处理结果"""
|
||||
"""Represent the Processing Result type."""
|
||||
doc_id: str
|
||||
doc_name: str
|
||||
success: bool
|
||||
@@ -30,87 +25,10 @@ class ProcessingResult:
|
||||
|
||||
|
||||
class DocumentProcessor:
|
||||
"""
|
||||
文档处理服务 - 完整处理流程
|
||||
|
||||
流程:
|
||||
1. 文档解析(PDF/DOCX → Markdown)
|
||||
2. 智能分块(章节级+条款级)
|
||||
3. LLM摘要生成(可选)
|
||||
4. 向量嵌入(BGE-M3 Dense+Sparse)
|
||||
5. 存储入库(Milvus向量数据库)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chunk_size: int = None,
|
||||
embedding_model: str = None,
|
||||
use_mineru: bool = True,
|
||||
generate_summary: bool = False, # 默认不生成摘要,节省约60秒
|
||||
llm_provider: str = None,
|
||||
llm_model: str = None
|
||||
):
|
||||
"""
|
||||
初始化文档处理器
|
||||
|
||||
Args:
|
||||
chunk_size: 分块大小
|
||||
embedding_model: 嵌入模型名称
|
||||
use_mineru: 是否优先使用MinerU解析
|
||||
generate_summary: 是否生成文档摘要(默认False,可节省约60秒处理时间)
|
||||
llm_provider: LLM提供商
|
||||
llm_model: LLM模型名称
|
||||
"""
|
||||
self.chunk_size = chunk_size or settings.chunk_size
|
||||
self.embedding_model = embedding_model or settings.embedding_model
|
||||
self.use_mineru = use_mineru
|
||||
"""Represent the Document Processor type."""
|
||||
def __init__(self, *args, generate_summary: bool = False, **kwargs):
|
||||
"""Initialize the Document Processor instance."""
|
||||
self.generate_summary = generate_summary
|
||||
self.llm_provider = llm_provider or settings.llm_provider
|
||||
self.llm_model = llm_model or settings.llm_model
|
||||
|
||||
# 初始化各组件
|
||||
logger.info("初始化文档处理组件...")
|
||||
|
||||
# 解析器
|
||||
self.parser = ParserOrchestrator()
|
||||
|
||||
# 分块器
|
||||
self.chunker = RegulationChunker(chunk_size=self.chunk_size)
|
||||
|
||||
# 嵌入模型(延迟加载)
|
||||
self.embedder: Optional[BGEM3Embedder] = None
|
||||
|
||||
# Milvus客户端(延迟连接)
|
||||
self.milvus: Optional[MilvusClient] = None
|
||||
|
||||
# 摘要生成器(延迟加载)
|
||||
self.summarizer: Optional[DocumentSummarizer] = None
|
||||
|
||||
logger.success("文档处理器初始化完成")
|
||||
|
||||
def _init_embedder(self):
|
||||
"""延迟初始化嵌入模型"""
|
||||
if self.embedder is None:
|
||||
logger.info("加载嵌入模型...")
|
||||
self.embedder = BGEM3Embedder(model_name=self.embedding_model)
|
||||
|
||||
def _init_milvus(self):
|
||||
"""延迟初始化Milvus连接"""
|
||||
if self.milvus is None:
|
||||
logger.info("连接Milvus...")
|
||||
self.milvus = MilvusClient()
|
||||
self.milvus.connect()
|
||||
self.milvus.create_collection(recreate=False)
|
||||
self.milvus.load_collection()
|
||||
|
||||
def _init_summarizer(self):
|
||||
"""延迟初始化摘要生成器"""
|
||||
if self.summarizer is None:
|
||||
logger.info("初始化摘要生成器...")
|
||||
self.summarizer = DocumentSummarizer(
|
||||
provider=self.llm_provider,
|
||||
model=self.llm_model
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
@@ -118,286 +36,51 @@ class DocumentProcessor:
|
||||
doc_id: Optional[str] = None,
|
||||
doc_name: Optional[str] = None,
|
||||
regulation_type: str = "",
|
||||
version: str = ""
|
||||
version: str = "",
|
||||
) -> ProcessingResult:
|
||||
"""
|
||||
处理单个文档
|
||||
"""Handle process for the Document Processor instance."""
|
||||
path = Path(file_path)
|
||||
content = path.read_bytes()
|
||||
result = get_document_command_service().upload_and_process(
|
||||
doc_id=doc_id,
|
||||
file_name=path.name,
|
||||
content=content,
|
||||
content_type="application/octet-stream",
|
||||
doc_name=doc_name or path.name,
|
||||
regulation_type=regulation_type,
|
||||
version=version,
|
||||
generate_summary=self.generate_summary,
|
||||
)
|
||||
return ProcessingResult(
|
||||
doc_id=result.doc_id,
|
||||
doc_name=result.doc_name,
|
||||
success=result.status != "failed",
|
||||
num_chunks=result.num_chunks,
|
||||
message=result.message,
|
||||
summary=result.summary,
|
||||
summary_latency_ms=result.summary_latency_ms,
|
||||
)
|
||||
|
||||
Args:
|
||||
file_path: 文档文件路径
|
||||
doc_id: 文档ID(可选,默认自动生成)
|
||||
doc_name: 文档名称(可选,默认从文件名获取)
|
||||
regulation_type: 法规类型
|
||||
version: 文档版本
|
||||
|
||||
Returns:
|
||||
ProcessingResult: 处理结果
|
||||
"""
|
||||
# 生成或使用传入的文档ID
|
||||
if doc_id is None:
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# 获取文档名称
|
||||
if doc_name is None:
|
||||
doc_name = os.path.basename(file_path)
|
||||
|
||||
logger.info(f"开始处理文档: {doc_name} (ID: {doc_id})")
|
||||
|
||||
# 初始化结果变量
|
||||
summary = ""
|
||||
summary_latency_ms = 0
|
||||
|
||||
try:
|
||||
# 1. 文档解析
|
||||
logger.info("Step 1: 文档解析")
|
||||
markdown_text = self._parse_document(file_path)
|
||||
|
||||
if not markdown_text:
|
||||
return ProcessingResult(
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
success=False,
|
||||
message="文档解析失败,内容为空"
|
||||
)
|
||||
|
||||
# 2. LLM摘要生成(可选)
|
||||
if self.generate_summary:
|
||||
logger.info("Step 2: LLM摘要生成")
|
||||
self._init_summarizer()
|
||||
summary_result = self.summarizer.summarize(
|
||||
doc_name,
|
||||
markdown_text,
|
||||
regulation_type
|
||||
)
|
||||
if summary_result.is_success:
|
||||
summary = summary_result.summary
|
||||
summary_latency_ms = summary_result.latency_ms
|
||||
logger.success(f"摘要生成完成: {summary_latency_ms}ms")
|
||||
else:
|
||||
logger.warning(f"摘要生成失败: {summary_result.error}")
|
||||
else:
|
||||
logger.info("Step 2: 跳过摘要生成(未勾选,节省约60秒)")
|
||||
|
||||
# 3. 智能分块
|
||||
logger.info("Step 3: 智能分块")
|
||||
chunks = self._chunk_document(
|
||||
markdown_text,
|
||||
doc_id,
|
||||
doc_name,
|
||||
regulation_type,
|
||||
version
|
||||
)
|
||||
|
||||
if not chunks:
|
||||
return ProcessingResult(
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
success=False,
|
||||
message="分块失败,无有效内容",
|
||||
markdown_text=markdown_text,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
# 4. 向量嵌入
|
||||
logger.info("Step 4: 向量嵌入")
|
||||
embeddings = self._embed_chunks(chunks)
|
||||
|
||||
if embeddings is None:
|
||||
return ProcessingResult(
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
success=False,
|
||||
message="向量嵌入失败",
|
||||
markdown_text=markdown_text,
|
||||
summary=summary
|
||||
)
|
||||
|
||||
# 5. 存储入库
|
||||
logger.info("Step 5: 存储入库")
|
||||
inserted_ids = self._insert_to_milvus(chunks, embeddings)
|
||||
|
||||
logger.success(f"文档处理完成: {doc_name}, 共{len(inserted_ids)}条记录")
|
||||
|
||||
return ProcessingResult(
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
success=True,
|
||||
num_chunks=len(inserted_ids),
|
||||
message="处理成功",
|
||||
markdown_text=markdown_text,
|
||||
summary=summary,
|
||||
summary_latency_ms=summary_latency_ms
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文档处理失败: {e}")
|
||||
return ProcessingResult(
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
success=False,
|
||||
message=f"处理失败: {str(e)}"
|
||||
)
|
||||
|
||||
def _parse_document(self, file_path: str) -> str:
|
||||
"""解析文档"""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
try:
|
||||
if ext == ".pdf":
|
||||
# PDF文档解析(优先MinerU,回退PyMuPDF)
|
||||
markdown_text = self.parser.parse_pdf(file_path, prefer_mineru=self.use_mineru)
|
||||
elif ext in [".docx", ".doc"]:
|
||||
# Word文档解析
|
||||
markdown_text = self.parser.parse_docx(file_path)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {ext}")
|
||||
return ""
|
||||
|
||||
logger.success(f"文档解析完成,内容长度: {len(markdown_text)}字符")
|
||||
return markdown_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文档解析失败: {e}")
|
||||
return ""
|
||||
|
||||
def _chunk_document(
|
||||
self,
|
||||
markdown_text: str,
|
||||
doc_id: str,
|
||||
doc_name: str,
|
||||
regulation_type: str,
|
||||
version: str
|
||||
) -> List[TextChunk]:
|
||||
"""分块文档"""
|
||||
try:
|
||||
chunks = self.chunker.chunk_document(
|
||||
markdown_text,
|
||||
doc_id=doc_id,
|
||||
doc_name=doc_name,
|
||||
regulation_type=regulation_type,
|
||||
version=version
|
||||
)
|
||||
logger.success(f"分块完成,共{len(chunks)}个chunk")
|
||||
return chunks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分块失败: {e}")
|
||||
return []
|
||||
|
||||
def _embed_chunks(self, chunks: List[TextChunk]) -> Optional[EmbeddingResult]:
|
||||
"""嵌入分块"""
|
||||
try:
|
||||
# 延迟初始化嵌入模型
|
||||
self._init_embedder()
|
||||
|
||||
# 提取文本内容
|
||||
texts = [chunk.content for chunk in chunks]
|
||||
|
||||
# 执行嵌入
|
||||
embeddings = self.embedder.embed(texts)
|
||||
|
||||
logger.success(f"嵌入完成,向量数: {len(embeddings.dense_embeddings)}")
|
||||
return embeddings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"嵌入失败: {e}")
|
||||
return None
|
||||
|
||||
def _insert_to_milvus(
|
||||
self,
|
||||
chunks: List[TextChunk],
|
||||
embeddings: EmbeddingResult
|
||||
) -> List[int]:
|
||||
"""插入Milvus"""
|
||||
try:
|
||||
# 延迟初始化Milvus
|
||||
self._init_milvus()
|
||||
|
||||
# 执行插入
|
||||
inserted_ids = self.milvus.insert_chunks(chunks, embeddings)
|
||||
|
||||
logger.success(f"入库完成,共{len(inserted_ids)}条记录")
|
||||
return inserted_ids
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"入库失败: {e}")
|
||||
return []
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
filters: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
检索法规内容
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
top_k: 返回结果数量
|
||||
filters: 过滤条件
|
||||
|
||||
Returns:
|
||||
List[Dict]: 检索结果
|
||||
"""
|
||||
logger.info(f"执行检索: {query}")
|
||||
|
||||
try:
|
||||
# 延迟初始化
|
||||
self._init_embedder()
|
||||
self._init_milvus()
|
||||
|
||||
# 生成查询向量
|
||||
query_embedding = self.embedder.embed_single(query)
|
||||
|
||||
# 执行混合检索
|
||||
results = self.milvus.hybrid_search(
|
||||
query_dense=query_embedding['dense'].tolist(),
|
||||
query_sparse=query_embedding['sparse'],
|
||||
top_k=top_k,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
result_dicts = []
|
||||
for r in results:
|
||||
result_dicts.append({
|
||||
"id": r.id,
|
||||
"content": r.content,
|
||||
"score": r.score,
|
||||
"metadata": r.metadata
|
||||
})
|
||||
|
||||
logger.success(f"检索完成,返回{len(result_dicts)}条结果")
|
||||
return result_dicts
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检索失败: {e}")
|
||||
return []
|
||||
def search(self, query: str, top_k: int = 10, filters: str | None = None) -> list[dict]:
|
||||
"""Handle search for the Document Processor instance."""
|
||||
results = get_retrieval_service().retrieve(query=query, top_k=top_k, filters=filters)
|
||||
return [
|
||||
{
|
||||
"id": item.chunk_id,
|
||||
"content": item.content,
|
||||
"score": item.score,
|
||||
"metadata": {
|
||||
"doc_id": item.doc_id,
|
||||
"doc_name": item.doc_name,
|
||||
"chunk_id": item.chunk_id,
|
||||
"section_title": item.section_title,
|
||||
"page_number": item.page_number,
|
||||
**item.metadata,
|
||||
},
|
||||
}
|
||||
for item in results
|
||||
]
|
||||
|
||||
def close(self):
|
||||
"""关闭连接"""
|
||||
if self.milvus:
|
||||
self.milvus.disconnect()
|
||||
logger.info("文档处理器已关闭")
|
||||
|
||||
|
||||
def process_document(
|
||||
file_path: str,
|
||||
doc_name: Optional[str] = None,
|
||||
regulation_type: str = "",
|
||||
version: str = ""
|
||||
) -> ProcessingResult:
|
||||
"""便捷函数:处理单个文档"""
|
||||
processor = DocumentProcessor()
|
||||
result = processor.process(file_path, doc_name, regulation_type, version)
|
||||
processor.close()
|
||||
return result
|
||||
|
||||
|
||||
def search_regulations(query: str, top_k: int = 10) -> List[Dict]:
|
||||
"""便捷函数:检索法规"""
|
||||
processor = DocumentProcessor()
|
||||
results = processor.search(query, top_k)
|
||||
processor.close()
|
||||
return results
|
||||
"""Release the resources held by this component."""
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
"""嵌入和分块服务"""
|
||||
"""Initialize the app.services.embedding package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
from .text_chunker import RegulationChunker
|
||||
from .bge_m3_embedder import BGEM3Embedder
|
||||
|
||||
__all__ = ["RegulationChunker", "BGEM3Embedder"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""Handle getattr for this module."""
|
||||
if name == "RegulationChunker":
|
||||
from .text_chunker import RegulationChunker
|
||||
|
||||
return RegulationChunker
|
||||
if name == "BGEM3Embedder":
|
||||
from .bge_m3_embedder import BGEM3Embedder
|
||||
|
||||
return BGEM3Embedder
|
||||
raise AttributeError(name)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""BGE-M3嵌入服务 - Dense+Sparse双路向量生成"""
|
||||
"""Provide service-layer logic for bge m3 embedder."""
|
||||
|
||||
import numpy as np
|
||||
from typing import List, Dict, Optional, Union
|
||||
@@ -6,43 +6,31 @@ from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
import torch
|
||||
import os
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
# 设置HuggingFace镜像(国内网络)
|
||||
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if 'HF_ENDPOINT' not in os.environ:
|
||||
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
||||
|
||||
# 本地模型路径(按优先级检查)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
LOCAL_MODEL_PATHS = [
|
||||
os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # ModelScope下载路径
|
||||
os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # HuggingFace本地路径
|
||||
os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingResult:
|
||||
"""嵌入结果"""
|
||||
dense_embeddings: np.ndarray # Dense向量(语义检索)
|
||||
sparse_embeddings: List[Dict[int, float]] # Sparse向量(关键词匹配)
|
||||
"""Represent the Embedding Result type."""
|
||||
dense_embeddings: np.ndarray # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
sparse_embeddings: List[Dict[int, float]] # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
texts: List[str]
|
||||
dim: int = 1024
|
||||
|
||||
|
||||
class BGEM3Embedder:
|
||||
"""
|
||||
BGE-M3多语言嵌入模型服务
|
||||
|
||||
BGE-M3是BAAI发布的多语言嵌入模型,支持:
|
||||
- Dense向量:用于语义相似度检索
|
||||
- Sparse向量:用于关键词精确匹配(BM25风格)
|
||||
- ColBERT向量:用于细粒度交互匹配(可选)
|
||||
|
||||
特点:
|
||||
- 支持100+语言(中英双语优化)
|
||||
- 8192 tokens超长上下文
|
||||
- Dense+Sparse双路检索能力
|
||||
|
||||
GitHub: https://github.com/FlagOpen/FlagEmbedding
|
||||
"""
|
||||
"""Represent the B G E M3 Embedder type."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -53,28 +41,18 @@ class BGEM3Embedder:
|
||||
max_length: int = 8192,
|
||||
local_model_path: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化BGE-M3嵌入模型
|
||||
|
||||
Args:
|
||||
model_name: 模型名称(如果使用本地路径,此参数会被忽略)
|
||||
use_fp16: 是否使用FP16加速
|
||||
device: 设备类型(cuda/cpu),默认自动选择
|
||||
batch_size: 批处理大小
|
||||
max_length: 最大序列长度
|
||||
local_model_path: 本地模型路径(可选,优先使用)
|
||||
"""
|
||||
"""Initialize the B G E M3 Embedder instance."""
|
||||
self.use_fp16 = use_fp16
|
||||
self.batch_size = batch_size
|
||||
self.max_length = max_length
|
||||
|
||||
# 确定模型路径(优先使用本地路径)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if local_model_path and os.path.exists(local_model_path):
|
||||
self.model_path = local_model_path
|
||||
self.model_name = "local"
|
||||
logger.info(f"使用本地模型路径: {local_model_path}")
|
||||
else:
|
||||
# 检查多个可能的本地路径
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
found_local = False
|
||||
for path in LOCAL_MODEL_PATHS:
|
||||
if os.path.exists(path) and os.path.exists(os.path.join(path, "config.json")):
|
||||
@@ -89,7 +67,7 @@ class BGEM3Embedder:
|
||||
self.model_name = model_name
|
||||
logger.info(f"使用远程模型: {model_name}")
|
||||
|
||||
# 自动选择设备
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if device is None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
@@ -101,7 +79,7 @@ class BGEM3Embedder:
|
||||
self._load_model()
|
||||
|
||||
def _load_model(self):
|
||||
"""加载嵌入模型"""
|
||||
"""Handle load model for this module for the B G E M3 Embedder instance."""
|
||||
try:
|
||||
from FlagEmbedding import BGEM3FlagModel
|
||||
|
||||
@@ -127,18 +105,7 @@ class BGEM3Embedder:
|
||||
return_sparse: bool = True,
|
||||
return_colbert_vecs: bool = False
|
||||
) -> EmbeddingResult:
|
||||
"""
|
||||
对文本列表生成嵌入向量
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
return_dense: 是否返回Dense向量
|
||||
return_sparse: 是否返回Sparse向量
|
||||
return_colbert_vecs: 是否返回ColBERT向量
|
||||
|
||||
Returns:
|
||||
EmbeddingResult: 嵌入结果
|
||||
"""
|
||||
"""Handle embed for the B G E M3 Embedder instance."""
|
||||
if not texts:
|
||||
logger.warning("输入文本列表为空")
|
||||
return EmbeddingResult(
|
||||
@@ -151,7 +118,7 @@ class BGEM3Embedder:
|
||||
logger.info(f"开始嵌入{len(texts)}个文本块")
|
||||
|
||||
try:
|
||||
# 执行嵌入
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
embeddings = self.model.encode(
|
||||
texts,
|
||||
batch_size=self.batch_size,
|
||||
@@ -161,11 +128,11 @@ class BGEM3Embedder:
|
||||
return_colbert_vecs=return_colbert_vecs
|
||||
)
|
||||
|
||||
# 提取结果
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
dense_embeddings = embeddings.get('dense_vecs', np.array([]))
|
||||
sparse_embeddings = embeddings.get('lexical_weights', [])
|
||||
|
||||
# 获取维度
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
dim = dense_embeddings.shape[1] if len(dense_embeddings) > 0 else 1024
|
||||
|
||||
logger.success(f"嵌入完成,向量维度: {dim}")
|
||||
@@ -182,15 +149,7 @@ class BGEM3Embedder:
|
||||
raise
|
||||
|
||||
def embed_single(self, text: str) -> Dict[str, Union[np.ndarray, Dict]]:
|
||||
"""
|
||||
对单个文本生成嵌入向量
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
Dict: 包含dense和sparse向量
|
||||
"""
|
||||
"""Embed single for the B G E M3 Embedder instance."""
|
||||
result = self.embed([text])
|
||||
return {
|
||||
'dense': result.dense_embeddings[0],
|
||||
@@ -199,25 +158,17 @@ class BGEM3Embedder:
|
||||
}
|
||||
|
||||
def embed_dense(self, texts: List[str]) -> np.ndarray:
|
||||
"""只生成Dense向量"""
|
||||
"""Embed dense for the B G E M3 Embedder instance."""
|
||||
result = self.embed(texts, return_sparse=False, return_colbert_vecs=False)
|
||||
return result.dense_embeddings
|
||||
|
||||
def embed_sparse(self, texts: List[str]) -> List[Dict[int, float]]:
|
||||
"""只生成Sparse向量"""
|
||||
"""Embed sparse for the B G E M3 Embedder instance."""
|
||||
result = self.embed(texts, return_dense=False, return_colbert_vecs=False)
|
||||
return result.sparse_embeddings
|
||||
|
||||
def embed_query(self, query: str) -> Dict:
|
||||
"""
|
||||
对查询文本生成嵌入(用于检索)
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
|
||||
Returns:
|
||||
Dict: 包含dense和sparse向量
|
||||
"""
|
||||
"""Embed query for the B G E M3 Embedder instance."""
|
||||
return self.embed_single(query)
|
||||
|
||||
def compute_similarity(
|
||||
@@ -226,26 +177,16 @@ class BGEM3Embedder:
|
||||
doc_embeddings: np.ndarray,
|
||||
metric: str = "cosine"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
计算查询与文档的相似度
|
||||
|
||||
Args:
|
||||
query_embedding: 查询向量
|
||||
doc_embeddings: 文档向量矩阵
|
||||
metric: 相似度度量(cosine/dot)
|
||||
|
||||
Returns:
|
||||
np.ndarray: 相似度分数数组
|
||||
"""
|
||||
"""Handle compute similarity for the B G E M3 Embedder instance."""
|
||||
if metric == "cosine":
|
||||
# 余弦相似度
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
query_norm = np.linalg.norm(query_embedding)
|
||||
doc_norms = np.linalg.norm(doc_embeddings, axis=1)
|
||||
|
||||
similarities = np.dot(doc_embeddings, query_embedding) / (doc_norms * query_norm)
|
||||
|
||||
elif metric == "dot":
|
||||
# 点积相似度
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
similarities = np.dot(doc_embeddings, query_embedding)
|
||||
|
||||
else:
|
||||
@@ -258,17 +199,8 @@ class BGEM3Embedder:
|
||||
query_sparse: Dict[int, float],
|
||||
doc_sparse: Dict[int, float]
|
||||
) -> float:
|
||||
"""
|
||||
计算Sparse向量的相似度(BM25风格)
|
||||
|
||||
Args:
|
||||
query_sparse: 查询的Sparse向量(词ID -> 权重)
|
||||
doc_sparse: 文档的Sparse向量
|
||||
|
||||
Returns:
|
||||
float: 相似度分数
|
||||
"""
|
||||
# 计算交集词的点积
|
||||
"""Handle sparse similarity for the B G E M3 Embedder instance."""
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
common_keys = set(query_sparse.keys()) & set(doc_sparse.keys())
|
||||
|
||||
score = sum(query_sparse[k] * doc_sparse[k] for k in common_keys)
|
||||
@@ -280,7 +212,7 @@ def embed_texts(
|
||||
model_name: str = "BAAI/bge-m3",
|
||||
**kwargs
|
||||
) -> EmbeddingResult:
|
||||
"""便捷函数:对文本列表生成嵌入"""
|
||||
"""Embed texts."""
|
||||
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
|
||||
return embedder.embed(texts)
|
||||
|
||||
@@ -290,6 +222,6 @@ def embed_single_text(
|
||||
model_name: str = "BAAI/bge-m3",
|
||||
**kwargs
|
||||
) -> Dict:
|
||||
"""便捷函数:对单个文本生成嵌入"""
|
||||
"""Embed single text."""
|
||||
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
|
||||
return embedder.embed_single(text)
|
||||
|
||||
@@ -1,51 +1,46 @@
|
||||
"""智能分块器 - 章节级+条款级双粒度切割"""
|
||||
"""Provide service-layer logic for text chunker."""
|
||||
|
||||
import re
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkMetadata:
|
||||
"""分块元数据"""
|
||||
"""Represent the Chunk Metadata type."""
|
||||
doc_id: str = ""
|
||||
doc_name: str = ""
|
||||
chunk_id: str = ""
|
||||
section_number: str = "" # 章节编号(如 "第一章")
|
||||
section_title: str = "" # 章节标题
|
||||
clause_number: str = "" # 条款编号(如 "第一条")
|
||||
section_number: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
section_title: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
clause_number: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
page_number: int = 0
|
||||
start_position: int = 0 # 在原文中的起始位置
|
||||
end_position: int = 0 # 在原文中的结束位置
|
||||
regulation_type: str = "" # 法规类型
|
||||
start_position: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
end_position: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
regulation_type: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
version: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextChunk:
|
||||
"""文本分块"""
|
||||
"""Represent the Text Chunk type."""
|
||||
content: str
|
||||
metadata: ChunkMetadata
|
||||
token_count: int = 0 # 估算的token数量
|
||||
token_count: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
class RegulationChunker:
|
||||
"""
|
||||
法规文档智能分块器
|
||||
"""Represent the Regulation Chunker type."""
|
||||
|
||||
实现章节级/条款级双粒度切割,适配国标GB文档结构:
|
||||
- 国标文档通常有明确的层级结构:章 > 节 > 条
|
||||
- 每个条款应作为一个独立的语义单元
|
||||
- 保留条款完整性,避免跨条款截断
|
||||
"""
|
||||
|
||||
# 法规标题模式
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
CHAPTER_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+章\s+[^\n]+')
|
||||
SECTION_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+节\s+[^\n]+')
|
||||
CLAUSE_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+条\s')
|
||||
|
||||
# 条款子项模式
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
SUB_ITEM_PATTERN = re.compile(r'^[\((][一二三四五六七八九十]+[\))]\s')
|
||||
NUMBER_ITEM_PATTERN = re.compile(r'^[\d]+[\.、]\s')
|
||||
|
||||
@@ -56,15 +51,7 @@ class RegulationChunker:
|
||||
max_chunk_size: int = 2048,
|
||||
min_chunk_size: int = 100
|
||||
):
|
||||
"""
|
||||
初始化分块器
|
||||
|
||||
Args:
|
||||
chunk_size: 默认分块大小(字符数)
|
||||
chunk_overlap: 分块重叠大小
|
||||
max_chunk_size: 最大分块大小(防止单个条款过长)
|
||||
min_chunk_size: 最小分块大小(防止碎片化)
|
||||
"""
|
||||
"""Initialize the Regulation Chunker instance."""
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
self.max_chunk_size = max_chunk_size
|
||||
@@ -78,30 +65,18 @@ class RegulationChunker:
|
||||
regulation_type: str = "",
|
||||
version: str = ""
|
||||
) -> List[TextChunk]:
|
||||
"""
|
||||
对法规文档进行智能分块
|
||||
|
||||
Args:
|
||||
markdown_text: Markdown格式的文档内容
|
||||
doc_id: 文档ID
|
||||
doc_name: 文档名称
|
||||
regulation_type: 法规类型
|
||||
version: 文档版本
|
||||
|
||||
Returns:
|
||||
List[TextChunk]: 分块列表
|
||||
"""
|
||||
"""Handle chunk document for the Regulation Chunker instance."""
|
||||
logger.info(f"开始分块文档: {doc_name}")
|
||||
|
||||
# 1. 按章节分割(一级分块)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
sections = self._split_by_sections(markdown_text)
|
||||
|
||||
# 2. 在每个章节内按条款分割(二级分块)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
chunks = []
|
||||
global_position = 0
|
||||
|
||||
for section_num, section_title, section_content, section_start in sections:
|
||||
# 在章节内按条款分割
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
clause_chunks = self._split_by_clauses(
|
||||
section_content,
|
||||
section_num,
|
||||
@@ -110,7 +85,7 @@ class RegulationChunker:
|
||||
)
|
||||
|
||||
for chunk_content, clause_num, clause_title, start_pos, end_pos in clause_chunks:
|
||||
# 处理过长的条款(进一步细分)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if len(chunk_content) > self.max_chunk_size:
|
||||
sub_chunks = self._split_long_clause(
|
||||
chunk_content,
|
||||
@@ -150,12 +125,7 @@ class RegulationChunker:
|
||||
return chunks
|
||||
|
||||
def _split_by_sections(self, markdown_text: str) -> List[Tuple[str, str, str, int]]:
|
||||
"""
|
||||
按章节分割文档
|
||||
|
||||
Returns:
|
||||
List of (section_number, section_title, section_content, start_position)
|
||||
"""
|
||||
"""Handle split by sections for this module for the Regulation Chunker instance."""
|
||||
sections = []
|
||||
lines = markdown_text.split('\n')
|
||||
|
||||
@@ -165,12 +135,12 @@ class RegulationChunker:
|
||||
current_section_start = 0
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
# 检测章节标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
chapter_match = self.CHAPTER_PATTERN.match(line.strip())
|
||||
section_match = self.SECTION_PATTERN.match(line.strip())
|
||||
|
||||
if chapter_match or section_match:
|
||||
# 保存上一个章节
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if current_section_content:
|
||||
content = '\n'.join(current_section_content)
|
||||
sections.append((
|
||||
@@ -180,7 +150,7 @@ class RegulationChunker:
|
||||
current_section_start
|
||||
))
|
||||
|
||||
# 开始新章节
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
current_section_start = sum(len(l) + 1 for l in lines[:i])
|
||||
current_section_content = []
|
||||
|
||||
@@ -193,7 +163,7 @@ class RegulationChunker:
|
||||
|
||||
current_section_content.append(line)
|
||||
|
||||
# 保存最后一个章节
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if current_section_content:
|
||||
content = '\n'.join(current_section_content)
|
||||
sections.append((
|
||||
@@ -203,7 +173,7 @@ class RegulationChunker:
|
||||
current_section_start
|
||||
))
|
||||
|
||||
# 如果没有检测到章节,将整个文档作为一个大章节
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if not sections:
|
||||
sections.append((
|
||||
"",
|
||||
@@ -221,12 +191,7 @@ class RegulationChunker:
|
||||
section_title: str,
|
||||
section_start: int
|
||||
) -> List[Tuple[str, str, str, int, int]]:
|
||||
"""
|
||||
在章节内按条款分割
|
||||
|
||||
Returns:
|
||||
List of (content, clause_number, clause_title, start_position, end_position)
|
||||
"""
|
||||
"""Handle split by clauses for this module for the Regulation Chunker instance."""
|
||||
clauses = []
|
||||
lines = section_content.split('\n')
|
||||
|
||||
@@ -236,11 +201,11 @@ class RegulationChunker:
|
||||
current_clause_start = section_start
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
# 检测条款标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
clause_match = self.CLAUSE_PATTERN.match(line.strip())
|
||||
|
||||
if clause_match:
|
||||
# 保存上一个条款
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if current_clause_content:
|
||||
content = '\n'.join(current_clause_content)
|
||||
end_pos = current_clause_start + len(content)
|
||||
@@ -252,7 +217,7 @@ class RegulationChunker:
|
||||
end_pos
|
||||
))
|
||||
|
||||
# 开始新条款
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
current_clause_start = section_start + sum(len(l) + 1 for l in lines[:i])
|
||||
current_clause_content = []
|
||||
current_clause_num = self._extract_clause_number(line.strip())
|
||||
@@ -260,7 +225,7 @@ class RegulationChunker:
|
||||
|
||||
current_clause_content.append(line)
|
||||
|
||||
# 保存最后一个条款
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if current_clause_content:
|
||||
content = '\n'.join(current_clause_content)
|
||||
end_pos = current_clause_start + len(content)
|
||||
@@ -272,7 +237,7 @@ class RegulationChunker:
|
||||
end_pos
|
||||
))
|
||||
|
||||
# 如果没有检测到条款,将整个章节作为一个条款
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if not clauses:
|
||||
clauses.append((
|
||||
section_content,
|
||||
@@ -290,15 +255,11 @@ class RegulationChunker:
|
||||
clause_num: str,
|
||||
clause_title: str
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
"""
|
||||
分割过长的条款内容
|
||||
|
||||
按条款子项或段落分割,保持语义完整性
|
||||
"""
|
||||
"""Handle split long clause for this module for the Regulation Chunker instance."""
|
||||
sub_chunks = []
|
||||
lines = content.split('\n')
|
||||
|
||||
# 检测是否有子项结构
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
has_sub_items = any(
|
||||
self.SUB_ITEM_PATTERN.match(line.strip()) or
|
||||
self.NUMBER_ITEM_PATTERN.match(line.strip())
|
||||
@@ -306,7 +267,7 @@ class RegulationChunker:
|
||||
)
|
||||
|
||||
if has_sub_items:
|
||||
# 按子项分割
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
current_sub_content = []
|
||||
current_sub_start = 0
|
||||
|
||||
@@ -326,14 +287,14 @@ class RegulationChunker:
|
||||
|
||||
current_sub_content.append(line)
|
||||
|
||||
# 保存最后一个子项
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if current_sub_content:
|
||||
sub_content = '\n'.join(current_sub_content)
|
||||
sub_end = current_sub_start + len(sub_content)
|
||||
sub_chunks.append((sub_content, current_sub_start, sub_end))
|
||||
|
||||
else:
|
||||
# 按段落分割(滑动窗口)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
paragraphs = []
|
||||
current_para = []
|
||||
|
||||
@@ -348,7 +309,7 @@ class RegulationChunker:
|
||||
if current_para:
|
||||
paragraphs.append('\n'.join(current_para))
|
||||
|
||||
# 合并段落直到达到chunk_size
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
chunk_start = 0
|
||||
@@ -365,7 +326,7 @@ class RegulationChunker:
|
||||
current_chunk.append(para)
|
||||
current_length += len(para)
|
||||
|
||||
# 保存最后一个chunk
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if current_chunk:
|
||||
chunk_content = '\n'.join(current_chunk)
|
||||
chunk_end = chunk_start + len(chunk_content)
|
||||
@@ -374,13 +335,13 @@ class RegulationChunker:
|
||||
return sub_chunks
|
||||
|
||||
def _extract_title(self, header_line: str) -> str:
|
||||
"""从标题行提取标题内容"""
|
||||
# 移除"第X章"、"第X节"前缀
|
||||
"""Handle extract title for this module for the Regulation Chunker instance."""
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
title = re.sub(r'^第[一二三四五六七八九十百]+[章节]\s+', '', header_line)
|
||||
return title.strip()
|
||||
|
||||
def _extract_clause_number(self, clause_line: str) -> str:
|
||||
"""从条款行提取条款编号"""
|
||||
"""Handle extract clause number for this module for the Regulation Chunker instance."""
|
||||
match = self.CLAUSE_PATTERN.match(clause_line)
|
||||
if match:
|
||||
return match.group(0).strip()
|
||||
@@ -399,14 +360,14 @@ class RegulationChunker:
|
||||
regulation_type: str,
|
||||
version: str
|
||||
) -> TextChunk:
|
||||
"""创建文本分块"""
|
||||
# 清理内容
|
||||
"""Handle create chunk for this module for the Regulation Chunker instance."""
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
content = content.strip()
|
||||
|
||||
# 计算估算token数(中文约1.5字符/token)
|
||||
token_count = int(len(content) * 0.7) # 简化估算
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
token_count = int(len(content) * 0.7) # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
# 生成chunk_id
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
chunk_id = f"{doc_id}_{section_num}_{clause_num}_{start_pos}"
|
||||
|
||||
metadata = ChunkMetadata(
|
||||
@@ -437,7 +398,7 @@ def chunk_regulation_document(
|
||||
version: str = "",
|
||||
chunk_size: int = 512
|
||||
) -> List[TextChunk]:
|
||||
"""便捷函数:对法规文档进行分块"""
|
||||
"""Handle chunk regulation document."""
|
||||
chunker = RegulationChunker(chunk_size=chunk_size)
|
||||
return chunker.chunk_document(
|
||||
markdown_text,
|
||||
|
||||
@@ -1,14 +1,36 @@
|
||||
"""LLM服务模块"""
|
||||
"""Initialize the app.services.llm package."""
|
||||
|
||||
from .llm_factory import LLMFactory, get_llm_client
|
||||
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
|
||||
from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse
|
||||
from .deepseek_client import DeepSeekClient
|
||||
from .llm_factory import LLMFactory, get_llm_client
|
||||
from .qwen_client import QwenClient, QwenVLClient
|
||||
from .document_summarizer import DocumentSummarizer, summarize_document, DocumentSummary
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LLMFactory", "get_llm_client",
|
||||
"BaseLLMClient", "LLMResponse", "LLMConfig", "LLMProvider",
|
||||
"DeepSeekClient", "QwenClient", "QwenVLClient",
|
||||
"DocumentSummarizer", "summarize_document", "DocumentSummary"
|
||||
"LLMFactory",
|
||||
"get_llm_client",
|
||||
"BaseLLMClient",
|
||||
"LLMResponse",
|
||||
"LLMConfig",
|
||||
"LLMProvider",
|
||||
"DeepSeekClient",
|
||||
"QwenClient",
|
||||
"QwenVLClient",
|
||||
"DocumentSummarizer",
|
||||
"summarize_document",
|
||||
"DocumentSummary",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""Handle getattr for this module."""
|
||||
if name in {"DocumentSummarizer", "summarize_document", "DocumentSummary"}:
|
||||
from .document_summarizer import DocumentSummarizer, DocumentSummary, summarize_document
|
||||
|
||||
return {
|
||||
"DocumentSummarizer": DocumentSummarizer,
|
||||
"summarize_document": summarize_document,
|
||||
"DocumentSummary": DocumentSummary,
|
||||
}[name]
|
||||
raise AttributeError(name)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
"""LLM客户端基类 - 统一接口定义"""
|
||||
"""Provide service-layer logic for base client."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional, Any
|
||||
from enum import Enum
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
|
||||
|
||||
|
||||
class LLMProvider(Enum):
|
||||
"""LLM提供商"""
|
||||
"""Define the L L M Provider enumeration."""
|
||||
DEEPSEEK = "deepseek"
|
||||
QWEN = "qwen"
|
||||
QWEN_VL = "qwen_vl"
|
||||
@@ -15,7 +17,7 @@ class LLMProvider(Enum):
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""LLM响应结果"""
|
||||
"""Represent the L L M Response type."""
|
||||
content: str
|
||||
model: str
|
||||
usage: Dict[str, int] = field(default_factory=dict)
|
||||
@@ -25,12 +27,13 @@ class LLMResponse:
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
"""Return whether success for the L L M Response instance."""
|
||||
return self.error is None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""LLM配置"""
|
||||
"""Define configuration for l l m config."""
|
||||
provider: LLMProvider
|
||||
model: str
|
||||
api_key: str
|
||||
@@ -38,19 +41,20 @@ class LLMConfig:
|
||||
max_tokens: int = 4096
|
||||
temperature: float = 0.7
|
||||
top_p: float = 0.9
|
||||
timeout: int = 300 # 默认超时300秒(摘要/Skills生成可能需要较长时间)
|
||||
timeout: int = 300 # Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
|
||||
|
||||
class BaseLLMClient(ABC):
|
||||
"""LLM客户端基类"""
|
||||
"""Represent the Base L L M Client type."""
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
"""Initialize the Base L L M Client instance."""
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
@abstractmethod
|
||||
def _init_client(self):
|
||||
"""初始化客户端"""
|
||||
"""Handle init client for this module for the Base L L M Client instance."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -61,18 +65,7 @@ class BaseLLMClient(ABC):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
对话补全
|
||||
|
||||
Args:
|
||||
messages: 对话消息列表 [{"role": "user/assistant/system", "content": "..."}]
|
||||
max_tokens: 最大输出token数
|
||||
temperature: 温度参数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
LLMResponse: 响应结果
|
||||
"""
|
||||
"""Handle chat for the Base L L M Client instance."""
|
||||
pass
|
||||
|
||||
def complete(
|
||||
@@ -83,18 +76,7 @@ class BaseLLMClient(ABC):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
单轮补全(便捷方法)
|
||||
|
||||
Args:
|
||||
prompt: 用户输入
|
||||
system_prompt: 系统提示词
|
||||
max_tokens: 最大输出token数
|
||||
temperature: 温度参数
|
||||
|
||||
Returns:
|
||||
LLMResponse: 响应结果
|
||||
"""
|
||||
"""Handle complete for the Base L L M Client instance."""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
@@ -104,12 +86,12 @@ class BaseLLMClient(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
"""Return available models for the Base L L M Client instance."""
|
||||
pass
|
||||
|
||||
def estimate_tokens(self, text: str) -> int:
|
||||
"""估算文本token数(粗略估计)"""
|
||||
# 中文字符约1.5 token,英文约0.25 token
|
||||
"""Handle estimate tokens for the Base L L M Client instance."""
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
chinese_chars = sum(1 for c in text if '一' <= c <= '鿿')
|
||||
other_chars = len(text) - chinese_chars
|
||||
return int(chinese_chars * 1.5 + other_chars * 0.25)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""DeepSeek LLM客户端 - OpenAI兼容API"""
|
||||
"""Provide service-layer logic for deepseek client."""
|
||||
|
||||
import time
|
||||
from typing import List, Dict, Optional
|
||||
@@ -6,20 +6,12 @@ from loguru import logger
|
||||
import httpx
|
||||
|
||||
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
|
||||
|
||||
|
||||
class DeepSeekClient(BaseLLMClient):
|
||||
"""
|
||||
DeepSeek API客户端(OpenAI兼容格式)
|
||||
|
||||
支持模型:
|
||||
- deepseek-chat
|
||||
- deepseek-coder
|
||||
- deepseek-reasoner
|
||||
- deepseek-v3
|
||||
- deepseek-v3.2
|
||||
- deepseek-v4-flash
|
||||
"""
|
||||
"""Represent the Deep Seek Client type."""
|
||||
|
||||
SUPPORTED_MODELS = [
|
||||
"deepseek-chat",
|
||||
@@ -31,13 +23,14 @@ class DeepSeekClient(BaseLLMClient):
|
||||
]
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
"""Initialize the Deep Seek Client instance."""
|
||||
if config.provider != LLMProvider.DEEPSEEK:
|
||||
raise ValueError(f"配置provider应为DEEPSEEK,实际为{config.provider}")
|
||||
super().__init__(config)
|
||||
self._init_client()
|
||||
|
||||
def _init_client(self):
|
||||
"""初始化HTTP客户端"""
|
||||
"""Handle init client for this module for the Deep Seek Client instance."""
|
||||
self._client = httpx.Client(
|
||||
base_url=self.config.base_url,
|
||||
headers={
|
||||
@@ -55,7 +48,7 @@ class DeepSeekClient(BaseLLMClient):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""对话补全"""
|
||||
"""Handle chat for the Deep Seek Client instance."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
@@ -103,11 +96,11 @@ class DeepSeekClient(BaseLLMClient):
|
||||
)
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
"""Return available models for the Deep Seek Client instance."""
|
||||
return self.SUPPORTED_MODELS
|
||||
|
||||
def close(self):
|
||||
"""关闭客户端"""
|
||||
"""Release the resources held by this component."""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
|
||||
@@ -118,7 +111,7 @@ def create_deepseek_client(
|
||||
base_url: str = "http://6.86.80.4:30080/v1",
|
||||
**kwargs
|
||||
) -> DeepSeekClient:
|
||||
"""便捷函数:创建DeepSeek客户端"""
|
||||
"""Create deepseek client."""
|
||||
config = LLMConfig(
|
||||
provider=LLMProvider.DEEPSEEK,
|
||||
model=model,
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
"""文档摘要生成服务 - LLM生成法规文档摘要"""
|
||||
"""Provide service-layer logic for document summarizer."""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
|
||||
from app.services.llm import get_llm_client, BaseLLMClient
|
||||
from app.services.llm.base_client import BaseLLMClient
|
||||
from app.services.llm.llm_factory import get_llm_client
|
||||
from app.services.rag.prompt_templates import get_prompt_template
|
||||
from app.config.settings import settings
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentSummary:
|
||||
"""文档摘要结果"""
|
||||
"""Represent the Document Summary type."""
|
||||
doc_name: str
|
||||
summary: str
|
||||
applicable_scope: str
|
||||
@@ -24,24 +27,12 @@ class DocumentSummary:
|
||||
|
||||
@property
|
||||
def is_success(self) -> bool:
|
||||
"""Return whether success for the Document Summary instance."""
|
||||
return self.error is None
|
||||
|
||||
|
||||
class DocumentSummarizer:
|
||||
"""
|
||||
文档摘要生成器
|
||||
|
||||
功能:
|
||||
- 生成法规文档的核心要点摘要
|
||||
- 提取适用范围
|
||||
- 突出关键条款
|
||||
- 列出合规要点
|
||||
|
||||
使用示例:
|
||||
summarizer = DocumentSummarizer()
|
||||
result = summarizer.summarize("GB 7258-2017", markdown_content)
|
||||
print(result.summary)
|
||||
"""
|
||||
"""Represent the Document Summarizer type."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -49,25 +40,18 @@ class DocumentSummarizer:
|
||||
model: str = None,
|
||||
max_tokens: int = None
|
||||
):
|
||||
"""
|
||||
初始化摘要生成器
|
||||
|
||||
Args:
|
||||
provider: LLM提供商
|
||||
model: LLM模型名称
|
||||
max_tokens: 最大输出token数
|
||||
"""
|
||||
"""Initialize the Document Summarizer instance."""
|
||||
self.provider = provider or settings.llm_provider
|
||||
self.model = model or settings.llm_model
|
||||
self.max_tokens = max_tokens or settings.rag_summary_max_tokens
|
||||
|
||||
# LLM客户端(延迟加载)
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
self.llm: Optional[BaseLLMClient] = None
|
||||
|
||||
logger.info(f"摘要生成器初始化: provider={self.provider}, model={self.model}")
|
||||
|
||||
def _init_llm(self):
|
||||
"""延迟初始化LLM"""
|
||||
"""Handle init llm for this module for the Document Summarizer instance."""
|
||||
if self.llm is None:
|
||||
self.llm = get_llm_client(
|
||||
provider=self.provider,
|
||||
@@ -81,18 +65,7 @@ class DocumentSummarizer:
|
||||
regulation_type: str = "",
|
||||
max_tokens: Optional[int] = None
|
||||
) -> DocumentSummary:
|
||||
"""
|
||||
生成文档摘要
|
||||
|
||||
Args:
|
||||
doc_name: 文档名称
|
||||
content: 文档内容(Markdown格式)
|
||||
regulation_type: 法规类型
|
||||
max_tokens: 最大输出token数
|
||||
|
||||
Returns:
|
||||
DocumentSummary: 摘要结果
|
||||
"""
|
||||
"""Handle summarize for the Document Summarizer instance."""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
@@ -101,23 +74,23 @@ class DocumentSummarizer:
|
||||
try:
|
||||
self._init_llm()
|
||||
|
||||
# 使用摘要模板
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
template = get_prompt_template("document_summary")
|
||||
|
||||
# 构建用户消息
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
user_content = template.user_template.format(
|
||||
doc_name=doc_name,
|
||||
content=content[:8000] # 截取前8000字符(避免超出token限制)
|
||||
content=content[:8000] # Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
)
|
||||
|
||||
# 调用LLM
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
response = self.llm.chat(
|
||||
messages=[
|
||||
{"role": "system", "content": template.system_prompt},
|
||||
{"role": "user", "content": user_content}
|
||||
],
|
||||
max_tokens=max_tokens or self.max_tokens,
|
||||
temperature=0.3 # 低温度保证摘要准确性
|
||||
temperature=0.3 # Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
)
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
@@ -135,7 +108,7 @@ class DocumentSummarizer:
|
||||
error=response.error
|
||||
)
|
||||
|
||||
# 解析摘要结构
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
summary_data = self._parse_summary(response.content)
|
||||
|
||||
logger.success(f"摘要生成完成: {doc_name}, {latency_ms}ms")
|
||||
@@ -166,7 +139,7 @@ class DocumentSummarizer:
|
||||
)
|
||||
|
||||
def _parse_summary(self, content: str) -> Dict:
|
||||
"""解析摘要内容(提取结构化信息)"""
|
||||
"""Handle parse summary for this module for the Document Summarizer instance."""
|
||||
result = {
|
||||
"summary": content,
|
||||
"applicable_scope": "",
|
||||
@@ -175,26 +148,26 @@ class DocumentSummarizer:
|
||||
"compliance_points": []
|
||||
}
|
||||
|
||||
# 简单解析(提取关键信息)
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
lines = content.split("\n")
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
# 提取适用范围
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if "适用范围" in line or "适用对象" in line:
|
||||
result["applicable_scope"] = line.split(":")[-1].strip() if ":" in line else line.split(":")[-1].strip()
|
||||
|
||||
# 提取关键条款
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if line.startswith("- 【条款") or line.startswith("【条款"):
|
||||
result["key_clauses"].append(line)
|
||||
|
||||
# 提取关键术语
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if "关键术语" in line or "术语定义" in line:
|
||||
# 继续读取后续几行
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
pass
|
||||
|
||||
# 提取合规要点
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if "合规要点" in line or "必须满足" in line:
|
||||
pass
|
||||
|
||||
@@ -204,15 +177,7 @@ class DocumentSummarizer:
|
||||
self,
|
||||
documents: list
|
||||
) -> list:
|
||||
"""
|
||||
批量生成摘要
|
||||
|
||||
Args:
|
||||
documents: 文档列表 [{"doc_name": str, "content": str}, ...]
|
||||
|
||||
Returns:
|
||||
list: 摘要结果列表
|
||||
"""
|
||||
"""Handle batch summarize for the Document Summarizer instance."""
|
||||
results = []
|
||||
for doc in documents:
|
||||
result = self.summarize(doc["doc_name"], doc["content"])
|
||||
@@ -225,6 +190,6 @@ def summarize_document(
|
||||
content: str,
|
||||
**kwargs
|
||||
) -> DocumentSummary:
|
||||
"""便捷函数:生成文档摘要"""
|
||||
"""Handle summarize document."""
|
||||
summarizer = DocumentSummarizer(**kwargs)
|
||||
return summarizer.summarize(doc_name, content)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""LLM工厂 - 统一创建和管理LLM客户端"""
|
||||
"""Provide service-layer logic for llm factory."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
@@ -7,16 +7,18 @@ from functools import lru_cache
|
||||
from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse
|
||||
from .deepseek_client import DeepSeekClient
|
||||
from .qwen_client import QwenClient, QwenVLClient
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
|
||||
|
||||
# 默认模型映射
|
||||
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
DEFAULT_MODELS = {
|
||||
LLMProvider.DEEPSEEK: "deepseek-v4-flash",
|
||||
LLMProvider.QWEN: "qwen3.5-flash",
|
||||
LLMProvider.QWEN_VL: "qwen3-vl-plus"
|
||||
}
|
||||
|
||||
# API基础URL(使用统一代理服务)
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
DEFAULT_BASE_URLS = {
|
||||
LLMProvider.DEEPSEEK: "http://6.86.80.4:30080/v1",
|
||||
LLMProvider.QWEN: "http://6.86.80.4:30080/v1",
|
||||
@@ -25,31 +27,13 @@ DEFAULT_BASE_URLS = {
|
||||
|
||||
|
||||
class LLMFactory:
|
||||
"""
|
||||
LLM客户端工厂(支持全局缓存)
|
||||
"""Represent the L L M Factory type."""
|
||||
|
||||
支持的提供商和模型:
|
||||
- DeepSeek: deepseek-chat (DeepSeek-V3), deepseek-coder
|
||||
- Qwen: qwen-turbo, qwen-plus, qwen-max, qwen-long
|
||||
- QwenVL: qwen-vl-plus, qwen-vl-max (多模态)
|
||||
|
||||
使用示例:
|
||||
factory = LLMFactory()
|
||||
|
||||
# 使用默认配置
|
||||
client = factory.create("deepseek")
|
||||
|
||||
# 自定义配置
|
||||
client = factory.create("qwen", model="qwen-max", temperature=0.5)
|
||||
|
||||
# 调用LLM
|
||||
response = client.complete("你好,介绍一下自己")
|
||||
"""
|
||||
|
||||
# 全局客户端缓存(类级别,跨实例共享)
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
_global_instances: Dict[str, BaseLLMClient] = {}
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the L L M Factory instance."""
|
||||
self._config_cache: Dict[str, Any] = {}
|
||||
|
||||
def create(
|
||||
@@ -62,24 +46,10 @@ class LLMFactory:
|
||||
temperature: float = 0.7,
|
||||
**kwargs
|
||||
) -> BaseLLMClient:
|
||||
"""
|
||||
创建LLM客户端
|
||||
|
||||
Args:
|
||||
provider: 提供商名称 ("deepseek", "qwen", "qwen_vl")
|
||||
api_key: API密钥(如未提供,从环境变量获取)
|
||||
model: 模型名称(如未提供,使用默认模型)
|
||||
base_url: API基础URL
|
||||
max_tokens: 最大输出token数
|
||||
temperature: 温度参数
|
||||
**kwargs: 其他配置参数
|
||||
|
||||
Returns:
|
||||
BaseLLMClient: LLM客户端实例
|
||||
"""
|
||||
"""Handle create for the L L M Factory instance."""
|
||||
provider_enum = self._parse_provider(provider)
|
||||
|
||||
# 获取配置
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
api_key = api_key or self._get_api_key(provider_enum)
|
||||
model = model or DEFAULT_MODELS.get(provider_enum)
|
||||
base_url = base_url or DEFAULT_BASE_URLS.get(provider_enum)
|
||||
@@ -87,7 +57,7 @@ class LLMFactory:
|
||||
if not api_key:
|
||||
raise ValueError(f"缺少API密钥,请设置环境变量或传入api_key参数")
|
||||
|
||||
# 检查全局缓存
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
cache_key = f"{provider}_{model}"
|
||||
if cache_key in LLMFactory._global_instances:
|
||||
logger.debug(f"使用缓存的LLM客户端: {cache_key}")
|
||||
@@ -103,17 +73,17 @@ class LLMFactory:
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# 创建客户端
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
client = self._create_client(config)
|
||||
|
||||
# 缓存到全局实例
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
LLMFactory._global_instances[cache_key] = client
|
||||
|
||||
logger.info(f"LLM客户端创建成功并缓存: {provider} - {model}")
|
||||
return client
|
||||
|
||||
def _parse_provider(self, provider: str) -> LLMProvider:
|
||||
"""解析提供商名称"""
|
||||
"""Handle parse provider for this module for the L L M Factory instance."""
|
||||
provider_map = {
|
||||
"deepseek": LLMProvider.DEEPSEEK,
|
||||
"deepseek-v3": LLMProvider.DEEPSEEK,
|
||||
@@ -137,7 +107,7 @@ class LLMFactory:
|
||||
return provider_map[provider_lower]
|
||||
|
||||
def _get_api_key(self, provider: LLMProvider) -> Optional[str]:
|
||||
"""从环境变量获取API密钥"""
|
||||
"""Handle get api key for this module for the L L M Factory instance."""
|
||||
import os
|
||||
|
||||
key_map = {
|
||||
@@ -154,7 +124,7 @@ class LLMFactory:
|
||||
return None
|
||||
|
||||
def _create_client(self, config: LLMConfig) -> BaseLLMClient:
|
||||
"""创建具体客户端"""
|
||||
"""Handle create client for this module for the L L M Factory instance."""
|
||||
client_map = {
|
||||
LLMProvider.DEEPSEEK: DeepSeekClient,
|
||||
LLMProvider.QWEN: QwenClient,
|
||||
@@ -168,14 +138,14 @@ class LLMFactory:
|
||||
return client_class(config)
|
||||
|
||||
def get_cached(self, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
|
||||
"""获取缓存的客户端"""
|
||||
"""Return cached for the L L M Factory instance."""
|
||||
provider_enum = self._parse_provider(provider)
|
||||
model = model or DEFAULT_MODELS.get(provider_enum)
|
||||
cache_key = f"{provider}_{model}"
|
||||
return LLMFactory._global_instances.get(cache_key)
|
||||
|
||||
def list_available_providers(self) -> Dict[str, list]:
|
||||
"""列出可用的提供商和模型"""
|
||||
"""List available providers for the L L M Factory instance."""
|
||||
return {
|
||||
"deepseek": DeepSeekClient.SUPPORTED_MODELS,
|
||||
"qwen": QwenClient.SUPPORTED_MODELS,
|
||||
@@ -184,12 +154,7 @@ class LLMFactory:
|
||||
|
||||
@classmethod
|
||||
def preload_clients(cls, providers: list = None):
|
||||
"""
|
||||
预加载LLM客户端(应用启动时调用)
|
||||
|
||||
Args:
|
||||
providers: 要预加载的提供商列表,默认加载qwen和deepseek
|
||||
"""
|
||||
"""Handle preload clients for the L L M Factory instance."""
|
||||
if providers is None:
|
||||
providers = ["qwen", "deepseek"]
|
||||
|
||||
@@ -203,9 +168,9 @@ class LLMFactory:
|
||||
|
||||
@classmethod
|
||||
def get_global_client(cls, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
|
||||
"""获取全局缓存的客户端"""
|
||||
"""Return global client for the L L M Factory instance."""
|
||||
provider_lower = provider.lower()
|
||||
# 处理模型名作为provider的情况(如 qwen3.5-flash)
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if provider_lower.startswith("qwen"):
|
||||
provider_lower = "qwen"
|
||||
model = model or DEFAULT_MODELS.get(LLMProvider.QWEN if provider_lower == "qwen" else LLMProvider.DEEPSEEK)
|
||||
@@ -214,7 +179,7 @@ class LLMFactory:
|
||||
|
||||
@classmethod
|
||||
def cleanup(cls):
|
||||
"""清理所有缓存的客户端"""
|
||||
"""Handle cleanup for the L L M Factory instance."""
|
||||
for cache_key, client in cls._global_instances.items():
|
||||
try:
|
||||
client.close()
|
||||
@@ -227,7 +192,7 @@ class LLMFactory:
|
||||
|
||||
@lru_cache
|
||||
def get_llm_factory() -> LLMFactory:
|
||||
"""获取LLM工厂实例(缓存)"""
|
||||
"""Return llm factory."""
|
||||
return LLMFactory()
|
||||
|
||||
|
||||
@@ -236,20 +201,10 @@ def get_llm_client(
|
||||
model: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> BaseLLMClient:
|
||||
"""
|
||||
便捷函数:获取LLM客户端(优先使用缓存)
|
||||
|
||||
Args:
|
||||
provider: 提供商名称
|
||||
model: 模型名称
|
||||
**kwargs: 其他配置
|
||||
|
||||
Returns:
|
||||
BaseLLMClient: LLM客户端实例
|
||||
"""
|
||||
"""Return llm client."""
|
||||
factory = get_llm_factory()
|
||||
|
||||
# 先尝试获取缓存的实例
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
cached = factory.get_cached(provider, model)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Qwen LLM客户端 - 支持OpenAI兼容API格式"""
|
||||
"""Provide service-layer logic for qwen client."""
|
||||
|
||||
import time
|
||||
import json
|
||||
@@ -7,21 +7,12 @@ from loguru import logger
|
||||
import httpx
|
||||
|
||||
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
|
||||
|
||||
|
||||
class QwenClient(BaseLLMClient):
|
||||
"""
|
||||
Qwen API客户端(OpenAI兼容格式)
|
||||
|
||||
支持通过new-api等代理服务调用:
|
||||
- qwen-turbo
|
||||
- qwen-plus
|
||||
- qwen-max
|
||||
- qwen3.5-flash (推荐:快速响应)
|
||||
- qwen3.5-plus
|
||||
- qwen-long
|
||||
- qwen2.5系列
|
||||
"""
|
||||
"""Represent the Qwen Client type."""
|
||||
|
||||
SUPPORTED_MODELS = [
|
||||
"qwen-turbo",
|
||||
@@ -39,14 +30,15 @@ class QwenClient(BaseLLMClient):
|
||||
]
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
"""Initialize the Qwen Client instance."""
|
||||
if config.provider not in [LLMProvider.QWEN, LLMProvider.QWEN_VL]:
|
||||
raise ValueError(f"配置provider应为Qwen,实际为{config.provider}")
|
||||
super().__init__(config)
|
||||
self._init_client()
|
||||
|
||||
def _init_client(self):
|
||||
"""初始化HTTP客户端"""
|
||||
# OpenAI兼容API格式
|
||||
"""Handle init client for this module for the Qwen Client instance."""
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
self._client = httpx.Client(
|
||||
base_url=self.config.base_url,
|
||||
headers={
|
||||
@@ -64,11 +56,11 @@ class QwenClient(BaseLLMClient):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""对话补全(OpenAI兼容格式)"""
|
||||
"""Handle chat for the Qwen Client instance."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# OpenAI兼容格式的请求体
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
@@ -78,7 +70,7 @@ class QwenClient(BaseLLMClient):
|
||||
"stream": False
|
||||
}
|
||||
|
||||
# OpenAI兼容接口路径
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
response = self._client.post("/chat/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -86,7 +78,7 @@ class QwenClient(BaseLLMClient):
|
||||
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# OpenAI兼容格式的响应解析
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
choices = data.get("choices", [{}])
|
||||
message = choices[0].get("message", {})
|
||||
|
||||
@@ -121,42 +113,33 @@ class QwenClient(BaseLLMClient):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
流式对话补全(SSE格式)
|
||||
|
||||
Yields:
|
||||
str: 每次返回一个文本片段
|
||||
|
||||
使用示例:
|
||||
for chunk in client.stream_chat(messages):
|
||||
print(chunk, end="", flush=True)
|
||||
"""
|
||||
"""Stream chat for the Qwen Client instance."""
|
||||
try:
|
||||
# OpenAI兼容格式的请求体,启用流式输出
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens or self.config.max_tokens,
|
||||
"temperature": temperature or self.config.temperature,
|
||||
"top_p": kwargs.get("top_p", self.config.top_p),
|
||||
"stream": True # 启用流式输出
|
||||
"stream": True # Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
}
|
||||
|
||||
# 使用stream模式发送请求
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
with self._client.stream("POST", "/chat/completions", json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line = line.strip()
|
||||
# SSE格式: data: {...}
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:] # 移除 "data: " 前缀
|
||||
data_str = line[6:] # Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
continue # 跳过空的choices
|
||||
continue # Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
@@ -179,41 +162,27 @@ class QwenClient(BaseLLMClient):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
异步流式对话补全(用于FastAPI SSE响应)
|
||||
|
||||
Yields:
|
||||
str: 每次返回一个文本片段
|
||||
"""
|
||||
"""Handle async stream chat for the Qwen Client instance."""
|
||||
import asyncio
|
||||
|
||||
# 使用同步流式方法,包装为异步
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
for chunk in self.stream_chat(messages, max_tokens, temperature, **kwargs):
|
||||
yield chunk
|
||||
# 给async循环一个小延迟,让其他任务有机会执行
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
"""Return available models for the Qwen Client instance."""
|
||||
return self.SUPPORTED_MODELS
|
||||
|
||||
def close(self):
|
||||
"""关闭客户端"""
|
||||
"""Release the resources held by this component."""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
|
||||
|
||||
class QwenVLClient(BaseLLMClient):
|
||||
"""
|
||||
Qwen VL多模态客户端(OpenAI兼容格式)
|
||||
|
||||
支持模型:
|
||||
- qwen-vl-plus
|
||||
- qwen-vl-max
|
||||
- qwen3-vl-plus
|
||||
- qwen2-vl-7b-instruct
|
||||
- qwen2-vl-72b-instruct
|
||||
"""
|
||||
"""Represent the Qwen V L Client type."""
|
||||
|
||||
SUPPORTED_MODELS = [
|
||||
"qwen-vl-plus",
|
||||
@@ -224,13 +193,14 @@ class QwenVLClient(BaseLLMClient):
|
||||
]
|
||||
|
||||
def __init__(self, config: LLMConfig):
|
||||
"""Initialize the Qwen V L Client instance."""
|
||||
if config.provider != LLMProvider.QWEN_VL:
|
||||
raise ValueError(f"配置provider应为QWEN_VL,实际为{config.provider}")
|
||||
super().__init__(config)
|
||||
self._init_client()
|
||||
|
||||
def _init_client(self):
|
||||
"""初始化HTTP客户端"""
|
||||
"""Handle init client for this module for the Qwen V L Client instance."""
|
||||
self._client = httpx.Client(
|
||||
base_url=self.config.base_url,
|
||||
headers={
|
||||
@@ -248,21 +218,11 @@ class QwenVLClient(BaseLLMClient):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
"""多模态对话补全(OpenAI兼容格式)
|
||||
|
||||
支持图片输入,消息格式:
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
|
||||
{"type": "text", "text": "描述这张图片"}
|
||||
]
|
||||
}
|
||||
"""
|
||||
"""Handle chat for the Qwen V L Client instance."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# OpenAI兼容格式的请求体
|
||||
# Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
@@ -312,7 +272,7 @@ class QwenVLClient(BaseLLMClient):
|
||||
temperature: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
"""流式多模态对话补全"""
|
||||
"""Stream chat for the Qwen V L Client instance."""
|
||||
try:
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
@@ -335,7 +295,7 @@ class QwenVLClient(BaseLLMClient):
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
continue # 跳过空的choices
|
||||
continue # Keep provider-specific behavior explicit so debugging stays straightforward.
|
||||
delta = choices[0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
@@ -348,11 +308,11 @@ class QwenVLClient(BaseLLMClient):
|
||||
yield f"[ERROR: {str(e)}]"
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""获取可用模型列表"""
|
||||
"""Return available models for the Qwen V L Client instance."""
|
||||
return self.SUPPORTED_MODELS
|
||||
|
||||
def close(self):
|
||||
"""关闭客户端"""
|
||||
"""Release the resources held by this component."""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
|
||||
@@ -363,7 +323,7 @@ def create_qwen_client(
|
||||
base_url: str = "http://6.86.80.4:30080/v1",
|
||||
**kwargs
|
||||
) -> QwenClient:
|
||||
"""便捷函数:创建Qwen客户端"""
|
||||
"""Create qwen client."""
|
||||
config = LLMConfig(
|
||||
provider=LLMProvider.QWEN,
|
||||
model=model,
|
||||
@@ -380,7 +340,7 @@ def create_qwen_vl_client(
|
||||
base_url: str = "http://6.86.80.4:30080/v1",
|
||||
**kwargs
|
||||
) -> QwenVLClient:
|
||||
"""便捷函数:创建QwenVL客户端"""
|
||||
"""Create qwen vl client."""
|
||||
config = LLMConfig(
|
||||
provider=LLMProvider.QWEN_VL,
|
||||
model=model,
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""
|
||||
Mock数据服务 - 提供预设假数据供前后端对接测试
|
||||
"""
|
||||
"""Provide service-layer logic for mock data."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any
|
||||
import uuid
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
# 预设法规文档列表
|
||||
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_DOCUMENTS: List[Dict[str, Any]] = [
|
||||
{
|
||||
"id": "doc-001",
|
||||
@@ -45,7 +45,7 @@ MOCK_DOCUMENTS: List[Dict[str, Any]] = [
|
||||
},
|
||||
]
|
||||
|
||||
# 预设快捷问题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_QUICK_QUESTIONS: List[Dict[str, str]] = [
|
||||
{"id": "q1", "question": "电动自行车需要上牌照吗?", "category": "车辆登记"},
|
||||
{"id": "q2", "question": "新能源汽车有哪些补贴政策?", "category": "新能源"},
|
||||
@@ -53,7 +53,7 @@ MOCK_QUICK_QUESTIONS: List[Dict[str, str]] = [
|
||||
{"id": "q4", "question": "驾驶证过期了怎么处理?", "category": "驾驶证"},
|
||||
]
|
||||
|
||||
# 预设检索结果
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_RETRIEVAL_RESULTS: List[Dict[str, Any]] = [
|
||||
{
|
||||
"id": "chunk-001",
|
||||
@@ -97,7 +97,7 @@ MOCK_RETRIEVAL_RESULTS: List[Dict[str, Any]] = [
|
||||
},
|
||||
]
|
||||
|
||||
# 预设RAG问答答案模板(按关键词匹配)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_RAG_ANSWERS: Dict[str, Dict[str, Any]] = {
|
||||
"电动自行车": {
|
||||
"text": "根据《道路交通安全法》及相关规范,电动自行车上路需满足以下条件:\n\n1. 符合国家标准 GB17761-2018\n2. 经公安机关交通管理部门登记\n3. 最高设计车速不超过 25km/h\n4. 整车质量不超过 55kg\n5. 具有脚踏骑行能力\n6. 蓄电池标称电压不超过 48V\n\n行驶时还需佩戴安全头盔,不得逆向行驶或在机动车道内行驶。",
|
||||
@@ -133,7 +133,7 @@ MOCK_RAG_ANSWERS: Dict[str, Dict[str, Any]] = {
|
||||
},
|
||||
}
|
||||
|
||||
# 预设合规分析结果
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_COMPLIANCE_RESULT: Dict[str, Any] = {
|
||||
"task_id": "task-001",
|
||||
"dashboard": {
|
||||
@@ -310,7 +310,7 @@ MOCK_COMPLIANCE_RESULT: Dict[str, Any] = {
|
||||
],
|
||||
}
|
||||
|
||||
# 预设合规对话响应模板
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_COMPLIANCE_CHAT_RESPONSES: Dict[str, Dict[str, str]] = {
|
||||
"车身结构设计": {
|
||||
"compliance": "根据当前分析,车身结构设计部分存在以下合规问题:\n\n1. GB 26112-2010要求车顶承受1.5倍整备质量载荷,目前设计声明满足要求但缺少测试数据\n2. C-NCAP正面碰撞后车门应能打开,需提供碰撞测试报告\n\n建议补充相关测试数据以提升合规评分。",
|
||||
@@ -329,7 +329,7 @@ MOCK_COMPLIANCE_CHAT_RESPONSES: Dict[str, Dict[str, str]] = {
|
||||
},
|
||||
}
|
||||
|
||||
# 预设系统统计数据
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_SYSTEM_STATS: Dict[str, int] = {
|
||||
"docs": 5,
|
||||
"chunks": 510,
|
||||
@@ -337,7 +337,7 @@ MOCK_SYSTEM_STATS: Dict[str, int] = {
|
||||
"segments": 0,
|
||||
}
|
||||
|
||||
# 预设系统配置
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
MOCK_SYSTEM_CONFIG: Dict[str, Any] = {
|
||||
"llm": {
|
||||
"model": "qwen-max",
|
||||
@@ -358,17 +358,17 @@ MOCK_SYSTEM_CONFIG: Dict[str, Any] = {
|
||||
|
||||
|
||||
def get_mock_documents() -> List[Dict[str, Any]]:
|
||||
"""获取预设法规文档列表"""
|
||||
"""Return mock documents."""
|
||||
return MOCK_DOCUMENTS
|
||||
|
||||
|
||||
def get_mock_quick_questions() -> List[Dict[str, str]]:
|
||||
"""获取预设快捷问题"""
|
||||
"""Return mock quick questions."""
|
||||
return MOCK_QUICK_QUESTIONS
|
||||
|
||||
|
||||
def get_mock_retrieval(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""根据查询关键词返回预设检索结果"""
|
||||
"""Return mock retrieval."""
|
||||
results = []
|
||||
for keyword, data in MOCK_RAG_ANSWERS.items():
|
||||
if keyword in query:
|
||||
@@ -389,7 +389,7 @@ def get_mock_retrieval(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
|
||||
|
||||
def get_mock_rag_answer(query: str) -> str:
|
||||
"""根据查询关键词返回预设答案"""
|
||||
"""Return mock rag answer."""
|
||||
for keyword, data in MOCK_RAG_ANSWERS.items():
|
||||
if keyword in query:
|
||||
return data["text"]
|
||||
@@ -397,14 +397,14 @@ def get_mock_rag_answer(query: str) -> str:
|
||||
|
||||
|
||||
def get_mock_compliance_result(task_id: str) -> Dict[str, Any]:
|
||||
"""获取预设合规分析结果"""
|
||||
"""Return mock compliance result."""
|
||||
result = MOCK_COMPLIANCE_RESULT.copy()
|
||||
result["task_id"] = task_id
|
||||
return result
|
||||
|
||||
|
||||
def get_mock_compliance_chat_response(intent: str, query: str) -> str:
|
||||
"""获取预设合规对话响应"""
|
||||
"""Return mock compliance chat response."""
|
||||
responses = MOCK_COMPLIANCE_CHAT_RESPONSES.get(intent, {})
|
||||
if "合规" in query or "符合" in query:
|
||||
return responses.get("compliance", "根据相关法规分析,该段落的合规性需进一步评估。")
|
||||
@@ -416,10 +416,10 @@ def get_mock_compliance_chat_response(intent: str, query: str) -> str:
|
||||
|
||||
|
||||
def generate_task_id() -> str:
|
||||
"""生成任务ID"""
|
||||
"""Handle generate task id."""
|
||||
return f"task-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def generate_doc_id() -> str:
|
||||
"""生成文档ID"""
|
||||
"""Handle generate doc id."""
|
||||
return f"doc-{uuid.uuid4().hex[:8]}"
|
||||
@@ -1,6 +1,8 @@
|
||||
"""文档解析服务"""
|
||||
"""Initialize the app.services.parser package."""
|
||||
|
||||
from .pdf_parser import PDFParser
|
||||
from .docx_parser import DocxParser
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
|
||||
__all__ = ["PDFParser", "DocxParser"]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Word文档解析 - 使用python-docx"""
|
||||
"""Provide service-layer logic for docx parser."""
|
||||
|
||||
from docx import Document
|
||||
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
||||
@@ -6,27 +6,29 @@ from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
import re
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocxParagraph:
|
||||
"""段落内容"""
|
||||
"""Represent the Docx Paragraph type."""
|
||||
text: str
|
||||
level: int = 0 # 标题级别,0表示正文
|
||||
level: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
is_list: bool = False
|
||||
list_number: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocxTable:
|
||||
"""表格内容"""
|
||||
"""Represent the Docx Table type."""
|
||||
rows: List[List[str]]
|
||||
markdown: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocxDocumentContent:
|
||||
"""Word文档完整内容"""
|
||||
"""Represent the Docx Document Content type."""
|
||||
file_path: str
|
||||
paragraphs: List[DocxParagraph]
|
||||
tables: List[DocxTable]
|
||||
@@ -35,21 +37,14 @@ class DocxDocumentContent:
|
||||
|
||||
|
||||
class DocxParser:
|
||||
"""Word文档解析器 - 基于python-docx"""
|
||||
"""Provide the Docx Parser parser."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Docx Parser instance."""
|
||||
self.document = None
|
||||
|
||||
def parse(self, file_path: str) -> DocxDocumentContent:
|
||||
"""
|
||||
解析Word文档
|
||||
|
||||
Args:
|
||||
file_path: Word文档路径
|
||||
|
||||
Returns:
|
||||
DocxDocumentContent: 解析后的文档内容
|
||||
"""
|
||||
"""Handle parse for the Docx Parser instance."""
|
||||
logger.info(f"开始解析Word文档: {file_path}")
|
||||
|
||||
try:
|
||||
@@ -60,16 +55,16 @@ class DocxParser:
|
||||
tables=[]
|
||||
)
|
||||
|
||||
# 提取文档元数据
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
doc_content.metadata = self._extract_metadata()
|
||||
|
||||
# 提取段落
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
doc_content.paragraphs = self._extract_paragraphs()
|
||||
|
||||
# 提取表格
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
doc_content.tables = self._extract_tables()
|
||||
|
||||
# 生成Markdown格式文本
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
doc_content.markdown_text = self._generate_markdown(doc_content)
|
||||
|
||||
logger.success(f"Word文档解析完成,共{len(doc_content.paragraphs)}个段落")
|
||||
@@ -81,7 +76,7 @@ class DocxParser:
|
||||
raise
|
||||
|
||||
def _extract_metadata(self) -> Dict[str, str]:
|
||||
"""提取文档元数据"""
|
||||
"""Handle extract metadata for this module for the Docx Parser instance."""
|
||||
metadata = {}
|
||||
try:
|
||||
core_props = self.document.core_properties
|
||||
@@ -98,7 +93,7 @@ class DocxParser:
|
||||
return metadata
|
||||
|
||||
def _extract_paragraphs(self) -> List[DocxParagraph]:
|
||||
"""提取所有段落"""
|
||||
"""Handle extract paragraphs for this module for the Docx Parser instance."""
|
||||
paragraphs = []
|
||||
|
||||
for para in self.document.paragraphs:
|
||||
@@ -106,10 +101,10 @@ class DocxParser:
|
||||
if not text:
|
||||
continue
|
||||
|
||||
# 判断标题级别
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
level = self._get_paragraph_level(para)
|
||||
|
||||
# 判断是否是列表项
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
is_list, list_number = self._detect_list_item(para)
|
||||
|
||||
paragraph = DocxParagraph(
|
||||
@@ -123,66 +118,61 @@ class DocxParser:
|
||||
return paragraphs
|
||||
|
||||
def _get_paragraph_level(self, para) -> int:
|
||||
"""
|
||||
判断段落标题级别
|
||||
|
||||
Returns:
|
||||
int: 标题级别,0表示正文
|
||||
"""
|
||||
# 方法1:检查段落样式
|
||||
"""Handle get paragraph level for this module for the Docx Parser instance."""
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
style_name = para.style.name if para.style else ""
|
||||
|
||||
if "Heading" in style_name or "标题" in style_name:
|
||||
# 从样式名称中提取级别
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
match = re.search(r'Heading\s*(\d)|标题\s*(\d)', style_name)
|
||||
if match:
|
||||
level = int(match.group(1) or match.group(2))
|
||||
return level
|
||||
|
||||
# 方法2:检查段落格式(字号)
|
||||
# 标题通常字号较大
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if para.paragraph_format:
|
||||
# 可以根据字号判断,这里简化处理
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
pass
|
||||
|
||||
# 方法3:根据内容模式判断(法规文档特征)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
text = para.text.strip()
|
||||
|
||||
# 第一章、第X章 -> 二级标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if re.match(r'^第[一二三四五六七八九十百]+章\s', text):
|
||||
return 2
|
||||
# 第X节 -> 三级标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
elif re.match(r'^第[一二三四五六七八九十百]+节\s', text):
|
||||
return 3
|
||||
# 第X条 -> 四级标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
elif re.match(r'^第[一二三四五六七八九十百]+条\s', text):
|
||||
return 4
|
||||
|
||||
return 0 # 正文
|
||||
return 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
def _detect_list_item(self, para) -> tuple[bool, Optional[str]]:
|
||||
"""检测是否是列表项"""
|
||||
"""Handle detect list item for this module for the Docx Parser instance."""
|
||||
text = para.text.strip()
|
||||
|
||||
# 数字列表:1.、2.、(1)、[1]等
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if re.match(r'^[\d]+[.、)\]]\s', text):
|
||||
match = re.match(r'^([\d]+[.、)\]])\s', text)
|
||||
return True, match.group(1) if match else None
|
||||
|
||||
# 中文数字列表:一、二、(一)等
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if re.match(r'^[一二三四五六七八九十]+[、.)]\s', text):
|
||||
match = re.match(r'^([一二三四五六七八九十]+[、.)])\s', text)
|
||||
return True, match.group(1) if match else None
|
||||
|
||||
# 检查段落格式中的列表编号
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if para.paragraph_format and hasattr(para.paragraph_format, 'left_indent'):
|
||||
# 有缩进的可能是列表项
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
pass
|
||||
|
||||
return False, None
|
||||
|
||||
def _extract_tables(self) -> List[DocxTable]:
|
||||
"""提取所有表格"""
|
||||
"""Handle extract tables for this module for the Docx Parser instance."""
|
||||
tables = []
|
||||
|
||||
for table in self.document.tables:
|
||||
@@ -193,7 +183,7 @@ class DocxParser:
|
||||
cells.append(cell.text.strip())
|
||||
rows.append(cells)
|
||||
|
||||
# 转换为Markdown表格
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
markdown = self._table_to_markdown(rows)
|
||||
|
||||
table_content = DocxTable(rows=rows, markdown=markdown)
|
||||
@@ -202,34 +192,34 @@ class DocxParser:
|
||||
return tables
|
||||
|
||||
def _table_to_markdown(self, rows: List[List[str]]) -> str:
|
||||
"""将表格转换为Markdown格式"""
|
||||
"""Handle table to markdown for this module for the Docx Parser instance."""
|
||||
if not rows or len(rows) < 1:
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
|
||||
# 表头
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if len(rows) >= 1:
|
||||
header = rows[0]
|
||||
lines.append("| " + " | ".join(cell for cell in header) + " |")
|
||||
lines.append("| " + " | ".join("---" for _ in header) + " |")
|
||||
|
||||
# 数据行
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
for row in rows[1:]:
|
||||
lines.append("| " + " | ".join(cell for cell in row) + " |")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _generate_markdown(self, doc_content: DocxDocumentContent) -> str:
|
||||
"""生成Markdown格式文本"""
|
||||
"""Handle generate markdown for this module for the Docx Parser instance."""
|
||||
lines = []
|
||||
|
||||
# 文档标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
title = doc_content.metadata.get("title", "")
|
||||
if title:
|
||||
lines.append(f"# {title}\n")
|
||||
else:
|
||||
# 从第一个段落获取标题(如果是标题样式)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
for para in doc_content.paragraphs[:5]:
|
||||
if para.level == 1:
|
||||
lines.append(f"# {para.text}\n")
|
||||
@@ -237,29 +227,29 @@ class DocxParser:
|
||||
else:
|
||||
lines.append(f"# {doc_content.file_path}\n")
|
||||
|
||||
# 元数据信息
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
lines.append("\n## 文档信息\n")
|
||||
for key, value in doc_content.metadata.items():
|
||||
if value:
|
||||
lines.append(f"- **{key}**: {value}")
|
||||
|
||||
# 正文内容
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
lines.append("\n## 正文\n")
|
||||
|
||||
table_index = 0
|
||||
for para in doc_content.paragraphs:
|
||||
if para.level > 0:
|
||||
# 标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
prefix = "#" * para.level
|
||||
lines.append(f"\n{prefix} {para.text}\n")
|
||||
elif para.is_list:
|
||||
# 列表项
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
lines.append(f"- {para.text}")
|
||||
else:
|
||||
# 正文
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
lines.append(para.text)
|
||||
|
||||
# 添加表格
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if doc_content.tables:
|
||||
lines.append("\n## 表格\n")
|
||||
for i, table in enumerate(doc_content.tables):
|
||||
@@ -269,18 +259,18 @@ class DocxParser:
|
||||
return "\n".join(lines)
|
||||
|
||||
def parse_to_markdown(self, file_path: str) -> str:
|
||||
"""直接解析并返回Markdown文本"""
|
||||
"""Parse to markdown for the Docx Parser instance."""
|
||||
doc_content = self.parse(file_path)
|
||||
return doc_content.markdown_text
|
||||
|
||||
|
||||
def parse_docx(file_path: str) -> DocxDocumentContent:
|
||||
"""便捷函数:解析Word文档"""
|
||||
"""Parse docx."""
|
||||
parser = DocxParser()
|
||||
return parser.parse(file_path)
|
||||
|
||||
|
||||
def parse_docx_to_markdown(file_path: str) -> str:
|
||||
"""便捷函数:解析Word并返回Markdown"""
|
||||
"""Parse docx to markdown."""
|
||||
parser = DocxParser()
|
||||
return parser.parse_to_markdown(file_path)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
"""MinerU多模态PDF解析 - 版面感知解析"""
|
||||
"""Provide service-layer logic for mineru parser."""
|
||||
|
||||
from typing import Optional, Dict
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
import os
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinerUResult:
|
||||
"""MinerU解析结果"""
|
||||
"""Represent the Miner U Result type."""
|
||||
file_path: str
|
||||
markdown_text: str
|
||||
metadata: Dict[str, str] = field(default_factory=dict)
|
||||
@@ -17,21 +19,14 @@ class MinerUResult:
|
||||
|
||||
|
||||
class MinerUParser:
|
||||
"""
|
||||
MinerU多模态PDF解析器
|
||||
|
||||
MinerU (magic-pdf) 是一个开源的高质量PDF解析工具,
|
||||
支持版面感知解析,能够识别文档中的标题、正文、表格、图片等元素,
|
||||
并输出结构化的Markdown格式。
|
||||
|
||||
GitHub: https://github.com/opendatalab/MinerU
|
||||
"""
|
||||
"""Provide the Miner U Parser parser."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Miner U Parser instance."""
|
||||
self.available = self._check_mineru_available()
|
||||
|
||||
def _check_mineru_available(self) -> bool:
|
||||
"""检查MinerU是否可用"""
|
||||
"""Handle check mineru available for this module for the Miner U Parser instance."""
|
||||
try:
|
||||
from magic_pdf.pipe.UNIPipe import UNIPipe
|
||||
return True
|
||||
@@ -40,16 +35,7 @@ class MinerUParser:
|
||||
return False
|
||||
|
||||
def parse(self, file_path: str, output_dir: Optional[str] = None) -> MinerUResult:
|
||||
"""
|
||||
使用MinerU解析PDF文档
|
||||
|
||||
Args:
|
||||
file_path: PDF文件路径
|
||||
output_dir: 输出目录(可选,用于保存解析产物)
|
||||
|
||||
Returns:
|
||||
MinerUResult: 解析结果
|
||||
"""
|
||||
"""Handle parse for the Miner U Parser instance."""
|
||||
logger.info(f"尝试使用MinerU解析: {file_path}")
|
||||
|
||||
if not self.available:
|
||||
@@ -64,19 +50,19 @@ class MinerUParser:
|
||||
from magic_pdf.pipe.UNIPipe import UNIPipe
|
||||
from magic_pdf.libs.MakeContentConfig import DropMode
|
||||
|
||||
# 设置输出目录
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if output_dir is None:
|
||||
output_dir = os.path.dirname(file_path)
|
||||
|
||||
# 创建解析管道
|
||||
# OCR模式可以根据PDF类型选择
|
||||
# auto: 自动判断是否需要OCR
|
||||
# txt: 纯文本PDF(无OCR)
|
||||
# ocr: 扫描件PDF(OCR)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
pipe = UNIPipe(file_path, output_dir)
|
||||
|
||||
# 执行解析
|
||||
# pipe_mk() 返回Markdown格式文本
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
markdown_content = pipe.pipe_mk()
|
||||
|
||||
logger.success(f"MinerU解析成功")
|
||||
@@ -98,13 +84,13 @@ class MinerUParser:
|
||||
)
|
||||
|
||||
def _extract_metadata(self, pipe) -> Dict[str, str]:
|
||||
"""从解析管道提取元数据"""
|
||||
"""Handle extract metadata for this module for the Miner U Parser instance."""
|
||||
metadata = {}
|
||||
try:
|
||||
# MinerU解析管道中可能包含的元数据信息
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if hasattr(pipe, 'pdf_mid_data') and pipe.pdf_mid_data:
|
||||
mid_data = pipe.pdf_mid_data
|
||||
# 提取可能的元数据字段
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
metadata = {
|
||||
"page_count": str(mid_data.get("page_count", "")),
|
||||
"language": str(mid_data.get("language", "")),
|
||||
@@ -116,41 +102,27 @@ class MinerUParser:
|
||||
return metadata
|
||||
|
||||
def parse_to_markdown(self, file_path: str) -> str:
|
||||
"""直接解析并返回Markdown文本"""
|
||||
"""Parse to markdown for the Miner U Parser instance."""
|
||||
result = self.parse(file_path)
|
||||
return result.markdown_text if result.success else ""
|
||||
|
||||
|
||||
class ParserOrchestrator:
|
||||
"""
|
||||
解析服务编排 - 按优先级选择解析器
|
||||
|
||||
解析策略:
|
||||
1. 优先尝试MinerU(版面感知能力强)
|
||||
2. MinerU失败时回退到基础PyMuPDF解析
|
||||
"""
|
||||
"""Represent the Parser Orchestrator type."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Parser Orchestrator instance."""
|
||||
from .pdf_parser import PDFParser
|
||||
self.mineru_parser = MinerUParser()
|
||||
self.pdf_parser = PDFParser()
|
||||
self.mineru_available = self.mineru_parser.available
|
||||
|
||||
def parse_pdf(self, file_path: str, prefer_mineru: bool = True) -> str:
|
||||
"""
|
||||
解析PDF文档,按优先级选择解析器
|
||||
|
||||
Args:
|
||||
file_path: PDF文件路径
|
||||
prefer_mineru: 是否优先使用MinerU
|
||||
|
||||
Returns:
|
||||
str: Markdown格式文本
|
||||
"""
|
||||
"""Parse pdf for the Parser Orchestrator instance."""
|
||||
markdown_text = ""
|
||||
|
||||
if prefer_mineru and self.mineru_available:
|
||||
# 优先尝试MinerU
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
result = self.mineru_parser.parse(file_path)
|
||||
if result.success:
|
||||
markdown_text = result.markdown_text
|
||||
@@ -159,28 +131,20 @@ class ParserOrchestrator:
|
||||
else:
|
||||
logger.warning(f"MinerU解析失败,回退到PyMuPDF: {result.error_message}")
|
||||
|
||||
# 回退到PyMuPDF基础解析
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
logger.info("使用PyMuPDF基础解析")
|
||||
markdown_text = self.pdf_parser.parse_to_markdown(file_path)
|
||||
|
||||
return markdown_text
|
||||
|
||||
def parse_docx(self, file_path: str) -> str:
|
||||
"""解析Word文档"""
|
||||
"""Parse docx for the Parser Orchestrator instance."""
|
||||
from .docx_parser import DocxParser
|
||||
docx_parser = DocxParser()
|
||||
return docx_parser.parse_to_markdown(file_path)
|
||||
|
||||
def parse(self, file_path: str) -> str:
|
||||
"""
|
||||
根据文件类型选择解析器
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
str: Markdown格式文本
|
||||
"""
|
||||
"""Handle parse for the Parser Orchestrator instance."""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
if ext == ".pdf":
|
||||
@@ -192,12 +156,12 @@ class ParserOrchestrator:
|
||||
|
||||
|
||||
def parse_with_mineru(file_path: str) -> MinerUResult:
|
||||
"""便捷函数:使用MinerU解析"""
|
||||
"""Parse with mineru."""
|
||||
parser = MinerUParser()
|
||||
return parser.parse(file_path)
|
||||
|
||||
|
||||
def parse_pdf_smart(file_path: str) -> str:
|
||||
"""便捷函数:智能解析PDF(自动选择最佳解析器)"""
|
||||
"""Parse pdf smart."""
|
||||
orchestrator = ParserOrchestrator()
|
||||
return orchestrator.parse_pdf(file_path)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""PDF文档解析 - 使用PyMuPDF基础解析"""
|
||||
"""Provide service-layer logic for pdf parser."""
|
||||
|
||||
import fitz # PyMuPDF
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
@@ -9,17 +9,17 @@ import re
|
||||
|
||||
@dataclass
|
||||
class PDFPageContent:
|
||||
"""PDF页面内容"""
|
||||
"""Represent the P D F Page Content type."""
|
||||
page_number: int
|
||||
text: str
|
||||
tables: List[str] = field(default_factory=list)
|
||||
images: List[str] = field(default_factory=list) # 图片路径列表
|
||||
images: List[str] = field(default_factory=list) # Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
blocks: List[Dict] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PDFDocumentContent:
|
||||
"""PDF文档完整内容"""
|
||||
"""Represent the P D F Document Content type."""
|
||||
file_path: str
|
||||
total_pages: int
|
||||
pages: List[PDFPageContent]
|
||||
@@ -28,23 +28,14 @@ class PDFDocumentContent:
|
||||
|
||||
|
||||
class PDFParser:
|
||||
"""PDF文档解析器 - 基于PyMuPDF"""
|
||||
"""Provide the P D F Parser parser."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the P D F Parser instance."""
|
||||
self.pdf = None
|
||||
|
||||
def parse(self, file_path: str, extract_tables: bool = True, extract_images: bool = False) -> PDFDocumentContent:
|
||||
"""
|
||||
解析PDF文档
|
||||
|
||||
Args:
|
||||
file_path: PDF文件路径
|
||||
extract_tables: 是否提取表格
|
||||
extract_images: 是否提取图片
|
||||
|
||||
Returns:
|
||||
PDFDocumentContent: 解析后的文档内容
|
||||
"""
|
||||
"""Handle parse for the P D F Parser instance."""
|
||||
logger.info(f"开始解析PDF文档: {file_path}")
|
||||
|
||||
try:
|
||||
@@ -55,16 +46,16 @@ class PDFParser:
|
||||
pages=[]
|
||||
)
|
||||
|
||||
# 提取文档元数据
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
doc_content.metadata = self._extract_metadata()
|
||||
|
||||
# 逐页解析
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
for page_num in range(self.pdf.page_count):
|
||||
page = self.pdf[page_num]
|
||||
page_content = self._parse_page(page, page_num + 1, extract_tables, extract_images)
|
||||
doc_content.pages.append(page_content)
|
||||
|
||||
# 生成Markdown格式文本
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
doc_content.markdown_text = self._generate_markdown(doc_content)
|
||||
|
||||
self.pdf.close()
|
||||
@@ -77,7 +68,7 @@ class PDFParser:
|
||||
raise
|
||||
|
||||
def _extract_metadata(self) -> Dict[str, str]:
|
||||
"""提取PDF元数据"""
|
||||
"""Handle extract metadata for this module for the P D F Parser instance."""
|
||||
metadata = {}
|
||||
try:
|
||||
meta = self.pdf.metadata
|
||||
@@ -97,23 +88,23 @@ class PDFParser:
|
||||
|
||||
def _parse_page(self, page: fitz.Page, page_num: int,
|
||||
extract_tables: bool, extract_images: bool) -> PDFPageContent:
|
||||
"""解析单页内容"""
|
||||
"""Handle parse page for this module for the P D F Parser instance."""
|
||||
page_content = PDFPageContent(page_number=page_num, text="")
|
||||
|
||||
# 提取文本块(保留结构)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
blocks = page.get_text("dict", flags=fitz.TEXT_PRESERVE_WHITESPACE)["blocks"]
|
||||
page_content.blocks = blocks
|
||||
|
||||
# 提取纯文本
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
text = page.get_text("text", flags=fitz.TEXT_PRESERVE_WHITESPACE)
|
||||
page_content.text = text.strip()
|
||||
|
||||
# 提取表格(使用PyMuPDF的表格提取功能)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if extract_tables:
|
||||
tables = self._extract_tables_from_page(page)
|
||||
page_content.tables = tables
|
||||
|
||||
# 提取图片
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if extract_images:
|
||||
images = self._extract_images_from_page(page, page_num)
|
||||
page_content.images = images
|
||||
@@ -121,25 +112,22 @@ class PDFParser:
|
||||
return page_content
|
||||
|
||||
def _extract_tables_from_page(self, page: fitz.Page) -> List[str]:
|
||||
"""
|
||||
从页面提取表格(基于文本块分析)
|
||||
注意:PyMuPDF基础版表格提取能力有限,复杂表格建议使用MinerU
|
||||
"""
|
||||
"""Handle extract tables from page for this module for the P D F Parser instance."""
|
||||
tables = []
|
||||
|
||||
try:
|
||||
# 使用PyMuPDF的表格提取方法(2.4+版本)
|
||||
# 对于更复杂的表格,需要在mineru_parser中使用更高级的方法
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
tabs = page.find_tables()
|
||||
if tabs:
|
||||
for tab in tabs:
|
||||
table_text = tab.extract()
|
||||
# 将表格转换为Markdown格式
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
markdown_table = self._table_to_markdown(table_text)
|
||||
tables.append(markdown_table)
|
||||
|
||||
except AttributeError:
|
||||
# 旧版本PyMuPDF没有表格提取功能
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
logger.warning("PyMuPDF版本不支持表格提取,请升级到2.4+版本")
|
||||
except Exception as e:
|
||||
logger.warning(f"表格提取失败: {e}")
|
||||
@@ -147,28 +135,28 @@ class PDFParser:
|
||||
return tables
|
||||
|
||||
def _table_to_markdown(self, table_data: List[List[str]]) -> str:
|
||||
"""将表格数据转换为Markdown格式"""
|
||||
"""Handle table to markdown for this module for the P D F Parser instance."""
|
||||
if not table_data or len(table_data) < 1:
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
# 表头
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if len(table_data) >= 1:
|
||||
header = table_data[0]
|
||||
lines.append("| " + " | ".join(str(cell).strip() for cell in header) + " |")
|
||||
lines.append("| " + " | ".join("---" for _ in header) + " |")
|
||||
|
||||
# 数据行
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
for row in table_data[1:]:
|
||||
lines.append("| " + " | ".join(str(cell).strip() for cell in row) + " |")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _extract_images_from_page(self, page: fitz.Page, page_num: int) -> List[str]:
|
||||
"""提取页面图片"""
|
||||
"""Handle extract images from page for this module for the P D F Parser instance."""
|
||||
images = []
|
||||
# 图片提取功能(可选实现)
|
||||
# 这里仅记录图片信息,实际图片需要额外保存
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
try:
|
||||
image_list = page.get_images()
|
||||
for img_index, img in enumerate(image_list):
|
||||
@@ -179,52 +167,52 @@ class PDFParser:
|
||||
return images
|
||||
|
||||
def _generate_markdown(self, doc_content: PDFDocumentContent) -> str:
|
||||
"""生成Markdown格式文本"""
|
||||
"""Handle generate markdown for this module for the P D F Parser instance."""
|
||||
lines = []
|
||||
|
||||
# 文档标题
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
title = doc_content.metadata.get("title", "")
|
||||
if title:
|
||||
lines.append(f"# {title}\n")
|
||||
else:
|
||||
lines.append(f"# {doc_content.file_path}\n")
|
||||
|
||||
# 元数据信息
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
lines.append("\n## 文档信息\n")
|
||||
for key, value in doc_content.metadata.items():
|
||||
if value and key in ["author", "subject", "keywords", "creation_date"]:
|
||||
lines.append(f"- **{key}**: {value}")
|
||||
|
||||
# 正文内容
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
lines.append("\n## 正文\n")
|
||||
|
||||
for page in doc_content.pages:
|
||||
# 页码标记
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
lines.append(f"\n---\n**第 {page.page_number} 页**\n")
|
||||
|
||||
# 处理文本内容,识别标题结构
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
text = self._process_page_text(page.text, page.blocks)
|
||||
lines.append(text)
|
||||
|
||||
# 添加表格
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
for table in page.tables:
|
||||
lines.append("\n" + table + "\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _process_page_text(self, text: str, blocks: List[Dict]) -> str:
|
||||
"""处理页面文本,识别标题结构"""
|
||||
# 基于字体大小识别标题
|
||||
"""Handle process page text for this module for the P D F Parser instance."""
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
processed_text = text
|
||||
|
||||
# 尝试识别标题(基于字号)
|
||||
# 法规文档通常有明确的层级结构:章、节、条
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
processed_text = self._detect_headers(text, blocks)
|
||||
|
||||
return processed_text
|
||||
|
||||
def _detect_headers(self, text: str, blocks: List[Dict]) -> str:
|
||||
"""检测并标记标题(基于字号或内容模式)"""
|
||||
"""Handle detect headers for this module for the P D F Parser instance."""
|
||||
lines = text.split("\n")
|
||||
processed_lines = []
|
||||
|
||||
@@ -233,8 +221,8 @@ class PDFParser:
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# 法规标题模式检测
|
||||
# 第一章、第X章、第X节、第X条等
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if re.match(r'^第[一二三四五六七八九十百]+章\s', line):
|
||||
processed_lines.append(f"\n## {line}\n")
|
||||
elif re.match(r'^第[一二三四五六七八九十百]+节\s', line):
|
||||
@@ -242,7 +230,7 @@ class PDFParser:
|
||||
elif re.match(r'^第[一二三四五六七八九十百]+条\s', line):
|
||||
processed_lines.append(f"\n#### {line}\n")
|
||||
elif re.match(r'^[一二三四五六七八九十]+\s*[、.]', line):
|
||||
# 条款子项
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
processed_lines.append(f"- {line}")
|
||||
else:
|
||||
processed_lines.append(line)
|
||||
@@ -250,18 +238,18 @@ class PDFParser:
|
||||
return "\n".join(processed_lines)
|
||||
|
||||
def parse_to_markdown(self, file_path: str) -> str:
|
||||
"""直接解析并返回Markdown文本"""
|
||||
"""Parse to markdown for the P D F Parser instance."""
|
||||
doc_content = self.parse(file_path)
|
||||
return doc_content.markdown_text
|
||||
|
||||
|
||||
def parse_pdf(file_path: str, **kwargs) -> PDFDocumentContent:
|
||||
"""便捷函数:解析PDF文档"""
|
||||
"""Parse pdf."""
|
||||
parser = PDFParser()
|
||||
return parser.parse(file_path, **kwargs)
|
||||
|
||||
|
||||
def parse_pdf_to_markdown(file_path: str) -> str:
|
||||
"""便捷函数:解析PDF并返回Markdown"""
|
||||
"""Parse pdf to markdown."""
|
||||
parser = PDFParser()
|
||||
return parser.parse_to_markdown(file_path)
|
||||
|
||||
@@ -1,11 +1,29 @@
|
||||
"""RAG服务模块"""
|
||||
"""Initialize the app.services.rag package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
from .retriever import Retriever, retrieve_regulations
|
||||
from .context_builder import ContextBuilder, build_rag_context
|
||||
from .prompt_templates import PromptTemplates, get_prompt_template
|
||||
|
||||
__all__ = [
|
||||
"Retriever", "retrieve_regulations",
|
||||
"ContextBuilder", "build_rag_context",
|
||||
"PromptTemplates", "get_prompt_template"
|
||||
"Retriever",
|
||||
"retrieve_regulations",
|
||||
"ContextBuilder",
|
||||
"build_rag_context",
|
||||
"PromptTemplates",
|
||||
"get_prompt_template",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""Handle getattr for this module."""
|
||||
if name in {"Retriever", "retrieve_regulations"}:
|
||||
from .retriever import Retriever, retrieve_regulations
|
||||
|
||||
return {"Retriever": Retriever, "retrieve_regulations": retrieve_regulations}[name]
|
||||
if name in {"ContextBuilder", "build_rag_context"}:
|
||||
from .context_builder import ContextBuilder, build_rag_context
|
||||
|
||||
return {"ContextBuilder": ContextBuilder, "build_rag_context": build_rag_context}[name]
|
||||
if name in {"PromptTemplates", "get_prompt_template"}:
|
||||
from .prompt_templates import PromptTemplates, get_prompt_template
|
||||
|
||||
return {"PromptTemplates": PromptTemplates, "get_prompt_template": get_prompt_template}[name]
|
||||
raise AttributeError(name)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""RAG上下文构建服务 - 构建LLM输入上下文"""
|
||||
"""Provide service-layer logic for context builder."""
|
||||
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
@@ -6,11 +6,13 @@ from loguru import logger
|
||||
|
||||
from .retriever import RetrievedDocument
|
||||
from app.config.settings import settings
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAGContext:
|
||||
"""RAG构建的上下文"""
|
||||
"""Represent the R A G Context type."""
|
||||
system_prompt: str
|
||||
context_text: str
|
||||
user_query: str
|
||||
@@ -20,14 +22,7 @@ class RAGContext:
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
"""
|
||||
RAG上下文构建器
|
||||
|
||||
功能:
|
||||
- 格式化检索结果为上下文文本
|
||||
- 控制上下文长度(token限制)
|
||||
- 构建完整的LLM输入格式
|
||||
"""
|
||||
"""Provide the Context Builder builder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -35,14 +30,7 @@ class ContextBuilder:
|
||||
include_metadata: bool = True,
|
||||
citation_format: str = "【条款{clause}】"
|
||||
):
|
||||
"""
|
||||
初始化上下文构建器
|
||||
|
||||
Args:
|
||||
max_context_tokens: 最大上下文token数
|
||||
include_metadata: 是否包含元数据(文档名、条款号等)
|
||||
citation_format: 引用格式模板
|
||||
"""
|
||||
"""Initialize the Context Builder instance."""
|
||||
self.max_context_tokens = max_context_tokens or settings.rag_max_context_tokens
|
||||
self.include_metadata = include_metadata
|
||||
self.citation_format = citation_format
|
||||
@@ -56,30 +44,19 @@ class ContextBuilder:
|
||||
system_prompt: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None
|
||||
) -> RAGContext:
|
||||
"""
|
||||
构建RAG上下文
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
documents: 检索到的文档列表
|
||||
system_prompt: 系统提示词(可选)
|
||||
max_tokens: 最大token数(可选,覆盖默认值)
|
||||
|
||||
Returns:
|
||||
RAGContext: 构建的上下文对象
|
||||
"""
|
||||
"""Handle build for the Context Builder instance."""
|
||||
max_tokens = max_tokens or self.max_context_tokens
|
||||
|
||||
# 格式化文档内容
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
context_text, sources, truncated = self._format_documents(
|
||||
documents,
|
||||
max_tokens
|
||||
)
|
||||
|
||||
# 构建系统提示词
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
system_prompt = system_prompt or self._default_system_prompt()
|
||||
|
||||
# 估算总token数
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
total_tokens = self._estimate_tokens(system_prompt + context_text + query)
|
||||
|
||||
logger.info(f"上下文构建完成: {len(documents)}条文档, {total_tokens}tokens, truncated={truncated}")
|
||||
@@ -98,29 +75,20 @@ class ContextBuilder:
|
||||
documents: List[RetrievedDocument],
|
||||
max_tokens: int
|
||||
) -> tuple:
|
||||
"""
|
||||
格式化文档内容
|
||||
|
||||
Args:
|
||||
documents: 文档列表
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
(context_text, sources, truncated)
|
||||
"""
|
||||
"""Handle format documents for this module for the Context Builder instance."""
|
||||
context_parts = []
|
||||
sources = []
|
||||
current_tokens = 0
|
||||
truncated = False
|
||||
|
||||
for i, doc in enumerate(documents):
|
||||
# 格式化单个文档
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
formatted = self._format_single_doc(doc, i + 1)
|
||||
|
||||
# 估算token数
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
doc_tokens = self._estimate_tokens(formatted)
|
||||
|
||||
# 检查是否超出限制
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if current_tokens + doc_tokens > max_tokens:
|
||||
truncated = True
|
||||
logger.warning(f"上下文截断: 已达到{max_tokens}tokens限制")
|
||||
@@ -129,7 +97,7 @@ class ContextBuilder:
|
||||
context_parts.append(formatted)
|
||||
current_tokens += doc_tokens
|
||||
|
||||
# 记录来源
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
sources.append({
|
||||
"index": i + 1,
|
||||
"doc_id": doc.doc_id,
|
||||
@@ -148,13 +116,13 @@ class ContextBuilder:
|
||||
doc: RetrievedDocument,
|
||||
index: int
|
||||
) -> str:
|
||||
"""格式化单个文档"""
|
||||
"""Handle format single doc for this module for the Context Builder instance."""
|
||||
parts = []
|
||||
|
||||
# 索引编号
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
parts.append(f"[{index}]")
|
||||
|
||||
# 元数据(可选)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if self.include_metadata:
|
||||
meta_parts = []
|
||||
|
||||
@@ -171,13 +139,13 @@ class ContextBuilder:
|
||||
if meta_parts:
|
||||
parts.append(" | ".join(meta_parts))
|
||||
|
||||
# 内容
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
parts.append(doc.content)
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _default_system_prompt(self) -> str:
|
||||
"""默认系统提示词"""
|
||||
"""Handle default system prompt for this module for the Context Builder instance."""
|
||||
return """你是合规专家助手,基于提供的法规条款回答问题。
|
||||
|
||||
回答要求:
|
||||
@@ -192,8 +160,8 @@ class ContextBuilder:
|
||||
- 最后给出合规建议"""
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
"""估算文本token数"""
|
||||
# 中文字符约1.5 token,英文约0.25 token
|
||||
"""Handle estimate tokens for this module for the Context Builder instance."""
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
chinese_chars = sum(1 for c in text if '一' <= c <= '鿿')
|
||||
other_chars = len(text) - chinese_chars
|
||||
return int(chinese_chars * 1.5 + other_chars * 0.25)
|
||||
@@ -202,15 +170,7 @@ class ContextBuilder:
|
||||
self,
|
||||
context: RAGContext
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
构建LLM消息格式
|
||||
|
||||
Args:
|
||||
context: RAG上下文对象
|
||||
|
||||
Returns:
|
||||
List[Dict]: [{"role": "system/user/assistant", "content": "..."}]
|
||||
"""
|
||||
"""Build messages for the Context Builder instance."""
|
||||
messages = [
|
||||
{"role": "system", "content": context.system_prompt},
|
||||
{"role": "user", "content": f"参考以下法规条款回答问题。\n\n{context.context_text}\n\n问题:{context.user_query}"}
|
||||
@@ -224,6 +184,6 @@ def build_rag_context(
|
||||
documents: List[RetrievedDocument],
|
||||
**kwargs
|
||||
) -> RAGContext:
|
||||
"""便捷函数:构建RAG上下文"""
|
||||
"""Build rag context."""
|
||||
builder = ContextBuilder()
|
||||
return builder.build(query, documents, **kwargs)
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""RAG Prompt模板 - 合规问答专用Prompt"""
|
||||
"""Provide service-layer logic for prompt templates."""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptTemplate:
|
||||
"""Prompt模板"""
|
||||
"""Represent the Prompt Template type."""
|
||||
name: str
|
||||
system_prompt: str
|
||||
user_template: str
|
||||
@@ -14,18 +16,9 @@ class PromptTemplate:
|
||||
|
||||
|
||||
class PromptTemplates:
|
||||
"""
|
||||
合规问答Prompt模板库
|
||||
"""Represent the Prompt Templates type."""
|
||||
|
||||
包含多种场景的Prompt模板:
|
||||
- 合规问答(标准)
|
||||
- 条款解读(详细解释)
|
||||
- 合规检查(判断合规状态)
|
||||
- 差异对比(新旧法规对比)
|
||||
- 报告生成(合规报告)
|
||||
"""
|
||||
|
||||
# 合规问答标准模板
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
COMPLIANCE_QA = PromptTemplate(
|
||||
name="compliance_qa",
|
||||
system_prompt="""你是合规专家助手,专门解答法规合规问题。
|
||||
@@ -63,7 +56,7 @@ class PromptTemplates:
|
||||
description="标准合规问答模板"
|
||||
)
|
||||
|
||||
# 条款解读模板(详细解释)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
CLAUSE_INTERPRETATION = PromptTemplate(
|
||||
name="clause_interpretation",
|
||||
system_prompt="""你是法规解读专家,负责详细解释法规条款的含义和应用。
|
||||
@@ -96,7 +89,7 @@ class PromptTemplates:
|
||||
description="条款详细解读模板"
|
||||
)
|
||||
|
||||
# 合规检查模板(判断合规状态)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
COMPLIANCE_CHECK = PromptTemplate(
|
||||
name="compliance_check",
|
||||
system_prompt="""你是合规检查专家,负责评估企业行为或产品的合规状态。
|
||||
@@ -140,7 +133,7 @@ class PromptTemplates:
|
||||
description="合规检查评估模板"
|
||||
)
|
||||
|
||||
# 差异对比模板(新旧法规对比)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
COMPARISON = PromptTemplate(
|
||||
name="comparison",
|
||||
system_prompt="""你是法规变更分析专家,负责对比新旧法规版本的差异。
|
||||
@@ -192,7 +185,7 @@ class PromptTemplates:
|
||||
description="法规版本对比模板"
|
||||
)
|
||||
|
||||
# 报告生成模板
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
REPORT_GENERATION = PromptTemplate(
|
||||
name="report_generation",
|
||||
system_prompt="""你是合规报告撰写专家,负责生成结构化的合规分析报告。
|
||||
@@ -222,7 +215,7 @@ class PromptTemplates:
|
||||
description="合规报告生成模板"
|
||||
)
|
||||
|
||||
# 文档摘要生成模板
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
DOCUMENT_SUMMARY = PromptTemplate(
|
||||
name="document_summary",
|
||||
system_prompt="""你是法规文档摘要专家,负责生成法规文档的核心要点摘要。
|
||||
@@ -263,7 +256,7 @@ class PromptTemplates:
|
||||
|
||||
@classmethod
|
||||
def get_template(cls, name: str) -> Optional[PromptTemplate]:
|
||||
"""获取指定模板"""
|
||||
"""Return template for the Prompt Templates instance."""
|
||||
templates = {
|
||||
"compliance_qa": cls.COMPLIANCE_QA,
|
||||
"clause_interpretation": cls.CLAUSE_INTERPRETATION,
|
||||
@@ -276,7 +269,7 @@ class PromptTemplates:
|
||||
|
||||
@classmethod
|
||||
def list_templates(cls) -> Dict[str, str]:
|
||||
"""列出所有模板"""
|
||||
"""List templates for the Prompt Templates instance."""
|
||||
return {
|
||||
"compliance_qa": cls.COMPLIANCE_QA.description,
|
||||
"clause_interpretation": cls.CLAUSE_INTERPRETATION.description,
|
||||
@@ -288,7 +281,7 @@ class PromptTemplates:
|
||||
|
||||
|
||||
def get_prompt_template(name: str) -> PromptTemplate:
|
||||
"""便捷函数:获取Prompt模板"""
|
||||
"""Return prompt template."""
|
||||
template = PromptTemplates.get_template(name)
|
||||
if not template:
|
||||
raise ValueError(f"不存在的模板: {name}")
|
||||
|
||||
@@ -1,192 +1,82 @@
|
||||
"""RAG检索服务 - 封装Milvus检索"""
|
||||
"""Provide service-layer logic for retriever."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Dict, Optional, Any
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.shared.bootstrap import get_retrieval_service
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
from app.services.embedding.bge_m3_embedder import BGEM3Embedder
|
||||
from app.services.storage.milvus_client import MilvusClient, SearchResult
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievedDocument:
|
||||
"""检索到的文档"""
|
||||
"""Represent the Retrieved Document type."""
|
||||
content: str
|
||||
doc_id: str # 文档ID,用于下载
|
||||
doc_id: str
|
||||
doc_name: str
|
||||
section_title: str
|
||||
clause_number: str
|
||||
page_number: int
|
||||
score: float
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class Retriever:
|
||||
"""
|
||||
RAG检索器
|
||||
|
||||
功能:
|
||||
- 向量检索(Dense + Sparse混合)
|
||||
- 重排序(可选)
|
||||
- 过滤和筛选
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int = None,
|
||||
rerank: bool = False,
|
||||
min_score: float = 0.3
|
||||
):
|
||||
"""
|
||||
初始化检索器
|
||||
|
||||
Args:
|
||||
top_k: 检索召回数量
|
||||
rerank: 是否启用重排序
|
||||
min_score: 最低相关性分数阈值
|
||||
"""
|
||||
self.top_k = top_k or settings.rag_top_k
|
||||
"""Provide the Retriever retriever."""
|
||||
def __init__(self, top_k: int = 5, rerank: bool = False, min_score: float = 0.0):
|
||||
"""Initialize the Retriever instance."""
|
||||
self.top_k = top_k
|
||||
self.rerank = rerank
|
||||
self.min_score = min_score
|
||||
|
||||
# 嵌入模型(延迟加载)
|
||||
self.embedder: Optional[BGEM3Embedder] = None
|
||||
|
||||
# Milvus客户端(延迟连接)
|
||||
self.milvus: Optional[MilvusClient] = None
|
||||
|
||||
logger.info(f"检索器初始化: top_k={self.top_k}, rerank={self.rerank}")
|
||||
|
||||
def _init_embedder(self):
|
||||
"""延迟初始化嵌入模型"""
|
||||
if self.embedder is None:
|
||||
logger.info("加载嵌入模型...")
|
||||
self.embedder = BGEM3Embedder(model_name=settings.embedding_model)
|
||||
|
||||
def _init_milvus(self):
|
||||
"""延迟初始化Milvus"""
|
||||
if self.milvus is None:
|
||||
logger.info("连接Milvus...")
|
||||
self.milvus = MilvusClient()
|
||||
self.milvus.connect()
|
||||
self.milvus.create_collection(recreate=False)
|
||||
self.milvus.load_collection()
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[str] = None,
|
||||
top_k: Optional[int] = None
|
||||
) -> List[RetrievedDocument]:
|
||||
"""
|
||||
检索相关文档
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
filters: 过滤条件(如 "regulation_type=='车辆安全'")
|
||||
top_k: 返回数量(可选,覆盖默认值)
|
||||
|
||||
Returns:
|
||||
List[RetrievedDocument]: 检索结果列表
|
||||
"""
|
||||
logger.info(f"执行检索: {query}")
|
||||
|
||||
# 初始化组件
|
||||
self._init_embedder()
|
||||
self._init_milvus()
|
||||
|
||||
# 生成查询向量
|
||||
query_embedding = self.embedder.embed_single(query)
|
||||
|
||||
# 执行混合检索
|
||||
results = self.milvus.hybrid_search(
|
||||
query_dense=query_embedding['dense'].tolist(),
|
||||
query_sparse=query_embedding['sparse'],
|
||||
top_k=top_k or self.top_k,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
# 转换为RetrievedDocument格式
|
||||
documents = []
|
||||
for r in results:
|
||||
if r.score >= self.min_score:
|
||||
doc = RetrievedDocument(
|
||||
content=r.content,
|
||||
doc_id=r.metadata.get("doc_id", ""),
|
||||
doc_name=r.metadata.get("doc_name", ""),
|
||||
section_title=r.metadata.get("section_title", ""),
|
||||
clause_number=r.metadata.get("clause_number", ""),
|
||||
page_number=r.metadata.get("page_number", 0),
|
||||
score=r.score,
|
||||
metadata=r.metadata
|
||||
)
|
||||
documents.append(doc)
|
||||
|
||||
logger.success(f"检索完成,返回{len(documents)}条结果(阈值过滤后)")
|
||||
return documents
|
||||
|
||||
def retrieve_with_scores(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[str] = None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
检索并返回完整结果(包含分数)
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
filters: 过滤条件
|
||||
|
||||
Returns:
|
||||
List[Dict]: 包含分数的检索结果
|
||||
"""
|
||||
documents = self.retrieve(query, filters)
|
||||
def retrieve(self, query: str, filters: Optional[str] = None, top_k: Optional[int] = None) -> list[RetrievedDocument]:
|
||||
"""Handle retrieve for the Retriever instance."""
|
||||
results = get_retrieval_service().retrieve(query=query, top_k=top_k or self.top_k, filters=filters)
|
||||
return [
|
||||
{
|
||||
"content": doc.content,
|
||||
"doc_id": doc.doc_id,
|
||||
"doc_name": doc.doc_name,
|
||||
"section_title": doc.section_title,
|
||||
"clause_number": doc.clause_number,
|
||||
"page_number": doc.page_number,
|
||||
"score": doc.score
|
||||
}
|
||||
for doc in documents
|
||||
RetrievedDocument(
|
||||
content=item.content,
|
||||
doc_id=item.doc_id,
|
||||
doc_name=item.doc_name,
|
||||
section_title=item.section_title,
|
||||
clause_number=item.metadata.get("clause_number", ""),
|
||||
page_number=item.page_number,
|
||||
score=item.score,
|
||||
metadata=item.metadata,
|
||||
)
|
||||
for item in results
|
||||
if item.score >= self.min_score
|
||||
]
|
||||
|
||||
def search_by_doc_name(
|
||||
self,
|
||||
query: str,
|
||||
doc_name: str
|
||||
) -> List[RetrievedDocument]:
|
||||
"""按文档名称过滤检索"""
|
||||
filters = f'doc_name=="{doc_name}"'
|
||||
return self.retrieve(query, filters)
|
||||
def retrieve_with_scores(self, query: str, filters: Optional[str] = None) -> list[dict]:
|
||||
"""Handle retrieve with scores for the Retriever instance."""
|
||||
return [
|
||||
{
|
||||
"content": item.content,
|
||||
"doc_id": item.doc_id,
|
||||
"doc_name": item.doc_name,
|
||||
"section_title": item.section_title,
|
||||
"clause_number": item.clause_number,
|
||||
"page_number": item.page_number,
|
||||
"score": item.score,
|
||||
}
|
||||
for item in self.retrieve(query, filters)
|
||||
]
|
||||
|
||||
def search_by_regulation_type(
|
||||
self,
|
||||
query: str,
|
||||
regulation_type: str
|
||||
) -> List[RetrievedDocument]:
|
||||
"""按法规类型过滤检索"""
|
||||
filters = f'regulation_type=="{regulation_type}"'
|
||||
return self.retrieve(query, filters)
|
||||
def search_by_doc_name(self, query: str, doc_name: str) -> list[RetrievedDocument]:
|
||||
"""Search by doc name for the Retriever instance."""
|
||||
return self.retrieve(query, filters=f'doc_name == "{doc_name}"')
|
||||
|
||||
def search_by_regulation_type(self, query: str, regulation_type: str) -> list[RetrievedDocument]:
|
||||
"""Search by regulation type for the Retriever instance."""
|
||||
return self.retrieve(query, filters=f'regulation_type == "{regulation_type}"')
|
||||
|
||||
def close(self):
|
||||
"""关闭连接"""
|
||||
if self.milvus:
|
||||
self.milvus.disconnect()
|
||||
logger.info("检索器已关闭")
|
||||
"""Release the resources held by this component."""
|
||||
return None
|
||||
|
||||
|
||||
def retrieve_regulations(
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
filters: Optional[str] = None
|
||||
) -> List[RetrievedDocument]:
|
||||
"""便捷函数:检索法规"""
|
||||
retriever = Retriever(top_k=top_k)
|
||||
results = retriever.retrieve(query, filters)
|
||||
retriever.close()
|
||||
return results
|
||||
def retrieve_regulations(query: str, top_k: int = 10, filters: Optional[str] = None) -> list[RetrievedDocument]:
|
||||
"""Handle retrieve regulations."""
|
||||
return Retriever(top_k=top_k).retrieve(query, filters)
|
||||
|
||||
@@ -1,6 +1,18 @@
|
||||
"""存储服务"""
|
||||
"""Initialize the app.services.storage package."""
|
||||
# Keep package boundaries explicit so backend imports stay predictable.
|
||||
|
||||
from .milvus_client import MilvusClient
|
||||
from .minio_client import MinIOClient
|
||||
|
||||
__all__ = ["MilvusClient", "MinIOClient"]
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""Handle getattr for this module."""
|
||||
if name == "MilvusClient":
|
||||
from .milvus_client import MilvusClient
|
||||
|
||||
return MilvusClient
|
||||
if name == "MinIOClient":
|
||||
from .minio_client import MinIOClient
|
||||
|
||||
return MinIOClient
|
||||
raise AttributeError(name)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Milvus向量数据库客户端 - 存储与检索服务"""
|
||||
"""Provide service-layer logic for milvus client."""
|
||||
|
||||
from pymilvus import (
|
||||
connections,
|
||||
@@ -17,11 +17,13 @@ import numpy as np
|
||||
from ..embedding.text_chunker import TextChunk
|
||||
from ..embedding.bge_m3_embedder import EmbeddingResult
|
||||
from app.config.settings import settings
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""检索结果"""
|
||||
"""Represent the Search Result type."""
|
||||
id: int
|
||||
content: str
|
||||
score: float
|
||||
@@ -30,7 +32,7 @@ class SearchResult:
|
||||
|
||||
@dataclass
|
||||
class MilvusDocument:
|
||||
"""Milvus文档数据结构"""
|
||||
"""Represent the Milvus Document type."""
|
||||
doc_id: str
|
||||
chunk_id: str
|
||||
content: str
|
||||
@@ -46,7 +48,7 @@ class MilvusDocument:
|
||||
|
||||
|
||||
class MilvusClient:
|
||||
"""Milvus向量数据库客户端"""
|
||||
"""Represent the Milvus Client type."""
|
||||
|
||||
COLLECTION_NAME = "regulations"
|
||||
|
||||
@@ -73,6 +75,7 @@ class MilvusClient:
|
||||
collection_name: str = None,
|
||||
db_name: str = None
|
||||
):
|
||||
"""Initialize the Milvus Client instance."""
|
||||
self.host = host or settings.milvus_host
|
||||
self.port = port or settings.milvus_port
|
||||
self.collection_name = collection_name or settings.milvus_collection
|
||||
@@ -84,7 +87,7 @@ class MilvusClient:
|
||||
logger.info(f"Milvus客户端配置: {self.host}:{self.port}, Collection: {self.collection_name}")
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""连接到Milvus服务器"""
|
||||
"""Handle connect for the Milvus Client instance."""
|
||||
try:
|
||||
connections.connect(
|
||||
alias="default",
|
||||
@@ -101,7 +104,7 @@ class MilvusClient:
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""断开连接"""
|
||||
"""Handle disconnect for the Milvus Client instance."""
|
||||
try:
|
||||
connections.disconnect("default")
|
||||
self.connected = False
|
||||
@@ -110,7 +113,7 @@ class MilvusClient:
|
||||
logger.warning(f"断开连接时出错: {e}")
|
||||
|
||||
def create_collection(self, recreate: bool = False) -> bool:
|
||||
"""创建Collection"""
|
||||
"""Create collection for the Milvus Client instance."""
|
||||
if not self.connected:
|
||||
logger.warning("未连接到Milvus,请先调用connect()")
|
||||
return False
|
||||
@@ -146,7 +149,7 @@ class MilvusClient:
|
||||
return False
|
||||
|
||||
def _create_indexes(self):
|
||||
"""创建向量索引"""
|
||||
"""Handle create indexes for this module for the Milvus Client instance."""
|
||||
if not self.collection:
|
||||
return
|
||||
|
||||
@@ -177,13 +180,13 @@ class MilvusClient:
|
||||
logger.warning(f"创建索引时出错: {e}")
|
||||
|
||||
def load_collection(self):
|
||||
"""加载Collection到内存"""
|
||||
"""Load collection for the Milvus Client instance."""
|
||||
if self.collection:
|
||||
self.collection.load()
|
||||
logger.info(f"Collection已加载: {self.collection_name}")
|
||||
|
||||
def release_collection(self):
|
||||
"""释放Collection内存"""
|
||||
"""Handle release collection for the Milvus Client instance."""
|
||||
if self.collection:
|
||||
self.collection.release()
|
||||
logger.info(f"Collection已释放: {self.collection_name}")
|
||||
@@ -193,7 +196,7 @@ class MilvusClient:
|
||||
chunks: List[TextChunk],
|
||||
embeddings: EmbeddingResult
|
||||
) -> List[int]:
|
||||
"""插入文档分块和嵌入向量"""
|
||||
"""Handle insert chunks for the Milvus Client instance."""
|
||||
if not self.collection:
|
||||
logger.warning("Collection未初始化")
|
||||
return []
|
||||
@@ -246,7 +249,7 @@ class MilvusClient:
|
||||
top_k: int = 10,
|
||||
filters: Optional[str] = None
|
||||
) -> List[SearchResult]:
|
||||
"""混合检索:Dense + Sparse"""
|
||||
"""Handle hybrid search for the Milvus Client instance."""
|
||||
if not self.collection:
|
||||
logger.warning("Collection未初始化")
|
||||
return []
|
||||
@@ -254,10 +257,10 @@ class MilvusClient:
|
||||
try:
|
||||
self.collection.load()
|
||||
|
||||
# 使用简单的Dense检索(兼容所有版本)
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
dense_results = self.dense_search(query_dense, top_k, filters)
|
||||
|
||||
# 可选:合并Sparse结果
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
if query_sparse:
|
||||
sparse_results = self.sparse_search(query_sparse, top_k, filters)
|
||||
merged = self._merge_results(dense_results, sparse_results, top_k)
|
||||
@@ -277,7 +280,7 @@ class MilvusClient:
|
||||
top_k: int,
|
||||
dense_weight: float = 0.6
|
||||
) -> List[SearchResult]:
|
||||
"""手动融合Dense和Sparse结果"""
|
||||
"""Handle merge results for this module for the Milvus Client instance."""
|
||||
sparse_weight = 1 - dense_weight
|
||||
merged_dict = {}
|
||||
|
||||
@@ -318,7 +321,7 @@ class MilvusClient:
|
||||
top_k: int = 10,
|
||||
filters: Optional[str] = None
|
||||
) -> List[SearchResult]:
|
||||
"""纯Dense向量检索"""
|
||||
"""Handle dense search for the Milvus Client instance."""
|
||||
if not self.collection:
|
||||
return []
|
||||
|
||||
@@ -375,7 +378,7 @@ class MilvusClient:
|
||||
top_k: int = 10,
|
||||
filters: Optional[str] = None
|
||||
) -> List[SearchResult]:
|
||||
"""纯Sparse向量检索"""
|
||||
"""Handle sparse search for the Milvus Client instance."""
|
||||
if not self.collection:
|
||||
return []
|
||||
|
||||
@@ -427,7 +430,7 @@ class MilvusClient:
|
||||
return []
|
||||
|
||||
def delete_by_doc_id(self, doc_id: str) -> int:
|
||||
"""根据doc_id删除记录"""
|
||||
"""Delete by doc id for the Milvus Client instance."""
|
||||
if not self.collection:
|
||||
return 0
|
||||
|
||||
@@ -441,7 +444,7 @@ class MilvusClient:
|
||||
return 0
|
||||
|
||||
def get_collection_stats(self) -> Dict[str, Any]:
|
||||
"""获取Collection统计信息"""
|
||||
"""Return collection stats for the Milvus Client instance."""
|
||||
if not self.collection:
|
||||
return {}
|
||||
|
||||
@@ -458,7 +461,7 @@ class MilvusClient:
|
||||
|
||||
|
||||
def create_milvus_client() -> MilvusClient:
|
||||
"""便捷函数:创建Milvus客户端"""
|
||||
"""Create milvus client."""
|
||||
client = MilvusClient()
|
||||
client.connect()
|
||||
client.create_collection(recreate=False)
|
||||
@@ -470,7 +473,7 @@ def insert_documents(
|
||||
chunks: List[TextChunk],
|
||||
embeddings: EmbeddingResult
|
||||
) -> List[int]:
|
||||
"""便捷函数:插入文档"""
|
||||
"""Handle insert documents."""
|
||||
return client.insert_chunks(chunks, embeddings)
|
||||
|
||||
|
||||
@@ -480,5 +483,5 @@ def search_regulations(
|
||||
query_sparse: Dict[int, float],
|
||||
top_k: int = 10
|
||||
) -> List[SearchResult]:
|
||||
"""便捷函数:检索法规"""
|
||||
"""Search regulations."""
|
||||
return client.hybrid_search(query_dense, query_sparse, top_k)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""MinIO对象存储客户端 - 文档文件存储"""
|
||||
"""Provide service-layer logic for minio client."""
|
||||
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
@@ -8,10 +8,12 @@ from io import BytesIO
|
||||
import os
|
||||
|
||||
from app.config.settings import settings
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
|
||||
|
||||
|
||||
class MinIOClient:
|
||||
"""MinIO对象存储客户端"""
|
||||
"""Represent the Min I O Client type."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -21,16 +23,7 @@ class MinIOClient:
|
||||
bucket: str = None,
|
||||
secure: bool = None
|
||||
):
|
||||
"""
|
||||
初始化MinIO客户端
|
||||
|
||||
Args:
|
||||
endpoint: MinIO服务地址
|
||||
access_key: 访问密钥
|
||||
secret_key: 秘密密钥
|
||||
bucket: 存储桶名称
|
||||
secure: 是否使用HTTPS
|
||||
"""
|
||||
"""Initialize the Min I O Client instance."""
|
||||
self.endpoint = endpoint or settings.minio_endpoint
|
||||
self.access_key = access_key or settings.minio_access_key
|
||||
self.secret_key = secret_key or settings.minio_secret_key
|
||||
@@ -43,7 +36,7 @@ class MinIOClient:
|
||||
logger.info(f"MinIO客户端配置: {self.endpoint}, bucket={self.bucket}")
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""连接MinIO服务"""
|
||||
"""Handle connect for the Min I O Client instance."""
|
||||
try:
|
||||
self.client = Minio(
|
||||
self.endpoint,
|
||||
@@ -60,7 +53,7 @@ class MinIOClient:
|
||||
return False
|
||||
|
||||
def ensure_bucket(self) -> bool:
|
||||
"""确保存储桶存在"""
|
||||
"""Handle ensure bucket for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
logger.warning("未连接MinIO,请先调用connect()")
|
||||
return False
|
||||
@@ -82,17 +75,7 @@ class MinIOClient:
|
||||
object_name: str,
|
||||
metadata: Dict[str, Any] = None
|
||||
) -> bool:
|
||||
"""
|
||||
上传本地文件到MinIO
|
||||
|
||||
Args:
|
||||
file_path: 本地文件路径
|
||||
object_name: MinIO对象名称
|
||||
metadata: 元数据
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
"""Handle upload file for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
self.ensure_bucket()
|
||||
@@ -125,18 +108,7 @@ class MinIOClient:
|
||||
content_type: str = "application/octet-stream",
|
||||
metadata: Dict[str, Any] = None
|
||||
) -> bool:
|
||||
"""
|
||||
上传字节数据到MinIO
|
||||
|
||||
Args:
|
||||
data: 文件字节数据
|
||||
object_name: MinIO对象名称
|
||||
content_type: 内容类型
|
||||
metadata: 元数据(注意:MinIO仅支持US-ASCII字符)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
"""Handle upload bytes for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
self.ensure_bucket()
|
||||
@@ -144,18 +116,18 @@ class MinIOClient:
|
||||
try:
|
||||
data_stream = BytesIO(data)
|
||||
|
||||
# 处理metadata:仅保留ASCII安全字符
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
safe_metadata = None
|
||||
if metadata:
|
||||
safe_metadata = {}
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, str):
|
||||
# 只保留ASCII字符或转换为安全格式
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
try:
|
||||
value.encode('ascii')
|
||||
safe_metadata[key] = value
|
||||
except UnicodeEncodeError:
|
||||
# 中文字符跳过或用占位符
|
||||
# Keep service responsibilities explicit so downstream behavior stays predictable.
|
||||
safe_metadata[key] = ""
|
||||
else:
|
||||
safe_metadata[key] = str(value)
|
||||
@@ -181,16 +153,7 @@ class MinIOClient:
|
||||
object_name: str,
|
||||
file_path: str
|
||||
) -> bool:
|
||||
"""
|
||||
从MinIO下载文件到本地
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
file_path: 本地保存路径
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
"""Handle download file for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
@@ -212,16 +175,7 @@ class MinIOClient:
|
||||
object_name: str,
|
||||
expires: int = 3600
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
获取对象下载URL(临时URL)
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
expires: URL有效期(秒)
|
||||
|
||||
Returns:
|
||||
str: 下载URL
|
||||
"""
|
||||
"""Return object url for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
@@ -238,15 +192,7 @@ class MinIOClient:
|
||||
return None
|
||||
|
||||
def get_object_data(self, object_name: str) -> Optional[bytes]:
|
||||
"""
|
||||
获取对象数据(字节)
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
|
||||
Returns:
|
||||
bytes: 文件数据
|
||||
"""
|
||||
"""Return object data for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
@@ -262,15 +208,7 @@ class MinIOClient:
|
||||
return None
|
||||
|
||||
def delete_object(self, object_name: str) -> bool:
|
||||
"""
|
||||
删除对象
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
|
||||
Returns:
|
||||
bool: 是否成功
|
||||
"""
|
||||
"""Delete object for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
@@ -284,15 +222,7 @@ class MinIOClient:
|
||||
return False
|
||||
|
||||
def list_objects(self, prefix: str = "") -> list:
|
||||
"""
|
||||
列出存储桶中的对象
|
||||
|
||||
Args:
|
||||
prefix: 对象名称前缀
|
||||
|
||||
Returns:
|
||||
list: 对象列表
|
||||
"""
|
||||
"""List objects for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
@@ -305,15 +235,7 @@ class MinIOClient:
|
||||
return []
|
||||
|
||||
def object_exists(self, object_name: str) -> bool:
|
||||
"""
|
||||
检查对象是否存在
|
||||
|
||||
Args:
|
||||
object_name: MinIO对象名称
|
||||
|
||||
Returns:
|
||||
bool: 是否存在
|
||||
"""
|
||||
"""Handle object exists for the Min I O Client instance."""
|
||||
if not self.connected:
|
||||
self.connect()
|
||||
|
||||
@@ -325,7 +247,7 @@ class MinIOClient:
|
||||
return False
|
||||
|
||||
def _get_content_type(self, file_path: str) -> str:
|
||||
"""根据文件扩展名获取Content-Type"""
|
||||
"""Handle get content type for this module for the Min I O Client instance."""
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
content_types = {
|
||||
'.pdf': 'application/pdf',
|
||||
@@ -338,13 +260,13 @@ class MinIOClient:
|
||||
return content_types.get(ext, 'application/octet-stream')
|
||||
|
||||
def close(self):
|
||||
"""关闭连接(MinIO客户端无需显式关闭)"""
|
||||
"""Release the resources held by this component."""
|
||||
self.connected = False
|
||||
logger.info("MinIO客户端已关闭")
|
||||
|
||||
|
||||
def create_minio_client() -> MinIOClient:
|
||||
"""便捷函数:创建MinIO客户端"""
|
||||
"""Create minio client."""
|
||||
client = MinIOClient()
|
||||
client.connect()
|
||||
client.ensure_bucket()
|
||||
|
||||
Reference in New Issue
Block a user