[feature]Implement session cache for Doris connections (#44)

* [feature]Implement session cache for Doris connections

This PR introduces a `DorisSessionCache` to cache and reuse `DorisConnection` objects in memory.
This helps to reduce the overhead of creating new connections, especially for frequently used system sessions like "query"
and "system", and avoid not calling release_connection leads to `Connection acquisition timed out` when the number of
connection pools reaches the maximum value.

The PR #34 fixed the issue when calling the tool `exec_query`, but in the codebase, a large number of other tools directly
using get_connection ("query") to get connection object but without calling the release_connection method will cause the
connection to fail to be obtained after a certain number of times.

Key changes:
- Added `DorisSessionCache` class to manage the lifecycle of cached sessions.
- The cache is configurable to store system sessions, user sessions, or both. By default, only system sessions are cached.
- Integrated the session cache into `DorisConnectionManager`.
- `get_connection` now checks the cache before creating a new connection.
- `release_connection` removes the connection from the cache.

* Add tests
This commit is contained in:
ivin
2025-08-11 13:39:30 +08:00
committed by GitHub
parent 55dbdd5e14
commit cc84d605e5
3 changed files with 144 additions and 6 deletions

View File

@@ -516,7 +516,7 @@ class SQLAnalyzer:
try: try:
# Switch to specified database/catalog if provided # Switch to specified database/catalog if provided
if catalog_name: if catalog_name:
await connection.execute(f"USE `{catalog_name}`") await connection.execute(f"SWITCH `{catalog_name}`")
if db_name: if db_name:
await connection.execute(f"USE `{db_name}`") await connection.execute(f"USE `{db_name}`")

View File

@@ -28,8 +28,7 @@ import time
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List from typing import Any, Dict, Optional
import random
import aiomysql import aiomysql
from aiomysql import Connection, Pool from aiomysql import Connection, Pool
@@ -191,6 +190,51 @@ class DorisConnection:
logging.error(f"Error occurred while closing connection: {e}") logging.error(f"Error occurred while closing connection: {e}")
class DorisSessionCache:
"""Doris database session cache
Save doris session in memory and get session by session id.
Provide cache_system_session/cache_user_session to specify whether to save system/user type sessions.
By default, only session_id is "query" or "system" will be saved.
"""
def __init__(self, connection_manager=None, cache_system_session=True, cache_user_session=False):
self.logger = get_logger(__name__)
self.cached = {}
self.connection_manager = connection_manager
self.cache_system_session = cache_system_session
self.cache_user_session = cache_user_session
self.logger.info(f"Session Cache initialized, save system session: {self.cache_system_session}, save user session: {self.cache_user_session}")
def save(self, connection: DorisConnection):
if self._should_cache(connection.session_id):
self.cached[connection.session_id] = connection
def get(self, session_id: str) -> Optional[DorisConnection]:
self.logger.debug(f"Use cached connection: {session_id}")
return self.cached.get(session_id)
def remove(self, session_id):
if session_id in self.cached:
del self.cached[session_id]
self.logger.debug(f"Removed session {session_id} from cache.")
else:
if self._should_cache(session_id):
self.logger.warning(f"Session {session_id} is not existed.")
def clear(self):
if self.connection_manager:
for k, v in self.cached.items():
self.connection_manager.release_connection(k, v)
self.cached = {}
def _is_system_session(self, session_id) -> bool:
return session_id in ["query", "system"]
def _should_cache(self, session_id):
return (self.cache_system_session and self._is_system_session(session_id)) or (self.cache_user_session and not self._is_system_session(session_id))
class DorisConnectionManager: class DorisConnectionManager:
"""Doris database connection manager - Enhanced Strategy """Doris database connection manager - Enhanced Strategy
@@ -198,11 +242,13 @@ class DorisConnectionManager:
Implements connection pool health monitoring and proactive cleanup Implements connection pool health monitoring and proactive cleanup
""" """
def __init__(self, config, security_manager=None): def __init__(self, config, security_manager=None):
self.config = config self.config = config
self.pool: Pool | None = None self.pool: Pool | None = None
self.logger = get_logger(__name__) self.logger = get_logger(__name__)
self.security_manager = security_manager self.security_manager = security_manager
self.session_cache = DorisSessionCache(self)
# Connection pool state management # Connection pool state management
self.pool_recovering = False self.pool_recovering = False
@@ -521,8 +567,13 @@ class DorisConnectionManager:
async def get_connection(self, session_id: str) -> DorisConnection: async def get_connection(self, session_id: str) -> DorisConnection:
"""🔧 FIX: Simplified connection acquisition without double locking """🔧 FIX: Simplified connection acquisition without double locking
Uses only semaphore to prevent too many concurrent acquisitions Uses only semaphore to prevent too many concurrent acquisitions.
If the connection is successfully obtained, it will be added to the connection pool cache.
""" """
cached_conn = self.session_cache.get(session_id)
if cached_conn:
return cached_conn
# 🔧 FIX: Use only semaphore to limit concurrent acquisitions (remove double locking) # 🔧 FIX: Use only semaphore to limit concurrent acquisitions (remove double locking)
async with self._connection_semaphore: async with self._connection_semaphore:
try: try:
@@ -582,6 +633,8 @@ class DorisConnectionManager:
raise RuntimeError("Acquired connection is already closed") raise RuntimeError("Acquired connection is already closed")
self.logger.debug(f"✅ Acquired fresh connection for session {session_id}") self.logger.debug(f"✅ Acquired fresh connection for session {session_id}")
self.session_cache.save(doris_conn)
return doris_conn return doris_conn
except Exception as e: except Exception as e:
@@ -590,6 +643,13 @@ class DorisConnectionManager:
async def release_connection(self, session_id: str, connection: DorisConnection): async def release_connection(self, session_id: str, connection: DorisConnection):
"""🔧 FIX: Release connection back to pool with proper error handling""" """🔧 FIX: Release connection back to pool with proper error handling"""
cached_conn = self.session_cache.get(session_id)
if cached_conn:
self.session_cache.remove(session_id)
if not (cached_conn is connection):
self.logger.warning("Invalid connection")
connection = cached_conn
if not connection or not connection.connection: if not connection or not connection.connection:
self.logger.debug(f"No connection to release for session {session_id}") self.logger.debug(f"No connection to release for session {session_id}")
return return

78
test/utils/test_db.py Normal file
View File

@@ -0,0 +1,78 @@
from unittest.mock import MagicMock
import pytest
from doris_mcp_server.utils.db import DorisConnection, DorisSessionCache
@pytest.fixture
def session_cache():
"""Provides a DorisSessionCache instance with a mock connection manager."""
connection_manager = MagicMock()
cache = DorisSessionCache(connection_manager=connection_manager)
yield cache, connection_manager
class TestDorisSessionCache:
def test_initialization(self, session_cache):
cache, _ = session_cache
assert cache.cache_system_session is True
assert cache.cache_user_session is False
assert not cache.cached
def test_should_cache(self, session_cache):
cache, _ = session_cache
assert cache._should_cache("query") is True
assert cache._should_cache("system") is True
assert cache._should_cache("user-test-session-id") is False
cache.cache_user_session = True
assert cache._should_cache("user-test-session-id") is True
def test_save_and_get_session(self, session_cache):
cache, _ = session_cache
mock_connection = MagicMock(spec=DorisConnection)
mock_connection.session_id = "query"
cache.save(mock_connection)
retrieved_conn = cache.get("query")
assert retrieved_conn is mock_connection
mock_user_connection = MagicMock(spec=DorisConnection)
mock_user_connection.session_id = "user-test-session-id"
cache.save(mock_user_connection)
assert cache.get("user-test-session-id") is None
cache.cache_user_session = True
cache.save(mock_user_connection)
retrieved_user_conn = cache.get("user-test-session-id")
assert retrieved_user_conn is mock_user_connection
def test_remove_session(self, session_cache):
cache, _ = session_cache
mock_connection = MagicMock(spec=DorisConnection)
mock_connection.session_id = "system"
cache.save(mock_connection)
assert cache.get("system") is not None
cache.remove("system")
assert cache.get("system") is None
def test_clear_cache(self, session_cache):
cache, connection_manager = session_cache
mock_conn1 = MagicMock(spec=DorisConnection)
mock_conn1.session_id = "query"
mock_conn2 = MagicMock(spec=DorisConnection)
mock_conn2.session_id = "system"
cache.save(mock_conn1)
cache.save(mock_conn2)
assert len(cache.cached) == 2
cache.clear()
assert not cache.cached
connection_manager.release_connection.assert_any_call("query", mock_conn1)
connection_manager.release_connection.assert_any_call("system", mock_conn2)
assert connection_manager.release_connection.call_count == 2