[feature]Implement session cache for Doris connections (#44)
* [feature]Implement session cache for Doris connections This PR introduces a `DorisSessionCache` to cache and reuse `DorisConnection` objects in memory. This helps to reduce the overhead of creating new connections, especially for frequently used system sessions like "query" and "system", and avoid not calling release_connection leads to `Connection acquisition timed out` when the number of connection pools reaches the maximum value. The PR #34 fixed the issue when calling the tool `exec_query`, but in the codebase, a large number of other tools directly using get_connection ("query") to get connection object but without calling the release_connection method will cause the connection to fail to be obtained after a certain number of times. Key changes: - Added `DorisSessionCache` class to manage the lifecycle of cached sessions. - The cache is configurable to store system sessions, user sessions, or both. By default, only system sessions are cached. - Integrated the session cache into `DorisConnectionManager`. - `get_connection` now checks the cache before creating a new connection. - `release_connection` removes the connection from the cache. * Add tests
This commit is contained in:
@@ -516,7 +516,7 @@ class SQLAnalyzer:
|
|||||||
try:
|
try:
|
||||||
# Switch to specified database/catalog if provided
|
# Switch to specified database/catalog if provided
|
||||||
if catalog_name:
|
if catalog_name:
|
||||||
await connection.execute(f"USE `{catalog_name}`")
|
await connection.execute(f"SWITCH `{catalog_name}`")
|
||||||
if db_name:
|
if db_name:
|
||||||
await connection.execute(f"USE `{db_name}`")
|
await connection.execute(f"USE `{db_name}`")
|
||||||
|
|
||||||
|
|||||||
@@ -28,8 +28,7 @@ import time
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, Optional
|
||||||
import random
|
|
||||||
|
|
||||||
import aiomysql
|
import aiomysql
|
||||||
from aiomysql import Connection, Pool
|
from aiomysql import Connection, Pool
|
||||||
@@ -191,6 +190,51 @@ class DorisConnection:
|
|||||||
logging.error(f"Error occurred while closing connection: {e}")
|
logging.error(f"Error occurred while closing connection: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class DorisSessionCache:
|
||||||
|
"""Doris database session cache
|
||||||
|
|
||||||
|
Save doris session in memory and get session by session id.
|
||||||
|
Provide cache_system_session/cache_user_session to specify whether to save system/user type sessions.
|
||||||
|
By default, only session_id is "query" or "system" will be saved.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, connection_manager=None, cache_system_session=True, cache_user_session=False):
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
self.cached = {}
|
||||||
|
self.connection_manager = connection_manager
|
||||||
|
self.cache_system_session = cache_system_session
|
||||||
|
self.cache_user_session = cache_user_session
|
||||||
|
self.logger.info(f"Session Cache initialized, save system session: {self.cache_system_session}, save user session: {self.cache_user_session}")
|
||||||
|
|
||||||
|
def save(self, connection: DorisConnection):
|
||||||
|
if self._should_cache(connection.session_id):
|
||||||
|
self.cached[connection.session_id] = connection
|
||||||
|
|
||||||
|
def get(self, session_id: str) -> Optional[DorisConnection]:
|
||||||
|
self.logger.debug(f"Use cached connection: {session_id}")
|
||||||
|
return self.cached.get(session_id)
|
||||||
|
|
||||||
|
def remove(self, session_id):
|
||||||
|
if session_id in self.cached:
|
||||||
|
del self.cached[session_id]
|
||||||
|
self.logger.debug(f"Removed session {session_id} from cache.")
|
||||||
|
else:
|
||||||
|
if self._should_cache(session_id):
|
||||||
|
self.logger.warning(f"Session {session_id} is not existed.")
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
if self.connection_manager:
|
||||||
|
for k, v in self.cached.items():
|
||||||
|
self.connection_manager.release_connection(k, v)
|
||||||
|
self.cached = {}
|
||||||
|
|
||||||
|
def _is_system_session(self, session_id) -> bool:
|
||||||
|
return session_id in ["query", "system"]
|
||||||
|
|
||||||
|
def _should_cache(self, session_id):
|
||||||
|
return (self.cache_system_session and self._is_system_session(session_id)) or (self.cache_user_session and not self._is_system_session(session_id))
|
||||||
|
|
||||||
|
|
||||||
class DorisConnectionManager:
|
class DorisConnectionManager:
|
||||||
"""Doris database connection manager - Enhanced Strategy
|
"""Doris database connection manager - Enhanced Strategy
|
||||||
|
|
||||||
@@ -198,11 +242,13 @@ class DorisConnectionManager:
|
|||||||
Implements connection pool health monitoring and proactive cleanup
|
Implements connection pool health monitoring and proactive cleanup
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, config, security_manager=None):
|
def __init__(self, config, security_manager=None):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.pool: Pool | None = None
|
self.pool: Pool | None = None
|
||||||
self.logger = get_logger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
self.security_manager = security_manager
|
self.security_manager = security_manager
|
||||||
|
self.session_cache = DorisSessionCache(self)
|
||||||
|
|
||||||
# Connection pool state management
|
# Connection pool state management
|
||||||
self.pool_recovering = False
|
self.pool_recovering = False
|
||||||
@@ -521,8 +567,13 @@ class DorisConnectionManager:
|
|||||||
async def get_connection(self, session_id: str) -> DorisConnection:
|
async def get_connection(self, session_id: str) -> DorisConnection:
|
||||||
"""🔧 FIX: Simplified connection acquisition without double locking
|
"""🔧 FIX: Simplified connection acquisition without double locking
|
||||||
|
|
||||||
Uses only semaphore to prevent too many concurrent acquisitions
|
Uses only semaphore to prevent too many concurrent acquisitions.
|
||||||
|
If the connection is successfully obtained, it will be added to the connection pool cache.
|
||||||
"""
|
"""
|
||||||
|
cached_conn = self.session_cache.get(session_id)
|
||||||
|
if cached_conn:
|
||||||
|
return cached_conn
|
||||||
|
|
||||||
# 🔧 FIX: Use only semaphore to limit concurrent acquisitions (remove double locking)
|
# 🔧 FIX: Use only semaphore to limit concurrent acquisitions (remove double locking)
|
||||||
async with self._connection_semaphore:
|
async with self._connection_semaphore:
|
||||||
try:
|
try:
|
||||||
@@ -582,6 +633,8 @@ class DorisConnectionManager:
|
|||||||
raise RuntimeError("Acquired connection is already closed")
|
raise RuntimeError("Acquired connection is already closed")
|
||||||
|
|
||||||
self.logger.debug(f"✅ Acquired fresh connection for session {session_id}")
|
self.logger.debug(f"✅ Acquired fresh connection for session {session_id}")
|
||||||
|
|
||||||
|
self.session_cache.save(doris_conn)
|
||||||
return doris_conn
|
return doris_conn
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -590,6 +643,13 @@ class DorisConnectionManager:
|
|||||||
|
|
||||||
async def release_connection(self, session_id: str, connection: DorisConnection):
|
async def release_connection(self, session_id: str, connection: DorisConnection):
|
||||||
"""🔧 FIX: Release connection back to pool with proper error handling"""
|
"""🔧 FIX: Release connection back to pool with proper error handling"""
|
||||||
|
cached_conn = self.session_cache.get(session_id)
|
||||||
|
if cached_conn:
|
||||||
|
self.session_cache.remove(session_id)
|
||||||
|
if not (cached_conn is connection):
|
||||||
|
self.logger.warning("Invalid connection")
|
||||||
|
connection = cached_conn
|
||||||
|
|
||||||
if not connection or not connection.connection:
|
if not connection or not connection.connection:
|
||||||
self.logger.debug(f"No connection to release for session {session_id}")
|
self.logger.debug(f"No connection to release for session {session_id}")
|
||||||
return
|
return
|
||||||
|
|||||||
78
test/utils/test_db.py
Normal file
78
test/utils/test_db.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from doris_mcp_server.utils.db import DorisConnection, DorisSessionCache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session_cache():
|
||||||
|
"""Provides a DorisSessionCache instance with a mock connection manager."""
|
||||||
|
connection_manager = MagicMock()
|
||||||
|
cache = DorisSessionCache(connection_manager=connection_manager)
|
||||||
|
yield cache, connection_manager
|
||||||
|
|
||||||
|
|
||||||
|
class TestDorisSessionCache:
|
||||||
|
|
||||||
|
def test_initialization(self, session_cache):
|
||||||
|
cache, _ = session_cache
|
||||||
|
assert cache.cache_system_session is True
|
||||||
|
assert cache.cache_user_session is False
|
||||||
|
assert not cache.cached
|
||||||
|
|
||||||
|
def test_should_cache(self, session_cache):
|
||||||
|
cache, _ = session_cache
|
||||||
|
assert cache._should_cache("query") is True
|
||||||
|
assert cache._should_cache("system") is True
|
||||||
|
assert cache._should_cache("user-test-session-id") is False
|
||||||
|
|
||||||
|
cache.cache_user_session = True
|
||||||
|
assert cache._should_cache("user-test-session-id") is True
|
||||||
|
|
||||||
|
def test_save_and_get_session(self, session_cache):
|
||||||
|
cache, _ = session_cache
|
||||||
|
mock_connection = MagicMock(spec=DorisConnection)
|
||||||
|
mock_connection.session_id = "query"
|
||||||
|
|
||||||
|
cache.save(mock_connection)
|
||||||
|
retrieved_conn = cache.get("query")
|
||||||
|
assert retrieved_conn is mock_connection
|
||||||
|
|
||||||
|
mock_user_connection = MagicMock(spec=DorisConnection)
|
||||||
|
mock_user_connection.session_id = "user-test-session-id"
|
||||||
|
cache.save(mock_user_connection)
|
||||||
|
assert cache.get("user-test-session-id") is None
|
||||||
|
|
||||||
|
cache.cache_user_session = True
|
||||||
|
cache.save(mock_user_connection)
|
||||||
|
retrieved_user_conn = cache.get("user-test-session-id")
|
||||||
|
assert retrieved_user_conn is mock_user_connection
|
||||||
|
|
||||||
|
def test_remove_session(self, session_cache):
|
||||||
|
cache, _ = session_cache
|
||||||
|
mock_connection = MagicMock(spec=DorisConnection)
|
||||||
|
mock_connection.session_id = "system"
|
||||||
|
|
||||||
|
cache.save(mock_connection)
|
||||||
|
assert cache.get("system") is not None
|
||||||
|
|
||||||
|
cache.remove("system")
|
||||||
|
assert cache.get("system") is None
|
||||||
|
|
||||||
|
def test_clear_cache(self, session_cache):
|
||||||
|
cache, connection_manager = session_cache
|
||||||
|
mock_conn1 = MagicMock(spec=DorisConnection)
|
||||||
|
mock_conn1.session_id = "query"
|
||||||
|
mock_conn2 = MagicMock(spec=DorisConnection)
|
||||||
|
mock_conn2.session_id = "system"
|
||||||
|
|
||||||
|
cache.save(mock_conn1)
|
||||||
|
cache.save(mock_conn2)
|
||||||
|
assert len(cache.cached) == 2
|
||||||
|
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
assert not cache.cached
|
||||||
|
connection_manager.release_connection.assert_any_call("query", mock_conn1)
|
||||||
|
connection_manager.release_connection.assert_any_call("system", mock_conn2)
|
||||||
|
assert connection_manager.release_connection.call_count == 2
|
||||||
Reference in New Issue
Block a user