Files
catonline_ai/vw-agentic-rag/service/config.py

298 lines
11 KiB
Python
Raw Normal View History

2025-09-26 17:15:54 +08:00
import yaml
import os
from typing import Dict, Any, Optional
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
class OpenAIConfig(BaseModel):
base_url: str = "https://api.openai.com/v1"
api_key: str
model: str = "gpt-4o"
class AzureConfig(BaseModel):
base_url: str
api_key: str
deployment: str
api_version: str = "2024-02-01"
class EmbeddingConfig(BaseModel):
base_url: str
api_key: str
model: str
dimension: int
api_version: Optional[str]
class IndexConfig(BaseModel):
standard_regulation_index: str
chunk_index: str
chunk_user_manual_index: str
class RetrievalConfig(BaseModel):
endpoint: str
api_key: str
api_version: str
semantic_configuration: str
embedding: EmbeddingConfig
index: IndexConfig
class PostgreSQLConfig(BaseModel):
host: str
port: int = 5432
database: str
username: str
password: str
ttl_days: int = 7
class RedisConfig(BaseModel):
host: str
port: int = 6379
password: str
use_ssl: bool = True
db: int = 0
ttl_days: int = 7
class AppLoggingConfig(BaseModel):
level: str = "INFO"
class AppConfig(BaseModel):
name: str = "agentic-rag"
memory_ttl_days: int = 7
max_tool_rounds: int = 3 # Maximum allowed tool calling rounds
max_tool_rounds_user_manual: int = 3 # Maximum allowed tool calling rounds for user manual agent
cors_origins: list[str] = Field(default_factory=lambda: ["*"])
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
# Service configuration
host: str = "0.0.0.0"
port: int = 8000
class SearchConfig(BaseModel):
"""Search index configuration"""
standard_regulation_index: str = ""
chunk_index: str = ""
chunk_user_manual_index: str = ""
class CitationConfig(BaseModel):
"""Citation link configuration"""
base_url: str = "" # Default empty string
class LLMParametersConfig(BaseModel):
"""LLM parameters configuration"""
temperature: Optional[float] = None
max_context_length: int = 96000 # Maximum context length for conversation history (in tokens)
max_output_tokens: Optional[int] = None # Optional limit for LLM output tokens (None = no limit)
class LLMPromptsConfig(BaseModel):
"""LLM prompts configuration"""
agent_system_prompt: str
synthesis_system_prompt: Optional[str] = None
synthesis_user_prompt: Optional[str] = None
intent_recognition_prompt: Optional[str] = None
user_manual_prompt: Optional[str] = None
class LLMPromptConfig(BaseModel):
"""LLM prompt configuration from llm_prompt.yaml"""
parameters: LLMParametersConfig = Field(default_factory=LLMParametersConfig)
prompts: LLMPromptsConfig
class LLMRagConfig(BaseModel):
"""Legacy LLM RAG configuration for backward compatibility"""
temperature: Optional[float] = None
max_context_length: int = 96000 # Maximum context length for conversation history (in tokens)
max_output_tokens: Optional[int] = None # Optional limit for LLM output tokens (None = no limit)
# Legacy prompts for backward compatibility
system_prompt: Optional[str] = None
user_prompt: Optional[str] = None
# New autonomous agent prompts
agent_system_prompt: Optional[str] = None
synthesis_system_prompt: Optional[str] = None
synthesis_user_prompt: Optional[str] = None
class LLMConfig(BaseModel):
rag: LLMRagConfig
class LoggingConfig(BaseModel):
level: str = "INFO"
format: str = "json"
class Config(BaseSettings):
provider: str = "openai"
openai: Optional[OpenAIConfig] = None
azure: Optional[AzureConfig] = None
retrieval: RetrievalConfig
postgresql: PostgreSQLConfig
redis: Optional[RedisConfig] = None
app: AppConfig = Field(default_factory=AppConfig)
search: SearchConfig = Field(default_factory=SearchConfig)
citation: CitationConfig = Field(default_factory=CitationConfig)
llm: Optional[LLMConfig] = None
logging: LoggingConfig = Field(default_factory=LoggingConfig)
# New LLM prompt configuration
llm_prompt: Optional[LLMPromptConfig] = None
@classmethod
def from_yaml(cls, config_path: str = "config.yaml", llm_prompt_path: str = "llm_prompt.yaml") -> "Config":
"""Load configuration from YAML files with environment variable substitution"""
# Load main config
with open(config_path, 'r', encoding='utf-8') as f:
yaml_data = yaml.safe_load(f)
# Substitute environment variables
yaml_data = cls._substitute_env_vars(yaml_data)
# Load LLM prompt config if exists
llm_prompt_data = None
if os.path.exists(llm_prompt_path):
with open(llm_prompt_path, 'r', encoding='utf-8') as f:
llm_prompt_data = yaml.safe_load(f)
llm_prompt_data = cls._substitute_env_vars(llm_prompt_data)
yaml_data['llm_prompt'] = llm_prompt_data
return cls(**yaml_data)
@classmethod
def _substitute_env_vars(cls, data: Any) -> Any:
"""Recursively substitute ${VAR} and ${VAR:-default} patterns with environment variables"""
if isinstance(data, dict):
return {k: cls._substitute_env_vars(v) for k, v in data.items()}
elif isinstance(data, list):
return [cls._substitute_env_vars(item) for item in data]
elif isinstance(data, str):
# Handle ${VAR:-default} pattern
if data.startswith("${") and data.endswith("}"):
env_spec = data[2:-1]
if ":-" in env_spec:
var_name, default_value = env_spec.split(":-", 1)
return os.getenv(var_name, default_value)
else:
return os.getenv(env_spec, data) # Return original if env var not found
return data
else:
return data
def get_llm_config(self) -> Dict[str, Any]:
"""Get LLM configuration based on provider"""
base_config = {}
# Get temperature and max_output_tokens from llm_prompt config first, fallback to legacy llm.rag config
if self.llm_prompt and self.llm_prompt.parameters:
# Only add temperature if explicitly set (not None)
if self.llm_prompt.parameters.temperature is not None:
base_config["temperature"] = self.llm_prompt.parameters.temperature
# Only add max_output_tokens if explicitly set (not None)
if self.llm_prompt.parameters.max_output_tokens is not None:
base_config["max_tokens"] = self.llm_prompt.parameters.max_output_tokens
elif self.llm and self.llm.rag:
# Only add temperature if explicitly set (not None)
if hasattr(self.llm.rag, 'temperature') and self.llm.rag.temperature is not None:
base_config["temperature"] = self.llm.rag.temperature
# Only add max_output_tokens if explicitly set (not None)
if self.llm.rag.max_output_tokens is not None:
base_config["max_tokens"] = self.llm.rag.max_output_tokens
if self.provider == "openai" and self.openai:
return {
**base_config,
"provider": "openai",
"base_url": self.openai.base_url,
"api_key": self.openai.api_key,
"model": self.openai.model,
}
elif self.provider == "azure" and self.azure:
return {
**base_config,
"provider": "azure",
"base_url": self.azure.base_url,
"api_key": self.azure.api_key,
"deployment": self.azure.deployment,
"api_version": self.azure.api_version,
}
else:
raise ValueError(f"Invalid provider '{self.provider}' or missing configuration")
def get_rag_prompts(self) -> Dict[str, str]:
"""Get RAG prompts configuration - prioritize llm_prompt.yaml over legacy config"""
# Use new llm_prompt config if available
if self.llm_prompt and self.llm_prompt.prompts:
return {
"system_prompt": self.llm_prompt.prompts.agent_system_prompt,
"user_prompt": "{{user_query}}", # Default template
"agent_system_prompt": self.llm_prompt.prompts.agent_system_prompt,
"synthesis_system_prompt": self.llm_prompt.prompts.synthesis_system_prompt or "You are a helpful assistant.",
"synthesis_user_prompt": self.llm_prompt.prompts.synthesis_user_prompt or "{{user_query}}",
"intent_recognition_prompt": self.llm_prompt.prompts.intent_recognition_prompt or "",
"user_manual_prompt": self.llm_prompt.prompts.user_manual_prompt or "",
}
# Fallback to legacy llm.rag config
if self.llm and self.llm.rag:
return {
"system_prompt": self.llm.rag.system_prompt or "You are a helpful assistant.",
"user_prompt": self.llm.rag.user_prompt or "{{user_query}}",
"agent_system_prompt": self.llm.rag.agent_system_prompt or "You are a helpful assistant.",
"synthesis_system_prompt": self.llm.rag.synthesis_system_prompt or "You are a helpful assistant.",
"synthesis_user_prompt": self.llm.rag.synthesis_user_prompt or "{{user_query}}",
"intent_recognition_prompt": "",
"user_manual_prompt": "",
}
# Default fallback
return {
"system_prompt": "You are a helpful assistant.",
"user_prompt": "{{user_query}}",
"agent_system_prompt": "You are a helpful assistant.",
"synthesis_system_prompt": "You are a helpful assistant.",
"synthesis_user_prompt": "{{user_query}}",
"intent_recognition_prompt": "",
"user_manual_prompt": "",
}
def get_max_context_length(self) -> int:
"""Get maximum context length for conversation history"""
# Use new llm_prompt config if available
if self.llm_prompt and self.llm_prompt.parameters:
return self.llm_prompt.parameters.max_context_length
# Fallback to legacy llm.rag config
if self.llm and self.llm.rag:
return self.llm.rag.max_context_length
# Default fallback
return 96000
# Global config instance
config: Optional[Config] = None
def load_config(config_path: str = "config.yaml", llm_prompt_path: str = "llm_prompt.yaml") -> Config:
"""Load and return the global configuration"""
global config
config = Config.from_yaml(config_path, llm_prompt_path)
return config
def get_config() -> Config:
"""Get the current configuration instance"""
if config is None:
raise RuntimeError("Configuration not loaded. Call load_config() first.")
return config