104 lines
4.2 KiB
Python
104 lines
4.2 KiB
Python
from typing import AsyncIterator, Dict, Any, List, Optional
|
|
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
|
|
from langchain_core.tools import BaseTool
|
|
import logging
|
|
|
|
from .config import get_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMClient:
|
|
"""Wrapper for OpenAI/Azure OpenAI clients with streaming and function calling support"""
|
|
|
|
def __init__(self):
|
|
self.config = get_config()
|
|
self.llm = self._create_llm()
|
|
self.llm_with_tools = None
|
|
|
|
def _create_llm(self) -> ChatOpenAI | AzureChatOpenAI:
|
|
"""Create LLM client based on configuration"""
|
|
llm_config = self.config.get_llm_config()
|
|
|
|
if llm_config["provider"] == "openai":
|
|
# Create base parameters
|
|
params = {
|
|
"base_url": llm_config["base_url"],
|
|
"api_key": llm_config["api_key"],
|
|
"model": llm_config["model"],
|
|
"streaming": True,
|
|
}
|
|
# Only add temperature if explicitly set
|
|
if "temperature" in llm_config:
|
|
params["temperature"] = llm_config["temperature"]
|
|
return ChatOpenAI(**params)
|
|
elif llm_config["provider"] == "azure":
|
|
# Create base parameters
|
|
params = {
|
|
"azure_endpoint": llm_config["base_url"],
|
|
"api_key": llm_config["api_key"],
|
|
"azure_deployment": llm_config["deployment"],
|
|
"api_version": llm_config["api_version"],
|
|
"streaming": True,
|
|
}
|
|
# Only add temperature if explicitly set
|
|
if "temperature" in llm_config:
|
|
params["temperature"] = llm_config["temperature"]
|
|
return AzureChatOpenAI(**params)
|
|
else:
|
|
raise ValueError(f"Unsupported provider: {llm_config['provider']}")
|
|
|
|
def bind_tools(self, tools: List[Dict[str, Any]], force_tool_choice: bool = False):
|
|
"""Bind tools to LLM for function calling"""
|
|
if force_tool_choice:
|
|
# Use tool_choice="required" to force tool calling for DeepSeek
|
|
self.llm_with_tools = self.llm.bind_tools(tools, tool_choice="required")
|
|
else:
|
|
self.llm_with_tools = self.llm.bind_tools(tools)
|
|
|
|
async def astream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
|
"""Stream LLM response tokens"""
|
|
try:
|
|
async for chunk in self.llm.astream(messages):
|
|
if chunk.content and isinstance(chunk.content, str):
|
|
yield chunk.content
|
|
except Exception as e:
|
|
logger.error(f"LLM streaming error: {e}")
|
|
raise
|
|
|
|
async def ainvoke(self, messages: list[BaseMessage]) -> AIMessage:
|
|
"""Get complete LLM response"""
|
|
try:
|
|
response = await self.llm.ainvoke(messages)
|
|
if isinstance(response, AIMessage):
|
|
return response
|
|
else:
|
|
# Convert to AIMessage if needed
|
|
return AIMessage(content=str(response.content) if response.content else "")
|
|
except Exception as e:
|
|
logger.error(f"LLM invoke error: {e}")
|
|
raise
|
|
|
|
async def ainvoke_with_tools(self, messages: list[BaseMessage]) -> AIMessage:
|
|
"""Get LLM response with tool calling capability"""
|
|
try:
|
|
if not self.llm_with_tools:
|
|
raise ValueError("Tools not bound to LLM. Call bind_tools() first.")
|
|
response = await self.llm_with_tools.ainvoke(messages)
|
|
if isinstance(response, AIMessage):
|
|
return response
|
|
else:
|
|
return AIMessage(content=str(response.content) if response.content else "")
|
|
except Exception as e:
|
|
logger.error(f"LLM with tools invoke error: {e}")
|
|
raise
|
|
|
|
def create_messages(self, system_prompt: str, user_prompt: str) -> list[BaseMessage]:
|
|
"""Create message list for LLM"""
|
|
messages = []
|
|
if system_prompt:
|
|
messages.append(SystemMessage(content=system_prompt))
|
|
messages.append(HumanMessage(content=user_prompt))
|
|
return messages
|