99 lines
3.2 KiB
Python
99 lines
3.2 KiB
Python
"""
|
|
Tool definitions and schemas for the Agentic RAG system.
|
|
This module contains all tool implementations and their corresponding schemas.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, Any, List
|
|
from langchain_core.tools import tool
|
|
|
|
from ..retrieval.retrieval import AgenticRetrieval
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Tool Definitions using @tool decorator (following LangGraph best practices)
|
|
@tool
|
|
async def retrieve_standard_regulation(query: str) -> Dict[str, Any]:
|
|
"""Search for attributes/metadata of China standards and regulations in automobile/manufacturing industry"""
|
|
async with AgenticRetrieval() as retrieval:
|
|
try:
|
|
result = await retrieval.retrieve_standard_regulation(
|
|
query=query
|
|
)
|
|
return {
|
|
"tool_name": "retrieve_standard_regulation",
|
|
"results_count": len(result.results),
|
|
"results": result.results, # Already dict objects, no need for model_dump()
|
|
"took_ms": result.took_ms
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Retrieval error: {e}")
|
|
return {"error": str(e), "results_count": 0, "results": []}
|
|
|
|
|
|
@tool
|
|
async def retrieve_doc_chunk_standard_regulation(query: str) -> Dict[str, Any]:
|
|
"""Search for detailed document content chunks of China standards and regulations in automobile/manufacturing industry"""
|
|
async with AgenticRetrieval() as retrieval:
|
|
try:
|
|
result = await retrieval.retrieve_doc_chunk_standard_regulation(
|
|
query=query
|
|
)
|
|
return {
|
|
"tool_name": "retrieve_doc_chunk_standard_regulation",
|
|
"results_count": len(result.results),
|
|
"results": result.results, # Already dict objects, no need for model_dump()
|
|
"took_ms": result.took_ms
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Doc chunk retrieval error: {e}")
|
|
return {"error": str(e), "results_count": 0, "results": []}
|
|
|
|
|
|
# Available tools list
|
|
tools = [retrieve_standard_regulation, retrieve_doc_chunk_standard_regulation]
|
|
|
|
|
|
def get_tool_schemas() -> List[Dict[str, Any]]:
|
|
"""
|
|
Generate tool schemas for LLM function calling.
|
|
|
|
Returns:
|
|
List of tool schemas in OpenAI function calling format
|
|
"""
|
|
tools.append();
|
|
|
|
tool_schemas = []
|
|
for tool in tools:
|
|
schema = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "Search query for retrieving relevant information"
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}
|
|
}
|
|
}
|
|
tool_schemas.append(schema)
|
|
|
|
return tool_schemas
|
|
|
|
|
|
def get_tools_by_name() -> Dict[str, Any]:
|
|
"""
|
|
Create a mapping of tool names to tool functions.
|
|
|
|
Returns:
|
|
Dictionary mapping tool names to tool functions
|
|
"""
|
|
return {tool.name: tool for tool in tools}
|