from typing import AsyncIterator, Dict, Any, List, Optional from langchain_openai import ChatOpenAI, AzureChatOpenAI from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage from langchain_core.tools import BaseTool import logging from .config import get_config logger = logging.getLogger(__name__) class LLMClient: """Wrapper for OpenAI/Azure OpenAI clients with streaming and function calling support""" def __init__(self): self.config = get_config() self.llm = self._create_llm() self.llm_with_tools = None def _create_llm(self) -> ChatOpenAI | AzureChatOpenAI: """Create LLM client based on configuration""" llm_config = self.config.get_llm_config() if llm_config["provider"] == "openai": # Create base parameters params = { "base_url": llm_config["base_url"], "api_key": llm_config["api_key"], "model": llm_config["model"], "streaming": True, } # Only add temperature if explicitly set if "temperature" in llm_config: params["temperature"] = llm_config["temperature"] return ChatOpenAI(**params) elif llm_config["provider"] == "azure": # Create base parameters params = { "azure_endpoint": llm_config["base_url"], "api_key": llm_config["api_key"], "azure_deployment": llm_config["deployment"], "api_version": llm_config["api_version"], "streaming": True, } # Only add temperature if explicitly set if "temperature" in llm_config: params["temperature"] = llm_config["temperature"] return AzureChatOpenAI(**params) else: raise ValueError(f"Unsupported provider: {llm_config['provider']}") def bind_tools(self, tools: List[Dict[str, Any]], force_tool_choice: bool = False): """Bind tools to LLM for function calling""" if force_tool_choice: # Use tool_choice="required" to force tool calling for DeepSeek self.llm_with_tools = self.llm.bind_tools(tools, tool_choice="required") else: self.llm_with_tools = self.llm.bind_tools(tools) async def astream(self, messages: list[BaseMessage]) -> AsyncIterator[str]: """Stream LLM response tokens""" try: async for chunk in self.llm.astream(messages): if chunk.content and isinstance(chunk.content, str): yield chunk.content except Exception as e: logger.error(f"LLM streaming error: {e}") raise async def ainvoke(self, messages: list[BaseMessage]) -> AIMessage: """Get complete LLM response""" try: response = await self.llm.ainvoke(messages) if isinstance(response, AIMessage): return response else: # Convert to AIMessage if needed return AIMessage(content=str(response.content) if response.content else "") except Exception as e: logger.error(f"LLM invoke error: {e}") raise async def ainvoke_with_tools(self, messages: list[BaseMessage]) -> AIMessage: """Get LLM response with tool calling capability""" try: if not self.llm_with_tools: raise ValueError("Tools not bound to LLM. Call bind_tools() first.") response = await self.llm_with_tools.ainvoke(messages) if isinstance(response, AIMessage): return response else: return AIMessage(content=str(response.content) if response.content else "") except Exception as e: logger.error(f"LLM with tools invoke error: {e}") raise def create_messages(self, system_prompt: str, user_prompt: str) -> list[BaseMessage]: """Create message list for LLM""" messages = [] if system_prompt: messages.append(SystemMessage(content=system_prompt)) messages.append(HumanMessage(content=user_prompt)) return messages