init
This commit is contained in:
103
vw-agentic-rag/service/llm_client.py
Normal file
103
vw-agentic-rag/service/llm_client.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from typing import AsyncIterator, Dict, Any, List, Optional
|
||||
from langchain_openai import ChatOpenAI, AzureChatOpenAI
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
import logging
|
||||
|
||||
from .config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""Wrapper for OpenAI/Azure OpenAI clients with streaming and function calling support"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
self.llm = self._create_llm()
|
||||
self.llm_with_tools = None
|
||||
|
||||
def _create_llm(self) -> ChatOpenAI | AzureChatOpenAI:
|
||||
"""Create LLM client based on configuration"""
|
||||
llm_config = self.config.get_llm_config()
|
||||
|
||||
if llm_config["provider"] == "openai":
|
||||
# Create base parameters
|
||||
params = {
|
||||
"base_url": llm_config["base_url"],
|
||||
"api_key": llm_config["api_key"],
|
||||
"model": llm_config["model"],
|
||||
"streaming": True,
|
||||
}
|
||||
# Only add temperature if explicitly set
|
||||
if "temperature" in llm_config:
|
||||
params["temperature"] = llm_config["temperature"]
|
||||
return ChatOpenAI(**params)
|
||||
elif llm_config["provider"] == "azure":
|
||||
# Create base parameters
|
||||
params = {
|
||||
"azure_endpoint": llm_config["base_url"],
|
||||
"api_key": llm_config["api_key"],
|
||||
"azure_deployment": llm_config["deployment"],
|
||||
"api_version": llm_config["api_version"],
|
||||
"streaming": True,
|
||||
}
|
||||
# Only add temperature if explicitly set
|
||||
if "temperature" in llm_config:
|
||||
params["temperature"] = llm_config["temperature"]
|
||||
return AzureChatOpenAI(**params)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {llm_config['provider']}")
|
||||
|
||||
def bind_tools(self, tools: List[Dict[str, Any]], force_tool_choice: bool = False):
|
||||
"""Bind tools to LLM for function calling"""
|
||||
if force_tool_choice:
|
||||
# Use tool_choice="required" to force tool calling for DeepSeek
|
||||
self.llm_with_tools = self.llm.bind_tools(tools, tool_choice="required")
|
||||
else:
|
||||
self.llm_with_tools = self.llm.bind_tools(tools)
|
||||
|
||||
async def astream(self, messages: list[BaseMessage]) -> AsyncIterator[str]:
|
||||
"""Stream LLM response tokens"""
|
||||
try:
|
||||
async for chunk in self.llm.astream(messages):
|
||||
if chunk.content and isinstance(chunk.content, str):
|
||||
yield chunk.content
|
||||
except Exception as e:
|
||||
logger.error(f"LLM streaming error: {e}")
|
||||
raise
|
||||
|
||||
async def ainvoke(self, messages: list[BaseMessage]) -> AIMessage:
|
||||
"""Get complete LLM response"""
|
||||
try:
|
||||
response = await self.llm.ainvoke(messages)
|
||||
if isinstance(response, AIMessage):
|
||||
return response
|
||||
else:
|
||||
# Convert to AIMessage if needed
|
||||
return AIMessage(content=str(response.content) if response.content else "")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM invoke error: {e}")
|
||||
raise
|
||||
|
||||
async def ainvoke_with_tools(self, messages: list[BaseMessage]) -> AIMessage:
|
||||
"""Get LLM response with tool calling capability"""
|
||||
try:
|
||||
if not self.llm_with_tools:
|
||||
raise ValueError("Tools not bound to LLM. Call bind_tools() first.")
|
||||
response = await self.llm_with_tools.ainvoke(messages)
|
||||
if isinstance(response, AIMessage):
|
||||
return response
|
||||
else:
|
||||
return AIMessage(content=str(response.content) if response.content else "")
|
||||
except Exception as e:
|
||||
logger.error(f"LLM with tools invoke error: {e}")
|
||||
raise
|
||||
|
||||
def create_messages(self, system_prompt: str, user_prompt: str) -> list[BaseMessage]:
|
||||
"""Create message list for LLM"""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append(SystemMessage(content=system_prompt))
|
||||
messages.append(HumanMessage(content=user_prompt))
|
||||
return messages
|
||||
Reference in New Issue
Block a user