[feature]Implement session cache for Doris connections (#44)
* [feature]Implement session cache for Doris connections This PR introduces a `DorisSessionCache` to cache and reuse `DorisConnection` objects in memory. This helps to reduce the overhead of creating new connections, especially for frequently used system sessions like "query" and "system", and avoid not calling release_connection leads to `Connection acquisition timed out` when the number of connection pools reaches the maximum value. The PR #34 fixed the issue when calling the tool `exec_query`, but in the codebase, a large number of other tools directly using get_connection ("query") to get connection object but without calling the release_connection method will cause the connection to fail to be obtained after a certain number of times. Key changes: - Added `DorisSessionCache` class to manage the lifecycle of cached sessions. - The cache is configurable to store system sessions, user sessions, or both. By default, only system sessions are cached. - Integrated the session cache into `DorisConnectionManager`. - `get_connection` now checks the cache before creating a new connection. - `release_connection` removes the connection from the cache. * Add tests
This commit is contained in:
@@ -516,7 +516,7 @@ class SQLAnalyzer:
|
||||
try:
|
||||
# Switch to specified database/catalog if provided
|
||||
if catalog_name:
|
||||
await connection.execute(f"USE `{catalog_name}`")
|
||||
await connection.execute(f"SWITCH `{catalog_name}`")
|
||||
if db_name:
|
||||
await connection.execute(f"USE `{db_name}`")
|
||||
|
||||
@@ -1237,4 +1237,4 @@ class MemoryTracker:
|
||||
"tracker_names": tracker_names,
|
||||
"time_range": time_range,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,8 +28,7 @@ import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
import random
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import aiomysql
|
||||
from aiomysql import Connection, Pool
|
||||
@@ -191,6 +190,51 @@ class DorisConnection:
|
||||
logging.error(f"Error occurred while closing connection: {e}")
|
||||
|
||||
|
||||
class DorisSessionCache:
|
||||
"""Doris database session cache
|
||||
|
||||
Save doris session in memory and get session by session id.
|
||||
Provide cache_system_session/cache_user_session to specify whether to save system/user type sessions.
|
||||
By default, only session_id is "query" or "system" will be saved.
|
||||
"""
|
||||
|
||||
def __init__(self, connection_manager=None, cache_system_session=True, cache_user_session=False):
|
||||
self.logger = get_logger(__name__)
|
||||
self.cached = {}
|
||||
self.connection_manager = connection_manager
|
||||
self.cache_system_session = cache_system_session
|
||||
self.cache_user_session = cache_user_session
|
||||
self.logger.info(f"Session Cache initialized, save system session: {self.cache_system_session}, save user session: {self.cache_user_session}")
|
||||
|
||||
def save(self, connection: DorisConnection):
|
||||
if self._should_cache(connection.session_id):
|
||||
self.cached[connection.session_id] = connection
|
||||
|
||||
def get(self, session_id: str) -> Optional[DorisConnection]:
|
||||
self.logger.debug(f"Use cached connection: {session_id}")
|
||||
return self.cached.get(session_id)
|
||||
|
||||
def remove(self, session_id):
|
||||
if session_id in self.cached:
|
||||
del self.cached[session_id]
|
||||
self.logger.debug(f"Removed session {session_id} from cache.")
|
||||
else:
|
||||
if self._should_cache(session_id):
|
||||
self.logger.warning(f"Session {session_id} is not existed.")
|
||||
|
||||
def clear(self):
|
||||
if self.connection_manager:
|
||||
for k, v in self.cached.items():
|
||||
self.connection_manager.release_connection(k, v)
|
||||
self.cached = {}
|
||||
|
||||
def _is_system_session(self, session_id) -> bool:
|
||||
return session_id in ["query", "system"]
|
||||
|
||||
def _should_cache(self, session_id):
|
||||
return (self.cache_system_session and self._is_system_session(session_id)) or (self.cache_user_session and not self._is_system_session(session_id))
|
||||
|
||||
|
||||
class DorisConnectionManager:
|
||||
"""Doris database connection manager - Enhanced Strategy
|
||||
|
||||
@@ -198,11 +242,13 @@ class DorisConnectionManager:
|
||||
Implements connection pool health monitoring and proactive cleanup
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, config, security_manager=None):
|
||||
self.config = config
|
||||
self.pool: Pool | None = None
|
||||
self.logger = get_logger(__name__)
|
||||
self.security_manager = security_manager
|
||||
self.session_cache = DorisSessionCache(self)
|
||||
|
||||
# Connection pool state management
|
||||
self.pool_recovering = False
|
||||
@@ -521,8 +567,13 @@ class DorisConnectionManager:
|
||||
async def get_connection(self, session_id: str) -> DorisConnection:
|
||||
"""🔧 FIX: Simplified connection acquisition without double locking
|
||||
|
||||
Uses only semaphore to prevent too many concurrent acquisitions
|
||||
Uses only semaphore to prevent too many concurrent acquisitions.
|
||||
If the connection is successfully obtained, it will be added to the connection pool cache.
|
||||
"""
|
||||
cached_conn = self.session_cache.get(session_id)
|
||||
if cached_conn:
|
||||
return cached_conn
|
||||
|
||||
# 🔧 FIX: Use only semaphore to limit concurrent acquisitions (remove double locking)
|
||||
async with self._connection_semaphore:
|
||||
try:
|
||||
@@ -582,6 +633,8 @@ class DorisConnectionManager:
|
||||
raise RuntimeError("Acquired connection is already closed")
|
||||
|
||||
self.logger.debug(f"✅ Acquired fresh connection for session {session_id}")
|
||||
|
||||
self.session_cache.save(doris_conn)
|
||||
return doris_conn
|
||||
|
||||
except Exception as e:
|
||||
@@ -590,6 +643,13 @@ class DorisConnectionManager:
|
||||
|
||||
async def release_connection(self, session_id: str, connection: DorisConnection):
|
||||
"""🔧 FIX: Release connection back to pool with proper error handling"""
|
||||
cached_conn = self.session_cache.get(session_id)
|
||||
if cached_conn:
|
||||
self.session_cache.remove(session_id)
|
||||
if not (cached_conn is connection):
|
||||
self.logger.warning("Invalid connection")
|
||||
connection = cached_conn
|
||||
|
||||
if not connection or not connection.connection:
|
||||
self.logger.debug(f"No connection to release for session {session_id}")
|
||||
return
|
||||
@@ -811,4 +871,4 @@ class ConnectionPoolMonitor:
|
||||
if pool_status["free_connections"] == 0:
|
||||
report["recommendations"].append("No free connections available, consider increasing pool size")
|
||||
|
||||
return report
|
||||
return report
|
||||
|
||||
78
test/utils/test_db.py
Normal file
78
test/utils/test_db.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.db import DorisConnection, DorisSessionCache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_cache():
|
||||
"""Provides a DorisSessionCache instance with a mock connection manager."""
|
||||
connection_manager = MagicMock()
|
||||
cache = DorisSessionCache(connection_manager=connection_manager)
|
||||
yield cache, connection_manager
|
||||
|
||||
|
||||
class TestDorisSessionCache:
|
||||
|
||||
def test_initialization(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
assert cache.cache_system_session is True
|
||||
assert cache.cache_user_session is False
|
||||
assert not cache.cached
|
||||
|
||||
def test_should_cache(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
assert cache._should_cache("query") is True
|
||||
assert cache._should_cache("system") is True
|
||||
assert cache._should_cache("user-test-session-id") is False
|
||||
|
||||
cache.cache_user_session = True
|
||||
assert cache._should_cache("user-test-session-id") is True
|
||||
|
||||
def test_save_and_get_session(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
mock_connection = MagicMock(spec=DorisConnection)
|
||||
mock_connection.session_id = "query"
|
||||
|
||||
cache.save(mock_connection)
|
||||
retrieved_conn = cache.get("query")
|
||||
assert retrieved_conn is mock_connection
|
||||
|
||||
mock_user_connection = MagicMock(spec=DorisConnection)
|
||||
mock_user_connection.session_id = "user-test-session-id"
|
||||
cache.save(mock_user_connection)
|
||||
assert cache.get("user-test-session-id") is None
|
||||
|
||||
cache.cache_user_session = True
|
||||
cache.save(mock_user_connection)
|
||||
retrieved_user_conn = cache.get("user-test-session-id")
|
||||
assert retrieved_user_conn is mock_user_connection
|
||||
|
||||
def test_remove_session(self, session_cache):
|
||||
cache, _ = session_cache
|
||||
mock_connection = MagicMock(spec=DorisConnection)
|
||||
mock_connection.session_id = "system"
|
||||
|
||||
cache.save(mock_connection)
|
||||
assert cache.get("system") is not None
|
||||
|
||||
cache.remove("system")
|
||||
assert cache.get("system") is None
|
||||
|
||||
def test_clear_cache(self, session_cache):
|
||||
cache, connection_manager = session_cache
|
||||
mock_conn1 = MagicMock(spec=DorisConnection)
|
||||
mock_conn1.session_id = "query"
|
||||
mock_conn2 = MagicMock(spec=DorisConnection)
|
||||
mock_conn2.session_id = "system"
|
||||
|
||||
cache.save(mock_conn1)
|
||||
cache.save(mock_conn2)
|
||||
assert len(cache.cached) == 2
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert not cache.cached
|
||||
connection_manager.release_connection.assert_any_call("query", mock_conn1)
|
||||
connection_manager.release_connection.assert_any_call("system", mock_conn2)
|
||||
assert connection_manager.release_connection.call_count == 2
|
||||
Reference in New Issue
Block a user