Fix SSE route dependency and align architecture docs

This commit is contained in:
ash66
2026-05-18 16:32:42 +08:00
parent 86b9ac806a
commit 3f69cad404
149 changed files with 4786 additions and 5957 deletions

View File

@@ -1,3 +1,5 @@
"""Backend service package."""
# Keep package boundaries explicit so backend imports stay predictable.
__all__: list[str] = []

View File

@@ -1,6 +1,8 @@
"""Agent服务模块"""
"""Initialize the app.services.agent package."""
from .qa_agent import QAAgent, ask_compliance_question
from .session_manager import SessionManager, ChatSession
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = ["QAAgent", "ask_compliance_question", "SessionManager", "ChatSession"]

View File

@@ -1,21 +1,19 @@
"""RAG问答Agent - 合规智能问答核心实现"""
"""Provide service-layer logic for qa agent."""
from __future__ import annotations
import time
from typing import List, Dict, Optional, Any, Generator
from dataclasses import dataclass, field
from loguru import logger
from typing import Dict, Generator, List, Optional
from app.services.llm import get_llm_client, BaseLLMClient, LLMResponse
from app.services.llm.llm_factory import LLMFactory
from app.services.rag.retriever import Retriever, RetrievedDocument
from app.services.rag.context_builder import ContextBuilder, RAGContext
from app.services.rag.prompt_templates import get_prompt_template, PromptTemplate
from app.config.settings import settings
from app.shared.bootstrap import get_agent_conversation_service
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass
class AgentResponse:
"""Agent响应结果"""
"""Represent the Agent Response type."""
answer: str
sources: List[Dict] = field(default_factory=list)
model: str = ""
@@ -27,385 +25,73 @@ class AgentResponse:
@property
def is_success(self) -> bool:
"""Return whether success for the Agent Response instance."""
return self.error is None
@dataclass
class AgentConfig:
"""Agent配置"""
llm_provider: str = "deepseek"
llm_model: str = "deepseek-v4-flash"
top_k: int = 5
min_score: float = 0.3
max_context_tokens: int = 2000
temperature: float = 0.7
"""Define configuration for agent config."""
llm_provider: str = settings.llm_provider
llm_model: str = settings.llm_model
top_k: int = settings.rag_top_k
min_score: float = 0.0
max_context_tokens: int = settings.rag_max_context_tokens
temperature: float = settings.llm_temperature
prompt_template: str = "compliance_qa"
include_metadata: bool = True
class QAAgent:
"""
合规问答Agent
核心流程:
1. 接收用户问题
2. Milvus混合检索相关法规条款
3. 构建RAG上下文
4. 调用LLM生成回答
5. 返回答案和引用来源
使用示例:
agent = QAAgent()
response = agent.ask("机动车安全技术检验有哪些要求?")
print(response.answer)
for source in response.sources:
print(f"引用: {source['doc_name']} - {source['clause_number']}")
"""
"""Represent the Q A Agent type."""
def __init__(self, config: Optional[AgentConfig] = None):
"""
初始化问答Agent
Args:
config: Agent配置可选使用默认配置
"""
self.config = config or AgentConfig(
llm_provider=settings.llm_provider,
llm_model=settings.llm_model,
top_k=settings.rag_top_k,
max_context_tokens=settings.rag_max_context_tokens
)
# 初始化组件(延迟加载)
self.llm: Optional[BaseLLMClient] = None
self.retriever: Optional[Retriever] = None
self.context_builder: Optional[ContextBuilder] = None
logger.info(f"问答Agent初始化: provider={self.config.llm_provider}, model={self.config.llm_model}")
def _init_llm(self):
"""延迟初始化LLM客户端优先使用全局缓存"""
if self.llm is None:
# 尝试先获取全局缓存的客户端
cached = LLMFactory.get_global_client(self.config.llm_provider, self.config.llm_model)
if cached:
self.llm = cached
logger.debug(f"使用全局缓存的LLM客户端: {self.config.llm_provider} - {self.config.llm_model}")
else:
logger.info("创建新的LLM客户端...")
self.llm = get_llm_client(
provider=self.config.llm_provider,
model=self.config.llm_model,
temperature=self.config.temperature
)
def _init_retriever(self):
"""延迟初始化检索器"""
if self.retriever is None:
logger.info("初始化检索器...")
self.retriever = Retriever(
top_k=self.config.top_k,
min_score=self.config.min_score
)
def _init_context_builder(self):
"""延迟初始化上下文构建器"""
if self.context_builder is None:
logger.info("初始化上下文构建器...")
self.context_builder = ContextBuilder(
max_context_tokens=self.config.max_context_tokens,
include_metadata=self.config.include_metadata
)
"""Initialize the Q A Agent instance."""
self.config = config or AgentConfig()
def ask(
self,
query: str,
filters: Optional[str] = None,
prompt_template: Optional[str] = None
prompt_template: Optional[str] = None,
) -> AgentResponse:
"""
回答用户问题
Args:
query: 用户问题
filters: 检索过滤条件(如 "regulation_type=='车辆安全'"
prompt_template: Prompt模板名称可选覆盖默认配置
Returns:
AgentResponse: 包含答案和引用来源的响应对象
"""
start_time = time.time()
logger.info(f"收到问题: {query}")
try:
# Step 1: 检索相关法规
self._init_retriever()
documents = self.retriever.retrieve(query, filters)
retrieved_count = len(documents)
if retrieved_count == 0:
return AgentResponse(
answer="抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问,或提供更具体的法规名称。",
retrieved_count=0,
error="no_retrieved_documents"
)
# Step 2: 构建RAG上下文
self._init_context_builder()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
# Step 3: 构建LLM输入消息
messages = self._build_messages(template, context)
# Step 4: 调用LLM生成回答
self._init_llm()
llm_response = self.llm.chat(
messages=messages,
temperature=self.config.temperature
)
if not llm_response.is_success:
return AgentResponse(
answer="",
retrieved_count=retrieved_count,
error=llm_response.error
)
latency_ms = int((time.time() - start_time) * 1000)
# Step 5: 返回结果
logger.success(f"问答完成: {latency_ms}ms, {retrieved_count}条引用")
return AgentResponse(
answer=llm_response.content,
sources=context.sources,
model=llm_response.model,
latency_ms=latency_ms,
retrieved_count=retrieved_count,
context_tokens=context.total_tokens,
truncated=context.truncated
)
except Exception as e:
logger.error(f"问答失败: {e}")
return AgentResponse(
answer="",
error=str(e)
)
def ask_with_context(
self,
query: str,
documents: List[RetrievedDocument],
prompt_template: Optional[str] = None
) -> AgentResponse:
"""
使用提供的文档回答问题(不执行检索)
Args:
query: 用户问题
documents: 已检索的文档列表
prompt_template: Prompt模板名称
Returns:
AgentResponse: 响应结果
"""
start_time = time.time()
try:
self._init_context_builder()
self._init_llm()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
messages = self._build_messages(template, context)
llm_response = self.llm.chat(messages)
latency_ms = int((time.time() - start_time) * 1000)
return AgentResponse(
answer=llm_response.content,
sources=context.sources,
model=llm_response.model,
latency_ms=latency_ms,
retrieved_count=len(documents),
context_tokens=context.total_tokens,
truncated=context.truncated
)
except Exception as e:
logger.error(f"问答失败: {e}")
return AgentResponse(answer="", error=str(e))
def _build_messages(
self,
template: PromptTemplate,
context: RAGContext
) -> List[Dict[str, str]]:
"""构建LLM输入消息"""
user_content = template.user_template.format(
context=context.context_text,
query=context.user_query
"""Handle ask for the Q A Agent instance."""
_, result = get_agent_conversation_service().ask(
query=query,
filters=filters,
provider=self.config.llm_provider,
model=self.config.llm_model,
top_k=self.config.top_k,
prompt_template=prompt_template or self.config.prompt_template,
)
return AgentResponse(
answer=result.answer,
sources=[source.__dict__ for source in result.sources],
model=result.model,
latency_ms=result.latency_ms,
retrieved_count=result.retrieved_count,
context_tokens=result.context_tokens,
truncated=result.truncated,
error=result.error,
)
return [
{"role": "system", "content": template.system_prompt},
{"role": "user", "content": user_content}
]
def ask_stream(
self,
query: str,
filters: Optional[str] = None,
prompt_template: Optional[str] = None
) -> Generator[Dict[str, Any], None, None]:
"""
流式回答用户问题SSE模式
返回事件类型:
- {"event": "status", "data": "正在检索..."} - 状态更新
- {"event": "sources", "data": [...]} - 引用来源
- {"event": "content", "data": "文本片段"} - 回答内容
- {"event": "done", "data": {"latency_ms": ..., "model": ...}} - 完成
Args:
query: 用户问题
filters: 检索过滤条件
prompt_template: Prompt模板名称
Yields:
Dict: SSE事件数据
"""
start_time = time.time()
logger.info(f"收到流式问题: {query}")
try:
# Step 1: 检索相关法规
yield {"event": "status", "data": "正在检索相关法规..."}
self._init_retriever()
documents = self.retriever.retrieve(query, filters)
retrieved_count = len(documents)
if retrieved_count == 0:
yield {"event": "status", "data": "未找到相关法规"}
yield {"event": "content", "data": "抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问。"}
yield {"event": "done", "data": {"latency_ms": 0, "retrieved_count": 0}}
return
# Step 2: 发送检索结果
yield {"event": "status", "data": f"找到{retrieved_count}条相关法规,正在生成回答..."}
sources = [
{
"doc_name": doc.doc_name,
"doc_id": doc.doc_id,
"clause_number": doc.clause_number,
"score": doc.score
}
for doc in documents[:5] # 只返回前5条引用
]
yield {"event": "sources", "data": sources}
# Step 3: 构建RAG上下文
self._init_context_builder()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
# Step 4: 构建LLM输入消息
messages = self._build_messages(template, context)
# Step 5: 流式调用LLM生成回答
self._init_llm()
full_answer = ""
# 检查LLM是否支持流式输出
if hasattr(self.llm, 'stream_chat'):
yield {"event": "status", "data": "思考中..."}
for chunk in self.llm.stream_chat(
messages=messages,
temperature=self.config.temperature
):
full_answer += chunk
yield {"event": "content", "data": chunk}
else:
# 如果不支持流式,回退到普通调用
yield {"event": "status", "data": "生成回答中..."}
llm_response = self.llm.chat(
messages=messages,
temperature=self.config.temperature
)
if llm_response.is_success:
full_answer = llm_response.content
yield {"event": "content", "data": full_answer}
# Step 6: 发送完成事件
latency_ms = int((time.time() - start_time) * 1000)
logger.success(f"流式问答完成: {latency_ms}ms, {retrieved_count}条引用")
yield {
"event": "done",
"data": {
"latency_ms": latency_ms,
"model": self.config.llm_model,
"retrieved_count": retrieved_count,
"context_tokens": context.total_tokens
}
}
except Exception as e:
logger.error(f"流式问答失败: {e}")
yield {"event": "error", "data": str(e)}
def ask_stream(self, query: str, filters: Optional[str] = None) -> Generator[dict, None, None]:
"""Handle ask stream for the Q A Agent instance."""
_, stream = get_agent_conversation_service().stream_chat(
query=query,
filters=filters,
provider=self.config.llm_provider,
model=self.config.llm_model,
top_k=self.config.top_k,
prompt_template=self.config.prompt_template,
)
for event in stream:
yield event
def close(self):
"""关闭Agent资源不关闭LLM客户端因为它全局缓存"""
if self.retriever:
self.retriever.close()
logger.info("问答Agent已关闭")
"""Release the resources held by this component."""
return None
def ask_compliance_question(
query: str,
provider: str = "deepseek",
model: str = "deepseek-v4-flash",
top_k: int = 10
) -> AgentResponse:
"""
便捷函数:问答合规问题
Args:
query: 用户问题
provider: LLM提供商
model: LLM模型
top_k: 检索数量
Returns:
AgentResponse: 响应结果
"""
config = AgentConfig(
llm_provider=provider,
llm_model=model,
top_k=top_k
)
agent = QAAgent(config)
response = agent.ask(query)
agent.close()
return response
def ask_compliance_question(query: str, top_k: int = 5) -> AgentResponse:
"""Handle ask compliance question."""
return QAAgent(AgentConfig(top_k=top_k)).ask(query)

View File

@@ -1,4 +1,4 @@
"""多轮对话会话管理"""
"""Provide service-layer logic for session manager."""
import time
import uuid
@@ -9,7 +9,7 @@ from loguru import logger
@dataclass
class ChatMessage:
"""对话消息"""
"""Represent the Chat Message type."""
role: str # "user" / "assistant" / "system"
content: str
timestamp: int
@@ -19,7 +19,7 @@ class ChatMessage:
@dataclass
class ChatSession:
"""对话会话"""
"""Represent the Chat Session type."""
session_id: str
messages: List[ChatMessage] = field(default_factory=list)
created_at: int = field(default_factory=lambda: int(time.time()))
@@ -27,7 +27,7 @@ class ChatSession:
metadata: Dict = field(default_factory=dict)
def add_user_message(self, content: str) -> ChatMessage:
"""添加用户消息"""
"""Handle add user message for the Chat Session instance."""
message = ChatMessage(
role="user",
content=content,
@@ -42,7 +42,7 @@ class ChatSession:
content: str,
sources: List[Dict] = None
) -> ChatMessage:
"""添加助手消息"""
"""Handle add assistant message for the Chat Session instance."""
message = ChatMessage(
role="assistant",
content=content,
@@ -54,9 +54,9 @@ class ChatSession:
return message
def get_history(self, max_turns: int = 5) -> List[Dict[str, str]]:
"""获取历史对话用于LLM上下文"""
"""Return history for the Chat Session instance."""
history = []
# 获取最近N轮对话每轮包含user + assistant
# Keep service responsibilities explicit so downstream behavior stays predictable.
recent_messages = self.messages[-(max_turns * 2):]
for msg in recent_messages:
@@ -68,81 +68,47 @@ class ChatSession:
return history
def clear_history(self):
"""清空对话历史"""
"""Handle clear history for the Chat Session instance."""
self.messages = []
self.updated_at = int(time.time())
logger.info(f"会话历史已清空: {self.session_id}")
@property
def message_count(self) -> int:
"""消息数量"""
"""Handle message count for the Chat Session instance."""
return len(self.messages)
@property
def is_empty(self) -> bool:
"""是否为空会话"""
"""Return whether empty for the Chat Session instance."""
return len(self.messages) == 0
class SessionManager:
"""
会话管理器
功能:
- 创建/获取/删除会话
- 会话超时清理
- 会话历史记录管理
使用示例:
manager = SessionManager()
# 创建会话
session = manager.create_session()
# 添加消息
session.add_user_message("什么是机动车安全技术检验?")
session.add_assistant_message("根据GB 7258...", sources=[...])
# 获取历史用于LLM多轮对话
history = session.get_history(max_turns=3)
"""
"""Represent the Session Manager type."""
def __init__(
self,
max_sessions: int = 100,
session_timeout_minutes: int = 30
):
"""
初始化会话管理器
Args:
max_sessions: 最大会话数量
session_timeout_minutes: 会话超时时间(分钟)
"""
"""Initialize the Session Manager instance."""
self.max_sessions = max_sessions
self.session_timeout = session_timeout_minutes * 60
# 会话存储(内存)
# Keep service responsibilities explicit so downstream behavior stays predictable.
self._sessions: Dict[str, ChatSession] = {}
logger.info(f"会话管理器初始化: max_sessions={max_sessions}, timeout={session_timeout_minutes}min")
def create_session(self, metadata: Dict = None) -> ChatSession:
"""
创建新会话
Args:
metadata: 会话元数据(可选)
Returns:
ChatSession: 新创建的会话
"""
# 检查会话数量限制
"""Create session for the Session Manager instance."""
# Keep service responsibilities explicit so downstream behavior stays predictable.
if len(self._sessions) >= self.max_sessions:
# 清理过期会话
# Keep service responsibilities explicit so downstream behavior stays predictable.
self._cleanup_expired_sessions()
# 如果仍然超出限制,删除最老的会话
# Keep service responsibilities explicit so downstream behavior stays predictable.
if len(self._sessions) >= self.max_sessions:
oldest_id = min(
self._sessions.keys(),
@@ -163,19 +129,11 @@ class SessionManager:
return session
def get_session(self, session_id: str) -> Optional[ChatSession]:
"""
获取会话
Args:
session_id: 会话ID
Returns:
ChatSession: 会话对象如不存在返回None
"""
"""Return session for the Session Manager instance."""
session = self._sessions.get(session_id)
if session:
# 检查是否过期
# Keep service responsibilities explicit so downstream behavior stays predictable.
if self._is_session_expired(session):
self.delete_session(session_id)
logger.info(f"会话已过期,已删除: {session_id}")
@@ -184,15 +142,7 @@ class SessionManager:
return session
def delete_session(self, session_id: str) -> bool:
"""
删除会话
Args:
session_id: 会话ID
Returns:
bool: 是否成功删除
"""
"""Delete session for the Session Manager instance."""
if session_id in self._sessions:
del self._sessions[session_id]
logger.info(f"删除会话: {session_id}")
@@ -200,12 +150,7 @@ class SessionManager:
return False
def list_sessions(self) -> List[Dict]:
"""
列出所有会话
Returns:
List[Dict]: 会话列表摘要
"""
"""List sessions for the Session Manager instance."""
return [
{
"session_id": s.session_id,
@@ -217,12 +162,12 @@ class SessionManager:
]
def _is_session_expired(self, session: ChatSession) -> bool:
"""检查会话是否过期"""
"""Handle is session expired for this module for the Session Manager instance."""
current_time = int(time.time())
return (current_time - session.updated_at) > self.session_timeout
def _cleanup_expired_sessions(self) -> int:
"""清理过期会话"""
"""Handle cleanup expired sessions for this module for the Session Manager instance."""
expired_ids = [
sid for sid, session in self._sessions.items()
if self._is_session_expired(session)
@@ -237,10 +182,10 @@ class SessionManager:
return len(expired_ids)
def get_session_count(self) -> int:
"""获取当前会话数量"""
"""Return session count for the Session Manager instance."""
return len(self._sessions)
def clear_all_sessions(self):
"""清空所有会话"""
"""Handle clear all sessions for the Session Manager instance."""
self._sessions.clear()
logger.info("所有会话已清空")

View File

@@ -1,24 +1,19 @@
"""文档处理主流程 - 解析→摘要→分块→嵌入→入库"""
"""Provide service-layer logic for document processor."""
from __future__ import annotations
import os
from typing import List, Dict, Optional
from dataclasses import dataclass
from loguru import logger
import uuid
from pathlib import Path
from typing import Optional
from app.shared.bootstrap import get_document_command_service, get_retrieval_service
# Keep service responsibilities explicit so downstream behavior stays predictable.
from .parser.pdf_parser import PDFParser
from .parser.docx_parser import DocxParser
from .parser.mineru_parser import ParserOrchestrator
from .embedding.text_chunker import RegulationChunker, TextChunk
from .embedding.bge_m3_embedder import BGEM3Embedder, EmbeddingResult
from .storage.milvus_client import MilvusClient
from .llm.document_summarizer import DocumentSummarizer, DocumentSummary
from app.config.settings import settings
@dataclass
class ProcessingResult:
"""文档处理结果"""
"""Represent the Processing Result type."""
doc_id: str
doc_name: str
success: bool
@@ -30,87 +25,10 @@ class ProcessingResult:
class DocumentProcessor:
"""
文档处理服务 - 完整处理流程
流程:
1. 文档解析PDF/DOCX → Markdown
2. 智能分块(章节级+条款级)
3. LLM摘要生成可选
4. 向量嵌入BGE-M3 Dense+Sparse
5. 存储入库Milvus向量数据库
"""
def __init__(
self,
chunk_size: int = None,
embedding_model: str = None,
use_mineru: bool = True,
generate_summary: bool = False, # 默认不生成摘要节省约60秒
llm_provider: str = None,
llm_model: str = None
):
"""
初始化文档处理器
Args:
chunk_size: 分块大小
embedding_model: 嵌入模型名称
use_mineru: 是否优先使用MinerU解析
generate_summary: 是否生成文档摘要默认False可节省约60秒处理时间
llm_provider: LLM提供商
llm_model: LLM模型名称
"""
self.chunk_size = chunk_size or settings.chunk_size
self.embedding_model = embedding_model or settings.embedding_model
self.use_mineru = use_mineru
"""Represent the Document Processor type."""
def __init__(self, *args, generate_summary: bool = False, **kwargs):
"""Initialize the Document Processor instance."""
self.generate_summary = generate_summary
self.llm_provider = llm_provider or settings.llm_provider
self.llm_model = llm_model or settings.llm_model
# 初始化各组件
logger.info("初始化文档处理组件...")
# 解析器
self.parser = ParserOrchestrator()
# 分块器
self.chunker = RegulationChunker(chunk_size=self.chunk_size)
# 嵌入模型(延迟加载)
self.embedder: Optional[BGEM3Embedder] = None
# Milvus客户端延迟连接
self.milvus: Optional[MilvusClient] = None
# 摘要生成器(延迟加载)
self.summarizer: Optional[DocumentSummarizer] = None
logger.success("文档处理器初始化完成")
def _init_embedder(self):
"""延迟初始化嵌入模型"""
if self.embedder is None:
logger.info("加载嵌入模型...")
self.embedder = BGEM3Embedder(model_name=self.embedding_model)
def _init_milvus(self):
"""延迟初始化Milvus连接"""
if self.milvus is None:
logger.info("连接Milvus...")
self.milvus = MilvusClient()
self.milvus.connect()
self.milvus.create_collection(recreate=False)
self.milvus.load_collection()
def _init_summarizer(self):
"""延迟初始化摘要生成器"""
if self.summarizer is None:
logger.info("初始化摘要生成器...")
self.summarizer = DocumentSummarizer(
provider=self.llm_provider,
model=self.llm_model
)
def process(
self,
@@ -118,286 +36,51 @@ class DocumentProcessor:
doc_id: Optional[str] = None,
doc_name: Optional[str] = None,
regulation_type: str = "",
version: str = ""
version: str = "",
) -> ProcessingResult:
"""
处理单个文档
"""Handle process for the Document Processor instance."""
path = Path(file_path)
content = path.read_bytes()
result = get_document_command_service().upload_and_process(
doc_id=doc_id,
file_name=path.name,
content=content,
content_type="application/octet-stream",
doc_name=doc_name or path.name,
regulation_type=regulation_type,
version=version,
generate_summary=self.generate_summary,
)
return ProcessingResult(
doc_id=result.doc_id,
doc_name=result.doc_name,
success=result.status != "failed",
num_chunks=result.num_chunks,
message=result.message,
summary=result.summary,
summary_latency_ms=result.summary_latency_ms,
)
Args:
file_path: 文档文件路径
doc_id: 文档ID可选默认自动生成
doc_name: 文档名称(可选,默认从文件名获取)
regulation_type: 法规类型
version: 文档版本
Returns:
ProcessingResult: 处理结果
"""
# 生成或使用传入的文档ID
if doc_id is None:
doc_id = str(uuid.uuid4())[:8]
# 获取文档名称
if doc_name is None:
doc_name = os.path.basename(file_path)
logger.info(f"开始处理文档: {doc_name} (ID: {doc_id})")
# 初始化结果变量
summary = ""
summary_latency_ms = 0
try:
# 1. 文档解析
logger.info("Step 1: 文档解析")
markdown_text = self._parse_document(file_path)
if not markdown_text:
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message="文档解析失败,内容为空"
)
# 2. LLM摘要生成可选
if self.generate_summary:
logger.info("Step 2: LLM摘要生成")
self._init_summarizer()
summary_result = self.summarizer.summarize(
doc_name,
markdown_text,
regulation_type
)
if summary_result.is_success:
summary = summary_result.summary
summary_latency_ms = summary_result.latency_ms
logger.success(f"摘要生成完成: {summary_latency_ms}ms")
else:
logger.warning(f"摘要生成失败: {summary_result.error}")
else:
logger.info("Step 2: 跳过摘要生成未勾选节省约60秒")
# 3. 智能分块
logger.info("Step 3: 智能分块")
chunks = self._chunk_document(
markdown_text,
doc_id,
doc_name,
regulation_type,
version
)
if not chunks:
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message="分块失败,无有效内容",
markdown_text=markdown_text,
summary=summary
)
# 4. 向量嵌入
logger.info("Step 4: 向量嵌入")
embeddings = self._embed_chunks(chunks)
if embeddings is None:
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message="向量嵌入失败",
markdown_text=markdown_text,
summary=summary
)
# 5. 存储入库
logger.info("Step 5: 存储入库")
inserted_ids = self._insert_to_milvus(chunks, embeddings)
logger.success(f"文档处理完成: {doc_name}, 共{len(inserted_ids)}条记录")
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=True,
num_chunks=len(inserted_ids),
message="处理成功",
markdown_text=markdown_text,
summary=summary,
summary_latency_ms=summary_latency_ms
)
except Exception as e:
logger.error(f"文档处理失败: {e}")
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message=f"处理失败: {str(e)}"
)
def _parse_document(self, file_path: str) -> str:
"""解析文档"""
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == ".pdf":
# PDF文档解析优先MinerU回退PyMuPDF
markdown_text = self.parser.parse_pdf(file_path, prefer_mineru=self.use_mineru)
elif ext in [".docx", ".doc"]:
# Word文档解析
markdown_text = self.parser.parse_docx(file_path)
else:
logger.warning(f"不支持的文件类型: {ext}")
return ""
logger.success(f"文档解析完成,内容长度: {len(markdown_text)}字符")
return markdown_text
except Exception as e:
logger.error(f"文档解析失败: {e}")
return ""
def _chunk_document(
self,
markdown_text: str,
doc_id: str,
doc_name: str,
regulation_type: str,
version: str
) -> List[TextChunk]:
"""分块文档"""
try:
chunks = self.chunker.chunk_document(
markdown_text,
doc_id=doc_id,
doc_name=doc_name,
regulation_type=regulation_type,
version=version
)
logger.success(f"分块完成,共{len(chunks)}个chunk")
return chunks
except Exception as e:
logger.error(f"分块失败: {e}")
return []
def _embed_chunks(self, chunks: List[TextChunk]) -> Optional[EmbeddingResult]:
"""嵌入分块"""
try:
# 延迟初始化嵌入模型
self._init_embedder()
# 提取文本内容
texts = [chunk.content for chunk in chunks]
# 执行嵌入
embeddings = self.embedder.embed(texts)
logger.success(f"嵌入完成,向量数: {len(embeddings.dense_embeddings)}")
return embeddings
except Exception as e:
logger.error(f"嵌入失败: {e}")
return None
def _insert_to_milvus(
self,
chunks: List[TextChunk],
embeddings: EmbeddingResult
) -> List[int]:
"""插入Milvus"""
try:
# 延迟初始化Milvus
self._init_milvus()
# 执行插入
inserted_ids = self.milvus.insert_chunks(chunks, embeddings)
logger.success(f"入库完成,共{len(inserted_ids)}条记录")
return inserted_ids
except Exception as e:
logger.error(f"入库失败: {e}")
return []
def search(
self,
query: str,
top_k: int = 10,
filters: Optional[str] = None
) -> List[Dict]:
"""
检索法规内容
Args:
query: 查询文本
top_k: 返回结果数量
filters: 过滤条件
Returns:
List[Dict]: 检索结果
"""
logger.info(f"执行检索: {query}")
try:
# 延迟初始化
self._init_embedder()
self._init_milvus()
# 生成查询向量
query_embedding = self.embedder.embed_single(query)
# 执行混合检索
results = self.milvus.hybrid_search(
query_dense=query_embedding['dense'].tolist(),
query_sparse=query_embedding['sparse'],
top_k=top_k,
filters=filters
)
# 转换为字典格式
result_dicts = []
for r in results:
result_dicts.append({
"id": r.id,
"content": r.content,
"score": r.score,
"metadata": r.metadata
})
logger.success(f"检索完成,返回{len(result_dicts)}条结果")
return result_dicts
except Exception as e:
logger.error(f"检索失败: {e}")
return []
def search(self, query: str, top_k: int = 10, filters: str | None = None) -> list[dict]:
"""Handle search for the Document Processor instance."""
results = get_retrieval_service().retrieve(query=query, top_k=top_k, filters=filters)
return [
{
"id": item.chunk_id,
"content": item.content,
"score": item.score,
"metadata": {
"doc_id": item.doc_id,
"doc_name": item.doc_name,
"chunk_id": item.chunk_id,
"section_title": item.section_title,
"page_number": item.page_number,
**item.metadata,
},
}
for item in results
]
def close(self):
"""关闭连接"""
if self.milvus:
self.milvus.disconnect()
logger.info("文档处理器已关闭")
def process_document(
file_path: str,
doc_name: Optional[str] = None,
regulation_type: str = "",
version: str = ""
) -> ProcessingResult:
"""便捷函数:处理单个文档"""
processor = DocumentProcessor()
result = processor.process(file_path, doc_name, regulation_type, version)
processor.close()
return result
def search_regulations(query: str, top_k: int = 10) -> List[Dict]:
"""便捷函数:检索法规"""
processor = DocumentProcessor()
results = processor.search(query, top_k)
processor.close()
return results
"""Release the resources held by this component."""
return None

View File

@@ -1,6 +1,18 @@
"""嵌入和分块服务"""
"""Initialize the app.services.embedding package."""
# Keep package boundaries explicit so backend imports stay predictable.
from .text_chunker import RegulationChunker
from .bge_m3_embedder import BGEM3Embedder
__all__ = ["RegulationChunker", "BGEM3Embedder"]
def __getattr__(name: str):
"""Handle getattr for this module."""
if name == "RegulationChunker":
from .text_chunker import RegulationChunker
return RegulationChunker
if name == "BGEM3Embedder":
from .bge_m3_embedder import BGEM3Embedder
return BGEM3Embedder
raise AttributeError(name)

View File

@@ -1,4 +1,4 @@
"""BGE-M3嵌入服务 - Dense+Sparse双路向量生成"""
"""Provide service-layer logic for bge m3 embedder."""
import numpy as np
from typing import List, Dict, Optional, Union
@@ -6,43 +6,31 @@ from dataclasses import dataclass, field
from loguru import logger
import torch
import os
# Keep service responsibilities explicit so downstream behavior stays predictable.
# 设置HuggingFace镜像国内网络
# Keep service responsibilities explicit so downstream behavior stays predictable.
if 'HF_ENDPOINT' not in os.environ:
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# 本地模型路径(按优先级检查)
# Keep service responsibilities explicit so downstream behavior stays predictable.
LOCAL_MODEL_PATHS = [
os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # ModelScope下载路径
os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # HuggingFace本地路径
os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # Keep service responsibilities explicit so downstream behavior stays predictable.
os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # Keep service responsibilities explicit so downstream behavior stays predictable.
]
@dataclass
class EmbeddingResult:
"""嵌入结果"""
dense_embeddings: np.ndarray # Dense向量语义检索
sparse_embeddings: List[Dict[int, float]] # Sparse向量关键词匹配
"""Represent the Embedding Result type."""
dense_embeddings: np.ndarray # Keep service responsibilities explicit so downstream behavior stays predictable.
sparse_embeddings: List[Dict[int, float]] # Keep service responsibilities explicit so downstream behavior stays predictable.
texts: List[str]
dim: int = 1024
class BGEM3Embedder:
"""
BGE-M3多语言嵌入模型服务
BGE-M3是BAAI发布的多语言嵌入模型支持
- Dense向量用于语义相似度检索
- Sparse向量用于关键词精确匹配BM25风格
- ColBERT向量用于细粒度交互匹配可选
特点:
- 支持100+语言(中英双语优化)
- 8192 tokens超长上下文
- Dense+Sparse双路检索能力
GitHub: https://github.com/FlagOpen/FlagEmbedding
"""
"""Represent the B G E M3 Embedder type."""
def __init__(
self,
@@ -53,28 +41,18 @@ class BGEM3Embedder:
max_length: int = 8192,
local_model_path: Optional[str] = None
):
"""
初始化BGE-M3嵌入模型
Args:
model_name: 模型名称(如果使用本地路径,此参数会被忽略)
use_fp16: 是否使用FP16加速
device: 设备类型cuda/cpu默认自动选择
batch_size: 批处理大小
max_length: 最大序列长度
local_model_path: 本地模型路径(可选,优先使用)
"""
"""Initialize the B G E M3 Embedder instance."""
self.use_fp16 = use_fp16
self.batch_size = batch_size
self.max_length = max_length
# 确定模型路径(优先使用本地路径)
# Keep service responsibilities explicit so downstream behavior stays predictable.
if local_model_path and os.path.exists(local_model_path):
self.model_path = local_model_path
self.model_name = "local"
logger.info(f"使用本地模型路径: {local_model_path}")
else:
# 检查多个可能的本地路径
# Keep service responsibilities explicit so downstream behavior stays predictable.
found_local = False
for path in LOCAL_MODEL_PATHS:
if os.path.exists(path) and os.path.exists(os.path.join(path, "config.json")):
@@ -89,7 +67,7 @@ class BGEM3Embedder:
self.model_name = model_name
logger.info(f"使用远程模型: {model_name}")
# 自动选择设备
# Keep service responsibilities explicit so downstream behavior stays predictable.
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
@@ -101,7 +79,7 @@ class BGEM3Embedder:
self._load_model()
def _load_model(self):
"""加载嵌入模型"""
"""Handle load model for this module for the B G E M3 Embedder instance."""
try:
from FlagEmbedding import BGEM3FlagModel
@@ -127,18 +105,7 @@ class BGEM3Embedder:
return_sparse: bool = True,
return_colbert_vecs: bool = False
) -> EmbeddingResult:
"""
对文本列表生成嵌入向量
Args:
texts: 文本列表
return_dense: 是否返回Dense向量
return_sparse: 是否返回Sparse向量
return_colbert_vecs: 是否返回ColBERT向量
Returns:
EmbeddingResult: 嵌入结果
"""
"""Handle embed for the B G E M3 Embedder instance."""
if not texts:
logger.warning("输入文本列表为空")
return EmbeddingResult(
@@ -151,7 +118,7 @@ class BGEM3Embedder:
logger.info(f"开始嵌入{len(texts)}个文本块")
try:
# 执行嵌入
# Keep service responsibilities explicit so downstream behavior stays predictable.
embeddings = self.model.encode(
texts,
batch_size=self.batch_size,
@@ -161,11 +128,11 @@ class BGEM3Embedder:
return_colbert_vecs=return_colbert_vecs
)
# 提取结果
# Keep service responsibilities explicit so downstream behavior stays predictable.
dense_embeddings = embeddings.get('dense_vecs', np.array([]))
sparse_embeddings = embeddings.get('lexical_weights', [])
# 获取维度
# Keep service responsibilities explicit so downstream behavior stays predictable.
dim = dense_embeddings.shape[1] if len(dense_embeddings) > 0 else 1024
logger.success(f"嵌入完成,向量维度: {dim}")
@@ -182,15 +149,7 @@ class BGEM3Embedder:
raise
def embed_single(self, text: str) -> Dict[str, Union[np.ndarray, Dict]]:
"""
对单个文本生成嵌入向量
Args:
text: 输入文本
Returns:
Dict: 包含dense和sparse向量
"""
"""Embed single for the B G E M3 Embedder instance."""
result = self.embed([text])
return {
'dense': result.dense_embeddings[0],
@@ -199,25 +158,17 @@ class BGEM3Embedder:
}
def embed_dense(self, texts: List[str]) -> np.ndarray:
"""只生成Dense向量"""
"""Embed dense for the B G E M3 Embedder instance."""
result = self.embed(texts, return_sparse=False, return_colbert_vecs=False)
return result.dense_embeddings
def embed_sparse(self, texts: List[str]) -> List[Dict[int, float]]:
"""只生成Sparse向量"""
"""Embed sparse for the B G E M3 Embedder instance."""
result = self.embed(texts, return_dense=False, return_colbert_vecs=False)
return result.sparse_embeddings
def embed_query(self, query: str) -> Dict:
"""
对查询文本生成嵌入(用于检索)
Args:
query: 查询文本
Returns:
Dict: 包含dense和sparse向量
"""
"""Embed query for the B G E M3 Embedder instance."""
return self.embed_single(query)
def compute_similarity(
@@ -226,26 +177,16 @@ class BGEM3Embedder:
doc_embeddings: np.ndarray,
metric: str = "cosine"
) -> np.ndarray:
"""
计算查询与文档的相似度
Args:
query_embedding: 查询向量
doc_embeddings: 文档向量矩阵
metric: 相似度度量cosine/dot
Returns:
np.ndarray: 相似度分数数组
"""
"""Handle compute similarity for the B G E M3 Embedder instance."""
if metric == "cosine":
# 余弦相似度
# Keep service responsibilities explicit so downstream behavior stays predictable.
query_norm = np.linalg.norm(query_embedding)
doc_norms = np.linalg.norm(doc_embeddings, axis=1)
similarities = np.dot(doc_embeddings, query_embedding) / (doc_norms * query_norm)
elif metric == "dot":
# 点积相似度
# Keep service responsibilities explicit so downstream behavior stays predictable.
similarities = np.dot(doc_embeddings, query_embedding)
else:
@@ -258,17 +199,8 @@ class BGEM3Embedder:
query_sparse: Dict[int, float],
doc_sparse: Dict[int, float]
) -> float:
"""
计算Sparse向量的相似度BM25风格
Args:
query_sparse: 查询的Sparse向量词ID -> 权重)
doc_sparse: 文档的Sparse向量
Returns:
float: 相似度分数
"""
# 计算交集词的点积
"""Handle sparse similarity for the B G E M3 Embedder instance."""
# Keep service responsibilities explicit so downstream behavior stays predictable.
common_keys = set(query_sparse.keys()) & set(doc_sparse.keys())
score = sum(query_sparse[k] * doc_sparse[k] for k in common_keys)
@@ -280,7 +212,7 @@ def embed_texts(
model_name: str = "BAAI/bge-m3",
**kwargs
) -> EmbeddingResult:
"""便捷函数:对文本列表生成嵌入"""
"""Embed texts."""
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
return embedder.embed(texts)
@@ -290,6 +222,6 @@ def embed_single_text(
model_name: str = "BAAI/bge-m3",
**kwargs
) -> Dict:
"""便捷函数:对单个文本生成嵌入"""
"""Embed single text."""
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
return embedder.embed_single(text)

View File

@@ -1,51 +1,46 @@
"""智能分块器 - 章节级+条款级双粒度切割"""
"""Provide service-layer logic for text chunker."""
import re
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field
from loguru import logger
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass
class ChunkMetadata:
"""分块元数据"""
"""Represent the Chunk Metadata type."""
doc_id: str = ""
doc_name: str = ""
chunk_id: str = ""
section_number: str = "" # 章节编号(如 "第一章"
section_title: str = "" # 章节标题
clause_number: str = "" # 条款编号(如 "第一条"
section_number: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
section_title: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
clause_number: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
page_number: int = 0
start_position: int = 0 # 在原文中的起始位置
end_position: int = 0 # 在原文中的结束位置
regulation_type: str = "" # 法规类型
start_position: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
end_position: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
regulation_type: str = "" # Keep service responsibilities explicit so downstream behavior stays predictable.
version: str = ""
@dataclass
class TextChunk:
"""文本分块"""
"""Represent the Text Chunk type."""
content: str
metadata: ChunkMetadata
token_count: int = 0 # 估算的token数量
token_count: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
class RegulationChunker:
"""
法规文档智能分块器
"""Represent the Regulation Chunker type."""
实现章节级/条款级双粒度切割适配国标GB文档结构
- 国标文档通常有明确的层级结构:章 > 节 > 条
- 每个条款应作为一个独立的语义单元
- 保留条款完整性,避免跨条款截断
"""
# 法规标题模式
# Keep service responsibilities explicit so downstream behavior stays predictable.
CHAPTER_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+章\s+[^\n]+')
SECTION_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+节\s+[^\n]+')
CLAUSE_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+条\s')
# 条款子项模式
# Keep service responsibilities explicit so downstream behavior stays predictable.
SUB_ITEM_PATTERN = re.compile(r'^[\(][一二三四五六七八九十]+[\)]\s')
NUMBER_ITEM_PATTERN = re.compile(r'^[\d]+[\.、]\s')
@@ -56,15 +51,7 @@ class RegulationChunker:
max_chunk_size: int = 2048,
min_chunk_size: int = 100
):
"""
初始化分块器
Args:
chunk_size: 默认分块大小(字符数)
chunk_overlap: 分块重叠大小
max_chunk_size: 最大分块大小(防止单个条款过长)
min_chunk_size: 最小分块大小(防止碎片化)
"""
"""Initialize the Regulation Chunker instance."""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.max_chunk_size = max_chunk_size
@@ -78,30 +65,18 @@ class RegulationChunker:
regulation_type: str = "",
version: str = ""
) -> List[TextChunk]:
"""
对法规文档进行智能分块
Args:
markdown_text: Markdown格式的文档内容
doc_id: 文档ID
doc_name: 文档名称
regulation_type: 法规类型
version: 文档版本
Returns:
List[TextChunk]: 分块列表
"""
"""Handle chunk document for the Regulation Chunker instance."""
logger.info(f"开始分块文档: {doc_name}")
# 1. 按章节分割(一级分块)
# Keep service responsibilities explicit so downstream behavior stays predictable.
sections = self._split_by_sections(markdown_text)
# 2. 在每个章节内按条款分割(二级分块)
# Keep service responsibilities explicit so downstream behavior stays predictable.
chunks = []
global_position = 0
for section_num, section_title, section_content, section_start in sections:
# 在章节内按条款分割
# Keep service responsibilities explicit so downstream behavior stays predictable.
clause_chunks = self._split_by_clauses(
section_content,
section_num,
@@ -110,7 +85,7 @@ class RegulationChunker:
)
for chunk_content, clause_num, clause_title, start_pos, end_pos in clause_chunks:
# 处理过长的条款(进一步细分)
# Keep service responsibilities explicit so downstream behavior stays predictable.
if len(chunk_content) > self.max_chunk_size:
sub_chunks = self._split_long_clause(
chunk_content,
@@ -150,12 +125,7 @@ class RegulationChunker:
return chunks
def _split_by_sections(self, markdown_text: str) -> List[Tuple[str, str, str, int]]:
"""
按章节分割文档
Returns:
List of (section_number, section_title, section_content, start_position)
"""
"""Handle split by sections for this module for the Regulation Chunker instance."""
sections = []
lines = markdown_text.split('\n')
@@ -165,12 +135,12 @@ class RegulationChunker:
current_section_start = 0
for i, line in enumerate(lines):
# 检测章节标题
# Keep service responsibilities explicit so downstream behavior stays predictable.
chapter_match = self.CHAPTER_PATTERN.match(line.strip())
section_match = self.SECTION_PATTERN.match(line.strip())
if chapter_match or section_match:
# 保存上一个章节
# Keep service responsibilities explicit so downstream behavior stays predictable.
if current_section_content:
content = '\n'.join(current_section_content)
sections.append((
@@ -180,7 +150,7 @@ class RegulationChunker:
current_section_start
))
# 开始新章节
# Keep service responsibilities explicit so downstream behavior stays predictable.
current_section_start = sum(len(l) + 1 for l in lines[:i])
current_section_content = []
@@ -193,7 +163,7 @@ class RegulationChunker:
current_section_content.append(line)
# 保存最后一个章节
# Keep service responsibilities explicit so downstream behavior stays predictable.
if current_section_content:
content = '\n'.join(current_section_content)
sections.append((
@@ -203,7 +173,7 @@ class RegulationChunker:
current_section_start
))
# 如果没有检测到章节,将整个文档作为一个大章节
# Keep service responsibilities explicit so downstream behavior stays predictable.
if not sections:
sections.append((
"",
@@ -221,12 +191,7 @@ class RegulationChunker:
section_title: str,
section_start: int
) -> List[Tuple[str, str, str, int, int]]:
"""
在章节内按条款分割
Returns:
List of (content, clause_number, clause_title, start_position, end_position)
"""
"""Handle split by clauses for this module for the Regulation Chunker instance."""
clauses = []
lines = section_content.split('\n')
@@ -236,11 +201,11 @@ class RegulationChunker:
current_clause_start = section_start
for i, line in enumerate(lines):
# 检测条款标题
# Keep service responsibilities explicit so downstream behavior stays predictable.
clause_match = self.CLAUSE_PATTERN.match(line.strip())
if clause_match:
# 保存上一个条款
# Keep service responsibilities explicit so downstream behavior stays predictable.
if current_clause_content:
content = '\n'.join(current_clause_content)
end_pos = current_clause_start + len(content)
@@ -252,7 +217,7 @@ class RegulationChunker:
end_pos
))
# 开始新条款
# Keep service responsibilities explicit so downstream behavior stays predictable.
current_clause_start = section_start + sum(len(l) + 1 for l in lines[:i])
current_clause_content = []
current_clause_num = self._extract_clause_number(line.strip())
@@ -260,7 +225,7 @@ class RegulationChunker:
current_clause_content.append(line)
# 保存最后一个条款
# Keep service responsibilities explicit so downstream behavior stays predictable.
if current_clause_content:
content = '\n'.join(current_clause_content)
end_pos = current_clause_start + len(content)
@@ -272,7 +237,7 @@ class RegulationChunker:
end_pos
))
# 如果没有检测到条款,将整个章节作为一个条款
# Keep service responsibilities explicit so downstream behavior stays predictable.
if not clauses:
clauses.append((
section_content,
@@ -290,15 +255,11 @@ class RegulationChunker:
clause_num: str,
clause_title: str
) -> List[Tuple[str, int, int]]:
"""
分割过长的条款内容
按条款子项或段落分割,保持语义完整性
"""
"""Handle split long clause for this module for the Regulation Chunker instance."""
sub_chunks = []
lines = content.split('\n')
# 检测是否有子项结构
# Keep service responsibilities explicit so downstream behavior stays predictable.
has_sub_items = any(
self.SUB_ITEM_PATTERN.match(line.strip()) or
self.NUMBER_ITEM_PATTERN.match(line.strip())
@@ -306,7 +267,7 @@ class RegulationChunker:
)
if has_sub_items:
# 按子项分割
# Keep service responsibilities explicit so downstream behavior stays predictable.
current_sub_content = []
current_sub_start = 0
@@ -326,14 +287,14 @@ class RegulationChunker:
current_sub_content.append(line)
# 保存最后一个子项
# Keep service responsibilities explicit so downstream behavior stays predictable.
if current_sub_content:
sub_content = '\n'.join(current_sub_content)
sub_end = current_sub_start + len(sub_content)
sub_chunks.append((sub_content, current_sub_start, sub_end))
else:
# 按段落分割(滑动窗口)
# Keep service responsibilities explicit so downstream behavior stays predictable.
paragraphs = []
current_para = []
@@ -348,7 +309,7 @@ class RegulationChunker:
if current_para:
paragraphs.append('\n'.join(current_para))
# 合并段落直到达到chunk_size
# Keep service responsibilities explicit so downstream behavior stays predictable.
current_chunk = []
current_length = 0
chunk_start = 0
@@ -365,7 +326,7 @@ class RegulationChunker:
current_chunk.append(para)
current_length += len(para)
# 保存最后一个chunk
# Keep service responsibilities explicit so downstream behavior stays predictable.
if current_chunk:
chunk_content = '\n'.join(current_chunk)
chunk_end = chunk_start + len(chunk_content)
@@ -374,13 +335,13 @@ class RegulationChunker:
return sub_chunks
def _extract_title(self, header_line: str) -> str:
"""从标题行提取标题内容"""
# 移除"第X章"、"第X节"前缀
"""Handle extract title for this module for the Regulation Chunker instance."""
# Keep service responsibilities explicit so downstream behavior stays predictable.
title = re.sub(r'^第[一二三四五六七八九十百]+[章节]\s+', '', header_line)
return title.strip()
def _extract_clause_number(self, clause_line: str) -> str:
"""从条款行提取条款编号"""
"""Handle extract clause number for this module for the Regulation Chunker instance."""
match = self.CLAUSE_PATTERN.match(clause_line)
if match:
return match.group(0).strip()
@@ -399,14 +360,14 @@ class RegulationChunker:
regulation_type: str,
version: str
) -> TextChunk:
"""创建文本分块"""
# 清理内容
"""Handle create chunk for this module for the Regulation Chunker instance."""
# Keep service responsibilities explicit so downstream behavior stays predictable.
content = content.strip()
# 计算估算token数中文约1.5字符/token
token_count = int(len(content) * 0.7) # 简化估算
# Keep service responsibilities explicit so downstream behavior stays predictable.
token_count = int(len(content) * 0.7) # Keep service responsibilities explicit so downstream behavior stays predictable.
# 生成chunk_id
# Keep service responsibilities explicit so downstream behavior stays predictable.
chunk_id = f"{doc_id}_{section_num}_{clause_num}_{start_pos}"
metadata = ChunkMetadata(
@@ -437,7 +398,7 @@ def chunk_regulation_document(
version: str = "",
chunk_size: int = 512
) -> List[TextChunk]:
"""便捷函数:对法规文档进行分块"""
"""Handle chunk regulation document."""
chunker = RegulationChunker(chunk_size=chunk_size)
return chunker.chunk_document(
markdown_text,

View File

@@ -1,14 +1,36 @@
"""LLM服务模块"""
"""Initialize the app.services.llm package."""
from .llm_factory import LLMFactory, get_llm_client
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse
from .deepseek_client import DeepSeekClient
from .llm_factory import LLMFactory, get_llm_client
from .qwen_client import QwenClient, QwenVLClient
from .document_summarizer import DocumentSummarizer, summarize_document, DocumentSummary
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = [
"LLMFactory", "get_llm_client",
"BaseLLMClient", "LLMResponse", "LLMConfig", "LLMProvider",
"DeepSeekClient", "QwenClient", "QwenVLClient",
"DocumentSummarizer", "summarize_document", "DocumentSummary"
"LLMFactory",
"get_llm_client",
"BaseLLMClient",
"LLMResponse",
"LLMConfig",
"LLMProvider",
"DeepSeekClient",
"QwenClient",
"QwenVLClient",
"DocumentSummarizer",
"summarize_document",
"DocumentSummary",
]
def __getattr__(name: str):
"""Handle getattr for this module."""
if name in {"DocumentSummarizer", "summarize_document", "DocumentSummary"}:
from .document_summarizer import DocumentSummarizer, DocumentSummary, summarize_document
return {
"DocumentSummarizer": DocumentSummarizer,
"summarize_document": summarize_document,
"DocumentSummary": DocumentSummary,
}[name]
raise AttributeError(name)

View File

@@ -1,13 +1,15 @@
"""LLM客户端基类 - 统一接口定义"""
"""Provide service-layer logic for base client."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any
from enum import Enum
# Keep provider-specific behavior explicit so debugging stays straightforward.
class LLMProvider(Enum):
"""LLM提供商"""
"""Define the L L M Provider enumeration."""
DEEPSEEK = "deepseek"
QWEN = "qwen"
QWEN_VL = "qwen_vl"
@@ -15,7 +17,7 @@ class LLMProvider(Enum):
@dataclass
class LLMResponse:
"""LLM响应结果"""
"""Represent the L L M Response type."""
content: str
model: str
usage: Dict[str, int] = field(default_factory=dict)
@@ -25,12 +27,13 @@ class LLMResponse:
@property
def is_success(self) -> bool:
"""Return whether success for the L L M Response instance."""
return self.error is None
@dataclass
class LLMConfig:
"""LLM配置"""
"""Define configuration for l l m config."""
provider: LLMProvider
model: str
api_key: str
@@ -38,19 +41,20 @@ class LLMConfig:
max_tokens: int = 4096
temperature: float = 0.7
top_p: float = 0.9
timeout: int = 300 # 默认超时300秒摘要/Skills生成可能需要较长时间
timeout: int = 300 # Keep provider-specific behavior explicit so debugging stays straightforward.
class BaseLLMClient(ABC):
"""LLM客户端基类"""
"""Represent the Base L L M Client type."""
def __init__(self, config: LLMConfig):
"""Initialize the Base L L M Client instance."""
self.config = config
self._client = None
@abstractmethod
def _init_client(self):
"""初始化客户端"""
"""Handle init client for this module for the Base L L M Client instance."""
pass
@abstractmethod
@@ -61,18 +65,7 @@ class BaseLLMClient(ABC):
temperature: Optional[float] = None,
**kwargs
) -> LLMResponse:
"""
对话补全
Args:
messages: 对话消息列表 [{"role": "user/assistant/system", "content": "..."}]
max_tokens: 最大输出token数
temperature: 温度参数
**kwargs: 其他参数
Returns:
LLMResponse: 响应结果
"""
"""Handle chat for the Base L L M Client instance."""
pass
def complete(
@@ -83,18 +76,7 @@ class BaseLLMClient(ABC):
temperature: Optional[float] = None,
**kwargs
) -> LLMResponse:
"""
单轮补全(便捷方法)
Args:
prompt: 用户输入
system_prompt: 系统提示词
max_tokens: 最大输出token数
temperature: 温度参数
Returns:
LLMResponse: 响应结果
"""
"""Handle complete for the Base L L M Client instance."""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
@@ -104,12 +86,12 @@ class BaseLLMClient(ABC):
@abstractmethod
def get_available_models(self) -> List[str]:
"""获取可用模型列表"""
"""Return available models for the Base L L M Client instance."""
pass
def estimate_tokens(self, text: str) -> int:
"""估算文本token数粗略估计"""
# 中文字符约1.5 token英文约0.25 token
"""Handle estimate tokens for the Base L L M Client instance."""
# Keep provider-specific behavior explicit so debugging stays straightforward.
chinese_chars = sum(1 for c in text if '' <= c <= '鿿')
other_chars = len(text) - chinese_chars
return int(chinese_chars * 1.5 + other_chars * 0.25)

View File

@@ -1,4 +1,4 @@
"""DeepSeek LLM客户端 - OpenAI兼容API"""
"""Provide service-layer logic for deepseek client."""
import time
from typing import List, Dict, Optional
@@ -6,20 +6,12 @@ from loguru import logger
import httpx
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
# Keep provider-specific behavior explicit so debugging stays straightforward.
class DeepSeekClient(BaseLLMClient):
"""
DeepSeek API客户端OpenAI兼容格式
支持模型:
- deepseek-chat
- deepseek-coder
- deepseek-reasoner
- deepseek-v3
- deepseek-v3.2
- deepseek-v4-flash
"""
"""Represent the Deep Seek Client type."""
SUPPORTED_MODELS = [
"deepseek-chat",
@@ -31,13 +23,14 @@ class DeepSeekClient(BaseLLMClient):
]
def __init__(self, config: LLMConfig):
"""Initialize the Deep Seek Client instance."""
if config.provider != LLMProvider.DEEPSEEK:
raise ValueError(f"配置provider应为DEEPSEEK实际为{config.provider}")
super().__init__(config)
self._init_client()
def _init_client(self):
"""初始化HTTP客户端"""
"""Handle init client for this module for the Deep Seek Client instance."""
self._client = httpx.Client(
base_url=self.config.base_url,
headers={
@@ -55,7 +48,7 @@ class DeepSeekClient(BaseLLMClient):
temperature: Optional[float] = None,
**kwargs
) -> LLMResponse:
"""对话补全"""
"""Handle chat for the Deep Seek Client instance."""
start_time = time.time()
try:
@@ -103,11 +96,11 @@ class DeepSeekClient(BaseLLMClient):
)
def get_available_models(self) -> List[str]:
"""获取可用模型列表"""
"""Return available models for the Deep Seek Client instance."""
return self.SUPPORTED_MODELS
def close(self):
"""关闭客户端"""
"""Release the resources held by this component."""
if self._client:
self._client.close()
@@ -118,7 +111,7 @@ def create_deepseek_client(
base_url: str = "http://6.86.80.4:30080/v1",
**kwargs
) -> DeepSeekClient:
"""便捷函数创建DeepSeek客户端"""
"""Create deepseek client."""
config = LLMConfig(
provider=LLMProvider.DEEPSEEK,
model=model,

View File

@@ -1,17 +1,20 @@
"""文档摘要生成服务 - LLM生成法规文档摘要"""
"""Provide service-layer logic for document summarizer."""
from typing import Dict, Optional
from dataclasses import dataclass
from loguru import logger
from app.services.llm import get_llm_client, BaseLLMClient
from app.services.llm.base_client import BaseLLMClient
from app.services.llm.llm_factory import get_llm_client
from app.services.rag.prompt_templates import get_prompt_template
from app.config.settings import settings
# Keep provider-specific behavior explicit so debugging stays straightforward.
@dataclass
class DocumentSummary:
"""文档摘要结果"""
"""Represent the Document Summary type."""
doc_name: str
summary: str
applicable_scope: str
@@ -24,24 +27,12 @@ class DocumentSummary:
@property
def is_success(self) -> bool:
"""Return whether success for the Document Summary instance."""
return self.error is None
class DocumentSummarizer:
"""
文档摘要生成器
功能:
- 生成法规文档的核心要点摘要
- 提取适用范围
- 突出关键条款
- 列出合规要点
使用示例:
summarizer = DocumentSummarizer()
result = summarizer.summarize("GB 7258-2017", markdown_content)
print(result.summary)
"""
"""Represent the Document Summarizer type."""
def __init__(
self,
@@ -49,25 +40,18 @@ class DocumentSummarizer:
model: str = None,
max_tokens: int = None
):
"""
初始化摘要生成器
Args:
provider: LLM提供商
model: LLM模型名称
max_tokens: 最大输出token数
"""
"""Initialize the Document Summarizer instance."""
self.provider = provider or settings.llm_provider
self.model = model or settings.llm_model
self.max_tokens = max_tokens or settings.rag_summary_max_tokens
# LLM客户端延迟加载
# Keep provider-specific behavior explicit so debugging stays straightforward.
self.llm: Optional[BaseLLMClient] = None
logger.info(f"摘要生成器初始化: provider={self.provider}, model={self.model}")
def _init_llm(self):
"""延迟初始化LLM"""
"""Handle init llm for this module for the Document Summarizer instance."""
if self.llm is None:
self.llm = get_llm_client(
provider=self.provider,
@@ -81,18 +65,7 @@ class DocumentSummarizer:
regulation_type: str = "",
max_tokens: Optional[int] = None
) -> DocumentSummary:
"""
生成文档摘要
Args:
doc_name: 文档名称
content: 文档内容Markdown格式
regulation_type: 法规类型
max_tokens: 最大输出token数
Returns:
DocumentSummary: 摘要结果
"""
"""Handle summarize for the Document Summarizer instance."""
import time
start_time = time.time()
@@ -101,23 +74,23 @@ class DocumentSummarizer:
try:
self._init_llm()
# 使用摘要模板
# Keep provider-specific behavior explicit so debugging stays straightforward.
template = get_prompt_template("document_summary")
# 构建用户消息
# Keep provider-specific behavior explicit so debugging stays straightforward.
user_content = template.user_template.format(
doc_name=doc_name,
content=content[:8000] # 截取前8000字符避免超出token限制
content=content[:8000] # Keep provider-specific behavior explicit so debugging stays straightforward.
)
# 调用LLM
# Keep provider-specific behavior explicit so debugging stays straightforward.
response = self.llm.chat(
messages=[
{"role": "system", "content": template.system_prompt},
{"role": "user", "content": user_content}
],
max_tokens=max_tokens or self.max_tokens,
temperature=0.3 # 低温度保证摘要准确性
temperature=0.3 # Keep provider-specific behavior explicit so debugging stays straightforward.
)
latency_ms = int((time.time() - start_time) * 1000)
@@ -135,7 +108,7 @@ class DocumentSummarizer:
error=response.error
)
# 解析摘要结构
# Keep provider-specific behavior explicit so debugging stays straightforward.
summary_data = self._parse_summary(response.content)
logger.success(f"摘要生成完成: {doc_name}, {latency_ms}ms")
@@ -166,7 +139,7 @@ class DocumentSummarizer:
)
def _parse_summary(self, content: str) -> Dict:
"""解析摘要内容(提取结构化信息)"""
"""Handle parse summary for this module for the Document Summarizer instance."""
result = {
"summary": content,
"applicable_scope": "",
@@ -175,26 +148,26 @@ class DocumentSummarizer:
"compliance_points": []
}
# 简单解析(提取关键信息)
# Keep provider-specific behavior explicit so debugging stays straightforward.
lines = content.split("\n")
for line in lines:
line = line.strip()
# 提取适用范围
# Keep provider-specific behavior explicit so debugging stays straightforward.
if "适用范围" in line or "适用对象" in line:
result["applicable_scope"] = line.split("")[-1].strip() if "" in line else line.split(":")[-1].strip()
# 提取关键条款
# Keep provider-specific behavior explicit so debugging stays straightforward.
if line.startswith("- 【条款") or line.startswith("【条款"):
result["key_clauses"].append(line)
# 提取关键术语
# Keep provider-specific behavior explicit so debugging stays straightforward.
if "关键术语" in line or "术语定义" in line:
# 继续读取后续几行
# Keep provider-specific behavior explicit so debugging stays straightforward.
pass
# 提取合规要点
# Keep provider-specific behavior explicit so debugging stays straightforward.
if "合规要点" in line or "必须满足" in line:
pass
@@ -204,15 +177,7 @@ class DocumentSummarizer:
self,
documents: list
) -> list:
"""
批量生成摘要
Args:
documents: 文档列表 [{"doc_name": str, "content": str}, ...]
Returns:
list: 摘要结果列表
"""
"""Handle batch summarize for the Document Summarizer instance."""
results = []
for doc in documents:
result = self.summarize(doc["doc_name"], doc["content"])
@@ -225,6 +190,6 @@ def summarize_document(
content: str,
**kwargs
) -> DocumentSummary:
"""便捷函数:生成文档摘要"""
"""Handle summarize document."""
summarizer = DocumentSummarizer(**kwargs)
return summarizer.summarize(doc_name, content)

View File

@@ -1,4 +1,4 @@
"""LLM工厂 - 统一创建和管理LLM客户端"""
"""Provide service-layer logic for llm factory."""
from typing import Optional, Dict, Any
from loguru import logger
@@ -7,16 +7,18 @@ from functools import lru_cache
from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse
from .deepseek_client import DeepSeekClient
from .qwen_client import QwenClient, QwenVLClient
# Keep provider-specific behavior explicit so debugging stays straightforward.
# 默认模型映射
# Keep provider-specific behavior explicit so debugging stays straightforward.
DEFAULT_MODELS = {
LLMProvider.DEEPSEEK: "deepseek-v4-flash",
LLMProvider.QWEN: "qwen3.5-flash",
LLMProvider.QWEN_VL: "qwen3-vl-plus"
}
# API基础URL使用统一代理服务
# Keep provider-specific behavior explicit so debugging stays straightforward.
DEFAULT_BASE_URLS = {
LLMProvider.DEEPSEEK: "http://6.86.80.4:30080/v1",
LLMProvider.QWEN: "http://6.86.80.4:30080/v1",
@@ -25,31 +27,13 @@ DEFAULT_BASE_URLS = {
class LLMFactory:
"""
LLM客户端工厂支持全局缓存
"""Represent the L L M Factory type."""
支持的提供商和模型:
- DeepSeek: deepseek-chat (DeepSeek-V3), deepseek-coder
- Qwen: qwen-turbo, qwen-plus, qwen-max, qwen-long
- QwenVL: qwen-vl-plus, qwen-vl-max (多模态)
使用示例:
factory = LLMFactory()
# 使用默认配置
client = factory.create("deepseek")
# 自定义配置
client = factory.create("qwen", model="qwen-max", temperature=0.5)
# 调用LLM
response = client.complete("你好,介绍一下自己")
"""
# 全局客户端缓存(类级别,跨实例共享)
# Keep provider-specific behavior explicit so debugging stays straightforward.
_global_instances: Dict[str, BaseLLMClient] = {}
def __init__(self):
"""Initialize the L L M Factory instance."""
self._config_cache: Dict[str, Any] = {}
def create(
@@ -62,24 +46,10 @@ class LLMFactory:
temperature: float = 0.7,
**kwargs
) -> BaseLLMClient:
"""
创建LLM客户端
Args:
provider: 提供商名称 ("deepseek", "qwen", "qwen_vl")
api_key: API密钥如未提供从环境变量获取
model: 模型名称(如未提供,使用默认模型)
base_url: API基础URL
max_tokens: 最大输出token数
temperature: 温度参数
**kwargs: 其他配置参数
Returns:
BaseLLMClient: LLM客户端实例
"""
"""Handle create for the L L M Factory instance."""
provider_enum = self._parse_provider(provider)
# 获取配置
# Keep provider-specific behavior explicit so debugging stays straightforward.
api_key = api_key or self._get_api_key(provider_enum)
model = model or DEFAULT_MODELS.get(provider_enum)
base_url = base_url or DEFAULT_BASE_URLS.get(provider_enum)
@@ -87,7 +57,7 @@ class LLMFactory:
if not api_key:
raise ValueError(f"缺少API密钥请设置环境变量或传入api_key参数")
# 检查全局缓存
# Keep provider-specific behavior explicit so debugging stays straightforward.
cache_key = f"{provider}_{model}"
if cache_key in LLMFactory._global_instances:
logger.debug(f"使用缓存的LLM客户端: {cache_key}")
@@ -103,17 +73,17 @@ class LLMFactory:
**kwargs
)
# 创建客户端
# Keep provider-specific behavior explicit so debugging stays straightforward.
client = self._create_client(config)
# 缓存到全局实例
# Keep provider-specific behavior explicit so debugging stays straightforward.
LLMFactory._global_instances[cache_key] = client
logger.info(f"LLM客户端创建成功并缓存: {provider} - {model}")
return client
def _parse_provider(self, provider: str) -> LLMProvider:
"""解析提供商名称"""
"""Handle parse provider for this module for the L L M Factory instance."""
provider_map = {
"deepseek": LLMProvider.DEEPSEEK,
"deepseek-v3": LLMProvider.DEEPSEEK,
@@ -137,7 +107,7 @@ class LLMFactory:
return provider_map[provider_lower]
def _get_api_key(self, provider: LLMProvider) -> Optional[str]:
"""从环境变量获取API密钥"""
"""Handle get api key for this module for the L L M Factory instance."""
import os
key_map = {
@@ -154,7 +124,7 @@ class LLMFactory:
return None
def _create_client(self, config: LLMConfig) -> BaseLLMClient:
"""创建具体客户端"""
"""Handle create client for this module for the L L M Factory instance."""
client_map = {
LLMProvider.DEEPSEEK: DeepSeekClient,
LLMProvider.QWEN: QwenClient,
@@ -168,14 +138,14 @@ class LLMFactory:
return client_class(config)
def get_cached(self, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
"""获取缓存的客户端"""
"""Return cached for the L L M Factory instance."""
provider_enum = self._parse_provider(provider)
model = model or DEFAULT_MODELS.get(provider_enum)
cache_key = f"{provider}_{model}"
return LLMFactory._global_instances.get(cache_key)
def list_available_providers(self) -> Dict[str, list]:
"""列出可用的提供商和模型"""
"""List available providers for the L L M Factory instance."""
return {
"deepseek": DeepSeekClient.SUPPORTED_MODELS,
"qwen": QwenClient.SUPPORTED_MODELS,
@@ -184,12 +154,7 @@ class LLMFactory:
@classmethod
def preload_clients(cls, providers: list = None):
"""
预加载LLM客户端应用启动时调用
Args:
providers: 要预加载的提供商列表默认加载qwen和deepseek
"""
"""Handle preload clients for the L L M Factory instance."""
if providers is None:
providers = ["qwen", "deepseek"]
@@ -203,9 +168,9 @@ class LLMFactory:
@classmethod
def get_global_client(cls, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
"""获取全局缓存的客户端"""
"""Return global client for the L L M Factory instance."""
provider_lower = provider.lower()
# 处理模型名作为provider的情况(如 qwen3.5-flash
# Keep provider-specific behavior explicit so debugging stays straightforward.
if provider_lower.startswith("qwen"):
provider_lower = "qwen"
model = model or DEFAULT_MODELS.get(LLMProvider.QWEN if provider_lower == "qwen" else LLMProvider.DEEPSEEK)
@@ -214,7 +179,7 @@ class LLMFactory:
@classmethod
def cleanup(cls):
"""清理所有缓存的客户端"""
"""Handle cleanup for the L L M Factory instance."""
for cache_key, client in cls._global_instances.items():
try:
client.close()
@@ -227,7 +192,7 @@ class LLMFactory:
@lru_cache
def get_llm_factory() -> LLMFactory:
"""获取LLM工厂实例缓存"""
"""Return llm factory."""
return LLMFactory()
@@ -236,20 +201,10 @@ def get_llm_client(
model: Optional[str] = None,
**kwargs
) -> BaseLLMClient:
"""
便捷函数获取LLM客户端优先使用缓存
Args:
provider: 提供商名称
model: 模型名称
**kwargs: 其他配置
Returns:
BaseLLMClient: LLM客户端实例
"""
"""Return llm client."""
factory = get_llm_factory()
# 先尝试获取缓存的实例
# Keep provider-specific behavior explicit so debugging stays straightforward.
cached = factory.get_cached(provider, model)
if cached:
return cached

View File

@@ -1,4 +1,4 @@
"""Qwen LLM客户端 - 支持OpenAI兼容API格式"""
"""Provide service-layer logic for qwen client."""
import time
import json
@@ -7,21 +7,12 @@ from loguru import logger
import httpx
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
# Keep provider-specific behavior explicit so debugging stays straightforward.
class QwenClient(BaseLLMClient):
"""
Qwen API客户端OpenAI兼容格式
支持通过new-api等代理服务调用
- qwen-turbo
- qwen-plus
- qwen-max
- qwen3.5-flash (推荐:快速响应)
- qwen3.5-plus
- qwen-long
- qwen2.5系列
"""
"""Represent the Qwen Client type."""
SUPPORTED_MODELS = [
"qwen-turbo",
@@ -39,14 +30,15 @@ class QwenClient(BaseLLMClient):
]
def __init__(self, config: LLMConfig):
"""Initialize the Qwen Client instance."""
if config.provider not in [LLMProvider.QWEN, LLMProvider.QWEN_VL]:
raise ValueError(f"配置provider应为Qwen实际为{config.provider}")
super().__init__(config)
self._init_client()
def _init_client(self):
"""初始化HTTP客户端"""
# OpenAI兼容API格式
"""Handle init client for this module for the Qwen Client instance."""
# Keep provider-specific behavior explicit so debugging stays straightforward.
self._client = httpx.Client(
base_url=self.config.base_url,
headers={
@@ -64,11 +56,11 @@ class QwenClient(BaseLLMClient):
temperature: Optional[float] = None,
**kwargs
) -> LLMResponse:
"""对话补全OpenAI兼容格式"""
"""Handle chat for the Qwen Client instance."""
start_time = time.time()
try:
# OpenAI兼容格式的请求体
# Keep provider-specific behavior explicit so debugging stays straightforward.
payload = {
"model": self.config.model,
"messages": messages,
@@ -78,7 +70,7 @@ class QwenClient(BaseLLMClient):
"stream": False
}
# OpenAI兼容接口路径
# Keep provider-specific behavior explicit so debugging stays straightforward.
response = self._client.post("/chat/completions", json=payload)
response.raise_for_status()
@@ -86,7 +78,7 @@ class QwenClient(BaseLLMClient):
latency_ms = int((time.time() - start_time) * 1000)
# OpenAI兼容格式的响应解析
# Keep provider-specific behavior explicit so debugging stays straightforward.
choices = data.get("choices", [{}])
message = choices[0].get("message", {})
@@ -121,42 +113,33 @@ class QwenClient(BaseLLMClient):
temperature: Optional[float] = None,
**kwargs
) -> Generator[str, None, None]:
"""
流式对话补全SSE格式
Yields:
str: 每次返回一个文本片段
使用示例:
for chunk in client.stream_chat(messages):
print(chunk, end="", flush=True)
"""
"""Stream chat for the Qwen Client instance."""
try:
# OpenAI兼容格式的请求体启用流式输出
# Keep provider-specific behavior explicit so debugging stays straightforward.
payload = {
"model": self.config.model,
"messages": messages,
"max_tokens": max_tokens or self.config.max_tokens,
"temperature": temperature or self.config.temperature,
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": True # 启用流式输出
"stream": True # Keep provider-specific behavior explicit so debugging stays straightforward.
}
# 使用stream模式发送请求
# Keep provider-specific behavior explicit so debugging stays straightforward.
with self._client.stream("POST", "/chat/completions", json=payload) as response:
for line in response.iter_lines():
if line:
line = line.strip()
# SSE格式: data: {...}
# Keep provider-specific behavior explicit so debugging stays straightforward.
if line.startswith("data: "):
data_str = line[6:] # 移除 "data: " 前缀
data_str = line[6:] # Keep provider-specific behavior explicit so debugging stays straightforward.
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if not choices:
continue # 跳过空的choices
continue # Keep provider-specific behavior explicit so debugging stays straightforward.
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
@@ -179,41 +162,27 @@ class QwenClient(BaseLLMClient):
temperature: Optional[float] = None,
**kwargs
) -> AsyncGenerator[str, None]:
"""
异步流式对话补全用于FastAPI SSE响应
Yields:
str: 每次返回一个文本片段
"""
"""Handle async stream chat for the Qwen Client instance."""
import asyncio
# 使用同步流式方法,包装为异步
# Keep provider-specific behavior explicit so debugging stays straightforward.
for chunk in self.stream_chat(messages, max_tokens, temperature, **kwargs):
yield chunk
# 给async循环一个小延迟让其他任务有机会执行
# Keep provider-specific behavior explicit so debugging stays straightforward.
await asyncio.sleep(0)
def get_available_models(self) -> List[str]:
"""获取可用模型列表"""
"""Return available models for the Qwen Client instance."""
return self.SUPPORTED_MODELS
def close(self):
"""关闭客户端"""
"""Release the resources held by this component."""
if self._client:
self._client.close()
class QwenVLClient(BaseLLMClient):
"""
Qwen VL多模态客户端OpenAI兼容格式
支持模型:
- qwen-vl-plus
- qwen-vl-max
- qwen3-vl-plus
- qwen2-vl-7b-instruct
- qwen2-vl-72b-instruct
"""
"""Represent the Qwen V L Client type."""
SUPPORTED_MODELS = [
"qwen-vl-plus",
@@ -224,13 +193,14 @@ class QwenVLClient(BaseLLMClient):
]
def __init__(self, config: LLMConfig):
"""Initialize the Qwen V L Client instance."""
if config.provider != LLMProvider.QWEN_VL:
raise ValueError(f"配置provider应为QWEN_VL实际为{config.provider}")
super().__init__(config)
self._init_client()
def _init_client(self):
"""初始化HTTP客户端"""
"""Handle init client for this module for the Qwen V L Client instance."""
self._client = httpx.Client(
base_url=self.config.base_url,
headers={
@@ -248,21 +218,11 @@ class QwenVLClient(BaseLLMClient):
temperature: Optional[float] = None,
**kwargs
) -> LLMResponse:
"""多模态对话补全OpenAI兼容格式
支持图片输入,消息格式:
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
{"type": "text", "text": "描述这张图片"}
]
}
"""
"""Handle chat for the Qwen V L Client instance."""
start_time = time.time()
try:
# OpenAI兼容格式的请求体
# Keep provider-specific behavior explicit so debugging stays straightforward.
payload = {
"model": self.config.model,
"messages": messages,
@@ -312,7 +272,7 @@ class QwenVLClient(BaseLLMClient):
temperature: Optional[float] = None,
**kwargs
) -> Generator[str, None, None]:
"""流式多模态对话补全"""
"""Stream chat for the Qwen V L Client instance."""
try:
payload = {
"model": self.config.model,
@@ -335,7 +295,7 @@ class QwenVLClient(BaseLLMClient):
data = json.loads(data_str)
choices = data.get("choices", [])
if not choices:
continue # 跳过空的choices
continue # Keep provider-specific behavior explicit so debugging stays straightforward.
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
@@ -348,11 +308,11 @@ class QwenVLClient(BaseLLMClient):
yield f"[ERROR: {str(e)}]"
def get_available_models(self) -> List[str]:
"""获取可用模型列表"""
"""Return available models for the Qwen V L Client instance."""
return self.SUPPORTED_MODELS
def close(self):
"""关闭客户端"""
"""Release the resources held by this component."""
if self._client:
self._client.close()
@@ -363,7 +323,7 @@ def create_qwen_client(
base_url: str = "http://6.86.80.4:30080/v1",
**kwargs
) -> QwenClient:
"""便捷函数创建Qwen客户端"""
"""Create qwen client."""
config = LLMConfig(
provider=LLMProvider.QWEN,
model=model,
@@ -380,7 +340,7 @@ def create_qwen_vl_client(
base_url: str = "http://6.86.80.4:30080/v1",
**kwargs
) -> QwenVLClient:
"""便捷函数创建QwenVL客户端"""
"""Create qwen vl client."""
config = LLMConfig(
provider=LLMProvider.QWEN_VL,
model=model,

View File

@@ -1,12 +1,12 @@
"""
Mock数据服务 - 提供预设假数据供前后端对接测试
"""
"""Provide service-layer logic for mock data."""
from datetime import datetime
from typing import Dict, List, Any
import uuid
# Keep service responsibilities explicit so downstream behavior stays predictable.
# 预设法规文档列表
# Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_DOCUMENTS: List[Dict[str, Any]] = [
{
"id": "doc-001",
@@ -45,7 +45,7 @@ MOCK_DOCUMENTS: List[Dict[str, Any]] = [
},
]
# 预设快捷问题
# Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_QUICK_QUESTIONS: List[Dict[str, str]] = [
{"id": "q1", "question": "电动自行车需要上牌照吗?", "category": "车辆登记"},
{"id": "q2", "question": "新能源汽车有哪些补贴政策?", "category": "新能源"},
@@ -53,7 +53,7 @@ MOCK_QUICK_QUESTIONS: List[Dict[str, str]] = [
{"id": "q4", "question": "驾驶证过期了怎么处理?", "category": "驾驶证"},
]
# 预设检索结果
# Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_RETRIEVAL_RESULTS: List[Dict[str, Any]] = [
{
"id": "chunk-001",
@@ -97,7 +97,7 @@ MOCK_RETRIEVAL_RESULTS: List[Dict[str, Any]] = [
},
]
# 预设RAG问答答案模板按关键词匹配
# Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_RAG_ANSWERS: Dict[str, Dict[str, Any]] = {
"电动自行车": {
"text": "根据《道路交通安全法》及相关规范,电动自行车上路需满足以下条件:\n\n1. 符合国家标准 GB17761-2018\n2. 经公安机关交通管理部门登记\n3. 最高设计车速不超过 25km/h\n4. 整车质量不超过 55kg\n5. 具有脚踏骑行能力\n6. 蓄电池标称电压不超过 48V\n\n行驶时还需佩戴安全头盔,不得逆向行驶或在机动车道内行驶。",
@@ -133,7 +133,7 @@ MOCK_RAG_ANSWERS: Dict[str, Dict[str, Any]] = {
},
}
# 预设合规分析结果
# Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_COMPLIANCE_RESULT: Dict[str, Any] = {
"task_id": "task-001",
"dashboard": {
@@ -310,7 +310,7 @@ MOCK_COMPLIANCE_RESULT: Dict[str, Any] = {
],
}
# 预设合规对话响应模板
# Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_COMPLIANCE_CHAT_RESPONSES: Dict[str, Dict[str, str]] = {
"车身结构设计": {
"compliance": "根据当前分析,车身结构设计部分存在以下合规问题:\n\n1. GB 26112-2010要求车顶承受1.5倍整备质量载荷,目前设计声明满足要求但缺少测试数据\n2. C-NCAP正面碰撞后车门应能打开需提供碰撞测试报告\n\n建议补充相关测试数据以提升合规评分。",
@@ -329,7 +329,7 @@ MOCK_COMPLIANCE_CHAT_RESPONSES: Dict[str, Dict[str, str]] = {
},
}
# 预设系统统计数据
# Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_SYSTEM_STATS: Dict[str, int] = {
"docs": 5,
"chunks": 510,
@@ -337,7 +337,7 @@ MOCK_SYSTEM_STATS: Dict[str, int] = {
"segments": 0,
}
# 预设系统配置
# Keep service responsibilities explicit so downstream behavior stays predictable.
MOCK_SYSTEM_CONFIG: Dict[str, Any] = {
"llm": {
"model": "qwen-max",
@@ -358,17 +358,17 @@ MOCK_SYSTEM_CONFIG: Dict[str, Any] = {
def get_mock_documents() -> List[Dict[str, Any]]:
"""获取预设法规文档列表"""
"""Return mock documents."""
return MOCK_DOCUMENTS
def get_mock_quick_questions() -> List[Dict[str, str]]:
"""获取预设快捷问题"""
"""Return mock quick questions."""
return MOCK_QUICK_QUESTIONS
def get_mock_retrieval(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""根据查询关键词返回预设检索结果"""
"""Return mock retrieval."""
results = []
for keyword, data in MOCK_RAG_ANSWERS.items():
if keyword in query:
@@ -389,7 +389,7 @@ def get_mock_retrieval(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
def get_mock_rag_answer(query: str) -> str:
"""根据查询关键词返回预设答案"""
"""Return mock rag answer."""
for keyword, data in MOCK_RAG_ANSWERS.items():
if keyword in query:
return data["text"]
@@ -397,14 +397,14 @@ def get_mock_rag_answer(query: str) -> str:
def get_mock_compliance_result(task_id: str) -> Dict[str, Any]:
"""获取预设合规分析结果"""
"""Return mock compliance result."""
result = MOCK_COMPLIANCE_RESULT.copy()
result["task_id"] = task_id
return result
def get_mock_compliance_chat_response(intent: str, query: str) -> str:
"""获取预设合规对话响应"""
"""Return mock compliance chat response."""
responses = MOCK_COMPLIANCE_CHAT_RESPONSES.get(intent, {})
if "合规" in query or "符合" in query:
return responses.get("compliance", "根据相关法规分析,该段落的合规性需进一步评估。")
@@ -416,10 +416,10 @@ def get_mock_compliance_chat_response(intent: str, query: str) -> str:
def generate_task_id() -> str:
"""生成任务ID"""
"""Handle generate task id."""
return f"task-{uuid.uuid4().hex[:8]}"
def generate_doc_id() -> str:
"""生成文档ID"""
"""Handle generate doc id."""
return f"doc-{uuid.uuid4().hex[:8]}"

View File

@@ -1,6 +1,8 @@
"""文档解析服务"""
"""Initialize the app.services.parser package."""
from .pdf_parser import PDFParser
from .docx_parser import DocxParser
# Keep package boundaries explicit so backend imports stay predictable.
__all__ = ["PDFParser", "DocxParser"]

View File

@@ -1,4 +1,4 @@
"""Word文档解析 - 使用python-docx"""
"""Provide service-layer logic for docx parser."""
from docx import Document
from docx.enum.text import WD_ALIGN_PARAGRAPH
@@ -6,27 +6,29 @@ from typing import List, Dict, Optional
from dataclasses import dataclass, field
from loguru import logger
import re
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass
class DocxParagraph:
"""段落内容"""
"""Represent the Docx Paragraph type."""
text: str
level: int = 0 # 标题级别0表示正文
level: int = 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
is_list: bool = False
list_number: Optional[str] = None
@dataclass
class DocxTable:
"""表格内容"""
"""Represent the Docx Table type."""
rows: List[List[str]]
markdown: str = ""
@dataclass
class DocxDocumentContent:
"""Word文档完整内容"""
"""Represent the Docx Document Content type."""
file_path: str
paragraphs: List[DocxParagraph]
tables: List[DocxTable]
@@ -35,21 +37,14 @@ class DocxDocumentContent:
class DocxParser:
"""Word文档解析器 - 基于python-docx"""
"""Provide the Docx Parser parser."""
def __init__(self):
"""Initialize the Docx Parser instance."""
self.document = None
def parse(self, file_path: str) -> DocxDocumentContent:
"""
解析Word文档
Args:
file_path: Word文档路径
Returns:
DocxDocumentContent: 解析后的文档内容
"""
"""Handle parse for the Docx Parser instance."""
logger.info(f"开始解析Word文档: {file_path}")
try:
@@ -60,16 +55,16 @@ class DocxParser:
tables=[]
)
# 提取文档元数据
# Keep service responsibilities explicit so downstream behavior stays predictable.
doc_content.metadata = self._extract_metadata()
# 提取段落
# Keep service responsibilities explicit so downstream behavior stays predictable.
doc_content.paragraphs = self._extract_paragraphs()
# 提取表格
# Keep service responsibilities explicit so downstream behavior stays predictable.
doc_content.tables = self._extract_tables()
# 生成Markdown格式文本
# Keep service responsibilities explicit so downstream behavior stays predictable.
doc_content.markdown_text = self._generate_markdown(doc_content)
logger.success(f"Word文档解析完成{len(doc_content.paragraphs)}个段落")
@@ -81,7 +76,7 @@ class DocxParser:
raise
def _extract_metadata(self) -> Dict[str, str]:
"""提取文档元数据"""
"""Handle extract metadata for this module for the Docx Parser instance."""
metadata = {}
try:
core_props = self.document.core_properties
@@ -98,7 +93,7 @@ class DocxParser:
return metadata
def _extract_paragraphs(self) -> List[DocxParagraph]:
"""提取所有段落"""
"""Handle extract paragraphs for this module for the Docx Parser instance."""
paragraphs = []
for para in self.document.paragraphs:
@@ -106,10 +101,10 @@ class DocxParser:
if not text:
continue
# 判断标题级别
# Keep service responsibilities explicit so downstream behavior stays predictable.
level = self._get_paragraph_level(para)
# 判断是否是列表项
# Keep service responsibilities explicit so downstream behavior stays predictable.
is_list, list_number = self._detect_list_item(para)
paragraph = DocxParagraph(
@@ -123,66 +118,61 @@ class DocxParser:
return paragraphs
def _get_paragraph_level(self, para) -> int:
"""
判断段落标题级别
Returns:
int: 标题级别0表示正文
"""
# 方法1检查段落样式
"""Handle get paragraph level for this module for the Docx Parser instance."""
# Keep service responsibilities explicit so downstream behavior stays predictable.
style_name = para.style.name if para.style else ""
if "Heading" in style_name or "标题" in style_name:
# 从样式名称中提取级别
# Keep service responsibilities explicit so downstream behavior stays predictable.
match = re.search(r'Heading\s*(\d)|标题\s*(\d)', style_name)
if match:
level = int(match.group(1) or match.group(2))
return level
# 方法2检查段落格式字号
# 标题通常字号较大
# Keep service responsibilities explicit so downstream behavior stays predictable.
# Keep service responsibilities explicit so downstream behavior stays predictable.
if para.paragraph_format:
# 可以根据字号判断,这里简化处理
# Keep service responsibilities explicit so downstream behavior stays predictable.
pass
# 方法3根据内容模式判断法规文档特征
# Keep service responsibilities explicit so downstream behavior stays predictable.
text = para.text.strip()
# 第一章、第X章 -> 二级标题
# Keep service responsibilities explicit so downstream behavior stays predictable.
if re.match(r'^第[一二三四五六七八九十百]+章\s', text):
return 2
# 第X节 -> 三级标题
# Keep service responsibilities explicit so downstream behavior stays predictable.
elif re.match(r'^第[一二三四五六七八九十百]+节\s', text):
return 3
# 第X条 -> 四级标题
# Keep service responsibilities explicit so downstream behavior stays predictable.
elif re.match(r'^第[一二三四五六七八九十百]+条\s', text):
return 4
return 0 # 正文
return 0 # Keep service responsibilities explicit so downstream behavior stays predictable.
def _detect_list_item(self, para) -> tuple[bool, Optional[str]]:
"""检测是否是列表项"""
"""Handle detect list item for this module for the Docx Parser instance."""
text = para.text.strip()
# 数字列表1.、2.、1、[1]等
# Keep service responsibilities explicit so downstream behavior stays predictable.
if re.match(r'^[\d]+[.、)\]]\s', text):
match = re.match(r'^([\d]+[.、)\]])\s', text)
return True, match.group(1) if match else None
# 中文数字列表:一、二、(一)等
# Keep service responsibilities explicit so downstream behavior stays predictable.
if re.match(r'^[一二三四五六七八九十]+[、.)]\s', text):
match = re.match(r'^([一二三四五六七八九十]+[、.)])\s', text)
return True, match.group(1) if match else None
# 检查段落格式中的列表编号
# Keep service responsibilities explicit so downstream behavior stays predictable.
if para.paragraph_format and hasattr(para.paragraph_format, 'left_indent'):
# 有缩进的可能是列表项
# Keep service responsibilities explicit so downstream behavior stays predictable.
pass
return False, None
def _extract_tables(self) -> List[DocxTable]:
"""提取所有表格"""
"""Handle extract tables for this module for the Docx Parser instance."""
tables = []
for table in self.document.tables:
@@ -193,7 +183,7 @@ class DocxParser:
cells.append(cell.text.strip())
rows.append(cells)
# 转换为Markdown表格
# Keep service responsibilities explicit so downstream behavior stays predictable.
markdown = self._table_to_markdown(rows)
table_content = DocxTable(rows=rows, markdown=markdown)
@@ -202,34 +192,34 @@ class DocxParser:
return tables
def _table_to_markdown(self, rows: List[List[str]]) -> str:
"""将表格转换为Markdown格式"""
"""Handle table to markdown for this module for the Docx Parser instance."""
if not rows or len(rows) < 1:
return ""
lines = []
# 表头
# Keep service responsibilities explicit so downstream behavior stays predictable.
if len(rows) >= 1:
header = rows[0]
lines.append("| " + " | ".join(cell for cell in header) + " |")
lines.append("| " + " | ".join("---" for _ in header) + " |")
# 数据行
# Keep service responsibilities explicit so downstream behavior stays predictable.
for row in rows[1:]:
lines.append("| " + " | ".join(cell for cell in row) + " |")
return "\n".join(lines)
def _generate_markdown(self, doc_content: DocxDocumentContent) -> str:
"""生成Markdown格式文本"""
"""Handle generate markdown for this module for the Docx Parser instance."""
lines = []
# 文档标题
# Keep service responsibilities explicit so downstream behavior stays predictable.
title = doc_content.metadata.get("title", "")
if title:
lines.append(f"# {title}\n")
else:
# 从第一个段落获取标题(如果是标题样式)
# Keep service responsibilities explicit so downstream behavior stays predictable.
for para in doc_content.paragraphs[:5]:
if para.level == 1:
lines.append(f"# {para.text}\n")
@@ -237,29 +227,29 @@ class DocxParser:
else:
lines.append(f"# {doc_content.file_path}\n")
# 元数据信息
# Keep service responsibilities explicit so downstream behavior stays predictable.
lines.append("\n## 文档信息\n")
for key, value in doc_content.metadata.items():
if value:
lines.append(f"- **{key}**: {value}")
# 正文内容
# Keep service responsibilities explicit so downstream behavior stays predictable.
lines.append("\n## 正文\n")
table_index = 0
for para in doc_content.paragraphs:
if para.level > 0:
# 标题
# Keep service responsibilities explicit so downstream behavior stays predictable.
prefix = "#" * para.level
lines.append(f"\n{prefix} {para.text}\n")
elif para.is_list:
# 列表项
# Keep service responsibilities explicit so downstream behavior stays predictable.
lines.append(f"- {para.text}")
else:
# 正文
# Keep service responsibilities explicit so downstream behavior stays predictable.
lines.append(para.text)
# 添加表格
# Keep service responsibilities explicit so downstream behavior stays predictable.
if doc_content.tables:
lines.append("\n## 表格\n")
for i, table in enumerate(doc_content.tables):
@@ -269,18 +259,18 @@ class DocxParser:
return "\n".join(lines)
def parse_to_markdown(self, file_path: str) -> str:
"""直接解析并返回Markdown文本"""
"""Parse to markdown for the Docx Parser instance."""
doc_content = self.parse(file_path)
return doc_content.markdown_text
def parse_docx(file_path: str) -> DocxDocumentContent:
"""便捷函数解析Word文档"""
"""Parse docx."""
parser = DocxParser()
return parser.parse(file_path)
def parse_docx_to_markdown(file_path: str) -> str:
"""便捷函数解析Word并返回Markdown"""
"""Parse docx to markdown."""
parser = DocxParser()
return parser.parse_to_markdown(file_path)

View File

@@ -1,14 +1,16 @@
"""MinerU多模态PDF解析 - 版面感知解析"""
"""Provide service-layer logic for mineru parser."""
from typing import Optional, Dict
from dataclasses import dataclass, field
from loguru import logger
import os
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass
class MinerUResult:
"""MinerU解析结果"""
"""Represent the Miner U Result type."""
file_path: str
markdown_text: str
metadata: Dict[str, str] = field(default_factory=dict)
@@ -17,21 +19,14 @@ class MinerUResult:
class MinerUParser:
"""
MinerU多模态PDF解析器
MinerU (magic-pdf) 是一个开源的高质量PDF解析工具
支持版面感知解析,能够识别文档中的标题、正文、表格、图片等元素,
并输出结构化的Markdown格式。
GitHub: https://github.com/opendatalab/MinerU
"""
"""Provide the Miner U Parser parser."""
def __init__(self):
"""Initialize the Miner U Parser instance."""
self.available = self._check_mineru_available()
def _check_mineru_available(self) -> bool:
"""检查MinerU是否可用"""
"""Handle check mineru available for this module for the Miner U Parser instance."""
try:
from magic_pdf.pipe.UNIPipe import UNIPipe
return True
@@ -40,16 +35,7 @@ class MinerUParser:
return False
def parse(self, file_path: str, output_dir: Optional[str] = None) -> MinerUResult:
"""
使用MinerU解析PDF文档
Args:
file_path: PDF文件路径
output_dir: 输出目录(可选,用于保存解析产物)
Returns:
MinerUResult: 解析结果
"""
"""Handle parse for the Miner U Parser instance."""
logger.info(f"尝试使用MinerU解析: {file_path}")
if not self.available:
@@ -64,19 +50,19 @@ class MinerUParser:
from magic_pdf.pipe.UNIPipe import UNIPipe
from magic_pdf.libs.MakeContentConfig import DropMode
# 设置输出目录
# Keep service responsibilities explicit so downstream behavior stays predictable.
if output_dir is None:
output_dir = os.path.dirname(file_path)
# 创建解析管道
# OCR模式可以根据PDF类型选择
# auto: 自动判断是否需要OCR
# txt: 纯文本PDF无OCR
# ocr: 扫描件PDFOCR
# Keep service responsibilities explicit so downstream behavior stays predictable.
# Keep service responsibilities explicit so downstream behavior stays predictable.
# Keep service responsibilities explicit so downstream behavior stays predictable.
# Keep service responsibilities explicit so downstream behavior stays predictable.
# Keep service responsibilities explicit so downstream behavior stays predictable.
pipe = UNIPipe(file_path, output_dir)
# 执行解析
# pipe_mk() 返回Markdown格式文本
# Keep service responsibilities explicit so downstream behavior stays predictable.
# Keep service responsibilities explicit so downstream behavior stays predictable.
markdown_content = pipe.pipe_mk()
logger.success(f"MinerU解析成功")
@@ -98,13 +84,13 @@ class MinerUParser:
)
def _extract_metadata(self, pipe) -> Dict[str, str]:
"""从解析管道提取元数据"""
"""Handle extract metadata for this module for the Miner U Parser instance."""
metadata = {}
try:
# MinerU解析管道中可能包含的元数据信息
# Keep service responsibilities explicit so downstream behavior stays predictable.
if hasattr(pipe, 'pdf_mid_data') and pipe.pdf_mid_data:
mid_data = pipe.pdf_mid_data
# 提取可能的元数据字段
# Keep service responsibilities explicit so downstream behavior stays predictable.
metadata = {
"page_count": str(mid_data.get("page_count", "")),
"language": str(mid_data.get("language", "")),
@@ -116,41 +102,27 @@ class MinerUParser:
return metadata
def parse_to_markdown(self, file_path: str) -> str:
"""直接解析并返回Markdown文本"""
"""Parse to markdown for the Miner U Parser instance."""
result = self.parse(file_path)
return result.markdown_text if result.success else ""
class ParserOrchestrator:
"""
解析服务编排 - 按优先级选择解析器
解析策略:
1. 优先尝试MinerU版面感知能力强
2. MinerU失败时回退到基础PyMuPDF解析
"""
"""Represent the Parser Orchestrator type."""
def __init__(self):
"""Initialize the Parser Orchestrator instance."""
from .pdf_parser import PDFParser
self.mineru_parser = MinerUParser()
self.pdf_parser = PDFParser()
self.mineru_available = self.mineru_parser.available
def parse_pdf(self, file_path: str, prefer_mineru: bool = True) -> str:
"""
解析PDF文档按优先级选择解析器
Args:
file_path: PDF文件路径
prefer_mineru: 是否优先使用MinerU
Returns:
str: Markdown格式文本
"""
"""Parse pdf for the Parser Orchestrator instance."""
markdown_text = ""
if prefer_mineru and self.mineru_available:
# 优先尝试MinerU
# Keep service responsibilities explicit so downstream behavior stays predictable.
result = self.mineru_parser.parse(file_path)
if result.success:
markdown_text = result.markdown_text
@@ -159,28 +131,20 @@ class ParserOrchestrator:
else:
logger.warning(f"MinerU解析失败回退到PyMuPDF: {result.error_message}")
# 回退到PyMuPDF基础解析
# Keep service responsibilities explicit so downstream behavior stays predictable.
logger.info("使用PyMuPDF基础解析")
markdown_text = self.pdf_parser.parse_to_markdown(file_path)
return markdown_text
def parse_docx(self, file_path: str) -> str:
"""解析Word文档"""
"""Parse docx for the Parser Orchestrator instance."""
from .docx_parser import DocxParser
docx_parser = DocxParser()
return docx_parser.parse_to_markdown(file_path)
def parse(self, file_path: str) -> str:
"""
根据文件类型选择解析器
Args:
file_path: 文件路径
Returns:
str: Markdown格式文本
"""
"""Handle parse for the Parser Orchestrator instance."""
ext = os.path.splitext(file_path)[1].lower()
if ext == ".pdf":
@@ -192,12 +156,12 @@ class ParserOrchestrator:
def parse_with_mineru(file_path: str) -> MinerUResult:
"""便捷函数使用MinerU解析"""
"""Parse with mineru."""
parser = MinerUParser()
return parser.parse(file_path)
def parse_pdf_smart(file_path: str) -> str:
"""便捷函数智能解析PDF自动选择最佳解析器"""
"""Parse pdf smart."""
orchestrator = ParserOrchestrator()
return orchestrator.parse_pdf(file_path)

View File

@@ -1,4 +1,4 @@
"""PDF文档解析 - 使用PyMuPDF基础解析"""
"""Provide service-layer logic for pdf parser."""
import fitz # PyMuPDF
from typing import List, Dict, Optional, Tuple
@@ -9,17 +9,17 @@ import re
@dataclass
class PDFPageContent:
"""PDF页面内容"""
"""Represent the P D F Page Content type."""
page_number: int
text: str
tables: List[str] = field(default_factory=list)
images: List[str] = field(default_factory=list) # 图片路径列表
images: List[str] = field(default_factory=list) # Keep service responsibilities explicit so downstream behavior stays predictable.
blocks: List[Dict] = field(default_factory=list)
@dataclass
class PDFDocumentContent:
"""PDF文档完整内容"""
"""Represent the P D F Document Content type."""
file_path: str
total_pages: int
pages: List[PDFPageContent]
@@ -28,23 +28,14 @@ class PDFDocumentContent:
class PDFParser:
"""PDF文档解析器 - 基于PyMuPDF"""
"""Provide the P D F Parser parser."""
def __init__(self):
"""Initialize the P D F Parser instance."""
self.pdf = None
def parse(self, file_path: str, extract_tables: bool = True, extract_images: bool = False) -> PDFDocumentContent:
"""
解析PDF文档
Args:
file_path: PDF文件路径
extract_tables: 是否提取表格
extract_images: 是否提取图片
Returns:
PDFDocumentContent: 解析后的文档内容
"""
"""Handle parse for the P D F Parser instance."""
logger.info(f"开始解析PDF文档: {file_path}")
try:
@@ -55,16 +46,16 @@ class PDFParser:
pages=[]
)
# 提取文档元数据
# Keep service responsibilities explicit so downstream behavior stays predictable.
doc_content.metadata = self._extract_metadata()
# 逐页解析
# Keep service responsibilities explicit so downstream behavior stays predictable.
for page_num in range(self.pdf.page_count):
page = self.pdf[page_num]
page_content = self._parse_page(page, page_num + 1, extract_tables, extract_images)
doc_content.pages.append(page_content)
# 生成Markdown格式文本
# Keep service responsibilities explicit so downstream behavior stays predictable.
doc_content.markdown_text = self._generate_markdown(doc_content)
self.pdf.close()
@@ -77,7 +68,7 @@ class PDFParser:
raise
def _extract_metadata(self) -> Dict[str, str]:
"""提取PDF元数据"""
"""Handle extract metadata for this module for the P D F Parser instance."""
metadata = {}
try:
meta = self.pdf.metadata
@@ -97,23 +88,23 @@ class PDFParser:
def _parse_page(self, page: fitz.Page, page_num: int,
extract_tables: bool, extract_images: bool) -> PDFPageContent:
"""解析单页内容"""
"""Handle parse page for this module for the P D F Parser instance."""
page_content = PDFPageContent(page_number=page_num, text="")
# 提取文本块(保留结构)
# Keep service responsibilities explicit so downstream behavior stays predictable.
blocks = page.get_text("dict", flags=fitz.TEXT_PRESERVE_WHITESPACE)["blocks"]
page_content.blocks = blocks
# 提取纯文本
# Keep service responsibilities explicit so downstream behavior stays predictable.
text = page.get_text("text", flags=fitz.TEXT_PRESERVE_WHITESPACE)
page_content.text = text.strip()
# 提取表格使用PyMuPDF的表格提取功能
# Keep service responsibilities explicit so downstream behavior stays predictable.
if extract_tables:
tables = self._extract_tables_from_page(page)
page_content.tables = tables
# 提取图片
# Keep service responsibilities explicit so downstream behavior stays predictable.
if extract_images:
images = self._extract_images_from_page(page, page_num)
page_content.images = images
@@ -121,25 +112,22 @@ class PDFParser:
return page_content
def _extract_tables_from_page(self, page: fitz.Page) -> List[str]:
"""
从页面提取表格(基于文本块分析)
注意PyMuPDF基础版表格提取能力有限复杂表格建议使用MinerU
"""
"""Handle extract tables from page for this module for the P D F Parser instance."""
tables = []
try:
# 使用PyMuPDF的表格提取方法2.4+版本)
# 对于更复杂的表格需要在mineru_parser中使用更高级的方法
# Keep service responsibilities explicit so downstream behavior stays predictable.
# Keep service responsibilities explicit so downstream behavior stays predictable.
tabs = page.find_tables()
if tabs:
for tab in tabs:
table_text = tab.extract()
# 将表格转换为Markdown格式
# Keep service responsibilities explicit so downstream behavior stays predictable.
markdown_table = self._table_to_markdown(table_text)
tables.append(markdown_table)
except AttributeError:
# 旧版本PyMuPDF没有表格提取功能
# Keep service responsibilities explicit so downstream behavior stays predictable.
logger.warning("PyMuPDF版本不支持表格提取请升级到2.4+版本")
except Exception as e:
logger.warning(f"表格提取失败: {e}")
@@ -147,28 +135,28 @@ class PDFParser:
return tables
def _table_to_markdown(self, table_data: List[List[str]]) -> str:
"""将表格数据转换为Markdown格式"""
"""Handle table to markdown for this module for the P D F Parser instance."""
if not table_data or len(table_data) < 1:
return ""
lines = []
# 表头
# Keep service responsibilities explicit so downstream behavior stays predictable.
if len(table_data) >= 1:
header = table_data[0]
lines.append("| " + " | ".join(str(cell).strip() for cell in header) + " |")
lines.append("| " + " | ".join("---" for _ in header) + " |")
# 数据行
# Keep service responsibilities explicit so downstream behavior stays predictable.
for row in table_data[1:]:
lines.append("| " + " | ".join(str(cell).strip() for cell in row) + " |")
return "\n".join(lines)
def _extract_images_from_page(self, page: fitz.Page, page_num: int) -> List[str]:
"""提取页面图片"""
"""Handle extract images from page for this module for the P D F Parser instance."""
images = []
# 图片提取功能(可选实现)
# 这里仅记录图片信息,实际图片需要额外保存
# Keep service responsibilities explicit so downstream behavior stays predictable.
# Keep service responsibilities explicit so downstream behavior stays predictable.
try:
image_list = page.get_images()
for img_index, img in enumerate(image_list):
@@ -179,52 +167,52 @@ class PDFParser:
return images
def _generate_markdown(self, doc_content: PDFDocumentContent) -> str:
"""生成Markdown格式文本"""
"""Handle generate markdown for this module for the P D F Parser instance."""
lines = []
# 文档标题
# Keep service responsibilities explicit so downstream behavior stays predictable.
title = doc_content.metadata.get("title", "")
if title:
lines.append(f"# {title}\n")
else:
lines.append(f"# {doc_content.file_path}\n")
# 元数据信息
# Keep service responsibilities explicit so downstream behavior stays predictable.
lines.append("\n## 文档信息\n")
for key, value in doc_content.metadata.items():
if value and key in ["author", "subject", "keywords", "creation_date"]:
lines.append(f"- **{key}**: {value}")
# 正文内容
# Keep service responsibilities explicit so downstream behavior stays predictable.
lines.append("\n## 正文\n")
for page in doc_content.pages:
# 页码标记
# Keep service responsibilities explicit so downstream behavior stays predictable.
lines.append(f"\n---\n**第 {page.page_number} 页**\n")
# 处理文本内容,识别标题结构
# Keep service responsibilities explicit so downstream behavior stays predictable.
text = self._process_page_text(page.text, page.blocks)
lines.append(text)
# 添加表格
# Keep service responsibilities explicit so downstream behavior stays predictable.
for table in page.tables:
lines.append("\n" + table + "\n")
return "\n".join(lines)
def _process_page_text(self, text: str, blocks: List[Dict]) -> str:
"""处理页面文本,识别标题结构"""
# 基于字体大小识别标题
"""Handle process page text for this module for the P D F Parser instance."""
# Keep service responsibilities explicit so downstream behavior stays predictable.
processed_text = text
# 尝试识别标题(基于字号)
# 法规文档通常有明确的层级结构:章、节、条
# Keep service responsibilities explicit so downstream behavior stays predictable.
# Keep service responsibilities explicit so downstream behavior stays predictable.
processed_text = self._detect_headers(text, blocks)
return processed_text
def _detect_headers(self, text: str, blocks: List[Dict]) -> str:
"""检测并标记标题(基于字号或内容模式)"""
"""Handle detect headers for this module for the P D F Parser instance."""
lines = text.split("\n")
processed_lines = []
@@ -233,8 +221,8 @@ class PDFParser:
if not line:
continue
# 法规标题模式检测
# 第一章、第X章、第X节、第X条等
# Keep service responsibilities explicit so downstream behavior stays predictable.
# Keep service responsibilities explicit so downstream behavior stays predictable.
if re.match(r'^第[一二三四五六七八九十百]+章\s', line):
processed_lines.append(f"\n## {line}\n")
elif re.match(r'^第[一二三四五六七八九十百]+节\s', line):
@@ -242,7 +230,7 @@ class PDFParser:
elif re.match(r'^第[一二三四五六七八九十百]+条\s', line):
processed_lines.append(f"\n#### {line}\n")
elif re.match(r'^[一二三四五六七八九十]+\s*[、.]', line):
# 条款子项
# Keep service responsibilities explicit so downstream behavior stays predictable.
processed_lines.append(f"- {line}")
else:
processed_lines.append(line)
@@ -250,18 +238,18 @@ class PDFParser:
return "\n".join(processed_lines)
def parse_to_markdown(self, file_path: str) -> str:
"""直接解析并返回Markdown文本"""
"""Parse to markdown for the P D F Parser instance."""
doc_content = self.parse(file_path)
return doc_content.markdown_text
def parse_pdf(file_path: str, **kwargs) -> PDFDocumentContent:
"""便捷函数解析PDF文档"""
"""Parse pdf."""
parser = PDFParser()
return parser.parse(file_path, **kwargs)
def parse_pdf_to_markdown(file_path: str) -> str:
"""便捷函数解析PDF并返回Markdown"""
"""Parse pdf to markdown."""
parser = PDFParser()
return parser.parse_to_markdown(file_path)

View File

@@ -1,11 +1,29 @@
"""RAG服务模块"""
"""Initialize the app.services.rag package."""
# Keep package boundaries explicit so backend imports stay predictable.
from .retriever import Retriever, retrieve_regulations
from .context_builder import ContextBuilder, build_rag_context
from .prompt_templates import PromptTemplates, get_prompt_template
__all__ = [
"Retriever", "retrieve_regulations",
"ContextBuilder", "build_rag_context",
"PromptTemplates", "get_prompt_template"
"Retriever",
"retrieve_regulations",
"ContextBuilder",
"build_rag_context",
"PromptTemplates",
"get_prompt_template",
]
def __getattr__(name: str):
"""Handle getattr for this module."""
if name in {"Retriever", "retrieve_regulations"}:
from .retriever import Retriever, retrieve_regulations
return {"Retriever": Retriever, "retrieve_regulations": retrieve_regulations}[name]
if name in {"ContextBuilder", "build_rag_context"}:
from .context_builder import ContextBuilder, build_rag_context
return {"ContextBuilder": ContextBuilder, "build_rag_context": build_rag_context}[name]
if name in {"PromptTemplates", "get_prompt_template"}:
from .prompt_templates import PromptTemplates, get_prompt_template
return {"PromptTemplates": PromptTemplates, "get_prompt_template": get_prompt_template}[name]
raise AttributeError(name)

View File

@@ -1,4 +1,4 @@
"""RAG上下文构建服务 - 构建LLM输入上下文"""
"""Provide service-layer logic for context builder."""
from typing import List, Dict, Optional
from dataclasses import dataclass
@@ -6,11 +6,13 @@ from loguru import logger
from .retriever import RetrievedDocument
from app.config.settings import settings
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass
class RAGContext:
"""RAG构建的上下文"""
"""Represent the R A G Context type."""
system_prompt: str
context_text: str
user_query: str
@@ -20,14 +22,7 @@ class RAGContext:
class ContextBuilder:
"""
RAG上下文构建器
功能:
- 格式化检索结果为上下文文本
- 控制上下文长度token限制
- 构建完整的LLM输入格式
"""
"""Provide the Context Builder builder."""
def __init__(
self,
@@ -35,14 +30,7 @@ class ContextBuilder:
include_metadata: bool = True,
citation_format: str = "【条款{clause}"
):
"""
初始化上下文构建器
Args:
max_context_tokens: 最大上下文token数
include_metadata: 是否包含元数据(文档名、条款号等)
citation_format: 引用格式模板
"""
"""Initialize the Context Builder instance."""
self.max_context_tokens = max_context_tokens or settings.rag_max_context_tokens
self.include_metadata = include_metadata
self.citation_format = citation_format
@@ -56,30 +44,19 @@ class ContextBuilder:
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None
) -> RAGContext:
"""
构建RAG上下文
Args:
query: 用户查询
documents: 检索到的文档列表
system_prompt: 系统提示词(可选)
max_tokens: 最大token数可选覆盖默认值
Returns:
RAGContext: 构建的上下文对象
"""
"""Handle build for the Context Builder instance."""
max_tokens = max_tokens or self.max_context_tokens
# 格式化文档内容
# Keep service responsibilities explicit so downstream behavior stays predictable.
context_text, sources, truncated = self._format_documents(
documents,
max_tokens
)
# 构建系统提示词
# Keep service responsibilities explicit so downstream behavior stays predictable.
system_prompt = system_prompt or self._default_system_prompt()
# 估算总token数
# Keep service responsibilities explicit so downstream behavior stays predictable.
total_tokens = self._estimate_tokens(system_prompt + context_text + query)
logger.info(f"上下文构建完成: {len(documents)}条文档, {total_tokens}tokens, truncated={truncated}")
@@ -98,29 +75,20 @@ class ContextBuilder:
documents: List[RetrievedDocument],
max_tokens: int
) -> tuple:
"""
格式化文档内容
Args:
documents: 文档列表
max_tokens: 最大token数
Returns:
(context_text, sources, truncated)
"""
"""Handle format documents for this module for the Context Builder instance."""
context_parts = []
sources = []
current_tokens = 0
truncated = False
for i, doc in enumerate(documents):
# 格式化单个文档
# Keep service responsibilities explicit so downstream behavior stays predictable.
formatted = self._format_single_doc(doc, i + 1)
# 估算token数
# Keep service responsibilities explicit so downstream behavior stays predictable.
doc_tokens = self._estimate_tokens(formatted)
# 检查是否超出限制
# Keep service responsibilities explicit so downstream behavior stays predictable.
if current_tokens + doc_tokens > max_tokens:
truncated = True
logger.warning(f"上下文截断: 已达到{max_tokens}tokens限制")
@@ -129,7 +97,7 @@ class ContextBuilder:
context_parts.append(formatted)
current_tokens += doc_tokens
# 记录来源
# Keep service responsibilities explicit so downstream behavior stays predictable.
sources.append({
"index": i + 1,
"doc_id": doc.doc_id,
@@ -148,13 +116,13 @@ class ContextBuilder:
doc: RetrievedDocument,
index: int
) -> str:
"""格式化单个文档"""
"""Handle format single doc for this module for the Context Builder instance."""
parts = []
# 索引编号
# Keep service responsibilities explicit so downstream behavior stays predictable.
parts.append(f"[{index}]")
# 元数据(可选)
# Keep service responsibilities explicit so downstream behavior stays predictable.
if self.include_metadata:
meta_parts = []
@@ -171,13 +139,13 @@ class ContextBuilder:
if meta_parts:
parts.append(" | ".join(meta_parts))
# 内容
# Keep service responsibilities explicit so downstream behavior stays predictable.
parts.append(doc.content)
return "\n".join(parts)
def _default_system_prompt(self) -> str:
"""默认系统提示词"""
"""Handle default system prompt for this module for the Context Builder instance."""
return """你是合规专家助手,基于提供的法规条款回答问题。
回答要求:
@@ -192,8 +160,8 @@ class ContextBuilder:
- 最后给出合规建议"""
def _estimate_tokens(self, text: str) -> int:
"""估算文本token数"""
# 中文字符约1.5 token英文约0.25 token
"""Handle estimate tokens for this module for the Context Builder instance."""
# Keep service responsibilities explicit so downstream behavior stays predictable.
chinese_chars = sum(1 for c in text if '' <= c <= '鿿')
other_chars = len(text) - chinese_chars
return int(chinese_chars * 1.5 + other_chars * 0.25)
@@ -202,15 +170,7 @@ class ContextBuilder:
self,
context: RAGContext
) -> List[Dict[str, str]]:
"""
构建LLM消息格式
Args:
context: RAG上下文对象
Returns:
List[Dict]: [{"role": "system/user/assistant", "content": "..."}]
"""
"""Build messages for the Context Builder instance."""
messages = [
{"role": "system", "content": context.system_prompt},
{"role": "user", "content": f"参考以下法规条款回答问题。\n\n{context.context_text}\n\n问题:{context.user_query}"}
@@ -224,6 +184,6 @@ def build_rag_context(
documents: List[RetrievedDocument],
**kwargs
) -> RAGContext:
"""便捷函数构建RAG上下文"""
"""Build rag context."""
builder = ContextBuilder()
return builder.build(query, documents, **kwargs)

View File

@@ -1,12 +1,14 @@
"""RAG Prompt模板 - 合规问答专用Prompt"""
"""Provide service-layer logic for prompt templates."""
from typing import Dict, Optional
from dataclasses import dataclass
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass
class PromptTemplate:
"""Prompt模板"""
"""Represent the Prompt Template type."""
name: str
system_prompt: str
user_template: str
@@ -14,18 +16,9 @@ class PromptTemplate:
class PromptTemplates:
"""
合规问答Prompt模板库
"""Represent the Prompt Templates type."""
包含多种场景的Prompt模板
- 合规问答(标准)
- 条款解读(详细解释)
- 合规检查(判断合规状态)
- 差异对比(新旧法规对比)
- 报告生成(合规报告)
"""
# 合规问答标准模板
# Keep service responsibilities explicit so downstream behavior stays predictable.
COMPLIANCE_QA = PromptTemplate(
name="compliance_qa",
system_prompt="""你是合规专家助手,专门解答法规合规问题。
@@ -63,7 +56,7 @@ class PromptTemplates:
description="标准合规问答模板"
)
# 条款解读模板(详细解释)
# Keep service responsibilities explicit so downstream behavior stays predictable.
CLAUSE_INTERPRETATION = PromptTemplate(
name="clause_interpretation",
system_prompt="""你是法规解读专家,负责详细解释法规条款的含义和应用。
@@ -96,7 +89,7 @@ class PromptTemplates:
description="条款详细解读模板"
)
# 合规检查模板(判断合规状态)
# Keep service responsibilities explicit so downstream behavior stays predictable.
COMPLIANCE_CHECK = PromptTemplate(
name="compliance_check",
system_prompt="""你是合规检查专家,负责评估企业行为或产品的合规状态。
@@ -140,7 +133,7 @@ class PromptTemplates:
description="合规检查评估模板"
)
# 差异对比模板(新旧法规对比)
# Keep service responsibilities explicit so downstream behavior stays predictable.
COMPARISON = PromptTemplate(
name="comparison",
system_prompt="""你是法规变更分析专家,负责对比新旧法规版本的差异。
@@ -192,7 +185,7 @@ class PromptTemplates:
description="法规版本对比模板"
)
# 报告生成模板
# Keep service responsibilities explicit so downstream behavior stays predictable.
REPORT_GENERATION = PromptTemplate(
name="report_generation",
system_prompt="""你是合规报告撰写专家,负责生成结构化的合规分析报告。
@@ -222,7 +215,7 @@ class PromptTemplates:
description="合规报告生成模板"
)
# 文档摘要生成模板
# Keep service responsibilities explicit so downstream behavior stays predictable.
DOCUMENT_SUMMARY = PromptTemplate(
name="document_summary",
system_prompt="""你是法规文档摘要专家,负责生成法规文档的核心要点摘要。
@@ -263,7 +256,7 @@ class PromptTemplates:
@classmethod
def get_template(cls, name: str) -> Optional[PromptTemplate]:
"""获取指定模板"""
"""Return template for the Prompt Templates instance."""
templates = {
"compliance_qa": cls.COMPLIANCE_QA,
"clause_interpretation": cls.CLAUSE_INTERPRETATION,
@@ -276,7 +269,7 @@ class PromptTemplates:
@classmethod
def list_templates(cls) -> Dict[str, str]:
"""列出所有模板"""
"""List templates for the Prompt Templates instance."""
return {
"compliance_qa": cls.COMPLIANCE_QA.description,
"clause_interpretation": cls.CLAUSE_INTERPRETATION.description,
@@ -288,7 +281,7 @@ class PromptTemplates:
def get_prompt_template(name: str) -> PromptTemplate:
"""便捷函数获取Prompt模板"""
"""Return prompt template."""
template = PromptTemplates.get_template(name)
if not template:
raise ValueError(f"不存在的模板: {name}")

View File

@@ -1,192 +1,82 @@
"""RAG检索服务 - 封装Milvus检索"""
"""Provide service-layer logic for retriever."""
from __future__ import annotations
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, field
from loguru import logger
from typing import Any, Optional
from app.shared.bootstrap import get_retrieval_service
# Keep service responsibilities explicit so downstream behavior stays predictable.
from app.services.embedding.bge_m3_embedder import BGEM3Embedder
from app.services.storage.milvus_client import MilvusClient, SearchResult
from app.config.settings import settings
@dataclass
class RetrievedDocument:
"""检索到的文档"""
"""Represent the Retrieved Document type."""
content: str
doc_id: str # 文档ID用于下载
doc_id: str
doc_name: str
section_title: str
clause_number: str
page_number: int
score: float
metadata: Dict[str, Any] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)
class Retriever:
"""
RAG检索器
功能:
- 向量检索Dense + Sparse混合
- 重排序(可选)
- 过滤和筛选
"""
def __init__(
self,
top_k: int = None,
rerank: bool = False,
min_score: float = 0.3
):
"""
初始化检索器
Args:
top_k: 检索召回数量
rerank: 是否启用重排序
min_score: 最低相关性分数阈值
"""
self.top_k = top_k or settings.rag_top_k
"""Provide the Retriever retriever."""
def __init__(self, top_k: int = 5, rerank: bool = False, min_score: float = 0.0):
"""Initialize the Retriever instance."""
self.top_k = top_k
self.rerank = rerank
self.min_score = min_score
# 嵌入模型(延迟加载)
self.embedder: Optional[BGEM3Embedder] = None
# Milvus客户端延迟连接
self.milvus: Optional[MilvusClient] = None
logger.info(f"检索器初始化: top_k={self.top_k}, rerank={self.rerank}")
def _init_embedder(self):
"""延迟初始化嵌入模型"""
if self.embedder is None:
logger.info("加载嵌入模型...")
self.embedder = BGEM3Embedder(model_name=settings.embedding_model)
def _init_milvus(self):
"""延迟初始化Milvus"""
if self.milvus is None:
logger.info("连接Milvus...")
self.milvus = MilvusClient()
self.milvus.connect()
self.milvus.create_collection(recreate=False)
self.milvus.load_collection()
def retrieve(
self,
query: str,
filters: Optional[str] = None,
top_k: Optional[int] = None
) -> List[RetrievedDocument]:
"""
检索相关文档
Args:
query: 查询文本
filters: 过滤条件(如 "regulation_type=='车辆安全'"
top_k: 返回数量(可选,覆盖默认值)
Returns:
List[RetrievedDocument]: 检索结果列表
"""
logger.info(f"执行检索: {query}")
# 初始化组件
self._init_embedder()
self._init_milvus()
# 生成查询向量
query_embedding = self.embedder.embed_single(query)
# 执行混合检索
results = self.milvus.hybrid_search(
query_dense=query_embedding['dense'].tolist(),
query_sparse=query_embedding['sparse'],
top_k=top_k or self.top_k,
filters=filters
)
# 转换为RetrievedDocument格式
documents = []
for r in results:
if r.score >= self.min_score:
doc = RetrievedDocument(
content=r.content,
doc_id=r.metadata.get("doc_id", ""),
doc_name=r.metadata.get("doc_name", ""),
section_title=r.metadata.get("section_title", ""),
clause_number=r.metadata.get("clause_number", ""),
page_number=r.metadata.get("page_number", 0),
score=r.score,
metadata=r.metadata
)
documents.append(doc)
logger.success(f"检索完成,返回{len(documents)}条结果(阈值过滤后)")
return documents
def retrieve_with_scores(
self,
query: str,
filters: Optional[str] = None
) -> List[Dict]:
"""
检索并返回完整结果(包含分数)
Args:
query: 查询文本
filters: 过滤条件
Returns:
List[Dict]: 包含分数的检索结果
"""
documents = self.retrieve(query, filters)
def retrieve(self, query: str, filters: Optional[str] = None, top_k: Optional[int] = None) -> list[RetrievedDocument]:
"""Handle retrieve for the Retriever instance."""
results = get_retrieval_service().retrieve(query=query, top_k=top_k or self.top_k, filters=filters)
return [
{
"content": doc.content,
"doc_id": doc.doc_id,
"doc_name": doc.doc_name,
"section_title": doc.section_title,
"clause_number": doc.clause_number,
"page_number": doc.page_number,
"score": doc.score
}
for doc in documents
RetrievedDocument(
content=item.content,
doc_id=item.doc_id,
doc_name=item.doc_name,
section_title=item.section_title,
clause_number=item.metadata.get("clause_number", ""),
page_number=item.page_number,
score=item.score,
metadata=item.metadata,
)
for item in results
if item.score >= self.min_score
]
def search_by_doc_name(
self,
query: str,
doc_name: str
) -> List[RetrievedDocument]:
"""按文档名称过滤检索"""
filters = f'doc_name=="{doc_name}"'
return self.retrieve(query, filters)
def retrieve_with_scores(self, query: str, filters: Optional[str] = None) -> list[dict]:
"""Handle retrieve with scores for the Retriever instance."""
return [
{
"content": item.content,
"doc_id": item.doc_id,
"doc_name": item.doc_name,
"section_title": item.section_title,
"clause_number": item.clause_number,
"page_number": item.page_number,
"score": item.score,
}
for item in self.retrieve(query, filters)
]
def search_by_regulation_type(
self,
query: str,
regulation_type: str
) -> List[RetrievedDocument]:
"""按法规类型过滤检索"""
filters = f'regulation_type=="{regulation_type}"'
return self.retrieve(query, filters)
def search_by_doc_name(self, query: str, doc_name: str) -> list[RetrievedDocument]:
"""Search by doc name for the Retriever instance."""
return self.retrieve(query, filters=f'doc_name == "{doc_name}"')
def search_by_regulation_type(self, query: str, regulation_type: str) -> list[RetrievedDocument]:
"""Search by regulation type for the Retriever instance."""
return self.retrieve(query, filters=f'regulation_type == "{regulation_type}"')
def close(self):
"""关闭连接"""
if self.milvus:
self.milvus.disconnect()
logger.info("检索器已关闭")
"""Release the resources held by this component."""
return None
def retrieve_regulations(
query: str,
top_k: int = 10,
filters: Optional[str] = None
) -> List[RetrievedDocument]:
"""便捷函数:检索法规"""
retriever = Retriever(top_k=top_k)
results = retriever.retrieve(query, filters)
retriever.close()
return results
def retrieve_regulations(query: str, top_k: int = 10, filters: Optional[str] = None) -> list[RetrievedDocument]:
"""Handle retrieve regulations."""
return Retriever(top_k=top_k).retrieve(query, filters)

View File

@@ -1,6 +1,18 @@
"""存储服务"""
"""Initialize the app.services.storage package."""
# Keep package boundaries explicit so backend imports stay predictable.
from .milvus_client import MilvusClient
from .minio_client import MinIOClient
__all__ = ["MilvusClient", "MinIOClient"]
def __getattr__(name: str):
"""Handle getattr for this module."""
if name == "MilvusClient":
from .milvus_client import MilvusClient
return MilvusClient
if name == "MinIOClient":
from .minio_client import MinIOClient
return MinIOClient
raise AttributeError(name)

View File

@@ -1,4 +1,4 @@
"""Milvus向量数据库客户端 - 存储与检索服务"""
"""Provide service-layer logic for milvus client."""
from pymilvus import (
connections,
@@ -17,11 +17,13 @@ import numpy as np
from ..embedding.text_chunker import TextChunk
from ..embedding.bge_m3_embedder import EmbeddingResult
from app.config.settings import settings
# Keep service responsibilities explicit so downstream behavior stays predictable.
@dataclass
class SearchResult:
"""检索结果"""
"""Represent the Search Result type."""
id: int
content: str
score: float
@@ -30,7 +32,7 @@ class SearchResult:
@dataclass
class MilvusDocument:
"""Milvus文档数据结构"""
"""Represent the Milvus Document type."""
doc_id: str
chunk_id: str
content: str
@@ -46,7 +48,7 @@ class MilvusDocument:
class MilvusClient:
"""Milvus向量数据库客户端"""
"""Represent the Milvus Client type."""
COLLECTION_NAME = "regulations"
@@ -73,6 +75,7 @@ class MilvusClient:
collection_name: str = None,
db_name: str = None
):
"""Initialize the Milvus Client instance."""
self.host = host or settings.milvus_host
self.port = port or settings.milvus_port
self.collection_name = collection_name or settings.milvus_collection
@@ -84,7 +87,7 @@ class MilvusClient:
logger.info(f"Milvus客户端配置: {self.host}:{self.port}, Collection: {self.collection_name}")
def connect(self) -> bool:
"""连接到Milvus服务器"""
"""Handle connect for the Milvus Client instance."""
try:
connections.connect(
alias="default",
@@ -101,7 +104,7 @@ class MilvusClient:
return False
def disconnect(self):
"""断开连接"""
"""Handle disconnect for the Milvus Client instance."""
try:
connections.disconnect("default")
self.connected = False
@@ -110,7 +113,7 @@ class MilvusClient:
logger.warning(f"断开连接时出错: {e}")
def create_collection(self, recreate: bool = False) -> bool:
"""创建Collection"""
"""Create collection for the Milvus Client instance."""
if not self.connected:
logger.warning("未连接到Milvus请先调用connect()")
return False
@@ -146,7 +149,7 @@ class MilvusClient:
return False
def _create_indexes(self):
"""创建向量索引"""
"""Handle create indexes for this module for the Milvus Client instance."""
if not self.collection:
return
@@ -177,13 +180,13 @@ class MilvusClient:
logger.warning(f"创建索引时出错: {e}")
def load_collection(self):
"""加载Collection到内存"""
"""Load collection for the Milvus Client instance."""
if self.collection:
self.collection.load()
logger.info(f"Collection已加载: {self.collection_name}")
def release_collection(self):
"""释放Collection内存"""
"""Handle release collection for the Milvus Client instance."""
if self.collection:
self.collection.release()
logger.info(f"Collection已释放: {self.collection_name}")
@@ -193,7 +196,7 @@ class MilvusClient:
chunks: List[TextChunk],
embeddings: EmbeddingResult
) -> List[int]:
"""插入文档分块和嵌入向量"""
"""Handle insert chunks for the Milvus Client instance."""
if not self.collection:
logger.warning("Collection未初始化")
return []
@@ -246,7 +249,7 @@ class MilvusClient:
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""混合检索Dense + Sparse"""
"""Handle hybrid search for the Milvus Client instance."""
if not self.collection:
logger.warning("Collection未初始化")
return []
@@ -254,10 +257,10 @@ class MilvusClient:
try:
self.collection.load()
# 使用简单的Dense检索兼容所有版本
# Keep service responsibilities explicit so downstream behavior stays predictable.
dense_results = self.dense_search(query_dense, top_k, filters)
# 可选合并Sparse结果
# Keep service responsibilities explicit so downstream behavior stays predictable.
if query_sparse:
sparse_results = self.sparse_search(query_sparse, top_k, filters)
merged = self._merge_results(dense_results, sparse_results, top_k)
@@ -277,7 +280,7 @@ class MilvusClient:
top_k: int,
dense_weight: float = 0.6
) -> List[SearchResult]:
"""手动融合Dense和Sparse结果"""
"""Handle merge results for this module for the Milvus Client instance."""
sparse_weight = 1 - dense_weight
merged_dict = {}
@@ -318,7 +321,7 @@ class MilvusClient:
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""纯Dense向量检索"""
"""Handle dense search for the Milvus Client instance."""
if not self.collection:
return []
@@ -375,7 +378,7 @@ class MilvusClient:
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""纯Sparse向量检索"""
"""Handle sparse search for the Milvus Client instance."""
if not self.collection:
return []
@@ -427,7 +430,7 @@ class MilvusClient:
return []
def delete_by_doc_id(self, doc_id: str) -> int:
"""根据doc_id删除记录"""
"""Delete by doc id for the Milvus Client instance."""
if not self.collection:
return 0
@@ -441,7 +444,7 @@ class MilvusClient:
return 0
def get_collection_stats(self) -> Dict[str, Any]:
"""获取Collection统计信息"""
"""Return collection stats for the Milvus Client instance."""
if not self.collection:
return {}
@@ -458,7 +461,7 @@ class MilvusClient:
def create_milvus_client() -> MilvusClient:
"""便捷函数创建Milvus客户端"""
"""Create milvus client."""
client = MilvusClient()
client.connect()
client.create_collection(recreate=False)
@@ -470,7 +473,7 @@ def insert_documents(
chunks: List[TextChunk],
embeddings: EmbeddingResult
) -> List[int]:
"""便捷函数:插入文档"""
"""Handle insert documents."""
return client.insert_chunks(chunks, embeddings)
@@ -480,5 +483,5 @@ def search_regulations(
query_sparse: Dict[int, float],
top_k: int = 10
) -> List[SearchResult]:
"""便捷函数:检索法规"""
"""Search regulations."""
return client.hybrid_search(query_dense, query_sparse, top_k)

View File

@@ -1,4 +1,4 @@
"""MinIO对象存储客户端 - 文档文件存储"""
"""Provide service-layer logic for minio client."""
from minio import Minio
from minio.error import S3Error
@@ -8,10 +8,12 @@ from io import BytesIO
import os
from app.config.settings import settings
# Keep service responsibilities explicit so downstream behavior stays predictable.
class MinIOClient:
"""MinIO对象存储客户端"""
"""Represent the Min I O Client type."""
def __init__(
self,
@@ -21,16 +23,7 @@ class MinIOClient:
bucket: str = None,
secure: bool = None
):
"""
初始化MinIO客户端
Args:
endpoint: MinIO服务地址
access_key: 访问密钥
secret_key: 秘密密钥
bucket: 存储桶名称
secure: 是否使用HTTPS
"""
"""Initialize the Min I O Client instance."""
self.endpoint = endpoint or settings.minio_endpoint
self.access_key = access_key or settings.minio_access_key
self.secret_key = secret_key or settings.minio_secret_key
@@ -43,7 +36,7 @@ class MinIOClient:
logger.info(f"MinIO客户端配置: {self.endpoint}, bucket={self.bucket}")
def connect(self) -> bool:
"""连接MinIO服务"""
"""Handle connect for the Min I O Client instance."""
try:
self.client = Minio(
self.endpoint,
@@ -60,7 +53,7 @@ class MinIOClient:
return False
def ensure_bucket(self) -> bool:
"""确保存储桶存在"""
"""Handle ensure bucket for the Min I O Client instance."""
if not self.connected:
logger.warning("未连接MinIO请先调用connect()")
return False
@@ -82,17 +75,7 @@ class MinIOClient:
object_name: str,
metadata: Dict[str, Any] = None
) -> bool:
"""
上传本地文件到MinIO
Args:
file_path: 本地文件路径
object_name: MinIO对象名称
metadata: 元数据
Returns:
bool: 是否成功
"""
"""Handle upload file for the Min I O Client instance."""
if not self.connected:
self.connect()
self.ensure_bucket()
@@ -125,18 +108,7 @@ class MinIOClient:
content_type: str = "application/octet-stream",
metadata: Dict[str, Any] = None
) -> bool:
"""
上传字节数据到MinIO
Args:
data: 文件字节数据
object_name: MinIO对象名称
content_type: 内容类型
metadata: 元数据注意MinIO仅支持US-ASCII字符
Returns:
bool: 是否成功
"""
"""Handle upload bytes for the Min I O Client instance."""
if not self.connected:
self.connect()
self.ensure_bucket()
@@ -144,18 +116,18 @@ class MinIOClient:
try:
data_stream = BytesIO(data)
# 处理metadata仅保留ASCII安全字符
# Keep service responsibilities explicit so downstream behavior stays predictable.
safe_metadata = None
if metadata:
safe_metadata = {}
for key, value in metadata.items():
if isinstance(value, str):
# 只保留ASCII字符或转换为安全格式
# Keep service responsibilities explicit so downstream behavior stays predictable.
try:
value.encode('ascii')
safe_metadata[key] = value
except UnicodeEncodeError:
# 中文字符跳过或用占位符
# Keep service responsibilities explicit so downstream behavior stays predictable.
safe_metadata[key] = ""
else:
safe_metadata[key] = str(value)
@@ -181,16 +153,7 @@ class MinIOClient:
object_name: str,
file_path: str
) -> bool:
"""
从MinIO下载文件到本地
Args:
object_name: MinIO对象名称
file_path: 本地保存路径
Returns:
bool: 是否成功
"""
"""Handle download file for the Min I O Client instance."""
if not self.connected:
self.connect()
@@ -212,16 +175,7 @@ class MinIOClient:
object_name: str,
expires: int = 3600
) -> Optional[str]:
"""
获取对象下载URL临时URL
Args:
object_name: MinIO对象名称
expires: URL有效期
Returns:
str: 下载URL
"""
"""Return object url for the Min I O Client instance."""
if not self.connected:
self.connect()
@@ -238,15 +192,7 @@ class MinIOClient:
return None
def get_object_data(self, object_name: str) -> Optional[bytes]:
"""
获取对象数据(字节)
Args:
object_name: MinIO对象名称
Returns:
bytes: 文件数据
"""
"""Return object data for the Min I O Client instance."""
if not self.connected:
self.connect()
@@ -262,15 +208,7 @@ class MinIOClient:
return None
def delete_object(self, object_name: str) -> bool:
"""
删除对象
Args:
object_name: MinIO对象名称
Returns:
bool: 是否成功
"""
"""Delete object for the Min I O Client instance."""
if not self.connected:
self.connect()
@@ -284,15 +222,7 @@ class MinIOClient:
return False
def list_objects(self, prefix: str = "") -> list:
"""
列出存储桶中的对象
Args:
prefix: 对象名称前缀
Returns:
list: 对象列表
"""
"""List objects for the Min I O Client instance."""
if not self.connected:
self.connect()
@@ -305,15 +235,7 @@ class MinIOClient:
return []
def object_exists(self, object_name: str) -> bool:
"""
检查对象是否存在
Args:
object_name: MinIO对象名称
Returns:
bool: 是否存在
"""
"""Handle object exists for the Min I O Client instance."""
if not self.connected:
self.connect()
@@ -325,7 +247,7 @@ class MinIOClient:
return False
def _get_content_type(self, file_path: str) -> str:
"""根据文件扩展名获取Content-Type"""
"""Handle get content type for this module for the Min I O Client instance."""
ext = os.path.splitext(file_path)[1].lower()
content_types = {
'.pdf': 'application/pdf',
@@ -338,13 +260,13 @@ class MinIOClient:
return content_types.get(ext, 'application/octet-stream')
def close(self):
"""关闭连接MinIO客户端无需显式关闭"""
"""Release the resources held by this component."""
self.connected = False
logger.info("MinIO客户端已关闭")
def create_minio_client() -> MinIOClient:
"""便捷函数创建MinIO客户端"""
"""Create minio client."""
client = MinIOClient()
client.connect()
client.ensure_bucket()