from typing import TypedDict, List from langgraph.graph import StateGraph, END class RagState(TypedDict): query: str query_embedding: List[float] retrieved_docs: List[dict] context: str answer: str sources: List[dict] def embed_query(state: RagState) -> dict: """将查询转为向量""" from app.services import embedding_service embedding = embedding_service.embed_single(state["query"]) return {"query_embedding": embedding} def retrieve_docs(state: RagState) -> dict: """向量检索""" from app.services import milvus_service from app.core.config import settings docs = milvus_service.search( state["query_embedding"], top_k=settings.vector_top_k, ) return {"retrieved_docs": docs[:settings.final_top_k]} def build_context(state: RagState) -> dict: """构建上下文""" context_parts = [] sources = [] for doc in state["retrieved_docs"]: context_parts.append(f"【{doc['doc_name']} - {doc.get('clause_id', '')}】\n{doc['content']}") sources.append({ "name": doc["doc_name"], "clause": doc.get("clause_id"), }) context = "\n\n".join(context_parts) return {"context": context, "sources": sources} def generate_answer(state: RagState) -> dict: """生成答案""" from app.services import llm_service prompt = f"""请根据以下法规内容回答用户问题,并在回答中标注引用的法规条款。 法规内容: {state['context']} 用户问题:{state['query']} 请给出准确、简洁的回答,并引用相关法规条款。""" answer = "" for chunk in llm_service.generate_stream(prompt): answer += chunk return {"answer": answer} # 构建工作流图 rag_graph = StateGraph(RagState) rag_graph.add_node("embed", embed_query) rag_graph.add_node("retrieve", retrieve_docs) rag_graph.add_node("build_context", build_context) rag_graph.add_node("generate", generate_answer) rag_graph.set_entry_point("embed") rag_graph.add_edge("embed", "retrieve") rag_graph.add_edge("retrieve", "build_context") rag_graph.add_edge("build_context", "generate") rag_graph.add_edge("generate", END) rag_workflow = rag_graph.compile() async def run_rag_workflow(query: str) -> RagState: """运行RAG工作流""" initial_state: RagState = {"query": query} result = rag_workflow.invoke(initial_state) return result def stream_rag_workflow(query: str): """流式运行RAG工作流""" from app.services import llm_service # 先完成检索阶段 state: RagState = {"query": query} state.update(embed_query(state)) state.update(retrieve_docs(state)) state.update(build_context(state)) # 流式生成阶段 prompt = f"""请根据以下法规内容回答用户问题,并在回答中标注引用的法规条款。 法规内容: {state['context']} 用户问题:{state['query']} 请给出准确、简洁的回答,并引用相关法规条款。""" for chunk in llm_service.generate_stream(prompt): yield {"type": "chunk", "text": chunk} yield {"type": "done", "sources": state["sources"]}