114 lines
3.1 KiB
Python
114 lines
3.1 KiB
Python
from typing import TypedDict, List
|
|
from langgraph.graph import StateGraph, END
|
|
|
|
|
|
class RagState(TypedDict):
|
|
query: str
|
|
query_embedding: List[float]
|
|
retrieved_docs: List[dict]
|
|
context: str
|
|
answer: str
|
|
sources: List[dict]
|
|
|
|
|
|
def embed_query(state: RagState) -> dict:
|
|
"""将查询转为向量"""
|
|
from app.services import embedding_service
|
|
embedding = embedding_service.embed_single(state["query"])
|
|
return {"query_embedding": embedding}
|
|
|
|
|
|
def retrieve_docs(state: RagState) -> dict:
|
|
"""向量检索"""
|
|
from app.services import milvus_service
|
|
from app.core.config import settings
|
|
docs = milvus_service.search(
|
|
state["query_embedding"],
|
|
top_k=settings.vector_top_k,
|
|
)
|
|
return {"retrieved_docs": docs[:settings.final_top_k]}
|
|
|
|
|
|
def build_context(state: RagState) -> dict:
|
|
"""构建上下文"""
|
|
context_parts = []
|
|
sources = []
|
|
|
|
for doc in state["retrieved_docs"]:
|
|
context_parts.append(f"【{doc['doc_name']} - {doc.get('clause_id', '')}】\n{doc['content']}")
|
|
sources.append({
|
|
"name": doc["doc_name"],
|
|
"clause": doc.get("clause_id"),
|
|
})
|
|
|
|
context = "\n\n".join(context_parts)
|
|
return {"context": context, "sources": sources}
|
|
|
|
|
|
def generate_answer(state: RagState) -> dict:
|
|
"""生成答案"""
|
|
from app.services import llm_service
|
|
prompt = f"""请根据以下法规内容回答用户问题,并在回答中标注引用的法规条款。
|
|
|
|
法规内容:
|
|
{state['context']}
|
|
|
|
用户问题:{state['query']}
|
|
|
|
请给出准确、简洁的回答,并引用相关法规条款。"""
|
|
|
|
answer = ""
|
|
for chunk in llm_service.generate_stream(prompt):
|
|
answer += chunk
|
|
|
|
return {"answer": answer}
|
|
|
|
|
|
# 构建工作流图
|
|
rag_graph = StateGraph(RagState)
|
|
|
|
rag_graph.add_node("embed", embed_query)
|
|
rag_graph.add_node("retrieve", retrieve_docs)
|
|
rag_graph.add_node("build_context", build_context)
|
|
rag_graph.add_node("generate", generate_answer)
|
|
|
|
rag_graph.set_entry_point("embed")
|
|
rag_graph.add_edge("embed", "retrieve")
|
|
rag_graph.add_edge("retrieve", "build_context")
|
|
rag_graph.add_edge("build_context", "generate")
|
|
rag_graph.add_edge("generate", END)
|
|
|
|
rag_workflow = rag_graph.compile()
|
|
|
|
|
|
async def run_rag_workflow(query: str) -> RagState:
|
|
"""运行RAG工作流"""
|
|
initial_state: RagState = {"query": query}
|
|
result = rag_workflow.invoke(initial_state)
|
|
return result
|
|
|
|
|
|
def stream_rag_workflow(query: str):
|
|
"""流式运行RAG工作流"""
|
|
from app.services import llm_service
|
|
|
|
# 先完成检索阶段
|
|
state: RagState = {"query": query}
|
|
state.update(embed_query(state))
|
|
state.update(retrieve_docs(state))
|
|
state.update(build_context(state))
|
|
|
|
# 流式生成阶段
|
|
prompt = f"""请根据以下法规内容回答用户问题,并在回答中标注引用的法规条款。
|
|
|
|
法规内容:
|
|
{state['context']}
|
|
|
|
用户问题:{state['query']}
|
|
|
|
请给出准确、简洁的回答,并引用相关法规条款。"""
|
|
|
|
for chunk in llm_service.generate_stream(prompt):
|
|
yield {"type": "chunk", "text": chunk}
|
|
|
|
yield {"type": "done", "sources": state["sources"]} |