init
This commit is contained in:
297
vw-agentic-rag/service/config.py
Normal file
297
vw-agentic-rag/service/config.py
Normal file
@@ -0,0 +1,297 @@
|
||||
import yaml
|
||||
import os
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class OpenAIConfig(BaseModel):
|
||||
base_url: str = "https://api.openai.com/v1"
|
||||
api_key: str
|
||||
model: str = "gpt-4o"
|
||||
|
||||
|
||||
class AzureConfig(BaseModel):
|
||||
base_url: str
|
||||
api_key: str
|
||||
deployment: str
|
||||
api_version: str = "2024-02-01"
|
||||
|
||||
|
||||
class EmbeddingConfig(BaseModel):
|
||||
base_url: str
|
||||
api_key: str
|
||||
model: str
|
||||
dimension: int
|
||||
api_version: Optional[str]
|
||||
|
||||
|
||||
class IndexConfig(BaseModel):
|
||||
standard_regulation_index: str
|
||||
chunk_index: str
|
||||
chunk_user_manual_index: str
|
||||
|
||||
|
||||
class RetrievalConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: str
|
||||
api_version: str
|
||||
semantic_configuration: str
|
||||
embedding: EmbeddingConfig
|
||||
index: IndexConfig
|
||||
|
||||
|
||||
class PostgreSQLConfig(BaseModel):
|
||||
host: str
|
||||
port: int = 5432
|
||||
database: str
|
||||
username: str
|
||||
password: str
|
||||
ttl_days: int = 7
|
||||
|
||||
|
||||
class RedisConfig(BaseModel):
|
||||
host: str
|
||||
port: int = 6379
|
||||
password: str
|
||||
use_ssl: bool = True
|
||||
db: int = 0
|
||||
ttl_days: int = 7
|
||||
|
||||
|
||||
class AppLoggingConfig(BaseModel):
|
||||
level: str = "INFO"
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
name: str = "agentic-rag"
|
||||
memory_ttl_days: int = 7
|
||||
max_tool_rounds: int = 3 # Maximum allowed tool calling rounds
|
||||
max_tool_rounds_user_manual: int = 3 # Maximum allowed tool calling rounds for user manual agent
|
||||
cors_origins: list[str] = Field(default_factory=lambda: ["*"])
|
||||
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
|
||||
# Service configuration
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
|
||||
class SearchConfig(BaseModel):
|
||||
"""Search index configuration"""
|
||||
standard_regulation_index: str = ""
|
||||
chunk_index: str = ""
|
||||
chunk_user_manual_index: str = ""
|
||||
|
||||
|
||||
class CitationConfig(BaseModel):
|
||||
"""Citation link configuration"""
|
||||
base_url: str = "" # Default empty string
|
||||
|
||||
|
||||
class LLMParametersConfig(BaseModel):
|
||||
"""LLM parameters configuration"""
|
||||
temperature: Optional[float] = None
|
||||
max_context_length: int = 96000 # Maximum context length for conversation history (in tokens)
|
||||
max_output_tokens: Optional[int] = None # Optional limit for LLM output tokens (None = no limit)
|
||||
|
||||
|
||||
class LLMPromptsConfig(BaseModel):
|
||||
"""LLM prompts configuration"""
|
||||
agent_system_prompt: str
|
||||
synthesis_system_prompt: Optional[str] = None
|
||||
synthesis_user_prompt: Optional[str] = None
|
||||
intent_recognition_prompt: Optional[str] = None
|
||||
user_manual_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class LLMPromptConfig(BaseModel):
|
||||
"""LLM prompt configuration from llm_prompt.yaml"""
|
||||
parameters: LLMParametersConfig = Field(default_factory=LLMParametersConfig)
|
||||
prompts: LLMPromptsConfig
|
||||
|
||||
|
||||
class LLMRagConfig(BaseModel):
|
||||
"""Legacy LLM RAG configuration for backward compatibility"""
|
||||
temperature: Optional[float] = None
|
||||
max_context_length: int = 96000 # Maximum context length for conversation history (in tokens)
|
||||
max_output_tokens: Optional[int] = None # Optional limit for LLM output tokens (None = no limit)
|
||||
# Legacy prompts for backward compatibility
|
||||
system_prompt: Optional[str] = None
|
||||
user_prompt: Optional[str] = None
|
||||
# New autonomous agent prompts
|
||||
agent_system_prompt: Optional[str] = None
|
||||
synthesis_system_prompt: Optional[str] = None
|
||||
synthesis_user_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
rag: LLMRagConfig
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
level: str = "INFO"
|
||||
format: str = "json"
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
provider: str = "openai"
|
||||
openai: Optional[OpenAIConfig] = None
|
||||
azure: Optional[AzureConfig] = None
|
||||
retrieval: RetrievalConfig
|
||||
postgresql: PostgreSQLConfig
|
||||
redis: Optional[RedisConfig] = None
|
||||
app: AppConfig = Field(default_factory=AppConfig)
|
||||
search: SearchConfig = Field(default_factory=SearchConfig)
|
||||
citation: CitationConfig = Field(default_factory=CitationConfig)
|
||||
llm: Optional[LLMConfig] = None
|
||||
logging: LoggingConfig = Field(default_factory=LoggingConfig)
|
||||
|
||||
# New LLM prompt configuration
|
||||
llm_prompt: Optional[LLMPromptConfig] = None
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, config_path: str = "config.yaml", llm_prompt_path: str = "llm_prompt.yaml") -> "Config":
|
||||
"""Load configuration from YAML files with environment variable substitution"""
|
||||
# Load main config
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
yaml_data = yaml.safe_load(f)
|
||||
|
||||
# Substitute environment variables
|
||||
yaml_data = cls._substitute_env_vars(yaml_data)
|
||||
|
||||
# Load LLM prompt config if exists
|
||||
llm_prompt_data = None
|
||||
if os.path.exists(llm_prompt_path):
|
||||
with open(llm_prompt_path, 'r', encoding='utf-8') as f:
|
||||
llm_prompt_data = yaml.safe_load(f)
|
||||
llm_prompt_data = cls._substitute_env_vars(llm_prompt_data)
|
||||
yaml_data['llm_prompt'] = llm_prompt_data
|
||||
|
||||
return cls(**yaml_data)
|
||||
|
||||
@classmethod
|
||||
def _substitute_env_vars(cls, data: Any) -> Any:
|
||||
"""Recursively substitute ${VAR} and ${VAR:-default} patterns with environment variables"""
|
||||
if isinstance(data, dict):
|
||||
return {k: cls._substitute_env_vars(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [cls._substitute_env_vars(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
# Handle ${VAR:-default} pattern
|
||||
if data.startswith("${") and data.endswith("}"):
|
||||
env_spec = data[2:-1]
|
||||
if ":-" in env_spec:
|
||||
var_name, default_value = env_spec.split(":-", 1)
|
||||
return os.getenv(var_name, default_value)
|
||||
else:
|
||||
return os.getenv(env_spec, data) # Return original if env var not found
|
||||
return data
|
||||
else:
|
||||
return data
|
||||
|
||||
def get_llm_config(self) -> Dict[str, Any]:
|
||||
"""Get LLM configuration based on provider"""
|
||||
base_config = {}
|
||||
|
||||
# Get temperature and max_output_tokens from llm_prompt config first, fallback to legacy llm.rag config
|
||||
if self.llm_prompt and self.llm_prompt.parameters:
|
||||
# Only add temperature if explicitly set (not None)
|
||||
if self.llm_prompt.parameters.temperature is not None:
|
||||
base_config["temperature"] = self.llm_prompt.parameters.temperature
|
||||
# Only add max_output_tokens if explicitly set (not None)
|
||||
if self.llm_prompt.parameters.max_output_tokens is not None:
|
||||
base_config["max_tokens"] = self.llm_prompt.parameters.max_output_tokens
|
||||
elif self.llm and self.llm.rag:
|
||||
# Only add temperature if explicitly set (not None)
|
||||
if hasattr(self.llm.rag, 'temperature') and self.llm.rag.temperature is not None:
|
||||
base_config["temperature"] = self.llm.rag.temperature
|
||||
# Only add max_output_tokens if explicitly set (not None)
|
||||
if self.llm.rag.max_output_tokens is not None:
|
||||
base_config["max_tokens"] = self.llm.rag.max_output_tokens
|
||||
|
||||
if self.provider == "openai" and self.openai:
|
||||
return {
|
||||
**base_config,
|
||||
"provider": "openai",
|
||||
"base_url": self.openai.base_url,
|
||||
"api_key": self.openai.api_key,
|
||||
"model": self.openai.model,
|
||||
}
|
||||
elif self.provider == "azure" and self.azure:
|
||||
return {
|
||||
**base_config,
|
||||
"provider": "azure",
|
||||
"base_url": self.azure.base_url,
|
||||
"api_key": self.azure.api_key,
|
||||
"deployment": self.azure.deployment,
|
||||
"api_version": self.azure.api_version,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Invalid provider '{self.provider}' or missing configuration")
|
||||
|
||||
def get_rag_prompts(self) -> Dict[str, str]:
|
||||
"""Get RAG prompts configuration - prioritize llm_prompt.yaml over legacy config"""
|
||||
# Use new llm_prompt config if available
|
||||
if self.llm_prompt and self.llm_prompt.prompts:
|
||||
return {
|
||||
"system_prompt": self.llm_prompt.prompts.agent_system_prompt,
|
||||
"user_prompt": "{{user_query}}", # Default template
|
||||
"agent_system_prompt": self.llm_prompt.prompts.agent_system_prompt,
|
||||
"synthesis_system_prompt": self.llm_prompt.prompts.synthesis_system_prompt or "You are a helpful assistant.",
|
||||
"synthesis_user_prompt": self.llm_prompt.prompts.synthesis_user_prompt or "{{user_query}}",
|
||||
"intent_recognition_prompt": self.llm_prompt.prompts.intent_recognition_prompt or "",
|
||||
"user_manual_prompt": self.llm_prompt.prompts.user_manual_prompt or "",
|
||||
}
|
||||
|
||||
# Fallback to legacy llm.rag config
|
||||
if self.llm and self.llm.rag:
|
||||
return {
|
||||
"system_prompt": self.llm.rag.system_prompt or "You are a helpful assistant.",
|
||||
"user_prompt": self.llm.rag.user_prompt or "{{user_query}}",
|
||||
"agent_system_prompt": self.llm.rag.agent_system_prompt or "You are a helpful assistant.",
|
||||
"synthesis_system_prompt": self.llm.rag.synthesis_system_prompt or "You are a helpful assistant.",
|
||||
"synthesis_user_prompt": self.llm.rag.synthesis_user_prompt or "{{user_query}}",
|
||||
"intent_recognition_prompt": "",
|
||||
"user_manual_prompt": "",
|
||||
}
|
||||
|
||||
# Default fallback
|
||||
return {
|
||||
"system_prompt": "You are a helpful assistant.",
|
||||
"user_prompt": "{{user_query}}",
|
||||
"agent_system_prompt": "You are a helpful assistant.",
|
||||
"synthesis_system_prompt": "You are a helpful assistant.",
|
||||
"synthesis_user_prompt": "{{user_query}}",
|
||||
"intent_recognition_prompt": "",
|
||||
"user_manual_prompt": "",
|
||||
}
|
||||
|
||||
def get_max_context_length(self) -> int:
|
||||
"""Get maximum context length for conversation history"""
|
||||
# Use new llm_prompt config if available
|
||||
if self.llm_prompt and self.llm_prompt.parameters:
|
||||
return self.llm_prompt.parameters.max_context_length
|
||||
|
||||
# Fallback to legacy llm.rag config
|
||||
if self.llm and self.llm.rag:
|
||||
return self.llm.rag.max_context_length
|
||||
|
||||
# Default fallback
|
||||
return 96000
|
||||
|
||||
|
||||
# Global config instance
|
||||
config: Optional[Config] = None
|
||||
|
||||
|
||||
def load_config(config_path: str = "config.yaml", llm_prompt_path: str = "llm_prompt.yaml") -> Config:
|
||||
"""Load and return the global configuration"""
|
||||
global config
|
||||
config = Config.from_yaml(config_path, llm_prompt_path)
|
||||
return config
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
"""Get the current configuration instance"""
|
||||
if config is None:
|
||||
raise RuntimeError("Configuration not loaded. Call load_config() first.")
|
||||
return config
|
||||
Reference in New Issue
Block a user