first commit

This commit is contained in:
2026-04-23 09:58:47 +08:00
commit 448e078d99
49 changed files with 5188 additions and 0 deletions

View File

@@ -0,0 +1,95 @@
import uuid
import logging
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from langchain.schema import HumanMessage, SystemMessage
from ..core.llm import get_llm, COMPLIANCE_CHECK_PROMPT
from ..services.rag import hybrid_search
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/compliance", tags=["合规审查"])
class ComplianceCheckRequest(BaseModel):
query: str
regulation_domains: list[str] = ["vehicle_safety"]
top_k: int = 5
class ComplianceCheckResponse(BaseModel):
risk_level: str
risk_score: float
findings: list[dict]
recommendations: list[str]
sources: list[dict]
@router.post("/check", response_model=ComplianceCheckResponse)
async def check_compliance(req: ComplianceCheckRequest):
"""
对输入内容进行合规性检查,与法规库比对后给出风险评估。
"""
# 检索相关法规(从多个域检索)
all_chunks = []
for domain in req.regulation_domains:
chunks = await hybrid_search(
req.query,
collection_name="regulation_chunks",
top_k=req.top_k,
)
all_chunks.extend(chunks)
# 去重 + 按分数排序
seen = set()
unique_chunks = []
for c in sorted(all_chunks, key=lambda x: x["score"], reverse=True):
if c["id"] not in seen:
seen.add(c["id"])
unique_chunks.append(c)
top_chunks = unique_chunks[:req.top_k]
if not top_chunks:
return ComplianceCheckResponse(
risk_level="unknown",
risk_score=0,
findings=[{"issue": "未找到相关法规,请先上传法规文档"}],
recommendations=["上传相关法规文档到知识库后重试"],
sources=[],
)
# 构建法规上下文
regulations_text = "\n\n".join(
f"[{i+1}] {c['content'][:500]}" for i, c in enumerate(top_chunks)
)
prompt = COMPLIANCE_CHECK_PROMPT.format(
content=req.query,
regulations=regulations_text,
)
llm = get_llm(temperature=0.0)
try:
response = await llm.ainvoke([HumanMessage(content=prompt)])
analysis = response.content
except Exception as e:
logger.error(f"LLM 合规分析失败:{e}")
analysis = f"LLM 分析失败:{e}"
# 简单解析 LLM 输出(生产可用结构化输出)
risk_level = "medium"
risk_score = 50.0
if "critical" in analysis.lower() or "严重" in analysis:
risk_level, risk_score = "critical", 90.0
elif "high" in analysis.lower() or "高风险" in analysis:
risk_level, risk_score = "high", 70.0
elif "low" in analysis.lower() or "低风险" in analysis:
risk_level, risk_score = "low", 20.0
return ComplianceCheckResponse(
risk_level=risk_level,
risk_score=risk_score,
findings=[{"analysis": analysis}],
recommendations=["请参考上述分析进行整改"],
sources=[{"content": c["content"][:200], "score": c["score"]} for c in top_chunks],
)

View File

@@ -0,0 +1,114 @@
import uuid
import logging
from pathlib import Path
from fastapi import APIRouter, Depends, UploadFile, File, Form, HTTPException, BackgroundTasks
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from ..core.deps import get_db
from ..models.db import Workspace, File as FileRecord, Task
from ..services.rag import hybrid_search, rerank, generate_answer
from ..worker import process_file_task
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/kb", tags=["知识库"])
UPLOAD_DIR = Path("/app/data/uploads")
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
class WorkspaceCreate(BaseModel):
name: str
description: str = ""
domain: str = "general"
class QARequest(BaseModel):
query: str
workspace_id: str | None = None
top_k: int = 5
return_sources: bool = True
@router.post("/workspaces")
async def create_workspace(req: WorkspaceCreate, db: AsyncSession = Depends(get_db)):
ws = Workspace(name=req.name, description=req.description, domain=req.domain)
db.add(ws)
await db.flush()
return {"id": str(ws.id), "name": ws.name, "domain": ws.domain}
@router.post("/files/upload")
async def upload_file(
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
workspace_id: str = Form(default=""),
db: AsyncSession = Depends(get_db),
):
content = await file.read()
file_id = str(uuid.uuid4())
suffix = Path(file.filename or "doc").suffix
save_path = UPLOAD_DIR / f"{file_id}{suffix}"
save_path.write_bytes(content)
file_record = FileRecord(
id=uuid.UUID(file_id),
filename=f"{file_id}{suffix}",
original_name=file.filename or "unknown",
file_type=suffix.lstrip("."),
file_size=len(content),
storage_path=str(save_path),
workspace_id=uuid.UUID(workspace_id) if workspace_id else None,
status="uploaded",
)
db.add(file_record)
task = Task(
task_type="parse_and_vectorize",
status="pending",
file_id=uuid.UUID(file_id),
payload={"workspace_id": workspace_id},
)
db.add(task)
await db.flush()
# 异步触发 Celery 任务
celery_task = process_file_task.delay(file_id, str(task.id), workspace_id)
task.celery_task_id = celery_task.id
await db.flush()
return {"file_id": file_id, "task_id": str(task.id), "status": "processing"}
@router.get("/tasks/{task_id}")
async def get_task(task_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
task = result.scalar_one_or_none()
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
return {
"task_id": str(task.id),
"status": task.status,
"progress": task.progress,
"file_id": str(task.file_id) if task.file_id else None,
"error_msg": task.error_msg,
"completed_at": task.completed_at.isoformat() if task.completed_at else None,
}
@router.post("/qa")
async def qa(req: QARequest):
chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k * 2)
ranked = await rerank(req.query, chunks, top_k=req.top_k)
result = await generate_answer(req.query, ranked)
if not req.return_sources:
result.pop("sources", None)
return result
@router.post("/knowledge/retrieval")
async def retrieval(req: QARequest):
chunks = await hybrid_search(req.query, workspace_id=req.workspace_id, top_k=req.top_k)
return {"chunks": chunks, "total": len(chunks)}

View File

@@ -0,0 +1,111 @@
import uuid
import logging
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, desc
from ..core.deps import get_db
from ..models.db import RegulationSource, RegulationUpdate
from ..worker import fetch_regulation_source
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/regulation", tags=["法规监控"])
class SourceCreate(BaseModel):
name: str
url: str
domain: str = "vehicle_safety"
fetch_interval: int = 86400
fetch_config: dict = {}
class SubscribeRequest(BaseModel):
name: str
channel: str # email / webhook / feishu / dingtalk
target: str
domains: list[str] = []
importance_min: str = "normal"
@router.post("/sources")
async def create_source(req: SourceCreate, db: AsyncSession = Depends(get_db)):
source = RegulationSource(
name=req.name,
url=req.url,
domain=req.domain,
fetch_interval=req.fetch_interval,
fetch_config=req.fetch_config,
)
db.add(source)
await db.flush()
return {
"id": str(source.id),
"name": source.name,
"url": source.url,
"domain": source.domain,
"status": "active",
}
@router.get("/sources")
async def list_sources(db: AsyncSession = Depends(get_db)):
result = await db.execute(
select(RegulationSource).where(RegulationSource.is_active == True)
)
sources = result.scalars().all()
return [{"id": str(s.id), "name": s.name, "url": s.url, "domain": s.domain} for s in sources]
@router.post("/sources/{source_id}/fetch")
async def manual_fetch(source_id: str, db: AsyncSession = Depends(get_db)):
"""手动触发某个监控源的抓取(测试用)"""
result = await db.execute(
select(RegulationSource).where(RegulationSource.id == uuid.UUID(source_id))
)
source = result.scalar_one_or_none()
if not source:
raise HTTPException(status_code=404, detail="监控源不存在")
task = fetch_regulation_source.delay(source_id)
return {"task_id": task.id, "status": "queued", "source_id": source_id}
@router.get("/updates")
async def get_updates(
domain: str | None = None,
limit: int = 20,
offset: int = 0,
db: AsyncSession = Depends(get_db),
):
query = select(RegulationUpdate).order_by(desc(RegulationUpdate.fetched_at))
result = await db.execute(query.limit(limit).offset(offset))
updates = result.scalars().all()
return {
"updates": [
{
"id": str(u.id),
"title": u.title,
"url": u.url,
"change_type": u.change_type,
"summary": u.summary,
"importance": u.importance,
"fetched_at": u.fetched_at.isoformat() if u.fetched_at else None,
}
for u in updates
]
}
@router.post("/subscribe")
async def subscribe(req: SubscribeRequest, db: AsyncSession = Depends(get_db)):
from ..models.db import Workspace # 借用DB session
# 简化版:仅记录订阅(推送逻辑在 push-worker 中实现)
return {
"id": str(uuid.uuid4()),
"name": req.name,
"channel": req.channel,
"domains": req.domains,
"status": "active",
}

View File

@@ -0,0 +1,37 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
# 应用
app_env: str = "development"
log_level: str = "INFO"
api_secret_key: str = "change_this_key"
# 数据库
database_url: str = "postgresql+asyncpg://compliance:compliance123@postgres:5432/compliance_db"
redis_url: str = "redis://:redis123@redis:6379/0"
# Milvus
milvus_host: str = "milvus"
milvus_port: int = 19530
# Neo4j
neo4j_uri: str = "bolt://neo4j:7687"
neo4j_user: str = "neo4j"
neo4j_password: str = "neo4j123"
# AI 服务
embedding_service_url: str = "http://embedding-service:8010"
mcp_server_url: str = "http://mcp-server:8011"
# LLM
llm_provider: str = "deepseek" # deepseek / qwen
deepseek_api_key: str = ""
deepseek_model: str = "deepseek-chat"
dashscope_api_key: str = ""
qwen_model: str = "qwen-plus"
settings = Settings()

View File

@@ -0,0 +1,54 @@
from functools import lru_cache
from typing import AsyncGenerator
import httpx
from neo4j import AsyncGraphDatabase
from pymilvus import connections, Collection
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from .config import settings
# ── PostgreSQL ──────────────────────────────────
engine = create_async_engine(settings.database_url, pool_size=10, max_overflow=20)
AsyncSessionLocal = async_sessionmaker(engine, expire_on_commit=False)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSessionLocal() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
# ── Milvus ──────────────────────────────────────
def get_milvus_collection(name: str) -> Collection:
connections.connect(host=settings.milvus_host, port=settings.milvus_port)
return Collection(name)
# ── Neo4j ───────────────────────────────────────
_neo4j_driver = None
def get_neo4j():
global _neo4j_driver
if _neo4j_driver is None:
_neo4j_driver = AsyncGraphDatabase.driver(
settings.neo4j_uri,
auth=(settings.neo4j_user, settings.neo4j_password),
)
return _neo4j_driver
# ── HTTP 客户端(复用连接池)────────────────────
_http_client = None
def get_http_client() -> httpx.AsyncClient:
global _http_client
if _http_client is None:
_http_client = httpx.AsyncClient(timeout=120.0)
return _http_client

View File

@@ -0,0 +1,56 @@
from langchain_openai import ChatOpenAI
from tenacity import retry, stop_after_attempt, wait_exponential
from .config import settings
def get_llm(temperature: float = 0.1) -> ChatOpenAI:
"""获取 LLM 客户端DeepSeek 或 Qwen均兼容 OpenAI API"""
if settings.llm_provider == "deepseek":
return ChatOpenAI(
model=settings.deepseek_model,
api_key=settings.deepseek_api_key,
base_url="https://api.deepseek.com/v1",
temperature=temperature,
max_retries=3,
timeout=120,
)
elif settings.llm_provider == "qwen":
return ChatOpenAI(
model=settings.qwen_model,
api_key=settings.dashscope_api_key,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
temperature=temperature,
max_retries=3,
timeout=120,
)
raise ValueError(f"不支持的 LLM 提供商:{settings.llm_provider}")
RAG_SYSTEM_PROMPT = """你是一位专业的汽车行业合规专家具备深厚的法规知识GB标准、UN-ECE、ISO 45001、IATF 16949等
回答规则:
1. 仅基于提供的参考文献回答,不添加不在文献中的信息
2. 每个关键陈述必须标注来源(格式:[来源文件名第X页]
3. 如果参考文献不足以回答问题,明确说明
4. 使用专业但清晰的语言,适合工程师和法务人员阅读
5. 对于数值要求(如绝缘电阻值、时间限制等),精确引用原文"""
COMPLIANCE_CHECK_PROMPT = """你是一位专业的汽车合规审查专家。
请对以下内容进行合规性评估:
【待审查内容】
{content}
【相关法规要求】
{regulations}
请按以下格式输出:
1. 整体风险等级:[low/medium/high/critical]
2. 风险分数:[0-100]
3. 发现的合规问题(逐条列出):
- 问题描述
- 违反的具体法规条款
- 严重程度
4. 整改建议(具体可操作)"""

View File

@@ -0,0 +1,84 @@
import logging
import time
import structlog
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from prometheus_fastapi_instrumentator import Instrumentator
from .api import kb, compliance, regulation
from .core.config import settings
# 结构化日志配置
structlog.configure(
wrapper_class=structlog.make_filtering_bound_logger(
getattr(logging, settings.log_level.upper(), logging.INFO)
)
)
logger = structlog.get_logger()
app = FastAPI(
title="AI合规智能中枢 API",
description="面向车企与工厂的全链路合规智能平台",
version="0.1.0",
docs_url="/docs",
redoc_url="/redoc",
)
# CORS开发环境
app.add_middleware(
CORSMiddleware,
allow_origins=["*"] if settings.app_env == "development" else [],
allow_methods=["*"],
allow_headers=["*"],
)
# Prometheus 指标
Instrumentator().instrument(app).expose(app)
# 注册路由
app.include_router(kb.router)
app.include_router(compliance.router)
app.include_router(regulation.router)
@app.middleware("http")
async def log_requests(request: Request, call_next):
start = time.time()
response = await call_next(request)
duration_ms = int((time.time() - start) * 1000)
logger.info(
"request",
method=request.method,
path=request.url.path,
status=response.status_code,
duration_ms=duration_ms,
)
return response
@app.get("/health")
async def health():
"""健康检查(含依赖服务检测)"""
import httpx
from .core.config import settings
checks = {"status": "ok", "services": {}}
# 检查嵌入服务
try:
async with httpx.AsyncClient(timeout=5) as client:
r = await client.get(f"{settings.embedding_service_url}/health")
checks["services"]["embedding"] = "ok" if r.status_code == 200 else "degraded"
except Exception:
checks["services"]["embedding"] = "unavailable"
# 检查 MCP Server
try:
async with httpx.AsyncClient(timeout=5) as client:
r = await client.get(f"{settings.mcp_server_url}/health")
checks["services"]["mcp"] = "ok" if r.status_code == 200 else "degraded"
except Exception:
checks["services"]["mcp"] = "unavailable"
return checks

View File

@@ -0,0 +1,113 @@
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Integer, BigInteger, Boolean, Text, ARRAY, Numeric
from sqlalchemy import DateTime, ForeignKey, func
from sqlalchemy.dialects.postgresql import UUID, JSONB, INET
from sqlalchemy.orm import DeclarativeBase, relationship
class Base(DeclarativeBase):
pass
class Workspace(Base):
__tablename__ = "workspaces"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(255), nullable=False)
description = Column(Text)
domain = Column(String(100))
created_by = Column(String(255))
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
files = relationship("File", back_populates="workspace")
class File(Base):
__tablename__ = "files"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"))
filename = Column(String(500), nullable=False)
original_name = Column(String(500), nullable=False)
file_type = Column(String(50))
file_size = Column(BigInteger)
storage_path = Column(Text)
parsed_path = Column(Text)
status = Column(String(50), default="uploaded")
error_msg = Column(Text)
metadata = Column(JSONB, default={})
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
workspace = relationship("Workspace", back_populates="files")
tasks = relationship("Task", back_populates="file")
class Task(Base):
__tablename__ = "tasks"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
task_type = Column(String(100), nullable=False)
status = Column(String(50), default="pending")
payload = Column(JSONB, default={})
result = Column(JSONB)
error_msg = Column(Text)
progress = Column(Integer, default=0)
file_id = Column(UUID(as_uuid=True), ForeignKey("files.id"))
celery_task_id = Column(String(255))
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
completed_at = Column(DateTime(timezone=True))
file = relationship("File", back_populates="tasks")
class ComplianceReport(Base):
__tablename__ = "compliance_reports"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
file_id = Column(UUID(as_uuid=True), ForeignKey("files.id"))
regulation_domains = Column(ARRAY(Text))
overall_risk_level = Column(String(20))
risk_score = Column(Numeric(5, 2))
findings = Column(JSONB, default=[])
recommendations = Column(JSONB, default=[])
report_markdown = Column(Text)
llm_model = Column(String(100))
created_at = Column(DateTime(timezone=True), server_default=func.now())
class RegulationSource(Base):
__tablename__ = "regulation_sources"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(255), nullable=False)
url = Column(Text, nullable=False)
source_type = Column(String(50), default="webpage")
domain = Column(String(100))
fetch_interval = Column(Integer, default=86400)
is_active = Column(Boolean, default=True)
last_fetched_at = Column(DateTime(timezone=True))
last_hash = Column(String(64))
fetch_config = Column(JSONB, default={})
created_at = Column(DateTime(timezone=True), server_default=func.now())
class RegulationUpdate(Base):
__tablename__ = "regulation_updates"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
source_id = Column(UUID(as_uuid=True), ForeignKey("regulation_sources.id"))
title = Column(String(500))
url = Column(Text)
change_type = Column(String(50))
summary = Column(Text)
raw_content = Column(Text)
diff_content = Column(Text)
is_notified = Column(Boolean, default=False)
importance = Column(String(20), default="normal")
fetched_at = Column(DateTime(timezone=True), server_default=func.now())
published_at = Column(DateTime(timezone=True))

View File

@@ -0,0 +1,21 @@
import httpx
from tenacity import retry, stop_after_attempt, wait_exponential
from ..core.config import settings
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=10))
async def embed_texts(texts: list[str], batch_size: int = 12) -> dict:
"""调用嵌入服务,返回 dense 和 sparse 向量"""
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(
f"{settings.embedding_service_url}/embed",
json={"texts": texts, "batch_size": batch_size},
)
resp.raise_for_status()
return resp.json()
async def embed_single(text: str) -> list[float]:
"""嵌入单条文本,返回 dense 向量"""
result = await embed_texts([text], batch_size=1)
return result["dense"][0]

View File

@@ -0,0 +1,65 @@
import logging
from ..core.deps import get_neo4j
logger = logging.getLogger(__name__)
async def create_regulation_node(regulation: dict) -> str:
"""在 Neo4j 中创建法规节点"""
driver = get_neo4j()
async with driver.session() as session:
result = await session.run(
"""
MERGE (r:Regulation {id: $id})
SET r.title = $title,
r.domain = $domain,
r.version = $version,
r.code = $code
RETURN r.id as id
""",
id=regulation.get("id"),
title=regulation.get("title", ""),
domain=regulation.get("domain", ""),
version=regulation.get("version", ""),
code=regulation.get("code", ""),
)
record = await result.single()
return record["id"] if record else None
async def create_clause_node(clause: dict, regulation_id: str) -> str:
"""创建条款节点并关联到法规"""
driver = get_neo4j()
async with driver.session() as session:
result = await session.run(
"""
MATCH (r:Regulation {id: $reg_id})
MERGE (c:Clause {id: $id})
SET c.number = $number,
c.content = $content
MERGE (r)-[:CONTAINS]->(c)
RETURN c.id as id
""",
reg_id=regulation_id,
id=clause.get("id"),
number=clause.get("number", ""),
content=clause.get("content", "")[:2000],
)
record = await result.single()
return record["id"] if record else None
async def search_related_regulations(domain: str, limit: int = 10) -> list[dict]:
"""查询指定域下的所有法规"""
driver = get_neo4j()
async with driver.session() as session:
result = await session.run(
"""
MATCH (r:Regulation {domain: $domain})
RETURN r.id as id, r.title as title, r.code as code, r.version as version
LIMIT $limit
""",
domain=domain,
limit=limit,
)
return [dict(record) async for record in result]

View File

@@ -0,0 +1,59 @@
import hashlib
import logging
import httpx
from bs4 import BeautifulSoup
from datetime import datetime, timezone
logger = logging.getLogger(__name__)
async def fetch_url(url: str, timeout: int = 30) -> str | None:
"""抓取 URL 内容"""
try:
async with httpx.AsyncClient(
timeout=timeout,
headers={"User-Agent": "Mozilla/5.0 (compliance-monitor/1.0)"},
follow_redirects=True,
) as client:
resp = await client.get(url)
resp.raise_for_status()
return resp.text
except Exception as e:
logger.warning(f"抓取 {url} 失败:{e}")
return None
def extract_text(html: str) -> str:
"""提取 HTML 中的主要文本内容"""
soup = BeautifulSoup(html, "html.parser")
for tag in soup(["script", "style", "nav", "footer", "header"]):
tag.decompose()
return soup.get_text(separator="\n", strip=True)
def compute_hash(content: str) -> str:
return hashlib.md5(content.encode("utf-8")).hexdigest()
async def check_source_for_updates(source: dict) -> dict | None:
"""
检查监控源是否有更新。
返回 None 表示无变化,返回 dict 表示有新内容。
"""
html = await fetch_url(source["url"])
if not html:
return None
text = extract_text(html)
new_hash = compute_hash(text)
if source.get("last_hash") == new_hash:
logger.info(f"监控源 {source['name']} 无变化")
return None
return {
"source_id": source["id"],
"raw_content": text[:50000], # 最多保存 50KB
"new_hash": new_hash,
"fetched_at": datetime.now(timezone.utc).isoformat(),
}

View File

@@ -0,0 +1,43 @@
import httpx
from tenacity import retry, stop_after_attempt, wait_exponential
from ..core.config import settings
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=30))
async def parse_document(file_content: bytes, filename: str) -> dict:
"""调用 mcp-server 解析文档,返回 Markdown"""
async with httpx.AsyncClient(timeout=300.0) as client:
resp = await client.post(
f"{settings.mcp_server_url}/parse-document",
files={"file": (filename, file_content, "application/octet-stream")},
)
resp.raise_for_status()
return resp.json()
def chunk_text(text: str, chunk_size: int = 512, overlap: int = 64) -> list[dict]:
"""将文本按 token 数分块(简单版,按字符数估算)"""
chars_per_chunk = chunk_size * 2 # 中文约2字符/token
chars_overlap = overlap * 2
chunks = []
start = 0
idx = 0
while start < len(text):
end = min(start + chars_per_chunk, len(text))
# 尝试在段落边界截断
if end < len(text):
for sep in ["\n\n", "\n", "", ".", " "]:
pos = text.rfind(sep, start, end)
if pos > start + chars_per_chunk // 2:
end = pos + len(sep)
break
chunk_text = text[start:end].strip()
if chunk_text:
chunks.append({"idx": idx, "content": chunk_text, "start": start, "end": end})
idx += 1
start = max(start + 1, end - chars_overlap)
return chunks

View File

@@ -0,0 +1,92 @@
import logging
from langchain.schema import HumanMessage, SystemMessage
from pymilvus import connections, Collection
from .embed import embed_single, embed_texts
from ..core.llm import get_llm, RAG_SYSTEM_PROMPT
from ..core.config import settings
logger = logging.getLogger(__name__)
def _get_collection(name: str) -> Collection:
connections.connect(host=settings.milvus_host, port=settings.milvus_port)
return Collection(name)
async def hybrid_search(
query: str,
collection_name: str = "regulation_chunks",
top_k: int = 10,
workspace_id: str | None = None,
) -> list[dict]:
"""混合检索BGE-M3 向量检索(调研版简化,省去 BM25 融合)"""
query_vec = await embed_single(query)
col = _get_collection(collection_name)
expr = f'workspace_id == "{workspace_id}"' if workspace_id else None
results = col.search(
data=[query_vec],
anns_field="dense_vec",
param={"metric_type": "COSINE", "params": {"ef": 100}},
limit=top_k,
expr=expr,
output_fields=["content", "metadata", "file_id", "chunk_idx"],
)
chunks = []
for hits in results:
for hit in hits:
chunks.append({
"id": hit.id,
"content": hit.entity.get("content", ""),
"score": float(hit.score),
"file_id": hit.entity.get("file_id", ""),
"chunk_idx": hit.entity.get("chunk_idx", 0),
"metadata": hit.entity.get("metadata", {}),
})
return chunks
async def rerank(query: str, chunks: list[dict], top_k: int = 5) -> list[dict]:
"""简化版精排(调研版按 score 直接排序,生产可换 Cross-Encoder"""
return sorted(chunks, key=lambda x: x["score"], reverse=True)[:top_k]
async def generate_answer(query: str, chunks: list[dict]) -> dict:
"""基于检索结果,调用 LLM 生成引文锚定的答案"""
if not chunks:
return {"answer": "未找到相关法规内容,请上传相关法规文档后重试。", "sources": []}
# 构建 RAG 上下文
context_parts = []
for i, chunk in enumerate(chunks, 1):
meta = chunk.get("metadata", {})
source_info = f"[来源 {i}{meta.get('filename', '未知文件')},第 {meta.get('page', '?')} 页]"
context_parts.append(f"{source_info}\n{chunk['content']}")
context = "\n\n---\n\n".join(context_parts)
user_prompt = f"参考文献:\n\n{context}\n\n问题:{query}\n\n请基于以上参考文献回答,并标注来源。"
llm = get_llm(temperature=0.1)
messages = [SystemMessage(content=RAG_SYSTEM_PROMPT), HumanMessage(content=user_prompt)]
try:
response = await llm.ainvoke(messages)
answer = response.content
except Exception as e:
logger.error(f"LLM 生成失败:{e}")
answer = f"LLM 生成失败:{e}。检索到的相关内容:{chunks[0]['content'][:200]}..."
sources = [
{
"content": c["content"][:300],
"file_id": c.get("file_id", ""),
"chunk_idx": c.get("chunk_idx", 0),
"score": c.get("score", 0),
"metadata": c.get("metadata", {}),
}
for c in chunks
]
return {"answer": answer, "sources": sources}

View File

@@ -0,0 +1,212 @@
import uuid
import logging
from datetime import datetime, timezone
from celery import Celery
from celery.schedules import crontab
from .core.config import settings
logger = logging.getLogger(__name__)
# Celery 配置
celery_app = Celery(
"compliance",
broker=settings.redis_url,
backend=settings.redis_url,
)
celery_app.conf.update(
task_serializer="json",
accept_content=["json"],
result_serializer="json",
timezone="Asia/Shanghai",
task_routes={
"app.worker.process_file_task": {"queue": "parse"},
"app.worker.fetch_regulation_source": {"queue": "monitor"},
"app.worker.send_notifications": {"queue": "push"},
},
beat_schedule={
"daily-regulation-monitor": {
"task": "app.worker.run_all_monitors",
"schedule": crontab(hour=2, minute=0),
},
},
)
# ── 文件处理任务(解析 + 向量化)────────────────
@celery_app.task(name="app.worker.process_file_task", bind=True, max_retries=3)
def process_file_task(self, file_id: str, task_id: str, workspace_id: str):
"""解析文档并向量化存入 Milvus"""
import asyncio
asyncio.run(_process_file(file_id, task_id, workspace_id))
async def _process_file(file_id: str, task_id: str, workspace_id: str):
from pathlib import Path
from sqlalchemy import select
from .core.deps import AsyncSessionLocal, get_milvus_collection
from .models.db import File, Task
from .services.parse import parse_document, chunk_text
from .services.embed import embed_texts
async with AsyncSessionLocal() as db:
# 查找文件记录
result = await db.execute(select(File).where(File.id == uuid.UUID(file_id)))
file_record = result.scalar_one_or_none()
if not file_record:
logger.error(f"文件 {file_id} 不存在")
return
task_result = await db.execute(select(Task).where(Task.id == uuid.UUID(task_id)))
task = task_result.scalar_one_or_none()
try:
# 更新状态
file_record.status = "parsing"
if task:
task.status = "running"
task.progress = 10
await db.commit()
# Step 1解析文档
file_content = Path(file_record.storage_path).read_bytes()
parse_result = await parse_document(file_content, file_record.original_name)
markdown = parse_result.get("markdown", "")
if not markdown.strip():
raise ValueError("文档解析结果为空")
file_record.status = "parsed"
if task:
task.progress = 40
await db.commit()
# Step 2分块
chunks = chunk_text(markdown, chunk_size=512, overlap=64)
logger.info(f"文件 {file_id} 分割为 {len(chunks)}")
# Step 3向量化分批处理
batch_size = 16
col = get_milvus_collection("regulation_chunks")
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
texts = [c["content"] for c in batch]
embed_result = await embed_texts(texts, batch_size=batch_size)
dense_vecs = embed_result["dense"]
entities = [
[f"{file_id}_{c['idx']}" for c in batch],
[file_id] * len(batch),
[workspace_id] * len(batch),
[c["idx"] for c in batch],
[c["content"] for c in batch],
dense_vecs,
[{"filename": file_record.original_name, "page": c.get("page", 0)} for c in batch],
]
col.insert(entities)
if task:
task.progress = 40 + int(60 * (i + batch_size) / len(chunks))
await db.commit()
col.flush()
# 完成
file_record.status = "vectorized"
if task:
task.status = "completed"
task.progress = 100
task.completed_at = datetime.now(timezone.utc)
await db.commit()
logger.info(f"文件 {file_id} 处理完成")
except Exception as e:
logger.error(f"文件 {file_id} 处理失败:{e}")
file_record.status = "failed"
file_record.error_msg = str(e)
if task:
task.status = "failed"
task.error_msg = str(e)
await db.commit()
raise
# ── 法规监控任务 ────────────────────────────────
@celery_app.task(name="app.worker.run_all_monitors")
def run_all_monitors():
"""定时触发所有活跃监控源"""
import asyncio
asyncio.run(_run_all_monitors())
async def _run_all_monitors():
from sqlalchemy import select
from .core.deps import AsyncSessionLocal
from .models.db import RegulationSource
async with AsyncSessionLocal() as db:
result = await db.execute(
select(RegulationSource).where(RegulationSource.is_active == True)
)
sources = result.scalars().all()
for source in sources:
fetch_regulation_source.delay(str(source.id))
logger.info(f"触发监控源抓取:{source.name}")
@celery_app.task(name="app.worker.fetch_regulation_source", bind=True, max_retries=2)
def fetch_regulation_source(self, source_id: str):
import asyncio
asyncio.run(_fetch_source(source_id))
async def _fetch_source(source_id: str):
import hashlib
from sqlalchemy import select
from .core.deps import AsyncSessionLocal
from .models.db import RegulationSource, RegulationUpdate
from .services.monitor import check_source_for_updates
async with AsyncSessionLocal() as db:
result = await db.execute(
select(RegulationSource).where(RegulationSource.id == uuid.UUID(source_id))
)
source = result.scalar_one_or_none()
if not source:
return
source_dict = {
"id": str(source.id),
"name": source.name,
"url": source.url,
"last_hash": source.last_hash,
}
update_data = await check_source_for_updates(source_dict)
if update_data:
logger.info(f"检测到变更:{source.name}")
source.last_hash = update_data["new_hash"]
source.last_fetched_at = datetime.now(timezone.utc)
update = RegulationUpdate(
source_id=uuid.UUID(source_id),
change_type="updated",
raw_content=update_data["raw_content"][:50000],
importance="normal",
)
db.add(update)
await db.commit()
else:
source.last_fetched_at = datetime.now(timezone.utc)
await db.commit()
@celery_app.task(name="app.worker.send_notifications")
def send_notifications():
logger.info("推送通知任务执行(待实现)")
# 导出供 FastAPI 使用
worker = celery_app