Files
AIRegulation-DocAnalysis/backend/app/workflows/rag_workflow.py

119 lines
3.4 KiB
Python
Raw Normal View History

"""Define workflow state for rag workflow."""
2026-05-14 15:07:34 +08:00
from typing import TypedDict, List
from langgraph.graph import StateGraph, END
# Keep workflow state definitions compact so transitions stay easy to audit.
2026-05-14 15:07:34 +08:00
class RagState(TypedDict):
"""Track workflow state for rag state."""
2026-05-14 15:07:34 +08:00
query: str
query_embedding: List[float]
retrieved_docs: List[dict]
context: str
answer: str
sources: List[dict]
def embed_query(state: RagState) -> dict:
"""Embed query."""
2026-05-14 15:07:34 +08:00
from app.services import embedding_service
embedding = embedding_service.embed_single(state["query"])
return {"query_embedding": embedding}
def retrieve_docs(state: RagState) -> dict:
"""Handle retrieve docs."""
2026-05-14 15:07:34 +08:00
from app.services import milvus_service
from app.core.config import settings
docs = milvus_service.search(
state["query_embedding"],
top_k=settings.vector_top_k,
)
return {"retrieved_docs": docs[:settings.final_top_k]}
def build_context(state: RagState) -> dict:
"""Build context."""
2026-05-14 15:07:34 +08:00
context_parts = []
sources = []
for doc in state["retrieved_docs"]:
context_parts.append(f"{doc['doc_name']} - {doc.get('clause_id', '')}\n{doc['content']}")
sources.append({
"name": doc["doc_name"],
"clause": doc.get("clause_id"),
})
context = "\n\n".join(context_parts)
return {"context": context, "sources": sources}
def generate_answer(state: RagState) -> dict:
"""Handle generate answer."""
2026-05-14 15:07:34 +08:00
from app.services import llm_service
prompt = f"""请根据以下法规内容回答用户问题,并在回答中标注引用的法规条款。
法规内容
{state['context']}
用户问题{state['query']}
请给出准确简洁的回答并引用相关法规条款"""
answer = ""
for chunk in llm_service.generate_stream(prompt):
answer += chunk
return {"answer": answer}
# Keep workflow state definitions compact so transitions stay easy to audit.
2026-05-14 15:07:34 +08:00
rag_graph = StateGraph(RagState)
rag_graph.add_node("embed", embed_query)
rag_graph.add_node("retrieve", retrieve_docs)
rag_graph.add_node("build_context", build_context)
rag_graph.add_node("generate", generate_answer)
rag_graph.set_entry_point("embed")
rag_graph.add_edge("embed", "retrieve")
rag_graph.add_edge("retrieve", "build_context")
rag_graph.add_edge("build_context", "generate")
rag_graph.add_edge("generate", END)
rag_workflow = rag_graph.compile()
async def run_rag_workflow(query: str) -> RagState:
"""Handle run rag workflow."""
2026-05-14 15:07:34 +08:00
initial_state: RagState = {"query": query}
result = rag_workflow.invoke(initial_state)
return result
def stream_rag_workflow(query: str):
"""Stream rag workflow."""
2026-05-14 15:07:34 +08:00
from app.services import llm_service
# Keep workflow state definitions compact so transitions stay easy to audit.
2026-05-14 15:07:34 +08:00
state: RagState = {"query": query}
state.update(embed_query(state))
state.update(retrieve_docs(state))
state.update(build_context(state))
# Keep workflow state definitions compact so transitions stay easy to audit.
2026-05-14 15:07:34 +08:00
prompt = f"""请根据以下法规内容回答用户问题,并在回答中标注引用的法规条款。
法规内容
{state['context']}
用户问题{state['query']}
请给出准确简洁的回答并引用相关法规条款"""
for chunk in llm_service.generate_stream(prompt):
yield {"type": "chunk", "text": chunk}
yield {"type": "done", "sources": state["sources"]}