v0.21.1-fastapi
This commit is contained in:
@@ -313,9 +313,75 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
||||
raise
|
||||
|
||||
|
||||
class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.max_retries = kwargs.pop("max_retries", 5)
|
||||
self.retry_delay = kwargs.pop("retry_delay", 1)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def execute_sql(self, sql, params=None, commit=True):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return super().execute_sql(sql, params, commit)
|
||||
except (OperationalError, InterfaceError) as e:
|
||||
# PostgreSQL specific error codes
|
||||
# 57P01: admin_shutdown
|
||||
# 57P02: crash_shutdown
|
||||
# 57P03: cannot_connect_now
|
||||
# 08006: connection_failure
|
||||
# 08003: connection_does_not_exist
|
||||
# 08000: connection_exception
|
||||
error_messages = ['connection', 'server closed', 'connection refused',
|
||||
'no connection to the server', 'terminating connection']
|
||||
|
||||
should_retry = any(msg in str(e).lower() for msg in error_messages)
|
||||
|
||||
if should_retry and attempt < self.max_retries:
|
||||
logging.warning(
|
||||
f"PostgreSQL connection issue (attempt {attempt+1}/{self.max_retries}): {e}"
|
||||
)
|
||||
self._handle_connection_loss()
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
logging.error(f"PostgreSQL execution failure: {e}")
|
||||
raise
|
||||
return None
|
||||
|
||||
def _handle_connection_loss(self):
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self.connect()
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to reconnect to PostgreSQL: {e}")
|
||||
time.sleep(0.1)
|
||||
self.connect()
|
||||
|
||||
def begin(self):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return super().begin()
|
||||
except (OperationalError, InterfaceError) as e:
|
||||
error_messages = ['connection', 'server closed', 'connection refused',
|
||||
'no connection to the server', 'terminating connection']
|
||||
|
||||
should_retry = any(msg in str(e).lower() for msg in error_messages)
|
||||
|
||||
if should_retry and attempt < self.max_retries:
|
||||
logging.warning(
|
||||
f"PostgreSQL connection lost during transaction (attempt {attempt+1}/{self.max_retries})"
|
||||
)
|
||||
self._handle_connection_loss()
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
class PooledDatabase(Enum):
|
||||
MYSQL = RetryingPooledMySQLDatabase
|
||||
POSTGRES = PooledPostgresqlDatabase
|
||||
POSTGRES = RetryingPooledPostgresqlDatabase
|
||||
|
||||
|
||||
class DatabaseMigrator(Enum):
|
||||
@@ -329,12 +395,11 @@ class BaseDataBase:
|
||||
database_config = settings.DATABASE.copy()
|
||||
db_name = database_config.pop("name")
|
||||
|
||||
# pool_config = {
|
||||
# 'max_retries': 5,
|
||||
# 'retry_delay': 1,
|
||||
# }
|
||||
# database_config.update(pool_config)
|
||||
|
||||
pool_config = {
|
||||
'max_retries': 5,
|
||||
'retry_delay': 1,
|
||||
}
|
||||
database_config.update(pool_config)
|
||||
self.database_connection = PooledDatabase[settings.DATABASE_TYPE.upper()].value(
|
||||
db_name, **database_config
|
||||
)
|
||||
@@ -642,7 +707,7 @@ class TenantLLM(DataBaseModel):
|
||||
llm_factory = CharField(max_length=128, null=False, help_text="LLM factory name", index=True)
|
||||
model_type = CharField(max_length=128, null=True, help_text="LLM, Text Embedding, Image2Text, ASR", index=True)
|
||||
llm_name = CharField(max_length=128, null=True, help_text="LLM name", default="", index=True)
|
||||
api_key = CharField(max_length=2048, null=True, help_text="API KEY", index=True)
|
||||
api_key = TextField(null=True, help_text="API KEY")
|
||||
api_base = CharField(max_length=255, null=True, help_text="API Base")
|
||||
max_tokens = IntegerField(default=8192, index=True)
|
||||
used_tokens = IntegerField(default=0, index=True)
|
||||
@@ -1143,4 +1208,8 @@ def migrate_db():
|
||||
migrate(migrator.add_column("knowledgebase", "mindmap_task_finish_at", CharField(null=True)))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(migrator.alter_column_type("tenant_llm", "api_key", TextField(null=True, help_text="API KEY")))
|
||||
except Exception:
|
||||
pass
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
Reference in New Issue
Block a user