Files
catonline_ai/vw-agentic-rag/service/llm_client.py
2025-09-26 17:15:54 +08:00

104 lines
4.2 KiB
Python

from typing import AsyncIterator, Dict, Any, List, Optional
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
from langchain_core.tools import BaseTool
import logging
from .config import get_config
logger = logging.getLogger(__name__)
class LLMClient:
"""Wrapper for OpenAI/Azure OpenAI clients with streaming and function calling support"""
def __init__(self):
self.config = get_config()
self.llm = self._create_llm()
self.llm_with_tools = None
def _create_llm(self) -> ChatOpenAI | AzureChatOpenAI:
"""Create LLM client based on configuration"""
llm_config = self.config.get_llm_config()
if llm_config["provider"] == "openai":
# Create base parameters
params = {
"base_url": llm_config["base_url"],
"api_key": llm_config["api_key"],
"model": llm_config["model"],
"streaming": True,
}
# Only add temperature if explicitly set
if "temperature" in llm_config:
params["temperature"] = llm_config["temperature"]
return ChatOpenAI(**params)
elif llm_config["provider"] == "azure":
# Create base parameters
params = {
"azure_endpoint": llm_config["base_url"],
"api_key": llm_config["api_key"],
"azure_deployment": llm_config["deployment"],
"api_version": llm_config["api_version"],
"streaming": True,
}
# Only add temperature if explicitly set
if "temperature" in llm_config:
params["temperature"] = llm_config["temperature"]
return AzureChatOpenAI(**params)
else:
raise ValueError(f"Unsupported provider: {llm_config['provider']}")
def bind_tools(self, tools: List[Dict[str, Any]], force_tool_choice: bool = False):
"""Bind tools to LLM for function calling"""
if force_tool_choice:
# Use tool_choice="required" to force tool calling for DeepSeek
self.llm_with_tools = self.llm.bind_tools(tools, tool_choice="required")
else:
self.llm_with_tools = self.llm.bind_tools(tools)
async def astream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
"""Stream LLM response tokens"""
try:
async for chunk in self.llm.astream(messages):
if chunk.content and isinstance(chunk.content, str):
yield chunk.content
except Exception as e:
logger.error(f"LLM streaming error: {e}")
raise
async def ainvoke(self, messages: list[BaseMessage]) -> AIMessage:
"""Get complete LLM response"""
try:
response = await self.llm.ainvoke(messages)
if isinstance(response, AIMessage):
return response
else:
# Convert to AIMessage if needed
return AIMessage(content=str(response.content) if response.content else "")
except Exception as e:
logger.error(f"LLM invoke error: {e}")
raise
async def ainvoke_with_tools(self, messages: list[BaseMessage]) -> AIMessage:
"""Get LLM response with tool calling capability"""
try:
if not self.llm_with_tools:
raise ValueError("Tools not bound to LLM. Call bind_tools() first.")
response = await self.llm_with_tools.ainvoke(messages)
if isinstance(response, AIMessage):
return response
else:
return AIMessage(content=str(response.content) if response.content else "")
except Exception as e:
logger.error(f"LLM with tools invoke error: {e}")
raise
def create_messages(self, system_prompt: str, user_prompt: str) -> list[BaseMessage]:
"""Create message list for LLM"""
messages = []
if system_prompt:
messages.append(SystemMessage(content=system_prompt))
messages.append(HumanMessage(content=user_prompt))
return messages