This commit is contained in:
2026-05-14 15:07:34 +08:00
parent c2a398930d
commit 10d04c4083
179 changed files with 24073 additions and 1243 deletions

View File

@@ -0,0 +1,3 @@
"""Backend service package."""
__all__: list[str] = []

View File

@@ -0,0 +1,7 @@
# src/services/agent/__init__.py
"""Agent服务模块"""
from .qa_agent import QAAgent, ask_compliance_question
from .session_manager import SessionManager, ChatSession
__all__ = ["QAAgent", "ask_compliance_question", "SessionManager", "ChatSession"]

View File

@@ -0,0 +1,412 @@
# src/services/agent/qa_agent.py
"""RAG问答Agent - 合规智能问答核心实现"""
import time
from typing import List, Dict, Optional, Any, Generator
from dataclasses import dataclass, field
from loguru import logger
from app.services.llm import get_llm_client, BaseLLMClient, LLMResponse
from app.services.llm.llm_factory import LLMFactory
from app.services.rag.retriever import Retriever, RetrievedDocument
from app.services.rag.context_builder import ContextBuilder, RAGContext
from app.services.rag.prompt_templates import get_prompt_template, PromptTemplate
from app.config.settings import settings
@dataclass
class AgentResponse:
"""Agent响应结果"""
answer: str
sources: List[Dict] = field(default_factory=list)
model: str = ""
latency_ms: int = 0
retrieved_count: int = 0
context_tokens: int = 0
truncated: bool = False
error: Optional[str] = None
@property
def is_success(self) -> bool:
return self.error is None
@dataclass
class AgentConfig:
"""Agent配置"""
llm_provider: str = "deepseek"
llm_model: str = "deepseek-v4-flash"
top_k: int = 5
min_score: float = 0.3
max_context_tokens: int = 2000
temperature: float = 0.7
prompt_template: str = "compliance_qa"
include_metadata: bool = True
class QAAgent:
"""
合规问答Agent
核心流程:
1. 接收用户问题
2. Milvus混合检索相关法规条款
3. 构建RAG上下文
4. 调用LLM生成回答
5. 返回答案和引用来源
使用示例:
agent = QAAgent()
response = agent.ask("机动车安全技术检验有哪些要求?")
print(response.answer)
for source in response.sources:
print(f"引用: {source['doc_name']} - {source['clause_number']}")
"""
def __init__(self, config: Optional[AgentConfig] = None):
"""
初始化问答Agent
Args:
config: Agent配置可选使用默认配置
"""
self.config = config or AgentConfig(
llm_provider=settings.llm_provider,
llm_model=settings.llm_model,
top_k=settings.rag_top_k,
max_context_tokens=settings.rag_max_context_tokens
)
# 初始化组件(延迟加载)
self.llm: Optional[BaseLLMClient] = None
self.retriever: Optional[Retriever] = None
self.context_builder: Optional[ContextBuilder] = None
logger.info(f"问答Agent初始化: provider={self.config.llm_provider}, model={self.config.llm_model}")
def _init_llm(self):
"""延迟初始化LLM客户端优先使用全局缓存"""
if self.llm is None:
# 尝试先获取全局缓存的客户端
cached = LLMFactory.get_global_client(self.config.llm_provider, self.config.llm_model)
if cached:
self.llm = cached
logger.debug(f"使用全局缓存的LLM客户端: {self.config.llm_provider} - {self.config.llm_model}")
else:
logger.info("创建新的LLM客户端...")
self.llm = get_llm_client(
provider=self.config.llm_provider,
model=self.config.llm_model,
temperature=self.config.temperature
)
def _init_retriever(self):
"""延迟初始化检索器"""
if self.retriever is None:
logger.info("初始化检索器...")
self.retriever = Retriever(
top_k=self.config.top_k,
min_score=self.config.min_score
)
def _init_context_builder(self):
"""延迟初始化上下文构建器"""
if self.context_builder is None:
logger.info("初始化上下文构建器...")
self.context_builder = ContextBuilder(
max_context_tokens=self.config.max_context_tokens,
include_metadata=self.config.include_metadata
)
def ask(
self,
query: str,
filters: Optional[str] = None,
prompt_template: Optional[str] = None
) -> AgentResponse:
"""
回答用户问题
Args:
query: 用户问题
filters: 检索过滤条件(如 "regulation_type=='车辆安全'"
prompt_template: Prompt模板名称可选覆盖默认配置
Returns:
AgentResponse: 包含答案和引用来源的响应对象
"""
start_time = time.time()
logger.info(f"收到问题: {query}")
try:
# Step 1: 检索相关法规
self._init_retriever()
documents = self.retriever.retrieve(query, filters)
retrieved_count = len(documents)
if retrieved_count == 0:
return AgentResponse(
answer="抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问,或提供更具体的法规名称。",
retrieved_count=0,
error="no_retrieved_documents"
)
# Step 2: 构建RAG上下文
self._init_context_builder()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
# Step 3: 构建LLM输入消息
messages = self._build_messages(template, context)
# Step 4: 调用LLM生成回答
self._init_llm()
llm_response = self.llm.chat(
messages=messages,
temperature=self.config.temperature
)
if not llm_response.is_success:
return AgentResponse(
answer="",
retrieved_count=retrieved_count,
error=llm_response.error
)
latency_ms = int((time.time() - start_time) * 1000)
# Step 5: 返回结果
logger.success(f"问答完成: {latency_ms}ms, {retrieved_count}条引用")
return AgentResponse(
answer=llm_response.content,
sources=context.sources,
model=llm_response.model,
latency_ms=latency_ms,
retrieved_count=retrieved_count,
context_tokens=context.total_tokens,
truncated=context.truncated
)
except Exception as e:
logger.error(f"问答失败: {e}")
return AgentResponse(
answer="",
error=str(e)
)
def ask_with_context(
self,
query: str,
documents: List[RetrievedDocument],
prompt_template: Optional[str] = None
) -> AgentResponse:
"""
使用提供的文档回答问题(不执行检索)
Args:
query: 用户问题
documents: 已检索的文档列表
prompt_template: Prompt模板名称
Returns:
AgentResponse: 响应结果
"""
start_time = time.time()
try:
self._init_context_builder()
self._init_llm()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
messages = self._build_messages(template, context)
llm_response = self.llm.chat(messages)
latency_ms = int((time.time() - start_time) * 1000)
return AgentResponse(
answer=llm_response.content,
sources=context.sources,
model=llm_response.model,
latency_ms=latency_ms,
retrieved_count=len(documents),
context_tokens=context.total_tokens,
truncated=context.truncated
)
except Exception as e:
logger.error(f"问答失败: {e}")
return AgentResponse(answer="", error=str(e))
def _build_messages(
self,
template: PromptTemplate,
context: RAGContext
) -> List[Dict[str, str]]:
"""构建LLM输入消息"""
user_content = template.user_template.format(
context=context.context_text,
query=context.user_query
)
return [
{"role": "system", "content": template.system_prompt},
{"role": "user", "content": user_content}
]
def ask_stream(
self,
query: str,
filters: Optional[str] = None,
prompt_template: Optional[str] = None
) -> Generator[Dict[str, Any], None, None]:
"""
流式回答用户问题SSE模式
返回事件类型:
- {"event": "status", "data": "正在检索..."} - 状态更新
- {"event": "sources", "data": [...]} - 引用来源
- {"event": "content", "data": "文本片段"} - 回答内容
- {"event": "done", "data": {"latency_ms": ..., "model": ...}} - 完成
Args:
query: 用户问题
filters: 检索过滤条件
prompt_template: Prompt模板名称
Yields:
Dict: SSE事件数据
"""
start_time = time.time()
logger.info(f"收到流式问题: {query}")
try:
# Step 1: 检索相关法规
yield {"event": "status", "data": "正在检索相关法规..."}
self._init_retriever()
documents = self.retriever.retrieve(query, filters)
retrieved_count = len(documents)
if retrieved_count == 0:
yield {"event": "status", "data": "未找到相关法规"}
yield {"event": "content", "data": "抱歉,未找到与您的问题相关的法规条款。请尝试用不同的关键词重新提问。"}
yield {"event": "done", "data": {"latency_ms": 0, "retrieved_count": 0}}
return
# Step 2: 发送检索结果
yield {"event": "status", "data": f"找到{retrieved_count}条相关法规,正在生成回答..."}
sources = [
{
"doc_name": doc.doc_name,
"doc_id": doc.doc_id,
"clause_number": doc.clause_number,
"score": doc.score
}
for doc in documents[:5] # 只返回前5条引用
]
yield {"event": "sources", "data": sources}
# Step 3: 构建RAG上下文
self._init_context_builder()
template_name = prompt_template or self.config.prompt_template
template = get_prompt_template(template_name)
context = self.context_builder.build(
query=query,
documents=documents,
system_prompt=template.system_prompt
)
# Step 4: 构建LLM输入消息
messages = self._build_messages(template, context)
# Step 5: 流式调用LLM生成回答
self._init_llm()
full_answer = ""
# 检查LLM是否支持流式输出
if hasattr(self.llm, 'stream_chat'):
yield {"event": "status", "data": "思考中..."}
for chunk in self.llm.stream_chat(
messages=messages,
temperature=self.config.temperature
):
full_answer += chunk
yield {"event": "content", "data": chunk}
else:
# 如果不支持流式,回退到普通调用
yield {"event": "status", "data": "生成回答中..."}
llm_response = self.llm.chat(
messages=messages,
temperature=self.config.temperature
)
if llm_response.is_success:
full_answer = llm_response.content
yield {"event": "content", "data": full_answer}
# Step 6: 发送完成事件
latency_ms = int((time.time() - start_time) * 1000)
logger.success(f"流式问答完成: {latency_ms}ms, {retrieved_count}条引用")
yield {
"event": "done",
"data": {
"latency_ms": latency_ms,
"model": self.config.llm_model,
"retrieved_count": retrieved_count,
"context_tokens": context.total_tokens
}
}
except Exception as e:
logger.error(f"流式问答失败: {e}")
yield {"event": "error", "data": str(e)}
def close(self):
"""关闭Agent资源不关闭LLM客户端因为它全局缓存"""
if self.retriever:
self.retriever.close()
logger.info("问答Agent已关闭")
def ask_compliance_question(
query: str,
provider: str = "deepseek",
model: str = "deepseek-v4-flash",
top_k: int = 10
) -> AgentResponse:
"""
便捷函数:问答合规问题
Args:
query: 用户问题
provider: LLM提供商
model: LLM模型
top_k: 检索数量
Returns:
AgentResponse: 响应结果
"""
config = AgentConfig(
llm_provider=provider,
llm_model=model,
top_k=top_k
)
agent = QAAgent(config)
response = agent.ask(query)
agent.close()
return response

View File

@@ -0,0 +1,247 @@
# src/services/agent/session_manager.py
"""多轮对话会话管理"""
import time
import uuid
from typing import Dict, List, Optional
from dataclasses import dataclass, field
from loguru import logger
@dataclass
class ChatMessage:
"""对话消息"""
role: str # "user" / "assistant" / "system"
content: str
timestamp: int
sources: List[Dict] = field(default_factory=list)
metadata: Dict = field(default_factory=dict)
@dataclass
class ChatSession:
"""对话会话"""
session_id: str
messages: List[ChatMessage] = field(default_factory=list)
created_at: int = field(default_factory=lambda: int(time.time()))
updated_at: int = field(default_factory=lambda: int(time.time()))
metadata: Dict = field(default_factory=dict)
def add_user_message(self, content: str) -> ChatMessage:
"""添加用户消息"""
message = ChatMessage(
role="user",
content=content,
timestamp=int(time.time())
)
self.messages.append(message)
self.updated_at = int(time.time())
return message
def add_assistant_message(
self,
content: str,
sources: List[Dict] = None
) -> ChatMessage:
"""添加助手消息"""
message = ChatMessage(
role="assistant",
content=content,
timestamp=int(time.time()),
sources=sources or []
)
self.messages.append(message)
self.updated_at = int(time.time())
return message
def get_history(self, max_turns: int = 5) -> List[Dict[str, str]]:
"""获取历史对话用于LLM上下文"""
history = []
# 获取最近N轮对话每轮包含user + assistant
recent_messages = self.messages[-(max_turns * 2):]
for msg in recent_messages:
history.append({
"role": msg.role,
"content": msg.content
})
return history
def clear_history(self):
"""清空对话历史"""
self.messages = []
self.updated_at = int(time.time())
logger.info(f"会话历史已清空: {self.session_id}")
@property
def message_count(self) -> int:
"""消息数量"""
return len(self.messages)
@property
def is_empty(self) -> bool:
"""是否为空会话"""
return len(self.messages) == 0
class SessionManager:
"""
会话管理器
功能:
- 创建/获取/删除会话
- 会话超时清理
- 会话历史记录管理
使用示例:
manager = SessionManager()
# 创建会话
session = manager.create_session()
# 添加消息
session.add_user_message("什么是机动车安全技术检验?")
session.add_assistant_message("根据GB 7258...", sources=[...])
# 获取历史用于LLM多轮对话
history = session.get_history(max_turns=3)
"""
def __init__(
self,
max_sessions: int = 100,
session_timeout_minutes: int = 30
):
"""
初始化会话管理器
Args:
max_sessions: 最大会话数量
session_timeout_minutes: 会话超时时间(分钟)
"""
self.max_sessions = max_sessions
self.session_timeout = session_timeout_minutes * 60
# 会话存储(内存)
self._sessions: Dict[str, ChatSession] = {}
logger.info(f"会话管理器初始化: max_sessions={max_sessions}, timeout={session_timeout_minutes}min")
def create_session(self, metadata: Dict = None) -> ChatSession:
"""
创建新会话
Args:
metadata: 会话元数据(可选)
Returns:
ChatSession: 新创建的会话
"""
# 检查会话数量限制
if len(self._sessions) >= self.max_sessions:
# 清理过期会话
self._cleanup_expired_sessions()
# 如果仍然超出限制,删除最老的会话
if len(self._sessions) >= self.max_sessions:
oldest_id = min(
self._sessions.keys(),
key=lambda x: self._sessions[x].created_at
)
self.delete_session(oldest_id)
logger.warning(f"删除最老会话以腾出空间: {oldest_id}")
session_id = str(uuid.uuid4())[:8]
session = ChatSession(
session_id=session_id,
metadata=metadata or {}
)
self._sessions[session_id] = session
logger.info(f"创建新会话: {session_id}")
return session
def get_session(self, session_id: str) -> Optional[ChatSession]:
"""
获取会话
Args:
session_id: 会话ID
Returns:
ChatSession: 会话对象如不存在返回None
"""
session = self._sessions.get(session_id)
if session:
# 检查是否过期
if self._is_session_expired(session):
self.delete_session(session_id)
logger.info(f"会话已过期,已删除: {session_id}")
return None
return session
def delete_session(self, session_id: str) -> bool:
"""
删除会话
Args:
session_id: 会话ID
Returns:
bool: 是否成功删除
"""
if session_id in self._sessions:
del self._sessions[session_id]
logger.info(f"删除会话: {session_id}")
return True
return False
def list_sessions(self) -> List[Dict]:
"""
列出所有会话
Returns:
List[Dict]: 会话列表摘要
"""
return [
{
"session_id": s.session_id,
"message_count": s.message_count,
"created_at": s.created_at,
"updated_at": s.updated_at
}
for s in self._sessions.values()
]
def _is_session_expired(self, session: ChatSession) -> bool:
"""检查会话是否过期"""
current_time = int(time.time())
return (current_time - session.updated_at) > self.session_timeout
def _cleanup_expired_sessions(self) -> int:
"""清理过期会话"""
expired_ids = [
sid for sid, session in self._sessions.items()
if self._is_session_expired(session)
]
for sid in expired_ids:
self.delete_session(sid)
if expired_ids:
logger.info(f"清理过期会话: {len(expired_ids)}")
return len(expired_ids)
def get_session_count(self) -> int:
"""获取当前会话数量"""
return len(self._sessions)
def clear_all_sessions(self):
"""清空所有会话"""
self._sessions.clear()
logger.info("所有会话已清空")

View File

@@ -0,0 +1,404 @@
# src/services/document_processor.py
"""文档处理主流程 - 解析→摘要→分块→嵌入→入库"""
import os
from typing import List, Dict, Optional
from dataclasses import dataclass
from loguru import logger
import uuid
from .parser.pdf_parser import PDFParser
from .parser.docx_parser import DocxParser
from .parser.mineru_parser import ParserOrchestrator
from .embedding.text_chunker import RegulationChunker, TextChunk
from .embedding.bge_m3_embedder import BGEM3Embedder, EmbeddingResult
from .storage.milvus_client import MilvusClient
from .llm.document_summarizer import DocumentSummarizer, DocumentSummary
from app.config.settings import settings
@dataclass
class ProcessingResult:
"""文档处理结果"""
doc_id: str
doc_name: str
success: bool
num_chunks: int = 0
message: str = ""
markdown_text: str = ""
summary: str = ""
summary_latency_ms: int = 0
class DocumentProcessor:
"""
文档处理服务 - 完整处理流程
流程:
1. 文档解析PDF/DOCX → Markdown
2. 智能分块(章节级+条款级)
3. LLM摘要生成可选
4. 向量嵌入BGE-M3 Dense+Sparse
5. 存储入库Milvus向量数据库
"""
def __init__(
self,
chunk_size: int = None,
embedding_model: str = None,
use_mineru: bool = True,
generate_summary: bool = False, # 默认不生成摘要节省约60秒
llm_provider: str = None,
llm_model: str = None
):
"""
初始化文档处理器
Args:
chunk_size: 分块大小
embedding_model: 嵌入模型名称
use_mineru: 是否优先使用MinerU解析
generate_summary: 是否生成文档摘要默认False可节省约60秒处理时间
llm_provider: LLM提供商
llm_model: LLM模型名称
"""
self.chunk_size = chunk_size or settings.chunk_size
self.embedding_model = embedding_model or settings.embedding_model
self.use_mineru = use_mineru
self.generate_summary = generate_summary
self.llm_provider = llm_provider or settings.llm_provider
self.llm_model = llm_model or settings.llm_model
# 初始化各组件
logger.info("初始化文档处理组件...")
# 解析器
self.parser = ParserOrchestrator()
# 分块器
self.chunker = RegulationChunker(chunk_size=self.chunk_size)
# 嵌入模型(延迟加载)
self.embedder: Optional[BGEM3Embedder] = None
# Milvus客户端延迟连接
self.milvus: Optional[MilvusClient] = None
# 摘要生成器(延迟加载)
self.summarizer: Optional[DocumentSummarizer] = None
logger.success("文档处理器初始化完成")
def _init_embedder(self):
"""延迟初始化嵌入模型"""
if self.embedder is None:
logger.info("加载嵌入模型...")
self.embedder = BGEM3Embedder(model_name=self.embedding_model)
def _init_milvus(self):
"""延迟初始化Milvus连接"""
if self.milvus is None:
logger.info("连接Milvus...")
self.milvus = MilvusClient()
self.milvus.connect()
self.milvus.create_collection(recreate=False)
self.milvus.load_collection()
def _init_summarizer(self):
"""延迟初始化摘要生成器"""
if self.summarizer is None:
logger.info("初始化摘要生成器...")
self.summarizer = DocumentSummarizer(
provider=self.llm_provider,
model=self.llm_model
)
def process(
self,
file_path: str,
doc_id: Optional[str] = None,
doc_name: Optional[str] = None,
regulation_type: str = "",
version: str = ""
) -> ProcessingResult:
"""
处理单个文档
Args:
file_path: 文档文件路径
doc_id: 文档ID可选默认自动生成
doc_name: 文档名称(可选,默认从文件名获取)
regulation_type: 法规类型
version: 文档版本
Returns:
ProcessingResult: 处理结果
"""
# 生成或使用传入的文档ID
if doc_id is None:
doc_id = str(uuid.uuid4())[:8]
# 获取文档名称
if doc_name is None:
doc_name = os.path.basename(file_path)
logger.info(f"开始处理文档: {doc_name} (ID: {doc_id})")
# 初始化结果变量
summary = ""
summary_latency_ms = 0
try:
# 1. 文档解析
logger.info("Step 1: 文档解析")
markdown_text = self._parse_document(file_path)
if not markdown_text:
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message="文档解析失败,内容为空"
)
# 2. LLM摘要生成可选
if self.generate_summary:
logger.info("Step 2: LLM摘要生成")
self._init_summarizer()
summary_result = self.summarizer.summarize(
doc_name,
markdown_text,
regulation_type
)
if summary_result.is_success:
summary = summary_result.summary
summary_latency_ms = summary_result.latency_ms
logger.success(f"摘要生成完成: {summary_latency_ms}ms")
else:
logger.warning(f"摘要生成失败: {summary_result.error}")
else:
logger.info("Step 2: 跳过摘要生成未勾选节省约60秒")
# 3. 智能分块
logger.info("Step 3: 智能分块")
chunks = self._chunk_document(
markdown_text,
doc_id,
doc_name,
regulation_type,
version
)
if not chunks:
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message="分块失败,无有效内容",
markdown_text=markdown_text,
summary=summary
)
# 4. 向量嵌入
logger.info("Step 4: 向量嵌入")
embeddings = self._embed_chunks(chunks)
if embeddings is None:
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message="向量嵌入失败",
markdown_text=markdown_text,
summary=summary
)
# 5. 存储入库
logger.info("Step 5: 存储入库")
inserted_ids = self._insert_to_milvus(chunks, embeddings)
logger.success(f"文档处理完成: {doc_name}, 共{len(inserted_ids)}条记录")
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=True,
num_chunks=len(inserted_ids),
message="处理成功",
markdown_text=markdown_text,
summary=summary,
summary_latency_ms=summary_latency_ms
)
except Exception as e:
logger.error(f"文档处理失败: {e}")
return ProcessingResult(
doc_id=doc_id,
doc_name=doc_name,
success=False,
message=f"处理失败: {str(e)}"
)
def _parse_document(self, file_path: str) -> str:
"""解析文档"""
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == ".pdf":
# PDF文档解析优先MinerU回退PyMuPDF
markdown_text = self.parser.parse_pdf(file_path, prefer_mineru=self.use_mineru)
elif ext in [".docx", ".doc"]:
# Word文档解析
markdown_text = self.parser.parse_docx(file_path)
else:
logger.warning(f"不支持的文件类型: {ext}")
return ""
logger.success(f"文档解析完成,内容长度: {len(markdown_text)}字符")
return markdown_text
except Exception as e:
logger.error(f"文档解析失败: {e}")
return ""
def _chunk_document(
self,
markdown_text: str,
doc_id: str,
doc_name: str,
regulation_type: str,
version: str
) -> List[TextChunk]:
"""分块文档"""
try:
chunks = self.chunker.chunk_document(
markdown_text,
doc_id=doc_id,
doc_name=doc_name,
regulation_type=regulation_type,
version=version
)
logger.success(f"分块完成,共{len(chunks)}个chunk")
return chunks
except Exception as e:
logger.error(f"分块失败: {e}")
return []
def _embed_chunks(self, chunks: List[TextChunk]) -> Optional[EmbeddingResult]:
"""嵌入分块"""
try:
# 延迟初始化嵌入模型
self._init_embedder()
# 提取文本内容
texts = [chunk.content for chunk in chunks]
# 执行嵌入
embeddings = self.embedder.embed(texts)
logger.success(f"嵌入完成,向量数: {len(embeddings.dense_embeddings)}")
return embeddings
except Exception as e:
logger.error(f"嵌入失败: {e}")
return None
def _insert_to_milvus(
self,
chunks: List[TextChunk],
embeddings: EmbeddingResult
) -> List[int]:
"""插入Milvus"""
try:
# 延迟初始化Milvus
self._init_milvus()
# 执行插入
inserted_ids = self.milvus.insert_chunks(chunks, embeddings)
logger.success(f"入库完成,共{len(inserted_ids)}条记录")
return inserted_ids
except Exception as e:
logger.error(f"入库失败: {e}")
return []
def search(
self,
query: str,
top_k: int = 10,
filters: Optional[str] = None
) -> List[Dict]:
"""
检索法规内容
Args:
query: 查询文本
top_k: 返回结果数量
filters: 过滤条件
Returns:
List[Dict]: 检索结果
"""
logger.info(f"执行检索: {query}")
try:
# 延迟初始化
self._init_embedder()
self._init_milvus()
# 生成查询向量
query_embedding = self.embedder.embed_single(query)
# 执行混合检索
results = self.milvus.hybrid_search(
query_dense=query_embedding['dense'].tolist(),
query_sparse=query_embedding['sparse'],
top_k=top_k,
filters=filters
)
# 转换为字典格式
result_dicts = []
for r in results:
result_dicts.append({
"id": r.id,
"content": r.content,
"score": r.score,
"metadata": r.metadata
})
logger.success(f"检索完成,返回{len(result_dicts)}条结果")
return result_dicts
except Exception as e:
logger.error(f"检索失败: {e}")
return []
def close(self):
"""关闭连接"""
if self.milvus:
self.milvus.disconnect()
logger.info("文档处理器已关闭")
def process_document(
file_path: str,
doc_name: Optional[str] = None,
regulation_type: str = "",
version: str = ""
) -> ProcessingResult:
"""便捷函数:处理单个文档"""
processor = DocumentProcessor()
result = processor.process(file_path, doc_name, regulation_type, version)
processor.close()
return result
def search_regulations(query: str, top_k: int = 10) -> List[Dict]:
"""便捷函数:检索法规"""
processor = DocumentProcessor()
results = processor.search(query, top_k)
processor.close()
return results

View File

@@ -0,0 +1,7 @@
# src/services/embedding/__init__.py
"""嵌入和分块服务"""
from .text_chunker import RegulationChunker
from .bge_m3_embedder import BGEM3Embedder
__all__ = ["RegulationChunker", "BGEM3Embedder"]

View File

@@ -0,0 +1,296 @@
# src/services/embedding/bge_m3_embedder.py
"""BGE-M3嵌入服务 - Dense+Sparse双路向量生成"""
import numpy as np
from typing import List, Dict, Optional, Union
from dataclasses import dataclass, field
from loguru import logger
import torch
import os
# 设置HuggingFace镜像国内网络
if 'HF_ENDPOINT' not in os.environ:
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# 本地模型路径(按优先级检查)
LOCAL_MODEL_PATHS = [
os.path.expanduser("~/.cache/modelscope/Xorbits/bge-m3"), # ModelScope下载路径
os.path.expanduser("~/.cache/huggingface/hub/models--BAAI--bge-m3/snapshots/main"), # HuggingFace本地路径
]
@dataclass
class EmbeddingResult:
"""嵌入结果"""
dense_embeddings: np.ndarray # Dense向量语义检索
sparse_embeddings: List[Dict[int, float]] # Sparse向量关键词匹配
texts: List[str]
dim: int = 1024
class BGEM3Embedder:
"""
BGE-M3多语言嵌入模型服务
BGE-M3是BAAI发布的多语言嵌入模型支持
- Dense向量用于语义相似度检索
- Sparse向量用于关键词精确匹配BM25风格
- ColBERT向量用于细粒度交互匹配可选
特点:
- 支持100+语言(中英双语优化)
- 8192 tokens超长上下文
- Dense+Sparse双路检索能力
GitHub: https://github.com/FlagOpen/FlagEmbedding
"""
def __init__(
self,
model_name: str = "BAAI/bge-m3",
use_fp16: bool = True,
device: Optional[str] = None,
batch_size: int = 12,
max_length: int = 8192,
local_model_path: Optional[str] = None
):
"""
初始化BGE-M3嵌入模型
Args:
model_name: 模型名称(如果使用本地路径,此参数会被忽略)
use_fp16: 是否使用FP16加速
device: 设备类型cuda/cpu默认自动选择
batch_size: 批处理大小
max_length: 最大序列长度
local_model_path: 本地模型路径(可选,优先使用)
"""
self.use_fp16 = use_fp16
self.batch_size = batch_size
self.max_length = max_length
# 确定模型路径(优先使用本地路径)
if local_model_path and os.path.exists(local_model_path):
self.model_path = local_model_path
self.model_name = "local"
logger.info(f"使用本地模型路径: {local_model_path}")
else:
# 检查多个可能的本地路径
found_local = False
for path in LOCAL_MODEL_PATHS:
if os.path.exists(path) and os.path.exists(os.path.join(path, "config.json")):
self.model_path = path
self.model_name = "local"
logger.info(f"使用本地模型路径: {path}")
found_local = True
break
if not found_local:
self.model_path = model_name
self.model_name = model_name
logger.info(f"使用远程模型: {model_name}")
# 自动选择设备
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
logger.info(f"初始化BGE-M3模型, 设备: {self.device}")
self.model = None
self._load_model()
def _load_model(self):
"""加载嵌入模型"""
try:
from FlagEmbedding import BGEM3FlagModel
self.model = BGEM3FlagModel(
self.model_path,
use_fp16=self.use_fp16,
device=self.device
)
logger.success(f"BGE-M3模型加载成功")
except ImportError:
logger.warning("FlagEmbedding库未安装请运行: pip install FlagEmbedding")
raise
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise
def embed(
self,
texts: List[str],
return_dense: bool = True,
return_sparse: bool = True,
return_colbert_vecs: bool = False
) -> EmbeddingResult:
"""
对文本列表生成嵌入向量
Args:
texts: 文本列表
return_dense: 是否返回Dense向量
return_sparse: 是否返回Sparse向量
return_colbert_vecs: 是否返回ColBERT向量
Returns:
EmbeddingResult: 嵌入结果
"""
if not texts:
logger.warning("输入文本列表为空")
return EmbeddingResult(
dense_embeddings=np.array([]),
sparse_embeddings=[],
texts=[],
dim=0
)
logger.info(f"开始嵌入{len(texts)}个文本块")
try:
# 执行嵌入
embeddings = self.model.encode(
texts,
batch_size=self.batch_size,
max_length=self.max_length,
return_dense=return_dense,
return_sparse=return_sparse,
return_colbert_vecs=return_colbert_vecs
)
# 提取结果
dense_embeddings = embeddings.get('dense_vecs', np.array([]))
sparse_embeddings = embeddings.get('lexical_weights', [])
# 获取维度
dim = dense_embeddings.shape[1] if len(dense_embeddings) > 0 else 1024
logger.success(f"嵌入完成,向量维度: {dim}")
return EmbeddingResult(
dense_embeddings=dense_embeddings,
sparse_embeddings=sparse_embeddings,
texts=texts,
dim=dim
)
except Exception as e:
logger.error(f"嵌入失败: {e}")
raise
def embed_single(self, text: str) -> Dict[str, Union[np.ndarray, Dict]]:
"""
对单个文本生成嵌入向量
Args:
text: 输入文本
Returns:
Dict: 包含dense和sparse向量
"""
result = self.embed([text])
return {
'dense': result.dense_embeddings[0],
'sparse': result.sparse_embeddings[0] if result.sparse_embeddings else {},
'dim': result.dim
}
def embed_dense(self, texts: List[str]) -> np.ndarray:
"""只生成Dense向量"""
result = self.embed(texts, return_sparse=False, return_colbert_vecs=False)
return result.dense_embeddings
def embed_sparse(self, texts: List[str]) -> List[Dict[int, float]]:
"""只生成Sparse向量"""
result = self.embed(texts, return_dense=False, return_colbert_vecs=False)
return result.sparse_embeddings
def embed_query(self, query: str) -> Dict:
"""
对查询文本生成嵌入(用于检索)
Args:
query: 查询文本
Returns:
Dict: 包含dense和sparse向量
"""
return self.embed_single(query)
def compute_similarity(
self,
query_embedding: np.ndarray,
doc_embeddings: np.ndarray,
metric: str = "cosine"
) -> np.ndarray:
"""
计算查询与文档的相似度
Args:
query_embedding: 查询向量
doc_embeddings: 文档向量矩阵
metric: 相似度度量cosine/dot
Returns:
np.ndarray: 相似度分数数组
"""
if metric == "cosine":
# 余弦相似度
query_norm = np.linalg.norm(query_embedding)
doc_norms = np.linalg.norm(doc_embeddings, axis=1)
similarities = np.dot(doc_embeddings, query_embedding) / (doc_norms * query_norm)
elif metric == "dot":
# 点积相似度
similarities = np.dot(doc_embeddings, query_embedding)
else:
raise ValueError(f"不支持的相似度度量: {metric}")
return similarities
def sparse_similarity(
self,
query_sparse: Dict[int, float],
doc_sparse: Dict[int, float]
) -> float:
"""
计算Sparse向量的相似度BM25风格
Args:
query_sparse: 查询的Sparse向量词ID -> 权重)
doc_sparse: 文档的Sparse向量
Returns:
float: 相似度分数
"""
# 计算交集词的点积
common_keys = set(query_sparse.keys()) & set(doc_sparse.keys())
score = sum(query_sparse[k] * doc_sparse[k] for k in common_keys)
return score
def embed_texts(
texts: List[str],
model_name: str = "BAAI/bge-m3",
**kwargs
) -> EmbeddingResult:
"""便捷函数:对文本列表生成嵌入"""
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
return embedder.embed(texts)
def embed_single_text(
text: str,
model_name: str = "BAAI/bge-m3",
**kwargs
) -> Dict:
"""便捷函数:对单个文本生成嵌入"""
embedder = BGEM3Embedder(model_name=model_name, **kwargs)
return embedder.embed_single(text)

View File

@@ -0,0 +1,449 @@
# src/services/embedding/text_chunker.py
"""智能分块器 - 章节级+条款级双粒度切割"""
import re
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field
from loguru import logger
@dataclass
class ChunkMetadata:
"""分块元数据"""
doc_id: str = ""
doc_name: str = ""
chunk_id: str = ""
section_number: str = "" # 章节编号(如 "第一章"
section_title: str = "" # 章节标题
clause_number: str = "" # 条款编号(如 "第一条"
page_number: int = 0
start_position: int = 0 # 在原文中的起始位置
end_position: int = 0 # 在原文中的结束位置
regulation_type: str = "" # 法规类型
version: str = ""
@dataclass
class TextChunk:
"""文本分块"""
content: str
metadata: ChunkMetadata
token_count: int = 0 # 估算的token数量
class RegulationChunker:
"""
法规文档智能分块器
实现章节级/条款级双粒度切割适配国标GB文档结构
- 国标文档通常有明确的层级结构:章 > 节 > 条
- 每个条款应作为一个独立的语义单元
- 保留条款完整性,避免跨条款截断
"""
# 法规标题模式
CHAPTER_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+章\s+[^\n]+')
SECTION_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+节\s+[^\n]+')
CLAUSE_PATTERN = re.compile(r'^第[一二三四五六七八九十百]+条\s')
# 条款子项模式
SUB_ITEM_PATTERN = re.compile(r'^[\(][一二三四五六七八九十]+[\)]\s')
NUMBER_ITEM_PATTERN = re.compile(r'^[\d]+[\.、]\s')
def __init__(
self,
chunk_size: int = 512,
chunk_overlap: int = 50,
max_chunk_size: int = 2048,
min_chunk_size: int = 100
):
"""
初始化分块器
Args:
chunk_size: 默认分块大小(字符数)
chunk_overlap: 分块重叠大小
max_chunk_size: 最大分块大小(防止单个条款过长)
min_chunk_size: 最小分块大小(防止碎片化)
"""
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.max_chunk_size = max_chunk_size
self.min_chunk_size = min_chunk_size
def chunk_document(
self,
markdown_text: str,
doc_id: str = "",
doc_name: str = "",
regulation_type: str = "",
version: str = ""
) -> List[TextChunk]:
"""
对法规文档进行智能分块
Args:
markdown_text: Markdown格式的文档内容
doc_id: 文档ID
doc_name: 文档名称
regulation_type: 法规类型
version: 文档版本
Returns:
List[TextChunk]: 分块列表
"""
logger.info(f"开始分块文档: {doc_name}")
# 1. 按章节分割(一级分块)
sections = self._split_by_sections(markdown_text)
# 2. 在每个章节内按条款分割(二级分块)
chunks = []
global_position = 0
for section_num, section_title, section_content, section_start in sections:
# 在章节内按条款分割
clause_chunks = self._split_by_clauses(
section_content,
section_num,
section_title,
section_start + global_position
)
for chunk_content, clause_num, clause_title, start_pos, end_pos in clause_chunks:
# 处理过长的条款(进一步细分)
if len(chunk_content) > self.max_chunk_size:
sub_chunks = self._split_long_clause(
chunk_content,
clause_num,
clause_title
)
for sub_content, sub_start, sub_end in sub_chunks:
chunk = self._create_chunk(
sub_content,
doc_id,
doc_name,
section_num,
section_title,
clause_num,
sub_start + start_pos,
sub_end + start_pos,
regulation_type,
version
)
chunks.append(chunk)
else:
chunk = self._create_chunk(
chunk_content,
doc_id,
doc_name,
section_num,
section_title,
clause_num,
start_pos,
end_pos,
regulation_type,
version
)
chunks.append(chunk)
logger.success(f"分块完成,共{len(chunks)}个chunk")
return chunks
def _split_by_sections(self, markdown_text: str) -> List[Tuple[str, str, str, int]]:
"""
按章节分割文档
Returns:
List of (section_number, section_title, section_content, start_position)
"""
sections = []
lines = markdown_text.split('\n')
current_section_num = ""
current_section_title = ""
current_section_content = []
current_section_start = 0
for i, line in enumerate(lines):
# 检测章节标题
chapter_match = self.CHAPTER_PATTERN.match(line.strip())
section_match = self.SECTION_PATTERN.match(line.strip())
if chapter_match or section_match:
# 保存上一个章节
if current_section_content:
content = '\n'.join(current_section_content)
sections.append((
current_section_num,
current_section_title,
content,
current_section_start
))
# 开始新章节
current_section_start = sum(len(l) + 1 for l in lines[:i])
current_section_content = []
if chapter_match:
current_section_num = line.strip()
current_section_title = self._extract_title(line.strip())
else:
current_section_num = line.strip()
current_section_title = self._extract_title(line.strip())
current_section_content.append(line)
# 保存最后一个章节
if current_section_content:
content = '\n'.join(current_section_content)
sections.append((
current_section_num,
current_section_title,
content,
current_section_start
))
# 如果没有检测到章节,将整个文档作为一个大章节
if not sections:
sections.append((
"",
"全文",
markdown_text,
0
))
return sections
def _split_by_clauses(
self,
section_content: str,
section_num: str,
section_title: str,
section_start: int
) -> List[Tuple[str, str, str, int, int]]:
"""
在章节内按条款分割
Returns:
List of (content, clause_number, clause_title, start_position, end_position)
"""
clauses = []
lines = section_content.split('\n')
current_clause_num = ""
current_clause_title = ""
current_clause_content = []
current_clause_start = section_start
for i, line in enumerate(lines):
# 检测条款标题
clause_match = self.CLAUSE_PATTERN.match(line.strip())
if clause_match:
# 保存上一个条款
if current_clause_content:
content = '\n'.join(current_clause_content)
end_pos = current_clause_start + len(content)
clauses.append((
content,
current_clause_num,
current_clause_title,
current_clause_start,
end_pos
))
# 开始新条款
current_clause_start = section_start + sum(len(l) + 1 for l in lines[:i])
current_clause_content = []
current_clause_num = self._extract_clause_number(line.strip())
current_clause_title = line.strip()
current_clause_content.append(line)
# 保存最后一个条款
if current_clause_content:
content = '\n'.join(current_clause_content)
end_pos = current_clause_start + len(content)
clauses.append((
content,
current_clause_num,
current_clause_title,
current_clause_start,
end_pos
))
# 如果没有检测到条款,将整个章节作为一个条款
if not clauses:
clauses.append((
section_content,
"",
section_title,
section_start,
section_start + len(section_content)
))
return clauses
def _split_long_clause(
self,
content: str,
clause_num: str,
clause_title: str
) -> List[Tuple[str, int, int]]:
"""
分割过长的条款内容
按条款子项或段落分割,保持语义完整性
"""
sub_chunks = []
lines = content.split('\n')
# 检测是否有子项结构
has_sub_items = any(
self.SUB_ITEM_PATTERN.match(line.strip()) or
self.NUMBER_ITEM_PATTERN.match(line.strip())
for line in lines
)
if has_sub_items:
# 按子项分割
current_sub_content = []
current_sub_start = 0
for i, line in enumerate(lines):
is_sub_item = (
self.SUB_ITEM_PATTERN.match(line.strip()) or
self.NUMBER_ITEM_PATTERN.match(line.strip())
)
if is_sub_item and current_sub_content:
sub_content = '\n'.join(current_sub_content)
sub_end = current_sub_start + len(sub_content)
if len(sub_content) >= self.min_chunk_size:
sub_chunks.append((sub_content, current_sub_start, sub_end))
current_sub_content = []
current_sub_start = sum(len(l) + 1 for l in lines[:i])
current_sub_content.append(line)
# 保存最后一个子项
if current_sub_content:
sub_content = '\n'.join(current_sub_content)
sub_end = current_sub_start + len(sub_content)
sub_chunks.append((sub_content, current_sub_start, sub_end))
else:
# 按段落分割(滑动窗口)
paragraphs = []
current_para = []
for line in lines:
if line.strip():
current_para.append(line)
else:
if current_para:
paragraphs.append('\n'.join(current_para))
current_para = []
if current_para:
paragraphs.append('\n'.join(current_para))
# 合并段落直到达到chunk_size
current_chunk = []
current_length = 0
chunk_start = 0
for para in paragraphs:
if current_length + len(para) > self.chunk_size and current_chunk:
chunk_content = '\n'.join(current_chunk)
chunk_end = chunk_start + len(chunk_content)
sub_chunks.append((chunk_content, chunk_start, chunk_end))
current_chunk = []
current_length = 0
chunk_start = chunk_end
current_chunk.append(para)
current_length += len(para)
# 保存最后一个chunk
if current_chunk:
chunk_content = '\n'.join(current_chunk)
chunk_end = chunk_start + len(chunk_content)
sub_chunks.append((chunk_content, chunk_start, chunk_end))
return sub_chunks
def _extract_title(self, header_line: str) -> str:
"""从标题行提取标题内容"""
# 移除"第X章"、"第X节"前缀
title = re.sub(r'^第[一二三四五六七八九十百]+[章节]\s+', '', header_line)
return title.strip()
def _extract_clause_number(self, clause_line: str) -> str:
"""从条款行提取条款编号"""
match = self.CLAUSE_PATTERN.match(clause_line)
if match:
return match.group(0).strip()
return ""
def _create_chunk(
self,
content: str,
doc_id: str,
doc_name: str,
section_num: str,
section_title: str,
clause_num: str,
start_pos: int,
end_pos: int,
regulation_type: str,
version: str
) -> TextChunk:
"""创建文本分块"""
# 清理内容
content = content.strip()
# 计算估算token数中文约1.5字符/token
token_count = int(len(content) * 0.7) # 简化估算
# 生成chunk_id
chunk_id = f"{doc_id}_{section_num}_{clause_num}_{start_pos}"
metadata = ChunkMetadata(
doc_id=doc_id,
doc_name=doc_name,
chunk_id=chunk_id,
section_number=section_num,
section_title=section_title,
clause_number=clause_num,
start_position=start_pos,
end_position=end_pos,
regulation_type=regulation_type,
version=version
)
return TextChunk(
content=content,
metadata=metadata,
token_count=token_count
)
def chunk_regulation_document(
markdown_text: str,
doc_id: str = "",
doc_name: str = "",
regulation_type: str = "",
version: str = "",
chunk_size: int = 512
) -> List[TextChunk]:
"""便捷函数:对法规文档进行分块"""
chunker = RegulationChunker(chunk_size=chunk_size)
return chunker.chunk_document(
markdown_text,
doc_id,
doc_name,
regulation_type,
version
)

View File

@@ -0,0 +1,15 @@
# src/services/llm/__init__.py
"""LLM服务模块"""
from .llm_factory import LLMFactory, get_llm_client
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
from .deepseek_client import DeepSeekClient
from .qwen_client import QwenClient, QwenVLClient
from .document_summarizer import DocumentSummarizer, summarize_document, DocumentSummary
__all__ = [
"LLMFactory", "get_llm_client",
"BaseLLMClient", "LLMResponse", "LLMConfig", "LLMProvider",
"DeepSeekClient", "QwenClient", "QwenVLClient",
"DocumentSummarizer", "summarize_document", "DocumentSummary"
]

View File

@@ -0,0 +1,116 @@
# src/services/llm/base_client.py
"""LLM客户端基类 - 统一接口定义"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any
from enum import Enum
class LLMProvider(Enum):
"""LLM提供商"""
DEEPSEEK = "deepseek"
QWEN = "qwen"
QWEN_VL = "qwen_vl"
@dataclass
class LLMResponse:
"""LLM响应结果"""
content: str
model: str
usage: Dict[str, int] = field(default_factory=dict)
finish_reason: str = "stop"
latency_ms: int = 0
error: Optional[str] = None
@property
def is_success(self) -> bool:
return self.error is None
@dataclass
class LLMConfig:
"""LLM配置"""
provider: LLMProvider
model: str
api_key: str
base_url: str
max_tokens: int = 4096
temperature: float = 0.7
top_p: float = 0.9
timeout: int = 300 # 默认超时300秒摘要/Skills生成可能需要较长时间
class BaseLLMClient(ABC):
"""LLM客户端基类"""
def __init__(self, config: LLMConfig):
self.config = config
self._client = None
@abstractmethod
def _init_client(self):
"""初始化客户端"""
pass
@abstractmethod
def chat(
self,
messages: List[Dict[str, str]],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> LLMResponse:
"""
对话补全
Args:
messages: 对话消息列表 [{"role": "user/assistant/system", "content": "..."}]
max_tokens: 最大输出token数
temperature: 温度参数
**kwargs: 其他参数
Returns:
LLMResponse: 响应结果
"""
pass
def complete(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> LLMResponse:
"""
单轮补全(便捷方法)
Args:
prompt: 用户输入
system_prompt: 系统提示词
max_tokens: 最大输出token数
temperature: 温度参数
Returns:
LLMResponse: 响应结果
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
return self.chat(messages, max_tokens, temperature, **kwargs)
@abstractmethod
def get_available_models(self) -> List[str]:
"""获取可用模型列表"""
pass
def estimate_tokens(self, text: str) -> int:
"""估算文本token数粗略估计"""
# 中文字符约1.5 token英文约0.25 token
chinese_chars = sum(1 for c in text if '' <= c <= '鿿')
other_chars = len(text) - chinese_chars
return int(chinese_chars * 1.5 + other_chars * 0.25)

View File

@@ -0,0 +1,130 @@
# src/services/llm/deepseek_client.py
"""DeepSeek LLM客户端 - OpenAI兼容API"""
import time
from typing import List, Dict, Optional
from loguru import logger
import httpx
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
class DeepSeekClient(BaseLLMClient):
"""
DeepSeek API客户端OpenAI兼容格式
支持模型:
- deepseek-chat
- deepseek-coder
- deepseek-reasoner
- deepseek-v3
- deepseek-v3.2
- deepseek-v4-flash
"""
SUPPORTED_MODELS = [
"deepseek-chat",
"deepseek-coder",
"deepseek-reasoner",
"deepseek-v3",
"deepseek-v3.2",
"deepseek-v4-flash"
]
def __init__(self, config: LLMConfig):
if config.provider != LLMProvider.DEEPSEEK:
raise ValueError(f"配置provider应为DEEPSEEK实际为{config.provider}")
super().__init__(config)
self._init_client()
def _init_client(self):
"""初始化HTTP客户端"""
self._client = httpx.Client(
base_url=self.config.base_url,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json"
},
timeout=self.config.timeout
)
logger.info(f"DeepSeek客户端初始化完成: {self.config.base_url} - {self.config.model}")
def chat(
self,
messages: List[Dict[str, str]],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> LLMResponse:
"""对话补全"""
start_time = time.time()
try:
payload = {
"model": self.config.model,
"messages": messages,
"max_tokens": max_tokens or self.config.max_tokens,
"temperature": temperature or self.config.temperature,
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": False
}
response = self._client.post("/chat/completions", json=payload)
response.raise_for_status()
data = response.json()
latency_ms = int((time.time() - start_time) * 1000)
choices = data.get("choices", [{}])
message = choices[0].get("message", {})
return LLMResponse(
content=message.get("content", ""),
model=data.get("model", self.config.model),
usage=data.get("usage", {}),
finish_reason=choices[0].get("finish_reason", "stop"),
latency_ms=latency_ms
)
except httpx.HTTPStatusError as e:
logger.error(f"DeepSeek API错误: {e.response.status_code} - {e.response.text}")
return LLMResponse(
content="",
model=self.config.model,
error=f"API错误: {e.response.status_code} - {e.response.text[:200]}"
)
except Exception as e:
logger.error(f"DeepSeek调用失败: {e}")
return LLMResponse(
content="",
model=self.config.model,
error=str(e)
)
def get_available_models(self) -> List[str]:
"""获取可用模型列表"""
return self.SUPPORTED_MODELS
def close(self):
"""关闭客户端"""
if self._client:
self._client.close()
def create_deepseek_client(
api_key: str,
model: str = "deepseek-v4-flash",
base_url: str = "http://6.86.80.4:30080/v1",
**kwargs
) -> DeepSeekClient:
"""便捷函数创建DeepSeek客户端"""
config = LLMConfig(
provider=LLMProvider.DEEPSEEK,
model=model,
api_key=api_key,
base_url=base_url,
**kwargs
)
return DeepSeekClient(config)

View File

@@ -0,0 +1,231 @@
# src/services/llm/document_summarizer.py
"""文档摘要生成服务 - LLM生成法规文档摘要"""
from typing import Dict, Optional
from dataclasses import dataclass
from loguru import logger
from app.services.llm import get_llm_client, BaseLLMClient
from app.services.rag.prompt_templates import get_prompt_template
from app.config.settings import settings
@dataclass
class DocumentSummary:
"""文档摘要结果"""
doc_name: str
summary: str
applicable_scope: str
key_clauses: list
key_terms: list
compliance_points: list
model: str
latency_ms: int
error: Optional[str] = None
@property
def is_success(self) -> bool:
return self.error is None
class DocumentSummarizer:
"""
文档摘要生成器
功能:
- 生成法规文档的核心要点摘要
- 提取适用范围
- 突出关键条款
- 列出合规要点
使用示例:
summarizer = DocumentSummarizer()
result = summarizer.summarize("GB 7258-2017", markdown_content)
print(result.summary)
"""
def __init__(
self,
provider: str = None,
model: str = None,
max_tokens: int = None
):
"""
初始化摘要生成器
Args:
provider: LLM提供商
model: LLM模型名称
max_tokens: 最大输出token数
"""
self.provider = provider or settings.llm_provider
self.model = model or settings.llm_model
self.max_tokens = max_tokens or settings.rag_summary_max_tokens
# LLM客户端延迟加载
self.llm: Optional[BaseLLMClient] = None
logger.info(f"摘要生成器初始化: provider={self.provider}, model={self.model}")
def _init_llm(self):
"""延迟初始化LLM"""
if self.llm is None:
self.llm = get_llm_client(
provider=self.provider,
model=self.model
)
def summarize(
self,
doc_name: str,
content: str,
regulation_type: str = "",
max_tokens: Optional[int] = None
) -> DocumentSummary:
"""
生成文档摘要
Args:
doc_name: 文档名称
content: 文档内容Markdown格式
regulation_type: 法规类型
max_tokens: 最大输出token数
Returns:
DocumentSummary: 摘要结果
"""
import time
start_time = time.time()
logger.info(f"生成文档摘要: {doc_name}")
try:
self._init_llm()
# 使用摘要模板
template = get_prompt_template("document_summary")
# 构建用户消息
user_content = template.user_template.format(
doc_name=doc_name,
content=content[:8000] # 截取前8000字符避免超出token限制
)
# 调用LLM
response = self.llm.chat(
messages=[
{"role": "system", "content": template.system_prompt},
{"role": "user", "content": user_content}
],
max_tokens=max_tokens or self.max_tokens,
temperature=0.3 # 低温度保证摘要准确性
)
latency_ms = int((time.time() - start_time) * 1000)
if not response.is_success:
return DocumentSummary(
doc_name=doc_name,
summary="",
applicable_scope="",
key_clauses=[],
key_terms=[],
compliance_points=[],
model=self.model,
latency_ms=latency_ms,
error=response.error
)
# 解析摘要结构
summary_data = self._parse_summary(response.content)
logger.success(f"摘要生成完成: {doc_name}, {latency_ms}ms")
return DocumentSummary(
doc_name=doc_name,
summary=summary_data.get("summary", response.content),
applicable_scope=summary_data.get("applicable_scope", ""),
key_clauses=summary_data.get("key_clauses", []),
key_terms=summary_data.get("key_terms", []),
compliance_points=summary_data.get("compliance_points", []),
model=response.model,
latency_ms=latency_ms
)
except Exception as e:
logger.error(f"摘要生成失败: {e}")
return DocumentSummary(
doc_name=doc_name,
summary="",
applicable_scope="",
key_clauses=[],
key_terms=[],
compliance_points=[],
model=self.model,
latency_ms=0,
error=str(e)
)
def _parse_summary(self, content: str) -> Dict:
"""解析摘要内容(提取结构化信息)"""
result = {
"summary": content,
"applicable_scope": "",
"key_clauses": [],
"key_terms": [],
"compliance_points": []
}
# 简单解析(提取关键信息)
lines = content.split("\n")
for line in lines:
line = line.strip()
# 提取适用范围
if "适用范围" in line or "适用对象" in line:
result["applicable_scope"] = line.split("")[-1].strip() if "" in line else line.split(":")[-1].strip()
# 提取关键条款
if line.startswith("- 【条款") or line.startswith("【条款"):
result["key_clauses"].append(line)
# 提取关键术语
if "关键术语" in line or "术语定义" in line:
# 继续读取后续几行
pass
# 提取合规要点
if "合规要点" in line or "必须满足" in line:
pass
return result
def batch_summarize(
self,
documents: list
) -> list:
"""
批量生成摘要
Args:
documents: 文档列表 [{"doc_name": str, "content": str}, ...]
Returns:
list: 摘要结果列表
"""
results = []
for doc in documents:
result = self.summarize(doc["doc_name"], doc["content"])
results.append(result)
return results
def summarize_document(
doc_name: str,
content: str,
**kwargs
) -> DocumentSummary:
"""便捷函数:生成文档摘要"""
summarizer = DocumentSummarizer(**kwargs)
return summarizer.summarize(doc_name, content)

View File

@@ -0,0 +1,258 @@
# src/services/llm/llm_factory.py
"""LLM工厂 - 统一创建和管理LLM客户端"""
from typing import Optional, Dict, Any
from loguru import logger
from functools import lru_cache
from .base_client import BaseLLMClient, LLMConfig, LLMProvider, LLMResponse
from .deepseek_client import DeepSeekClient
from .qwen_client import QwenClient, QwenVLClient
# 默认模型映射
DEFAULT_MODELS = {
LLMProvider.DEEPSEEK: "deepseek-v4-flash",
LLMProvider.QWEN: "qwen3.5-flash",
LLMProvider.QWEN_VL: "qwen3-vl-plus"
}
# API基础URL使用统一代理服务
DEFAULT_BASE_URLS = {
LLMProvider.DEEPSEEK: "http://6.86.80.4:30080/v1",
LLMProvider.QWEN: "http://6.86.80.4:30080/v1",
LLMProvider.QWEN_VL: "http://6.86.80.4:30080/v1"
}
class LLMFactory:
"""
LLM客户端工厂支持全局缓存
支持的提供商和模型:
- DeepSeek: deepseek-chat (DeepSeek-V3), deepseek-coder
- Qwen: qwen-turbo, qwen-plus, qwen-max, qwen-long
- QwenVL: qwen-vl-plus, qwen-vl-max (多模态)
使用示例:
factory = LLMFactory()
# 使用默认配置
client = factory.create("deepseek")
# 自定义配置
client = factory.create("qwen", model="qwen-max", temperature=0.5)
# 调用LLM
response = client.complete("你好,介绍一下自己")
"""
# 全局客户端缓存(类级别,跨实例共享)
_global_instances: Dict[str, BaseLLMClient] = {}
def __init__(self):
self._config_cache: Dict[str, Any] = {}
def create(
self,
provider: str,
api_key: Optional[str] = None,
model: Optional[str] = None,
base_url: Optional[str] = None,
max_tokens: int = 4096,
temperature: float = 0.7,
**kwargs
) -> BaseLLMClient:
"""
创建LLM客户端
Args:
provider: 提供商名称 ("deepseek", "qwen", "qwen_vl")
api_key: API密钥如未提供从环境变量获取
model: 模型名称(如未提供,使用默认模型)
base_url: API基础URL
max_tokens: 最大输出token数
temperature: 温度参数
**kwargs: 其他配置参数
Returns:
BaseLLMClient: LLM客户端实例
"""
provider_enum = self._parse_provider(provider)
# 获取配置
api_key = api_key or self._get_api_key(provider_enum)
model = model or DEFAULT_MODELS.get(provider_enum)
base_url = base_url or DEFAULT_BASE_URLS.get(provider_enum)
if not api_key:
raise ValueError(f"缺少API密钥请设置环境变量或传入api_key参数")
# 检查全局缓存
cache_key = f"{provider}_{model}"
if cache_key in LLMFactory._global_instances:
logger.debug(f"使用缓存的LLM客户端: {cache_key}")
return LLMFactory._global_instances[cache_key]
config = LLMConfig(
provider=provider_enum,
model=model,
api_key=api_key,
base_url=base_url,
max_tokens=max_tokens,
temperature=temperature,
**kwargs
)
# 创建客户端
client = self._create_client(config)
# 缓存到全局实例
LLMFactory._global_instances[cache_key] = client
logger.info(f"LLM客户端创建成功并缓存: {provider} - {model}")
return client
def _parse_provider(self, provider: str) -> LLMProvider:
"""解析提供商名称"""
provider_map = {
"deepseek": LLMProvider.DEEPSEEK,
"deepseek-v3": LLMProvider.DEEPSEEK,
"deepseek_chat": LLMProvider.DEEPSEEK,
"qwen": LLMProvider.QWEN,
"qwen-turbo": LLMProvider.QWEN,
"qwen-plus": LLMProvider.QWEN,
"qwen-max": LLMProvider.QWEN,
"qwen3.5-flash": LLMProvider.QWEN,
"qwen3.5-plus": LLMProvider.QWEN,
"qwen_vl": LLMProvider.QWEN_VL,
"qwen-vl": LLMProvider.QWEN_VL,
"qwen-vl-plus": LLMProvider.QWEN_VL,
"qwen-vl-max": LLMProvider.QWEN_VL
}
provider_lower = provider.lower()
if provider_lower not in provider_map:
raise ValueError(f"不支持的提供商: {provider},支持的: {list(provider_map.keys())}")
return provider_map[provider_lower]
def _get_api_key(self, provider: LLMProvider) -> Optional[str]:
"""从环境变量获取API密钥"""
import os
key_map = {
LLMProvider.DEEPSEEK: ["DEEPSEEK_API_KEY", "OPENAI_API_KEY"],
LLMProvider.QWEN: ["QWEN_API_KEY", "DASHSCOPE_API_KEY", "ALIBABA_API_KEY"],
LLMProvider.QWEN_VL: ["QWEN_API_KEY", "DASHSCOPE_API_KEY", "ALIBABA_API_KEY"]
}
for key_name in key_map.get(provider, []):
api_key = os.getenv(key_name)
if api_key:
return api_key
return None
def _create_client(self, config: LLMConfig) -> BaseLLMClient:
"""创建具体客户端"""
client_map = {
LLMProvider.DEEPSEEK: DeepSeekClient,
LLMProvider.QWEN: QwenClient,
LLMProvider.QWEN_VL: QwenVLClient
}
client_class = client_map.get(config.provider)
if not client_class:
raise ValueError(f"不支持的提供商: {config.provider}")
return client_class(config)
def get_cached(self, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
"""获取缓存的客户端"""
provider_enum = self._parse_provider(provider)
model = model or DEFAULT_MODELS.get(provider_enum)
cache_key = f"{provider}_{model}"
return LLMFactory._global_instances.get(cache_key)
def list_available_providers(self) -> Dict[str, list]:
"""列出可用的提供商和模型"""
return {
"deepseek": DeepSeekClient.SUPPORTED_MODELS,
"qwen": QwenClient.SUPPORTED_MODELS,
"qwen_vl": QwenVLClient.SUPPORTED_MODELS
}
@classmethod
def preload_clients(cls, providers: list = None):
"""
预加载LLM客户端应用启动时调用
Args:
providers: 要预加载的提供商列表默认加载qwen和deepseek
"""
if providers is None:
providers = ["qwen", "deepseek"]
factory = cls()
for provider in providers:
try:
client = factory.create(provider)
logger.success(f"预加载LLM客户端成功: {provider}")
except Exception as e:
logger.warning(f"预加载LLM客户端失败: {provider} - {e}")
@classmethod
def get_global_client(cls, provider: str, model: Optional[str] = None) -> Optional[BaseLLMClient]:
"""获取全局缓存的客户端"""
provider_lower = provider.lower()
# 处理模型名作为provider的情况如 qwen3.5-flash
if provider_lower.startswith("qwen"):
provider_lower = "qwen"
model = model or DEFAULT_MODELS.get(LLMProvider.QWEN if provider_lower == "qwen" else LLMProvider.DEEPSEEK)
cache_key = f"{provider_lower}_{model}"
return cls._global_instances.get(cache_key)
@classmethod
def cleanup(cls):
"""清理所有缓存的客户端"""
for cache_key, client in cls._global_instances.items():
try:
client.close()
logger.debug(f"关闭LLM客户端: {cache_key}")
except Exception as e:
logger.warning(f"关闭LLM客户端失败: {cache_key} - {e}")
cls._global_instances.clear()
logger.info("所有LLM客户端已清理")
@lru_cache
def get_llm_factory() -> LLMFactory:
"""获取LLM工厂实例缓存"""
return LLMFactory()
def get_llm_client(
provider: str = "qwen",
model: Optional[str] = None,
**kwargs
) -> BaseLLMClient:
"""
便捷函数获取LLM客户端优先使用缓存
Args:
provider: 提供商名称
model: 模型名称
**kwargs: 其他配置
Returns:
BaseLLMClient: LLM客户端实例
"""
factory = get_llm_factory()
# 先尝试获取缓存的实例
cached = factory.get_cached(provider, model)
if cached:
return cached
return factory.create(provider, model=model, **kwargs)

View File

@@ -0,0 +1,392 @@
# src/services/llm/qwen_client.py
"""Qwen LLM客户端 - 支持OpenAI兼容API格式"""
import time
import json
from typing import List, Dict, Optional, Generator, AsyncGenerator
from loguru import logger
import httpx
from .base_client import BaseLLMClient, LLMResponse, LLMConfig, LLMProvider
class QwenClient(BaseLLMClient):
"""
Qwen API客户端OpenAI兼容格式
支持通过new-api等代理服务调用
- qwen-turbo
- qwen-plus
- qwen-max
- qwen3.5-flash (推荐:快速响应)
- qwen3.5-plus
- qwen-long
- qwen2.5系列
"""
SUPPORTED_MODELS = [
"qwen-turbo",
"qwen-plus",
"qwen-max",
"qwen-max-longcontext",
"qwen-long",
"qwen3.5-flash",
"qwen3.5-plus",
"qwen3-plus",
"qwen2.5-72b-instruct",
"qwen2.5-32b-instruct",
"qwen2.5-14b-instruct",
"qwen2.5-7b-instruct"
]
def __init__(self, config: LLMConfig):
if config.provider not in [LLMProvider.QWEN, LLMProvider.QWEN_VL]:
raise ValueError(f"配置provider应为Qwen实际为{config.provider}")
super().__init__(config)
self._init_client()
def _init_client(self):
"""初始化HTTP客户端"""
# OpenAI兼容API格式
self._client = httpx.Client(
base_url=self.config.base_url,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json"
},
timeout=self.config.timeout
)
logger.info(f"Qwen客户端初始化完成: {self.config.base_url} - {self.config.model}")
def chat(
self,
messages: List[Dict[str, str]],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> LLMResponse:
"""对话补全OpenAI兼容格式"""
start_time = time.time()
try:
# OpenAI兼容格式的请求体
payload = {
"model": self.config.model,
"messages": messages,
"max_tokens": max_tokens or self.config.max_tokens,
"temperature": temperature or self.config.temperature,
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": False
}
# OpenAI兼容接口路径
response = self._client.post("/chat/completions", json=payload)
response.raise_for_status()
data = response.json()
latency_ms = int((time.time() - start_time) * 1000)
# OpenAI兼容格式的响应解析
choices = data.get("choices", [{}])
message = choices[0].get("message", {})
return LLMResponse(
content=message.get("content", ""),
model=data.get("model", self.config.model),
usage=data.get("usage", {}),
finish_reason=choices[0].get("finish_reason", "stop"),
latency_ms=latency_ms
)
except httpx.HTTPStatusError as e:
logger.error(f"Qwen API错误: {e.response.status_code} - {e.response.text}")
return LLMResponse(
content="",
model=self.config.model,
error=f"API错误: {e.response.status_code} - {e.response.text[:200]}"
)
except Exception as e:
logger.error(f"Qwen调用失败: {e}")
return LLMResponse(
content="",
model=self.config.model,
error=str(e)
)
def stream_chat(
self,
messages: List[Dict[str, str]],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> Generator[str, None, None]:
"""
流式对话补全SSE格式
Yields:
str: 每次返回一个文本片段
使用示例:
for chunk in client.stream_chat(messages):
print(chunk, end="", flush=True)
"""
try:
# OpenAI兼容格式的请求体启用流式输出
payload = {
"model": self.config.model,
"messages": messages,
"max_tokens": max_tokens or self.config.max_tokens,
"temperature": temperature or self.config.temperature,
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": True # 启用流式输出
}
# 使用stream模式发送请求
with self._client.stream("POST", "/chat/completions", json=payload) as response:
for line in response.iter_lines():
if line:
line = line.strip()
# SSE格式: data: {...}
if line.startswith("data: "):
data_str = line[6:] # 移除 "data: " 前缀
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if not choices:
continue # 跳过空的choices
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except httpx.HTTPStatusError as e:
logger.error(f"Qwen流式API错误: {e.response.status_code}")
yield f"[ERROR: API返回错误 {e.response.status_code}]"
except Exception as e:
logger.error(f"Qwen流式调用失败: {e}")
yield f"[ERROR: {str(e)}]"
async def async_stream_chat(
self,
messages: List[Dict[str, str]],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> AsyncGenerator[str, None]:
"""
异步流式对话补全用于FastAPI SSE响应
Yields:
str: 每次返回一个文本片段
"""
import asyncio
# 使用同步流式方法,包装为异步
for chunk in self.stream_chat(messages, max_tokens, temperature, **kwargs):
yield chunk
# 给async循环一个小延迟让其他任务有机会执行
await asyncio.sleep(0)
def get_available_models(self) -> List[str]:
"""获取可用模型列表"""
return self.SUPPORTED_MODELS
def close(self):
"""关闭客户端"""
if self._client:
self._client.close()
class QwenVLClient(BaseLLMClient):
"""
Qwen VL多模态客户端OpenAI兼容格式
支持模型:
- qwen-vl-plus
- qwen-vl-max
- qwen3-vl-plus
- qwen2-vl-7b-instruct
- qwen2-vl-72b-instruct
"""
SUPPORTED_MODELS = [
"qwen-vl-plus",
"qwen-vl-max",
"qwen3-vl-plus",
"qwen2-vl-7b-instruct",
"qwen2-vl-72b-instruct"
]
def __init__(self, config: LLMConfig):
if config.provider != LLMProvider.QWEN_VL:
raise ValueError(f"配置provider应为QWEN_VL实际为{config.provider}")
super().__init__(config)
self._init_client()
def _init_client(self):
"""初始化HTTP客户端"""
self._client = httpx.Client(
base_url=self.config.base_url,
headers={
"Authorization": f"Bearer {self.config.api_key}",
"Content-Type": "application/json"
},
timeout=self.config.timeout
)
logger.info(f"QwenVL客户端初始化完成: {self.config.base_url} - {self.config.model}")
def chat(
self,
messages: List[Dict[str, str]],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> LLMResponse:
"""多模态对话补全OpenAI兼容格式
支持图片输入,消息格式:
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "https://example.com/image.jpg"}},
{"type": "text", "text": "描述这张图片"}
]
}
"""
start_time = time.time()
try:
# OpenAI兼容格式的请求体
payload = {
"model": self.config.model,
"messages": messages,
"max_tokens": max_tokens or self.config.max_tokens,
"temperature": temperature or self.config.temperature,
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": False
}
response = self._client.post("/chat/completions", json=payload)
response.raise_for_status()
data = response.json()
latency_ms = int((time.time() - start_time) * 1000)
choices = data.get("choices", [{}])
message = choices[0].get("message", {})
return LLMResponse(
content=message.get("content", ""),
model=data.get("model", self.config.model),
usage=data.get("usage", {}),
finish_reason=choices[0].get("finish_reason", "stop"),
latency_ms=latency_ms
)
except httpx.HTTPStatusError as e:
logger.error(f"QwenVL API错误: {e.response.status_code} - {e.response.text}")
return LLMResponse(
content="",
model=self.config.model,
error=f"API错误: {e.response.status_code} - {e.response.text[:200]}"
)
except Exception as e:
logger.error(f"QwenVL调用失败: {e}")
return LLMResponse(
content="",
model=self.config.model,
error=str(e)
)
def stream_chat(
self,
messages: List[Dict[str, str]],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
**kwargs
) -> Generator[str, None, None]:
"""流式多模态对话补全"""
try:
payload = {
"model": self.config.model,
"messages": messages,
"max_tokens": max_tokens or self.config.max_tokens,
"temperature": temperature or self.config.temperature,
"top_p": kwargs.get("top_p", self.config.top_p),
"stream": True
}
with self._client.stream("POST", "/chat/completions", json=payload) as response:
for line in response.iter_lines():
if line:
line = line.strip()
if line.startswith("data: "):
data_str = line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if not choices:
continue # 跳过空的choices
delta = choices[0].get("delta", {})
content = delta.get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue
except Exception as e:
logger.error(f"QwenVL流式调用失败: {e}")
yield f"[ERROR: {str(e)}]"
def get_available_models(self) -> List[str]:
"""获取可用模型列表"""
return self.SUPPORTED_MODELS
def close(self):
"""关闭客户端"""
if self._client:
self._client.close()
def create_qwen_client(
api_key: str,
model: str = "qwen3.5-flash",
base_url: str = "http://6.86.80.4:30080/v1",
**kwargs
) -> QwenClient:
"""便捷函数创建Qwen客户端"""
config = LLMConfig(
provider=LLMProvider.QWEN,
model=model,
api_key=api_key,
base_url=base_url,
**kwargs
)
return QwenClient(config)
def create_qwen_vl_client(
api_key: str,
model: str = "qwen3-vl-plus",
base_url: str = "http://6.86.80.4:30080/v1",
**kwargs
) -> QwenVLClient:
"""便捷函数创建QwenVL客户端"""
config = LLMConfig(
provider=LLMProvider.QWEN_VL,
model=model,
api_key=api_key,
base_url=base_url,
**kwargs
)
return QwenVLClient(config)

View File

@@ -0,0 +1,425 @@
"""
Mock数据服务 - 提供预设假数据供前后端对接测试
"""
from datetime import datetime
from typing import Dict, List, Any
import uuid
# 预设法规文档列表
MOCK_DOCUMENTS: List[Dict[str, Any]] = [
{
"id": "doc-001",
"name": "道路交通安全法.pdf",
"chunks": 156,
"status": "indexed",
"created_at": datetime(2026, 5, 10, 10, 0, 0),
},
{
"id": "doc-002",
"name": "机动车登记规定.docx",
"chunks": 89,
"status": "indexed",
"created_at": datetime(2026, 5, 10, 11, 0, 0),
},
{
"id": "doc-003",
"name": "电动自行车规范.pdf",
"chunks": 42,
"status": "indexed",
"created_at": datetime(2026, 5, 10, 12, 0, 0),
},
{
"id": "doc-004",
"name": "GB 38031-2020 电动汽车安全要求.pdf",
"chunks": 128,
"status": "indexed",
"created_at": datetime(2026, 5, 10, 13, 0, 0),
},
{
"id": "doc-005",
"name": "C-NCAP管理规则(2021版).pdf",
"chunks": 95,
"status": "indexed",
"created_at": datetime(2026, 5, 10, 14, 0, 0),
},
]
# 预设快捷问题
MOCK_QUICK_QUESTIONS: List[Dict[str, str]] = [
{"id": "q1", "question": "电动自行车需要上牌照吗?", "category": "车辆登记"},
{"id": "q2", "question": "新能源汽车有哪些补贴政策?", "category": "新能源"},
{"id": "q3", "question": "车辆年检的规定是什么?", "category": "年检"},
{"id": "q4", "question": "驾驶证过期了怎么处理?", "category": "驾驶证"},
]
# 预设检索结果
MOCK_RETRIEVAL_RESULTS: List[Dict[str, Any]] = [
{
"id": "chunk-001",
"score": 0.95,
"preview": "根据《道路交通安全法》第十八条规定,电动自行车经公安机关交通管理部门登记后,方可上道路行驶...",
"doc_name": "道路交通安全法",
"clause": "第十八条",
"content": "根据《道路交通安全法》第十八条规定,电动自行车经公安机关交通管理部门登记后,方可上道路行驶。电动自行车应当符合国家标准,最高设计车速不超过二十五公里每小时,整车质量不超过五十五千克。",
},
{
"id": "chunk-002",
"score": 0.88,
"preview": "电动自行车需符合GB17761-2018国家标准包括最高车速、整车质量、脚踏骑行能力等要求...",
"doc_name": "电动自行车规范",
"clause": "第4条",
"content": "电动自行车需符合GB17761-2018国家标准。主要技术要求包括最高设计车速不超过25km/h整车质量不超过55kg具有脚踏骑行能力蓄电池标称电压不超过48V电动机额定连续输出功率不超过400W。",
},
{
"id": "chunk-003",
"score": 0.82,
"preview": "机动车登记规定:初次申领机动车号牌、行驶证的,机动车所有人应当向住所地的车辆管理所申请注册登记...",
"doc_name": "机动车登记规定",
"clause": "第5条",
"content": "机动车登记规定:初次申领机动车号牌、行驶证的,机动车所有人应当向住所地的车辆管理所申请注册登记。申请注册登记的,应当提交机动车所有人的身份证明、购车发票等机动车来历证明、机动车整车出厂合格证明或者进口机动车进口凭证。",
},
{
"id": "chunk-004",
"score": 0.75,
"preview": "驾驶电动自行车上道路行驶,应当佩戴安全头盔,遵守道路交通安全法律法规...",
"doc_name": "道路交通安全法",
"clause": "第76条",
"content": "驾驶电动自行车上道路行驶,应当佩戴安全头盔,遵守道路交通安全法律法规。电动自行车不得逆向行驶,不得在机动车道内行驶,最高车速不得超过规定的限速。",
},
{
"id": "chunk-005",
"score": 0.68,
"preview": "电动汽车动力电池安全要求电池系统发生热失控后应在5分钟内不起火不爆炸...",
"doc_name": "GB 38031-2020",
"clause": "第7条",
"content": "电动汽车动力电池安全要求GB 38031-2020电池系统发生热失控后应在5分钟内不起火不爆炸为乘员预留逃生时间。电池包需通过针刺、过充、短路等安全测试。",
},
]
# 预设RAG问答答案模板按关键词匹配
MOCK_RAG_ANSWERS: Dict[str, Dict[str, Any]] = {
"电动自行车": {
"text": "根据《道路交通安全法》及相关规范,电动自行车上路需满足以下条件:\n\n1. 符合国家标准 GB17761-2018\n2. 经公安机关交通管理部门登记\n3. 最高设计车速不超过 25km/h\n4. 整车质量不超过 55kg\n5. 具有脚踏骑行能力\n6. 蓄电池标称电压不超过 48V\n\n行驶时还需佩戴安全头盔,不得逆向行驶或在机动车道内行驶。",
"retrieval_ids": ["chunk-001", "chunk-002", "chunk-004"],
},
"驾驶证": {
"text": "驾驶证申请流程如下:\n\n1. 到驾校报名并参加培训\n2. 通过科目一(理论考试)\n3. 通过科目二(场地驾驶技能考试)\n4. 通过科目三(道路驾驶技能考试)\n5. 通过科目四(安全文明驾驶常识考试)\n6. 领取驾驶证\n\n初次申领需到住所地车辆管理所申请注册登记。",
"retrieval_ids": ["chunk-003"],
},
"超速": {
"text": "超速处罚标准(根据《道路交通安全法》):\n\n- 超速10%以下:警告\n- 超速10%-20%罚款50-200元\n- 超速20%-50%罚款200-500元记3-6分\n- 超速50%以上罚款500-2000元记12分可吊销驾驶证\n\n机动车驾驶人违反道路交通安全法律、法规将处警告或二十元以上二百元以下罚款。",
"retrieval_ids": ["chunk-001"],
},
"年检": {
"text": "车辆年检规定:\n\n- 小型私家车6年内免检每2年申领标志6-10年每2年检验10年以上每年检验\n- 车辆需携带行驶证、交强险保单\n- 检验项目:灯光、制动、排放等\n\n机动车所有人的住所迁出车辆管理所管辖区域的,需在登记证书上签注变更事项。",
"retrieval_ids": ["chunk-003"],
},
"电池": {
"text": "电动汽车电池安全标准GB 38031-2020\n\n1. 热失控要求电池系统发生热失控后应在5分钟内不起火不爆炸为乘员预留逃生时间\n2. 电池包需通过针刺、过充、短路等安全测试\n3. 充电系统应具备过充保护功能当电池SOC达到100%时应自动停止充电\n4. 充电接口应符合GB/T 18487.1标准要求\n\n以上要求确保电动汽车的整车安全性。",
"retrieval_ids": ["chunk-005"],
},
"碰撞": {
"text": "正面碰撞测试要求C-NCAP管理规则\n\n1. 正面100%重叠刚性壁障碰撞试验\n2. 碰撞速度50km/h\n3. 试验后要求:\n - 车门应能打开\n - 燃油系统无泄漏\n - 座椅及安全带功能正常\n\n此测试用于评估车辆在正面碰撞事故中对乘员的保护能力。",
"retrieval_ids": [],
},
"AEB": {
"text": "AEB自动紧急制动系统测试标准\n\n1. 系统应在检测到前方障碍物时主动减速或停车\n2. 测试场景分为三种:\n - 目标车静止\n - 目标车移动\n - 目标车制动\n3. AEB功能是C-NCAP评分的重要加分项\n\n该系统对提升车辆主动安全性能具有重要意义。",
"retrieval_ids": [],
},
"高速公路": {
"text": "高速公路安全距离规定:\n\n1. 车速超过100km/h时与同车道前车保持100米以上距离\n2. 车速低于100km/h时距离可适当缩短\n3. 执行紧急任务的警车、消防车、救护车、工程救险车不受行驶速度限制\n\n保持安全距离是预防追尾事故的关键措施。",
"retrieval_ids": [],
},
}
# 预设合规分析结果
MOCK_COMPLIANCE_RESULT: Dict[str, Any] = {
"task_id": "task-001",
"dashboard": {
"score": 78,
"high_risk_count": 2,
"medium_risk_count": 1,
"low_risk_count": 0,
"need_fix_segments": 3,
"status": "warning",
"status_label": "需优化",
},
"segments": [
{
"id": 1,
"index": 1,
"intent": "车身结构设计",
"start_pos": 45,
"end_pos": 230,
"content": "车身采用高强度钢铝混合结构A柱和B柱使用热成型钢板厚度2.5mm。车顶结构设计满足GB 26112-2010抗压强度要求正面碰撞能量吸收区域采用渐进式变形设计确保碰撞时能量有效分散。",
"risk_level": "high",
"regulations": [
{
"id": 1,
"name": "GB 26112-2010",
"clause": "第4.2条",
"score": 0.95,
"match_keyword": "车顶抗压强度",
"category": "high",
"full_content": "车顶结构应能承受相当于车辆整备质量1.5倍的载荷,载荷分布应均匀,试验后车顶变形量不超过规定值。",
},
{
"id": 2,
"name": "C-NCAP管理规则",
"clause": "第3.1条",
"score": 0.88,
"match_keyword": "正面碰撞",
"category": "high",
"full_content": "正面碰撞试验速度为50km/h碰撞后车门应能打开燃油系统无泄漏座椅及安全带功能正常。",
},
{
"id": 3,
"name": "GB 11551-2014",
"clause": "第5条",
"score": 0.72,
"match_keyword": "碰撞能量吸收",
"category": "medium",
"full_content": "车辆正面碰撞时应有效保护乘员,碰撞能量应通过车身结构合理分散。",
},
{
"id": 4,
"name": "机动车安全技术条件",
"clause": "第12条",
"score": 0.58,
"match_keyword": "A柱强度",
"category": "medium",
"full_content": "A柱应具备足够的抗变形能力材料强度应符合相关标准要求。",
},
],
},
{
"id": 2,
"index": 2,
"intent": "动力系统配置",
"start_pos": 298,
"end_pos": 425,
"content": "搭载永磁同步电机最大功率150kW峰值扭矩310Nm。电池组采用三元锂离子电池容量75kWh能量密度180Wh/kg。充电接口支持快充30分钟充至80%和慢充8小时充满符合GB/T 18487.1-2015标准。",
"risk_level": "medium",
"regulations": [
{
"id": 5,
"name": "GB/T 18487.1-2015",
"clause": "第6条",
"score": 0.94,
"match_keyword": "充电接口标准",
"category": "high",
"full_content": "电动汽车传导充电接口应符合GB/T 18487.1标准要求,充电系统应具备过充保护功能。",
},
{
"id": 6,
"name": "GB/T 31484-2015",
"clause": "第4条",
"score": 0.85,
"match_keyword": "电池能量密度",
"category": "high",
"full_content": "动力电池能量密度不低于120Wh/kg电池系统需通过热失控测试。",
},
{
"id": 7,
"name": "新能源汽车生产企业准入",
"clause": "第8条",
"score": 0.65,
"match_keyword": "电机功率",
"category": "medium",
"full_content": "驱动电机应符合相关技术标准,功率参数应在规定范围内。",
},
{
"id": 8,
"name": "电动汽车安全要求",
"clause": "第7条",
"score": 0.45,
"match_keyword": "充电时间",
"category": "low",
"full_content": "充电系统应具备过充保护功能当电池SOC达到100%时应自动停止充电。",
},
],
},
{
"id": 3,
"index": 3,
"intent": "安全配置设计",
"start_pos": 570,
"end_pos": 725,
"content": "配备6个安全气囊前排双气囊、侧气囊、侧气帘采用预紧式安全带。ABS系统采用博世第9代ESP具备碰撞预警功能FCW和自动紧急制动AEB。方向盘集成驾驶员疲劳监测摄像头。",
"risk_level": "low",
"regulations": [
{
"id": 9,
"name": "GB 27887-2011",
"clause": "第5条",
"score": 0.92,
"match_keyword": "安全气囊",
"category": "high",
"full_content": "乘用车应配备驾驶员和乘客安全气囊,气囊系统应符合相关技术标准。",
},
{
"id": 10,
"name": "GB/T 26991-2011",
"clause": "第3条",
"score": 0.78,
"match_keyword": "ABS系统",
"category": "medium",
"full_content": "车辆应配备防抱死制动系统,系统性能应符合相关标准要求。",
},
{
"id": 11,
"name": "C-NCAP管理规则",
"clause": "第4.2条",
"score": 0.71,
"match_keyword": "AEB自动制动",
"category": "medium",
"full_content": "主动安全配置评分包含AEB功能AEB系统应能有效检测障碍物并主动减速。",
},
{
"id": 12,
"name": "机动车运行安全技术条件",
"clause": "第15条",
"score": 0.38,
"match_keyword": "疲劳监测",
"category": "low",
"full_content": "建议配备驾驶员状态监测系统,及时发现驾驶员疲劳或分心状态。",
},
],
},
],
"priority_actions": [
{
"regulation": "GB 26112-2010 第4.2条",
"issue": "缺少车顶抗压强度测试数据",
"suggestion": "补充车顶抗压强度具体测试数据确保满足1.5倍整备质量载荷要求",
"severity": "high",
},
{
"regulation": "GB/T 31484-2015 第4条",
"issue": "缺少电池热失控测试报告",
"suggestion": "补充电池热失控测试报告验证5分钟内不起火不爆炸",
"severity": "high",
},
{
"regulation": "C-NCAP管理规则 第3.1条",
"issue": "缺少碰撞后车门开启性能数据",
"suggestion": "提供碰撞后车门开启性能测试数据",
"severity": "medium",
},
],
}
# 预设合规对话响应模板
MOCK_COMPLIANCE_CHAT_RESPONSES: Dict[str, Dict[str, str]] = {
"车身结构设计": {
"compliance": "根据当前分析,车身结构设计部分存在以下合规问题:\n\n1. GB 26112-2010要求车顶承受1.5倍整备质量载荷,目前设计声明满足要求但缺少测试数据\n2. C-NCAP正面碰撞后车门应能打开需提供碰撞测试报告\n\n建议补充相关测试数据以提升合规评分。",
"interpretation": "GB 26112-2010 第4.2条具体要求解读:\n\n车顶抗压强度测试是车辆被动安全的重要指标。该标准要求车顶结构能够承受相当于车辆整备质量1.5倍的均匀分布载荷,试验后车顶变形量不得超过规定限值。\n\n热成型钢板22MnB5材料抗拉强度约1500-1650 MPa理论上能满足要求但需通过实际测试验证。",
"suggestion": "针对车身结构设计的修改建议:\n\n1. 补充车顶抗压强度测试报告\n2. 提供A柱材料认证证书\n3. 完善正面碰撞能量吸收设计说明\n4. 添加碰撞后车门开启性能数据\n\n这些补充材料可有效提升合规评分。",
},
"动力系统配置": {
"compliance": "动力系统配置整体合规性良好,主要检查点:\n\n1. 电池能量密度180Wh/kg超过最低要求120Wh/kg ✓\n2. 充电接口符合GB/T 18487.1标准 ✓\n3. 快充30分钟充至80%符合行业标准 ✓\n\n需补充电池热失控测试报告。",
"interpretation": "GB/T 31484-2015对动力电池的要求解读\n\n1. 能量密度不低于120Wh/kg您的设计180Wh/kg满足要求\n2. 循环寿命不少于1000次循环后容量保持率≥80%\n3. 安全测试:需通过针刺、过充、短路等测试\n\n建议补充循环寿命测试数据。",
"suggestion": "动力系统配置改进建议:\n\n1. 补充电池热失控测试报告\n2. 提供循环寿命测试数据\n3. 添加充电系统过充保护功能说明\n4. 完善电池管理系统BMS技术文档",
},
"安全配置设计": {
"compliance": "安全配置设计合规性评估:\n\n1. 安全气囊配置满足GB 27887-2011要求 ✓\n2. ABS/ESP系统符合标准 ✓\n3. AEB功能是C-NCAP加分项 ✓\n\n驾驶员疲劳监测是建议配置,不强制要求。",
"interpretation": "C-NCAP主动安全评分规则解读\n\nAEB自动紧急制动系统是C-NCAP评分的重要加分项最高可获得额外加分。测试场景包括\n- 目标车静止场景\n- 目标车移动场景\n- 目标车制动场景\n\n建议完善AEB系统测试数据以获取更高评分。",
"suggestion": "安全配置优化建议:\n\n1. 提供AEB系统测试数据\n2. 补充FCW预警功能测试报告\n3. 添加安全气囊展开时间数据\n4. 完善驾驶员疲劳监测系统说明(如有)",
},
}
# 预设系统统计数据
MOCK_SYSTEM_STATS: Dict[str, int] = {
"docs": 5,
"chunks": 510,
"vectors": 510,
"segments": 0,
}
# 预设系统配置
MOCK_SYSTEM_CONFIG: Dict[str, Any] = {
"llm": {
"model": "qwen-max",
},
"embedding": {
"model": "text-embedding-v3",
"dimension": 1536,
},
"milvus": {
"host": "localhost",
"port": 19530,
},
"retrieval": {
"vector_top_k": 10,
"final_top_k": 5,
},
}
def get_mock_documents() -> List[Dict[str, Any]]:
"""获取预设法规文档列表"""
return MOCK_DOCUMENTS
def get_mock_quick_questions() -> List[Dict[str, str]]:
"""获取预设快捷问题"""
return MOCK_QUICK_QUESTIONS
def get_mock_retrieval(query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""根据查询关键词返回预设检索结果"""
results = []
for keyword, data in MOCK_RAG_ANSWERS.items():
if keyword in query:
for retrieval_id in data.get("retrieval_ids", []):
for item in MOCK_RETRIEVAL_RESULTS:
if item["id"] == retrieval_id:
results.append({
"id": item["id"],
"score": item["score"],
"preview": item["preview"],
"doc_name": item["doc_name"],
"clause": item["clause"],
})
break
if not results:
results = MOCK_RETRIEVAL_RESULTS[:top_k]
return results[:top_k]
def get_mock_rag_answer(query: str) -> str:
"""根据查询关键词返回预设答案"""
for keyword, data in MOCK_RAG_ANSWERS.items():
if keyword in query:
return data["text"]
return "抱歉,暂未找到与您问题直接相关的法规内容。请尝试更具体的问题,或联系交通管理部门获取详细信息。\n\n您可以尝试询问电动自行车、驾驶证、超速处罚、年检、电池安全、碰撞测试、AEB系统、高速公路规则等话题。"
def get_mock_compliance_result(task_id: str) -> Dict[str, Any]:
"""获取预设合规分析结果"""
result = MOCK_COMPLIANCE_RESULT.copy()
result["task_id"] = task_id
return result
def get_mock_compliance_chat_response(intent: str, query: str) -> str:
"""获取预设合规对话响应"""
responses = MOCK_COMPLIANCE_CHAT_RESPONSES.get(intent, {})
if "合规" in query or "符合" in query:
return responses.get("compliance", "根据相关法规分析,该段落的合规性需进一步评估。")
elif "解读" in query or "什么" in query or "如何" in query:
return responses.get("interpretation", "法规要求详细解读如下...")
elif "修改" in query or "建议" in query or "完善" in query:
return responses.get("suggestion", "建议进行以下修改以提升合规性...")
return f"关于您的问题,{intent}部分涉及多条相关法规。您可以进一步询问合规性评估或修改建议。"
def generate_task_id() -> str:
"""生成任务ID"""
return f"task-{uuid.uuid4().hex[:8]}"
def generate_doc_id() -> str:
"""生成文档ID"""
return f"doc-{uuid.uuid4().hex[:8]}"

View File

@@ -0,0 +1,7 @@
# src/services/parser/__init__.py
"""文档解析服务"""
from .pdf_parser import PDFParser
from .docx_parser import DocxParser
__all__ = ["PDFParser", "DocxParser"]

View File

@@ -0,0 +1,287 @@
# src/services/parser/docx_parser.py
"""Word文档解析 - 使用python-docx"""
from docx import Document
from docx.enum.text import WD_ALIGN_PARAGRAPH
from typing import List, Dict, Optional
from dataclasses import dataclass, field
from loguru import logger
import re
@dataclass
class DocxParagraph:
"""段落内容"""
text: str
level: int = 0 # 标题级别0表示正文
is_list: bool = False
list_number: Optional[str] = None
@dataclass
class DocxTable:
"""表格内容"""
rows: List[List[str]]
markdown: str = ""
@dataclass
class DocxDocumentContent:
"""Word文档完整内容"""
file_path: str
paragraphs: List[DocxParagraph]
tables: List[DocxTable]
metadata: Dict[str, str] = field(default_factory=dict)
markdown_text: str = ""
class DocxParser:
"""Word文档解析器 - 基于python-docx"""
def __init__(self):
self.document = None
def parse(self, file_path: str) -> DocxDocumentContent:
"""
解析Word文档
Args:
file_path: Word文档路径
Returns:
DocxDocumentContent: 解析后的文档内容
"""
logger.info(f"开始解析Word文档: {file_path}")
try:
self.document = Document(file_path)
doc_content = DocxDocumentContent(
file_path=file_path,
paragraphs=[],
tables=[]
)
# 提取文档元数据
doc_content.metadata = self._extract_metadata()
# 提取段落
doc_content.paragraphs = self._extract_paragraphs()
# 提取表格
doc_content.tables = self._extract_tables()
# 生成Markdown格式文本
doc_content.markdown_text = self._generate_markdown(doc_content)
logger.success(f"Word文档解析完成{len(doc_content.paragraphs)}个段落")
return doc_content
except Exception as e:
logger.error(f"Word文档解析失败: {e}")
raise
def _extract_metadata(self) -> Dict[str, str]:
"""提取文档元数据"""
metadata = {}
try:
core_props = self.document.core_properties
metadata = {
"title": core_props.title or "",
"author": core_props.author or "",
"subject": core_props.subject or "",
"keywords": core_props.keywords or "",
"created": str(core_props.created) if core_props.created else "",
"modified": str(core_props.modified) if core_props.modified else "",
}
except Exception as e:
logger.warning(f"提取元数据失败: {e}")
return metadata
def _extract_paragraphs(self) -> List[DocxParagraph]:
"""提取所有段落"""
paragraphs = []
for para in self.document.paragraphs:
text = para.text.strip()
if not text:
continue
# 判断标题级别
level = self._get_paragraph_level(para)
# 判断是否是列表项
is_list, list_number = self._detect_list_item(para)
paragraph = DocxParagraph(
text=text,
level=level,
is_list=is_list,
list_number=list_number
)
paragraphs.append(paragraph)
return paragraphs
def _get_paragraph_level(self, para) -> int:
"""
判断段落标题级别
Returns:
int: 标题级别0表示正文
"""
# 方法1检查段落样式
style_name = para.style.name if para.style else ""
if "Heading" in style_name or "标题" in style_name:
# 从样式名称中提取级别
match = re.search(r'Heading\s*(\d)|标题\s*(\d)', style_name)
if match:
level = int(match.group(1) or match.group(2))
return level
# 方法2检查段落格式字号
# 标题通常字号较大
if para.paragraph_format:
# 可以根据字号判断,这里简化处理
pass
# 方法3根据内容模式判断法规文档特征
text = para.text.strip()
# 第一章、第X章 -> 二级标题
if re.match(r'^第[一二三四五六七八九十百]+章\s', text):
return 2
# 第X节 -> 三级标题
elif re.match(r'^第[一二三四五六七八九十百]+节\s', text):
return 3
# 第X条 -> 四级标题
elif re.match(r'^第[一二三四五六七八九十百]+条\s', text):
return 4
return 0 # 正文
def _detect_list_item(self, para) -> tuple[bool, Optional[str]]:
"""检测是否是列表项"""
text = para.text.strip()
# 数字列表1.、2.、1、[1]等
if re.match(r'^[\d]+[.、)\]]\s', text):
match = re.match(r'^([\d]+[.、)\]])\s', text)
return True, match.group(1) if match else None
# 中文数字列表:一、二、(一)等
if re.match(r'^[一二三四五六七八九十]+[、.)]\s', text):
match = re.match(r'^([一二三四五六七八九十]+[、.)])\s', text)
return True, match.group(1) if match else None
# 检查段落格式中的列表编号
if para.paragraph_format and hasattr(para.paragraph_format, 'left_indent'):
# 有缩进的可能是列表项
pass
return False, None
def _extract_tables(self) -> List[DocxTable]:
"""提取所有表格"""
tables = []
for table in self.document.tables:
rows = []
for row in table.rows:
cells = []
for cell in row.cells:
cells.append(cell.text.strip())
rows.append(cells)
# 转换为Markdown表格
markdown = self._table_to_markdown(rows)
table_content = DocxTable(rows=rows, markdown=markdown)
tables.append(table_content)
return tables
def _table_to_markdown(self, rows: List[List[str]]) -> str:
"""将表格转换为Markdown格式"""
if not rows or len(rows) < 1:
return ""
lines = []
# 表头
if len(rows) >= 1:
header = rows[0]
lines.append("| " + " | ".join(cell for cell in header) + " |")
lines.append("| " + " | ".join("---" for _ in header) + " |")
# 数据行
for row in rows[1:]:
lines.append("| " + " | ".join(cell for cell in row) + " |")
return "\n".join(lines)
def _generate_markdown(self, doc_content: DocxDocumentContent) -> str:
"""生成Markdown格式文本"""
lines = []
# 文档标题
title = doc_content.metadata.get("title", "")
if title:
lines.append(f"# {title}\n")
else:
# 从第一个段落获取标题(如果是标题样式)
for para in doc_content.paragraphs[:5]:
if para.level == 1:
lines.append(f"# {para.text}\n")
break
else:
lines.append(f"# {doc_content.file_path}\n")
# 元数据信息
lines.append("\n## 文档信息\n")
for key, value in doc_content.metadata.items():
if value:
lines.append(f"- **{key}**: {value}")
# 正文内容
lines.append("\n## 正文\n")
table_index = 0
for para in doc_content.paragraphs:
if para.level > 0:
# 标题
prefix = "#" * para.level
lines.append(f"\n{prefix} {para.text}\n")
elif para.is_list:
# 列表项
lines.append(f"- {para.text}")
else:
# 正文
lines.append(para.text)
# 添加表格
if doc_content.tables:
lines.append("\n## 表格\n")
for i, table in enumerate(doc_content.tables):
lines.append(f"\n### 表格 {i + 1}\n")
lines.append(table.markdown + "\n")
return "\n".join(lines)
def parse_to_markdown(self, file_path: str) -> str:
"""直接解析并返回Markdown文本"""
doc_content = self.parse(file_path)
return doc_content.markdown_text
def parse_docx(file_path: str) -> DocxDocumentContent:
"""便捷函数解析Word文档"""
parser = DocxParser()
return parser.parse(file_path)
def parse_docx_to_markdown(file_path: str) -> str:
"""便捷函数解析Word并返回Markdown"""
parser = DocxParser()
return parser.parse_to_markdown(file_path)

View File

@@ -0,0 +1,204 @@
# src/services/parser/mineru_parser.py
"""MinerU多模态PDF解析 - 版面感知解析"""
from typing import Optional, Dict
from dataclasses import dataclass, field
from loguru import logger
import os
@dataclass
class MinerUResult:
"""MinerU解析结果"""
file_path: str
markdown_text: str
metadata: Dict[str, str] = field(default_factory=dict)
success: bool = True
error_message: str = ""
class MinerUParser:
"""
MinerU多模态PDF解析器
MinerU (magic-pdf) 是一个开源的高质量PDF解析工具
支持版面感知解析,能够识别文档中的标题、正文、表格、图片等元素,
并输出结构化的Markdown格式。
GitHub: https://github.com/opendatalab/MinerU
"""
def __init__(self):
self.available = self._check_mineru_available()
def _check_mineru_available(self) -> bool:
"""检查MinerU是否可用"""
try:
from magic_pdf.pipe.UNIPipe import UNIPipe
return True
except ImportError:
logger.warning("MinerU (magic-pdf) 未安装,请运行: pip install magic-pdf[full]")
return False
def parse(self, file_path: str, output_dir: Optional[str] = None) -> MinerUResult:
"""
使用MinerU解析PDF文档
Args:
file_path: PDF文件路径
output_dir: 输出目录(可选,用于保存解析产物)
Returns:
MinerUResult: 解析结果
"""
logger.info(f"尝试使用MinerU解析: {file_path}")
if not self.available:
return MinerUResult(
file_path=file_path,
markdown_text="",
success=False,
error_message="MinerU未安装"
)
try:
from magic_pdf.pipe.UNIPipe import UNIPipe
from magic_pdf.libs.MakeContentConfig import DropMode
# 设置输出目录
if output_dir is None:
output_dir = os.path.dirname(file_path)
# 创建解析管道
# OCR模式可以根据PDF类型选择
# auto: 自动判断是否需要OCR
# txt: 纯文本PDF无OCR
# ocr: 扫描件PDFOCR
pipe = UNIPipe(file_path, output_dir)
# 执行解析
# pipe_mk() 返回Markdown格式文本
markdown_content = pipe.pipe_mk()
logger.success(f"MinerU解析成功")
return MinerUResult(
file_path=file_path,
markdown_text=markdown_content,
metadata=self._extract_metadata(pipe),
success=True
)
except Exception as e:
logger.error(f"MinerU解析失败: {e}")
return MinerUResult(
file_path=file_path,
markdown_text="",
success=False,
error_message=str(e)
)
def _extract_metadata(self, pipe) -> Dict[str, str]:
"""从解析管道提取元数据"""
metadata = {}
try:
# MinerU解析管道中可能包含的元数据信息
if hasattr(pipe, 'pdf_mid_data') and pipe.pdf_mid_data:
mid_data = pipe.pdf_mid_data
# 提取可能的元数据字段
metadata = {
"page_count": str(mid_data.get("page_count", "")),
"language": str(mid_data.get("language", "")),
"is_scanned": str(mid_data.get("is_scanned", "")),
}
except Exception as e:
logger.warning(f"提取MinerU元数据失败: {e}")
return metadata
def parse_to_markdown(self, file_path: str) -> str:
"""直接解析并返回Markdown文本"""
result = self.parse(file_path)
return result.markdown_text if result.success else ""
class ParserOrchestrator:
"""
解析服务编排 - 按优先级选择解析器
解析策略:
1. 优先尝试MinerU版面感知能力强
2. MinerU失败时回退到基础PyMuPDF解析
"""
def __init__(self):
from .pdf_parser import PDFParser
self.mineru_parser = MinerUParser()
self.pdf_parser = PDFParser()
self.mineru_available = self.mineru_parser.available
def parse_pdf(self, file_path: str, prefer_mineru: bool = True) -> str:
"""
解析PDF文档按优先级选择解析器
Args:
file_path: PDF文件路径
prefer_mineru: 是否优先使用MinerU
Returns:
str: Markdown格式文本
"""
markdown_text = ""
if prefer_mineru and self.mineru_available:
# 优先尝试MinerU
result = self.mineru_parser.parse(file_path)
if result.success:
markdown_text = result.markdown_text
logger.info("使用MinerU解析成功")
return markdown_text
else:
logger.warning(f"MinerU解析失败回退到PyMuPDF: {result.error_message}")
# 回退到PyMuPDF基础解析
logger.info("使用PyMuPDF基础解析")
markdown_text = self.pdf_parser.parse_to_markdown(file_path)
return markdown_text
def parse_docx(self, file_path: str) -> str:
"""解析Word文档"""
from .docx_parser import DocxParser
docx_parser = DocxParser()
return docx_parser.parse_to_markdown(file_path)
def parse(self, file_path: str) -> str:
"""
根据文件类型选择解析器
Args:
file_path: 文件路径
Returns:
str: Markdown格式文本
"""
ext = os.path.splitext(file_path)[1].lower()
if ext == ".pdf":
return self.parse_pdf(file_path)
elif ext in [".docx", ".doc"]:
return self.parse_docx(file_path)
else:
raise ValueError(f"不支持的文件类型: {ext}")
def parse_with_mineru(file_path: str) -> MinerUResult:
"""便捷函数使用MinerU解析"""
parser = MinerUParser()
return parser.parse(file_path)
def parse_pdf_smart(file_path: str) -> str:
"""便捷函数智能解析PDF自动选择最佳解析器"""
orchestrator = ParserOrchestrator()
return orchestrator.parse_pdf(file_path)

View File

@@ -0,0 +1,268 @@
# src/services/parser/pdf_parser.py
"""PDF文档解析 - 使用PyMuPDF基础解析"""
import fitz # PyMuPDF
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass, field
from loguru import logger
import re
@dataclass
class PDFPageContent:
"""PDF页面内容"""
page_number: int
text: str
tables: List[str] = field(default_factory=list)
images: List[str] = field(default_factory=list) # 图片路径列表
blocks: List[Dict] = field(default_factory=list)
@dataclass
class PDFDocumentContent:
"""PDF文档完整内容"""
file_path: str
total_pages: int
pages: List[PDFPageContent]
metadata: Dict[str, str] = field(default_factory=dict)
markdown_text: str = ""
class PDFParser:
"""PDF文档解析器 - 基于PyMuPDF"""
def __init__(self):
self.pdf = None
def parse(self, file_path: str, extract_tables: bool = True, extract_images: bool = False) -> PDFDocumentContent:
"""
解析PDF文档
Args:
file_path: PDF文件路径
extract_tables: 是否提取表格
extract_images: 是否提取图片
Returns:
PDFDocumentContent: 解析后的文档内容
"""
logger.info(f"开始解析PDF文档: {file_path}")
try:
self.pdf = fitz.open(file_path)
doc_content = PDFDocumentContent(
file_path=file_path,
total_pages=self.pdf.page_count,
pages=[]
)
# 提取文档元数据
doc_content.metadata = self._extract_metadata()
# 逐页解析
for page_num in range(self.pdf.page_count):
page = self.pdf[page_num]
page_content = self._parse_page(page, page_num + 1, extract_tables, extract_images)
doc_content.pages.append(page_content)
# 生成Markdown格式文本
doc_content.markdown_text = self._generate_markdown(doc_content)
self.pdf.close()
logger.success(f"PDF解析完成{doc_content.total_pages}")
return doc_content
except Exception as e:
logger.error(f"PDF解析失败: {e}")
raise
def _extract_metadata(self) -> Dict[str, str]:
"""提取PDF元数据"""
metadata = {}
try:
meta = self.pdf.metadata
metadata = {
"title": meta.get("title", ""),
"author": meta.get("author", ""),
"subject": meta.get("subject", ""),
"keywords": meta.get("keywords", ""),
"creator": meta.get("creator", ""),
"producer": meta.get("producer", ""),
"creation_date": meta.get("creationDate", ""),
"mod_date": meta.get("modDate", ""),
}
except Exception as e:
logger.warning(f"提取元数据失败: {e}")
return metadata
def _parse_page(self, page: fitz.Page, page_num: int,
extract_tables: bool, extract_images: bool) -> PDFPageContent:
"""解析单页内容"""
page_content = PDFPageContent(page_number=page_num, text="")
# 提取文本块(保留结构)
blocks = page.get_text("dict", flags=fitz.TEXT_PRESERVE_WHITESPACE)["blocks"]
page_content.blocks = blocks
# 提取纯文本
text = page.get_text("text", flags=fitz.TEXT_PRESERVE_WHITESPACE)
page_content.text = text.strip()
# 提取表格使用PyMuPDF的表格提取功能
if extract_tables:
tables = self._extract_tables_from_page(page)
page_content.tables = tables
# 提取图片
if extract_images:
images = self._extract_images_from_page(page, page_num)
page_content.images = images
return page_content
def _extract_tables_from_page(self, page: fitz.Page) -> List[str]:
"""
从页面提取表格(基于文本块分析)
注意PyMuPDF基础版表格提取能力有限复杂表格建议使用MinerU
"""
tables = []
try:
# 使用PyMuPDF的表格提取方法2.4+版本)
# 对于更复杂的表格需要在mineru_parser中使用更高级的方法
tabs = page.find_tables()
if tabs:
for tab in tabs:
table_text = tab.extract()
# 将表格转换为Markdown格式
markdown_table = self._table_to_markdown(table_text)
tables.append(markdown_table)
except AttributeError:
# 旧版本PyMuPDF没有表格提取功能
logger.warning("PyMuPDF版本不支持表格提取请升级到2.4+版本")
except Exception as e:
logger.warning(f"表格提取失败: {e}")
return tables
def _table_to_markdown(self, table_data: List[List[str]]) -> str:
"""将表格数据转换为Markdown格式"""
if not table_data or len(table_data) < 1:
return ""
lines = []
# 表头
if len(table_data) >= 1:
header = table_data[0]
lines.append("| " + " | ".join(str(cell).strip() for cell in header) + " |")
lines.append("| " + " | ".join("---" for _ in header) + " |")
# 数据行
for row in table_data[1:]:
lines.append("| " + " | ".join(str(cell).strip() for cell in row) + " |")
return "\n".join(lines)
def _extract_images_from_page(self, page: fitz.Page, page_num: int) -> List[str]:
"""提取页面图片"""
images = []
# 图片提取功能(可选实现)
# 这里仅记录图片信息,实际图片需要额外保存
try:
image_list = page.get_images()
for img_index, img in enumerate(image_list):
xref = img[0]
images.append(f"image_p{page_num}_i{img_index}_xref{xref}")
except Exception as e:
logger.warning(f"图片提取失败: {e}")
return images
def _generate_markdown(self, doc_content: PDFDocumentContent) -> str:
"""生成Markdown格式文本"""
lines = []
# 文档标题
title = doc_content.metadata.get("title", "")
if title:
lines.append(f"# {title}\n")
else:
lines.append(f"# {doc_content.file_path}\n")
# 元数据信息
lines.append("\n## 文档信息\n")
for key, value in doc_content.metadata.items():
if value and key in ["author", "subject", "keywords", "creation_date"]:
lines.append(f"- **{key}**: {value}")
# 正文内容
lines.append("\n## 正文\n")
for page in doc_content.pages:
# 页码标记
lines.append(f"\n---\n**第 {page.page_number} 页**\n")
# 处理文本内容,识别标题结构
text = self._process_page_text(page.text, page.blocks)
lines.append(text)
# 添加表格
for table in page.tables:
lines.append("\n" + table + "\n")
return "\n".join(lines)
def _process_page_text(self, text: str, blocks: List[Dict]) -> str:
"""处理页面文本,识别标题结构"""
# 基于字体大小识别标题
processed_text = text
# 尝试识别标题(基于字号)
# 法规文档通常有明确的层级结构:章、节、条
processed_text = self._detect_headers(text, blocks)
return processed_text
def _detect_headers(self, text: str, blocks: List[Dict]) -> str:
"""检测并标记标题(基于字号或内容模式)"""
lines = text.split("\n")
processed_lines = []
for line in lines:
line = line.strip()
if not line:
continue
# 法规标题模式检测
# 第一章、第X章、第X节、第X条等
if re.match(r'^第[一二三四五六七八九十百]+章\s', line):
processed_lines.append(f"\n## {line}\n")
elif re.match(r'^第[一二三四五六七八九十百]+节\s', line):
processed_lines.append(f"\n### {line}\n")
elif re.match(r'^第[一二三四五六七八九十百]+条\s', line):
processed_lines.append(f"\n#### {line}\n")
elif re.match(r'^[一二三四五六七八九十]+\s*[、.]', line):
# 条款子项
processed_lines.append(f"- {line}")
else:
processed_lines.append(line)
return "\n".join(processed_lines)
def parse_to_markdown(self, file_path: str) -> str:
"""直接解析并返回Markdown文本"""
doc_content = self.parse(file_path)
return doc_content.markdown_text
def parse_pdf(file_path: str, **kwargs) -> PDFDocumentContent:
"""便捷函数解析PDF文档"""
parser = PDFParser()
return parser.parse(file_path, **kwargs)
def parse_pdf_to_markdown(file_path: str) -> str:
"""便捷函数解析PDF并返回Markdown"""
parser = PDFParser()
return parser.parse_to_markdown(file_path)

View File

@@ -0,0 +1,12 @@
# src/services/rag/__init__.py
"""RAG服务模块"""
from .retriever import Retriever, retrieve_regulations
from .context_builder import ContextBuilder, build_rag_context
from .prompt_templates import PromptTemplates, get_prompt_template
__all__ = [
"Retriever", "retrieve_regulations",
"ContextBuilder", "build_rag_context",
"PromptTemplates", "get_prompt_template"
]

View File

@@ -0,0 +1,230 @@
# src/services/rag/context_builder.py
"""RAG上下文构建服务 - 构建LLM输入上下文"""
from typing import List, Dict, Optional
from dataclasses import dataclass
from loguru import logger
from .retriever import RetrievedDocument
from app.config.settings import settings
@dataclass
class RAGContext:
"""RAG构建的上下文"""
system_prompt: str
context_text: str
user_query: str
total_tokens: int
sources: List[Dict]
truncated: bool = False
class ContextBuilder:
"""
RAG上下文构建器
功能:
- 格式化检索结果为上下文文本
- 控制上下文长度token限制
- 构建完整的LLM输入格式
"""
def __init__(
self,
max_context_tokens: int = None,
include_metadata: bool = True,
citation_format: str = "【条款{clause}"
):
"""
初始化上下文构建器
Args:
max_context_tokens: 最大上下文token数
include_metadata: 是否包含元数据(文档名、条款号等)
citation_format: 引用格式模板
"""
self.max_context_tokens = max_context_tokens or settings.rag_max_context_tokens
self.include_metadata = include_metadata
self.citation_format = citation_format
logger.info(f"上下文构建器初始化: max_tokens={self.max_context_tokens}")
def build(
self,
query: str,
documents: List[RetrievedDocument],
system_prompt: Optional[str] = None,
max_tokens: Optional[int] = None
) -> RAGContext:
"""
构建RAG上下文
Args:
query: 用户查询
documents: 检索到的文档列表
system_prompt: 系统提示词(可选)
max_tokens: 最大token数可选覆盖默认值
Returns:
RAGContext: 构建的上下文对象
"""
max_tokens = max_tokens or self.max_context_tokens
# 格式化文档内容
context_text, sources, truncated = self._format_documents(
documents,
max_tokens
)
# 构建系统提示词
system_prompt = system_prompt or self._default_system_prompt()
# 估算总token数
total_tokens = self._estimate_tokens(system_prompt + context_text + query)
logger.info(f"上下文构建完成: {len(documents)}条文档, {total_tokens}tokens, truncated={truncated}")
return RAGContext(
system_prompt=system_prompt,
context_text=context_text,
user_query=query,
total_tokens=total_tokens,
sources=sources,
truncated=truncated
)
def _format_documents(
self,
documents: List[RetrievedDocument],
max_tokens: int
) -> tuple:
"""
格式化文档内容
Args:
documents: 文档列表
max_tokens: 最大token数
Returns:
(context_text, sources, truncated)
"""
context_parts = []
sources = []
current_tokens = 0
truncated = False
for i, doc in enumerate(documents):
# 格式化单个文档
formatted = self._format_single_doc(doc, i + 1)
# 估算token数
doc_tokens = self._estimate_tokens(formatted)
# 检查是否超出限制
if current_tokens + doc_tokens > max_tokens:
truncated = True
logger.warning(f"上下文截断: 已达到{max_tokens}tokens限制")
break
context_parts.append(formatted)
current_tokens += doc_tokens
# 记录来源
sources.append({
"index": i + 1,
"doc_id": doc.doc_id,
"doc_name": doc.doc_name,
"section_title": doc.section_title,
"clause_number": doc.clause_number,
"page_number": doc.page_number,
"score": doc.score
})
context_text = "\n\n".join(context_parts)
return context_text, sources, truncated
def _format_single_doc(
self,
doc: RetrievedDocument,
index: int
) -> str:
"""格式化单个文档"""
parts = []
# 索引编号
parts.append(f"[{index}]")
# 元数据(可选)
if self.include_metadata:
meta_parts = []
if doc.doc_name:
meta_parts.append(f"文档: {doc.doc_name}")
if doc.section_title:
meta_parts.append(f"章节: {doc.section_title}")
if doc.clause_number:
clause_text = self.citation_format.format(clause=doc.clause_number)
meta_parts.append(clause_text)
if meta_parts:
parts.append(" | ".join(meta_parts))
# 内容
parts.append(doc.content)
return "\n".join(parts)
def _default_system_prompt(self) -> str:
"""默认系统提示词"""
return """你是合规专家助手,基于提供的法规条款回答问题。
回答要求:
1. 直接回答问题必须引用具体条款编号如【条款5.2.1】)
2. 如引用的条款不完整,说明需要进一步查阅原文
3. 给出明确的合规建议和操作指导
4. 如果检索内容不足以回答问题,如实说明
回答格式:
- 先给出直接结论
- 然后引用支撑条款
- 最后给出合规建议"""
def _estimate_tokens(self, text: str) -> int:
"""估算文本token数"""
# 中文字符约1.5 token英文约0.25 token
chinese_chars = sum(1 for c in text if '' <= c <= '鿿')
other_chars = len(text) - chinese_chars
return int(chinese_chars * 1.5 + other_chars * 0.25)
def build_messages(
self,
context: RAGContext
) -> List[Dict[str, str]]:
"""
构建LLM消息格式
Args:
context: RAG上下文对象
Returns:
List[Dict]: [{"role": "system/user/assistant", "content": "..."}]
"""
messages = [
{"role": "system", "content": context.system_prompt},
{"role": "user", "content": f"参考以下法规条款回答问题。\n\n{context.context_text}\n\n问题:{context.user_query}"}
]
return messages
def build_rag_context(
query: str,
documents: List[RetrievedDocument],
**kwargs
) -> RAGContext:
"""便捷函数构建RAG上下文"""
builder = ContextBuilder()
return builder.build(query, documents, **kwargs)

View File

@@ -0,0 +1,296 @@
# src/services/rag/prompt_templates.py
"""RAG Prompt模板 - 合规问答专用Prompt"""
from typing import Dict, Optional
from dataclasses import dataclass
@dataclass
class PromptTemplate:
"""Prompt模板"""
name: str
system_prompt: str
user_template: str
description: str
class PromptTemplates:
"""
合规问答Prompt模板库
包含多种场景的Prompt模板
- 合规问答(标准)
- 条款解读(详细解释)
- 合规检查(判断合规状态)
- 差异对比(新旧法规对比)
- 报告生成(合规报告)
"""
# 合规问答标准模板
COMPLIANCE_QA = PromptTemplate(
name="compliance_qa",
system_prompt="""你是合规专家助手,专门解答法规合规问题。
角色定位:
- 深谙国家法规标准GB标准、行业标准
- 熟悉车辆安全、数据安全、EHS等领域合规要求
- 提供专业、准确、可操作的合规建议
回答规范:
1. 必须引用具体条款编号如【条款5.2.1】)
2. 优先引用高相关性条款score ≥ 0.5
3. 如条款内容不完整,明确提示需要查阅原文
4. 给出明确的合规结论和建议
5. 如检索内容不足以回答,如实说明
回答格式:
【结论】直接给出合规判断或答案
【条款依据】
- 【条款X.X.X】简要内容摘要相关性: 高/中)
- ...
【合规建议】
1. 具体操作建议
2. 需要注意的风险点
3. 后续行动建议""",
user_template="""请根据以下法规条款回答问题。
【法规条款】
{context}
【用户问题】
{query}""",
description="标准合规问答模板"
)
# 条款解读模板(详细解释)
CLAUSE_INTERPRETATION = PromptTemplate(
name="clause_interpretation",
system_prompt="""你是法规解读专家,负责详细解释法规条款的含义和应用。
解读要求:
1. 逐句解释条款原文的含义
2. 说明条款的目的和背景
3. 举例说明条款的实际应用场景
4. 解释常见的误解和注意事项
解读格式:
【条款原文】完整引用条款
【逐句解读】
- "原文句1":解读含义
- "原文句2":解读含义
...
【应用场景】
具体举例说明该条款在实际工作中如何应用
【注意事项】
常见误解、执行难点、合规风险点""",
user_template="""请解读以下法规条款:
条款编号:{clause_number}
条款内容:{content}
用户关注点:{query}""",
description="条款详细解读模板"
)
# 合规检查模板(判断合规状态)
COMPLIANCE_CHECK = PromptTemplate(
name="compliance_check",
system_prompt="""你是合规检查专家,负责评估企业行为或产品的合规状态。
检查流程:
1. 理解企业行为/产品描述
2. 识别相关的法规条款
3. 逐条对照检查合规状态
4. 给出综合合规结论和整改建议
合规状态分类:
- ✅ 符合:完全满足法规要求
- ⚠️ 需评估:需要进一步核实或补充材料
- ❌ 不符合:明确违反法规要求
- ❓ 无适用条款:检索内容不足以判断
检查格式:
【合规检查报告】
一、检查对象
{描述企业行为/产品}
二、条款对照检查
| 条款编号 | 要求摘要 | 检查状态 | 说明 |
|--------|---------|---------|------|
| 【条款X.X.X】 | ... | ✅/⚠️/❌/❓ | ... |
三、综合结论
合规等级A/B/C/D完全合规/基本合规/部分合规/不合规)
四、整改建议(如需要)
1. ...
2. ...""",
user_template="""请对以下企业行为进行合规检查:
【行为/产品描述】
{query}
【相关法规条款】
{context}""",
description="合规检查评估模板"
)
# 差异对比模板(新旧法规对比)
COMPARISON = PromptTemplate(
name="comparison",
system_prompt="""你是法规变更分析专家,负责对比新旧法规版本的差异。
对比任务:
1. 识别新旧版本的条款差异
2. 分类差异类型(新增/修改/删除)
3. 分析差异的影响范围
4. 给出企业应对建议
差异分类:
- 🆕 新增条款:原版本不存在
- 🔄 修改条款:内容有实质性变更
- ❌ 删除条款:原条款被移除
- ⚖️ 调整条款:仅格式/编号调整,实质内容不变
对比格式:
【法规变更对比分析】
一、变更概述
- 旧版本:{version_old}
- 新版本:{version_new}
- 变更条款数:{count}
二、差异明细
| 类型 | 条款编号 | 旧版本内容 | 新版本内容 | 变化要点 |
|-----|---------|-----------|-----------|---------|
| 🆕 | X.X.X | - | ... | 新增要求... |
三、影响分析
- 高影响:{条款列表}
- 中影响:{条款列表}
- 低影响:{条款列表}
四、应对建议
1. 立即整改项
2. 逐步调整项
3. 持续关注项""",
user_template="""请对比分析以下法规差异:
【用户问题】
{query}
【旧版本条款】
{context_old}
【新版本条款】
{context_new}""",
description="法规版本对比模板"
)
# 报告生成模板
REPORT_GENERATION = PromptTemplate(
name="report_generation",
system_prompt="""你是合规报告撰写专家,负责生成结构化的合规分析报告。
报告要求:
1. 结构清晰、逻辑严谨
2. 数据准确、引用规范
3. 结论明确、建议可操作
4. 语言专业、表达简洁
报告结构:
1. 概述(背景、范围)
2. 分析内容(主体分析)
3. 发现问题(合规差距)
4. 整改建议(具体措施)
5. 附录(引用条款原文)""",
user_template="""请生成以下合规报告:
【报告主题】
{query}
【分析依据】
{context}
【报告要求】
{requirements}""",
description="合规报告生成模板"
)
# 文档摘要生成模板
DOCUMENT_SUMMARY = PromptTemplate(
name="document_summary",
system_prompt="""你是法规文档摘要专家,负责生成法规文档的核心要点摘要。
摘要要求:
1. 精炼核心内容不超过1024字
2. 突出关键合规要求和条款编号
3. 说明适用范围和生效条件
4. 列出重要定义和术语解释
摘要格式:
【法规名称】{doc_name}
【适用范围】{适用范围描述}
【核心条款摘要】
- 【条款X.X.X】{关键要求}(重要度:高)
- ...
【关键术语】
- 术语1定义解释
- ...
【合规要点】
1. 必须满足的核心要求
2. 需要特别注意的条款""",
user_template="""请生成以下法规文档的摘要:
【文档名称】
{doc_name}
【文档内容】
{content}
请生成不超过1024字的摘要。""",
description="文档摘要生成模板"
)
@classmethod
def get_template(cls, name: str) -> Optional[PromptTemplate]:
"""获取指定模板"""
templates = {
"compliance_qa": cls.COMPLIANCE_QA,
"clause_interpretation": cls.CLAUSE_INTERPRETATION,
"compliance_check": cls.COMPLIANCE_CHECK,
"comparison": cls.COMPARISON,
"report": cls.REPORT_GENERATION,
"document_summary": cls.DOCUMENT_SUMMARY
}
return templates.get(name)
@classmethod
def list_templates(cls) -> Dict[str, str]:
"""列出所有模板"""
return {
"compliance_qa": cls.COMPLIANCE_QA.description,
"clause_interpretation": cls.CLAUSE_INTERPRETATION.description,
"compliance_check": cls.COMPLIANCE_CHECK.description,
"comparison": cls.COMPARISON.description,
"report": cls.REPORT_GENERATION.description,
"document_summary": cls.DOCUMENT_SUMMARY.description
}
def get_prompt_template(name: str) -> PromptTemplate:
"""便捷函数获取Prompt模板"""
template = PromptTemplates.get_template(name)
if not template:
raise ValueError(f"不存在的模板: {name}")
return template

View File

@@ -0,0 +1,193 @@
# src/services/rag/retriever.py
"""RAG检索服务 - 封装Milvus检索"""
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, field
from loguru import logger
from app.services.embedding.bge_m3_embedder import BGEM3Embedder
from app.services.storage.milvus_client import MilvusClient, SearchResult
from app.config.settings import settings
@dataclass
class RetrievedDocument:
"""检索到的文档"""
content: str
doc_id: str # 文档ID用于下载
doc_name: str
section_title: str
clause_number: str
page_number: int
score: float
metadata: Dict[str, Any] = field(default_factory=dict)
class Retriever:
"""
RAG检索器
功能:
- 向量检索Dense + Sparse混合
- 重排序(可选)
- 过滤和筛选
"""
def __init__(
self,
top_k: int = None,
rerank: bool = False,
min_score: float = 0.3
):
"""
初始化检索器
Args:
top_k: 检索召回数量
rerank: 是否启用重排序
min_score: 最低相关性分数阈值
"""
self.top_k = top_k or settings.rag_top_k
self.rerank = rerank
self.min_score = min_score
# 嵌入模型(延迟加载)
self.embedder: Optional[BGEM3Embedder] = None
# Milvus客户端延迟连接
self.milvus: Optional[MilvusClient] = None
logger.info(f"检索器初始化: top_k={self.top_k}, rerank={self.rerank}")
def _init_embedder(self):
"""延迟初始化嵌入模型"""
if self.embedder is None:
logger.info("加载嵌入模型...")
self.embedder = BGEM3Embedder(model_name=settings.embedding_model)
def _init_milvus(self):
"""延迟初始化Milvus"""
if self.milvus is None:
logger.info("连接Milvus...")
self.milvus = MilvusClient()
self.milvus.connect()
self.milvus.create_collection(recreate=False)
self.milvus.load_collection()
def retrieve(
self,
query: str,
filters: Optional[str] = None,
top_k: Optional[int] = None
) -> List[RetrievedDocument]:
"""
检索相关文档
Args:
query: 查询文本
filters: 过滤条件(如 "regulation_type=='车辆安全'"
top_k: 返回数量(可选,覆盖默认值)
Returns:
List[RetrievedDocument]: 检索结果列表
"""
logger.info(f"执行检索: {query}")
# 初始化组件
self._init_embedder()
self._init_milvus()
# 生成查询向量
query_embedding = self.embedder.embed_single(query)
# 执行混合检索
results = self.milvus.hybrid_search(
query_dense=query_embedding['dense'].tolist(),
query_sparse=query_embedding['sparse'],
top_k=top_k or self.top_k,
filters=filters
)
# 转换为RetrievedDocument格式
documents = []
for r in results:
if r.score >= self.min_score:
doc = RetrievedDocument(
content=r.content,
doc_id=r.metadata.get("doc_id", ""),
doc_name=r.metadata.get("doc_name", ""),
section_title=r.metadata.get("section_title", ""),
clause_number=r.metadata.get("clause_number", ""),
page_number=r.metadata.get("page_number", 0),
score=r.score,
metadata=r.metadata
)
documents.append(doc)
logger.success(f"检索完成,返回{len(documents)}条结果(阈值过滤后)")
return documents
def retrieve_with_scores(
self,
query: str,
filters: Optional[str] = None
) -> List[Dict]:
"""
检索并返回完整结果(包含分数)
Args:
query: 查询文本
filters: 过滤条件
Returns:
List[Dict]: 包含分数的检索结果
"""
documents = self.retrieve(query, filters)
return [
{
"content": doc.content,
"doc_id": doc.doc_id,
"doc_name": doc.doc_name,
"section_title": doc.section_title,
"clause_number": doc.clause_number,
"page_number": doc.page_number,
"score": doc.score
}
for doc in documents
]
def search_by_doc_name(
self,
query: str,
doc_name: str
) -> List[RetrievedDocument]:
"""按文档名称过滤检索"""
filters = f'doc_name=="{doc_name}"'
return self.retrieve(query, filters)
def search_by_regulation_type(
self,
query: str,
regulation_type: str
) -> List[RetrievedDocument]:
"""按法规类型过滤检索"""
filters = f'regulation_type=="{regulation_type}"'
return self.retrieve(query, filters)
def close(self):
"""关闭连接"""
if self.milvus:
self.milvus.disconnect()
logger.info("检索器已关闭")
def retrieve_regulations(
query: str,
top_k: int = 10,
filters: Optional[str] = None
) -> List[RetrievedDocument]:
"""便捷函数:检索法规"""
retriever = Retriever(top_k=top_k)
results = retriever.retrieve(query, filters)
retriever.close()
return results

View File

@@ -0,0 +1,7 @@
# src/services/storage/__init__.py
"""存储服务"""
from .milvus_client import MilvusClient
from .minio_client import MinIOClient
__all__ = ["MilvusClient", "MinIOClient"]

View File

@@ -0,0 +1,485 @@
# src/services/storage/milvus_client.py
"""Milvus向量数据库客户端 - 存储与检索服务"""
from pymilvus import (
connections,
Collection,
FieldSchema,
CollectionSchema,
DataType,
utility
)
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, field
from loguru import logger
import time
import numpy as np
from ..embedding.text_chunker import TextChunk
from ..embedding.bge_m3_embedder import EmbeddingResult
from app.config.settings import settings
@dataclass
class SearchResult:
"""检索结果"""
id: int
content: str
score: float
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class MilvusDocument:
"""Milvus文档数据结构"""
doc_id: str
chunk_id: str
content: str
dense_vector: List[float]
sparse_vector: Dict[int, float]
doc_name: str
section_title: str
clause_number: str
page_number: int
regulation_type: str
version: str
create_time: int
class MilvusClient:
"""Milvus向量数据库客户端"""
COLLECTION_NAME = "regulations"
SCHEMA_FIELDS = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="doc_id", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=128),
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=8192),
FieldSchema(name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=1024),
FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR),
FieldSchema(name="doc_name", dtype=DataType.VARCHAR, max_length=256),
FieldSchema(name="section_title", dtype=DataType.VARCHAR, max_length=512),
FieldSchema(name="clause_number", dtype=DataType.VARCHAR, max_length=64),
FieldSchema(name="page_number", dtype=DataType.INT64),
FieldSchema(name="regulation_type", dtype=DataType.VARCHAR, max_length=32),
FieldSchema(name="version", dtype=DataType.VARCHAR, max_length=32),
FieldSchema(name="create_time", dtype=DataType.INT64),
]
def __init__(
self,
host: str = None,
port: int = None,
collection_name: str = None,
db_name: str = None
):
self.host = host or settings.milvus_host
self.port = port or settings.milvus_port
self.collection_name = collection_name or settings.milvus_collection
self.db_name = db_name or settings.milvus_db_name
self.collection: Optional[Collection] = None
self.connected = False
logger.info(f"Milvus客户端配置: {self.host}:{self.port}, Collection: {self.collection_name}")
def connect(self) -> bool:
"""连接到Milvus服务器"""
try:
connections.connect(
alias="default",
host=self.host,
port=self.port,
db_name=self.db_name
)
self.connected = True
logger.success(f"Milvus连接成功: {self.host}:{self.port}")
return True
except Exception as e:
logger.error(f"Milvus连接失败: {e}")
self.connected = False
return False
def disconnect(self):
"""断开连接"""
try:
connections.disconnect("default")
self.connected = False
logger.info("Milvus连接已断开")
except Exception as e:
logger.warning(f"断开连接时出错: {e}")
def create_collection(self, recreate: bool = False) -> bool:
"""创建Collection"""
if not self.connected:
logger.warning("未连接到Milvus请先调用connect()")
return False
try:
if utility.has_collection(self.collection_name):
if recreate:
logger.info(f"删除已存在的Collection: {self.collection_name}")
utility.drop_collection(self.collection_name)
else:
logger.info(f"Collection已存在: {self.collection_name}")
self.collection = Collection(self.collection_name)
return True
schema = CollectionSchema(
fields=self.SCHEMA_FIELDS,
description="法规文档向量存储",
enable_dynamic_field=True
)
self.collection = Collection(
name=self.collection_name,
schema=schema
)
self._create_indexes()
logger.success(f"Collection创建成功: {self.collection_name}")
return True
except Exception as e:
logger.error(f"Collection创建失败: {e}")
return False
def _create_indexes(self):
"""创建向量索引"""
if not self.collection:
return
try:
dense_index_params = {
"metric_type": "COSINE",
"index_type": "IVF_FLAT",
"params": {"nlist": 128}
}
self.collection.create_index(
field_name="dense_vector",
index_params=dense_index_params
)
sparse_index_params = {
"metric_type": "IP",
"index_type": "SPARSE_INVERTED_INDEX",
"params": {"drop_ratio_build": 0.2}
}
self.collection.create_index(
field_name="sparse_vector",
index_params=sparse_index_params
)
logger.success("向量索引创建成功")
except Exception as e:
logger.warning(f"创建索引时出错: {e}")
def load_collection(self):
"""加载Collection到内存"""
if self.collection:
self.collection.load()
logger.info(f"Collection已加载: {self.collection_name}")
def release_collection(self):
"""释放Collection内存"""
if self.collection:
self.collection.release()
logger.info(f"Collection已释放: {self.collection_name}")
def insert_chunks(
self,
chunks: List[TextChunk],
embeddings: EmbeddingResult
) -> List[int]:
"""插入文档分块和嵌入向量"""
if not self.collection:
logger.warning("Collection未初始化")
return []
if len(chunks) != len(embeddings.texts):
logger.warning(f"Chunks数量与嵌入数量不匹配")
return []
logger.info(f"准备插入{len(chunks)}个文档分块")
try:
data = []
current_time = int(time.time())
for chunk, dense_emb, sparse_emb in zip(
chunks,
embeddings.dense_embeddings,
embeddings.sparse_embeddings
):
row = {
"doc_id": chunk.metadata.doc_id,
"chunk_id": chunk.metadata.chunk_id,
"content": chunk.content,
"dense_vector": dense_emb.tolist(),
"sparse_vector": sparse_emb,
"doc_name": chunk.metadata.doc_name,
"section_title": chunk.metadata.section_title,
"clause_number": chunk.metadata.clause_number,
"page_number": chunk.metadata.page_number,
"regulation_type": chunk.metadata.regulation_type,
"version": chunk.metadata.version,
"create_time": current_time
}
data.append(row)
result = self.collection.insert(data)
self.collection.flush()
logger.success(f"插入完成,共{len(result.primary_keys)}条记录")
return result.primary_keys
except Exception as e:
logger.error(f"插入数据失败: {e}")
return []
def hybrid_search(
self,
query_dense: List[float],
query_sparse: Dict[int, float],
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""混合检索Dense + Sparse"""
if not self.collection:
logger.warning("Collection未初始化")
return []
try:
self.collection.load()
# 使用简单的Dense检索兼容所有版本
dense_results = self.dense_search(query_dense, top_k, filters)
# 可选合并Sparse结果
if query_sparse:
sparse_results = self.sparse_search(query_sparse, top_k, filters)
merged = self._merge_results(dense_results, sparse_results, top_k)
logger.success(f"混合检索完成,返回{len(merged)}条结果")
return merged
return dense_results
except Exception as e:
logger.error(f"混合检索失败: {e}")
return []
def _merge_results(
self,
dense_results: List[SearchResult],
sparse_results: List[SearchResult],
top_k: int,
dense_weight: float = 0.6
) -> List[SearchResult]:
"""手动融合Dense和Sparse结果"""
sparse_weight = 1 - dense_weight
merged_dict = {}
for r in dense_results:
merged_dict[r.id] = {
"result": r,
"dense_score": r.score * dense_weight,
"sparse_score": 0
}
for r in sparse_results:
if r.id in merged_dict:
merged_dict[r.id]["sparse_score"] = r.score * sparse_weight
else:
merged_dict[r.id] = {
"result": r,
"dense_score": 0,
"sparse_score": r.score * sparse_weight
}
final_results = []
for id_, data in merged_dict.items():
result = data["result"]
final_score = data["dense_score"] + data["sparse_score"]
final_results.append(SearchResult(
id=result.id,
content=result.content,
score=final_score,
metadata=result.metadata
))
final_results.sort(key=lambda x: x.score, reverse=True)
return final_results[:top_k]
def dense_search(
self,
query_dense: List[float],
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""纯Dense向量检索"""
if not self.collection:
return []
try:
self.collection.load()
search_params = {
"metric_type": "COSINE",
"params": {"nprobe": 16}
}
results = self.collection.search(
data=[query_dense],
anns_field="dense_vector",
param=search_params,
limit=top_k,
filter=filters,
output_fields=[
"doc_id", "chunk_id", "content",
"doc_name", "section_title", "clause_number",
"page_number", "regulation_type", "version"
]
)
search_results = []
for hits in results:
for hit in hits:
result = SearchResult(
id=hit.id,
content=hit.entity.get("content", ""),
score=hit.score,
metadata={
"doc_id": hit.entity.get("doc_id", ""),
"chunk_id": hit.entity.get("chunk_id", ""),
"doc_name": hit.entity.get("doc_name", ""),
"section_title": hit.entity.get("section_title", ""),
"clause_number": hit.entity.get("clause_number", ""),
"page_number": hit.entity.get("page_number", 0),
"regulation_type": hit.entity.get("regulation_type", ""),
"version": hit.entity.get("version", ""),
}
)
search_results.append(result)
return search_results
except Exception as e:
logger.error(f"Dense检索失败: {e}")
return []
def sparse_search(
self,
query_sparse: Dict[int, float],
top_k: int = 10,
filters: Optional[str] = None
) -> List[SearchResult]:
"""纯Sparse向量检索"""
if not self.collection:
return []
try:
self.collection.load()
search_params = {
"metric_type": "IP",
"params": {"drop_ratio_search": 0.2}
}
results = self.collection.search(
data=[query_sparse],
anns_field="sparse_vector",
param=search_params,
limit=top_k,
filter=filters,
output_fields=[
"doc_id", "chunk_id", "content",
"doc_name", "section_title", "clause_number",
"page_number", "regulation_type", "version"
]
)
search_results = []
for hits in results:
for hit in hits:
result = SearchResult(
id=hit.id,
content=hit.entity.get("content", ""),
score=hit.score,
metadata={
"doc_id": hit.entity.get("doc_id", ""),
"chunk_id": hit.entity.get("chunk_id", ""),
"doc_name": hit.entity.get("doc_name", ""),
"section_title": hit.entity.get("section_title", ""),
"clause_number": hit.entity.get("clause_number", ""),
"page_number": hit.entity.get("page_number", 0),
"regulation_type": hit.entity.get("regulation_type", ""),
"version": hit.entity.get("version", ""),
}
)
search_results.append(result)
return search_results
except Exception as e:
logger.error(f"Sparse检索失败: {e}")
return []
def delete_by_doc_id(self, doc_id: str) -> int:
"""根据doc_id删除记录"""
if not self.collection:
return 0
try:
expr = f'doc_id=="{doc_id}"'
result = self.collection.delete(expr)
logger.info(f"删除记录: doc_id={doc_id}, 数量={len(result.primary_keys)}")
return len(result.primary_keys)
except Exception as e:
logger.error(f"删除失败: {e}")
return 0
def get_collection_stats(self) -> Dict[str, Any]:
"""获取Collection统计信息"""
if not self.collection:
return {}
try:
stats = {
"name": self.collection_name,
"num_entities": self.collection.num_entities,
"description": self.collection.description,
}
return stats
except Exception as e:
logger.warning(f"获取统计信息失败: {e}")
return {}
def create_milvus_client() -> MilvusClient:
"""便捷函数创建Milvus客户端"""
client = MilvusClient()
client.connect()
client.create_collection(recreate=False)
return client
def insert_documents(
client: MilvusClient,
chunks: List[TextChunk],
embeddings: EmbeddingResult
) -> List[int]:
"""便捷函数:插入文档"""
return client.insert_chunks(chunks, embeddings)
def search_regulations(
client: MilvusClient,
query_dense: List[float],
query_sparse: Dict[int, float],
top_k: int = 10
) -> List[SearchResult]:
"""便捷函数:检索法规"""
return client.hybrid_search(query_dense, query_sparse, top_k)

View File

@@ -0,0 +1,352 @@
# src/services/storage/minio_client.py
"""MinIO对象存储客户端 - 文档文件存储"""
from minio import Minio
from minio.error import S3Error
from typing import Optional, Dict, Any
from loguru import logger
from io import BytesIO
import os
from app.config.settings import settings
class MinIOClient:
"""MinIO对象存储客户端"""
def __init__(
self,
endpoint: str = None,
access_key: str = None,
secret_key: str = None,
bucket: str = None,
secure: bool = None
):
"""
初始化MinIO客户端
Args:
endpoint: MinIO服务地址
access_key: 访问密钥
secret_key: 秘密密钥
bucket: 存储桶名称
secure: 是否使用HTTPS
"""
self.endpoint = endpoint or settings.minio_endpoint
self.access_key = access_key or settings.minio_access_key
self.secret_key = secret_key or settings.minio_secret_key
self.bucket = bucket or settings.minio_bucket
self.secure = secure or settings.minio_secure
self.client: Optional[Minio] = None
self.connected = False
logger.info(f"MinIO客户端配置: {self.endpoint}, bucket={self.bucket}")
def connect(self) -> bool:
"""连接MinIO服务"""
try:
self.client = Minio(
self.endpoint,
access_key=self.access_key,
secret_key=self.secret_key,
secure=self.secure
)
self.connected = True
logger.success(f"MinIO连接成功: {self.endpoint}")
return True
except Exception as e:
logger.error(f"MinIO连接失败: {e}")
self.connected = False
return False
def ensure_bucket(self) -> bool:
"""确保存储桶存在"""
if not self.connected:
logger.warning("未连接MinIO请先调用connect()")
return False
try:
if not self.client.bucket_exists(self.bucket):
self.client.make_bucket(self.bucket)
logger.success(f"创建存储桶: {self.bucket}")
else:
logger.info(f"存储桶已存在: {self.bucket}")
return True
except S3Error as e:
logger.error(f"存储桶操作失败: {e}")
return False
def upload_file(
self,
file_path: str,
object_name: str,
metadata: Dict[str, Any] = None
) -> bool:
"""
上传本地文件到MinIO
Args:
file_path: 本地文件路径
object_name: MinIO对象名称
metadata: 元数据
Returns:
bool: 是否成功
"""
if not self.connected:
self.connect()
self.ensure_bucket()
try:
file_size = os.stat(file_path).st_size
content_type = self._get_content_type(file_path)
with open(file_path, 'rb') as f:
self.client.put_object(
self.bucket,
object_name,
f,
file_size,
content_type=content_type,
metadata=metadata
)
logger.success(f"文件上传成功: {object_name}, 大小={file_size}")
return True
except S3Error as e:
logger.error(f"文件上传失败: {e}")
return False
def upload_bytes(
self,
data: bytes,
object_name: str,
content_type: str = "application/octet-stream",
metadata: Dict[str, Any] = None
) -> bool:
"""
上传字节数据到MinIO
Args:
data: 文件字节数据
object_name: MinIO对象名称
content_type: 内容类型
metadata: 元数据注意MinIO仅支持US-ASCII字符
Returns:
bool: 是否成功
"""
if not self.connected:
self.connect()
self.ensure_bucket()
try:
data_stream = BytesIO(data)
# 处理metadata仅保留ASCII安全字符
safe_metadata = None
if metadata:
safe_metadata = {}
for key, value in metadata.items():
if isinstance(value, str):
# 只保留ASCII字符或转换为安全格式
try:
value.encode('ascii')
safe_metadata[key] = value
except UnicodeEncodeError:
# 中文字符跳过或用占位符
safe_metadata[key] = ""
else:
safe_metadata[key] = str(value)
self.client.put_object(
self.bucket,
object_name,
data_stream,
len(data),
content_type=content_type,
metadata=safe_metadata
)
logger.success(f"数据上传成功: {object_name}, 大小={len(data)}")
return True
except S3Error as e:
logger.error(f"数据上传失败: {e}")
return False
def download_file(
self,
object_name: str,
file_path: str
) -> bool:
"""
从MinIO下载文件到本地
Args:
object_name: MinIO对象名称
file_path: 本地保存路径
Returns:
bool: 是否成功
"""
if not self.connected:
self.connect()
try:
self.client.fget_object(
self.bucket,
object_name,
file_path
)
logger.success(f"文件下载成功: {object_name} -> {file_path}")
return True
except S3Error as e:
logger.error(f"文件下载失败: {e}")
return False
def get_object_url(
self,
object_name: str,
expires: int = 3600
) -> Optional[str]:
"""
获取对象下载URL临时URL
Args:
object_name: MinIO对象名称
expires: URL有效期
Returns:
str: 下载URL
"""
if not self.connected:
self.connect()
try:
url = self.client.presigned_get_object(
self.bucket,
object_name,
expires=expires
)
return url
except S3Error as e:
logger.error(f"获取URL失败: {e}")
return None
def get_object_data(self, object_name: str) -> Optional[bytes]:
"""
获取对象数据(字节)
Args:
object_name: MinIO对象名称
Returns:
bytes: 文件数据
"""
if not self.connected:
self.connect()
try:
response = self.client.get_object(self.bucket, object_name)
data = response.read()
response.close()
response.release_conn()
return data
except S3Error as e:
logger.error(f"获取对象数据失败: {e}")
return None
def delete_object(self, object_name: str) -> bool:
"""
删除对象
Args:
object_name: MinIO对象名称
Returns:
bool: 是否成功
"""
if not self.connected:
self.connect()
try:
self.client.remove_object(self.bucket, object_name)
logger.info(f"对象删除成功: {object_name}")
return True
except S3Error as e:
logger.error(f"对象删除失败: {e}")
return False
def list_objects(self, prefix: str = "") -> list:
"""
列出存储桶中的对象
Args:
prefix: 对象名称前缀
Returns:
list: 对象列表
"""
if not self.connected:
self.connect()
try:
objects = self.client.list_objects(self.bucket, prefix=prefix)
return [obj.object_name for obj in objects]
except S3Error as e:
logger.error(f"列出对象失败: {e}")
return []
def object_exists(self, object_name: str) -> bool:
"""
检查对象是否存在
Args:
object_name: MinIO对象名称
Returns:
bool: 是否存在
"""
if not self.connected:
self.connect()
try:
self.client.stat_object(self.bucket, object_name)
return True
except S3Error:
return False
def _get_content_type(self, file_path: str) -> str:
"""根据文件扩展名获取Content-Type"""
ext = os.path.splitext(file_path)[1].lower()
content_types = {
'.pdf': 'application/pdf',
'.doc': 'application/msword',
'.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'.txt': 'text/plain',
'.json': 'application/json',
'.xml': 'application/xml',
}
return content_types.get(ext, 'application/octet-stream')
def close(self):
"""关闭连接MinIO客户端无需显式关闭"""
self.connected = False
logger.info("MinIO客户端已关闭")
def create_minio_client() -> MinIOClient:
"""便捷函数创建MinIO客户端"""
client = MinIOClient()
client.connect()
client.ensure_bucket()
return client