From cc84d605e5e1417b2259534f861659d31129017c Mon Sep 17 00:00:00 2001 From: ivin Date: Mon, 11 Aug 2025 13:39:30 +0800 Subject: [PATCH] [feature]Implement session cache for Doris connections (#44) * [feature]Implement session cache for Doris connections This PR introduces a `DorisSessionCache` to cache and reuse `DorisConnection` objects in memory. This helps to reduce the overhead of creating new connections, especially for frequently used system sessions like "query" and "system", and avoid not calling release_connection leads to `Connection acquisition timed out` when the number of connection pools reaches the maximum value. The PR #34 fixed the issue when calling the tool `exec_query`, but in the codebase, a large number of other tools directly using get_connection ("query") to get connection object but without calling the release_connection method will cause the connection to fail to be obtained after a certain number of times. Key changes: - Added `DorisSessionCache` class to manage the lifecycle of cached sessions. - The cache is configurable to store system sessions, user sessions, or both. By default, only system sessions are cached. - Integrated the session cache into `DorisConnectionManager`. - `get_connection` now checks the cache before creating a new connection. - `release_connection` removes the connection from the cache. * Add tests --- doris_mcp_server/utils/analysis_tools.py | 4 +- doris_mcp_server/utils/db.py | 68 +++++++++++++++++++-- test/utils/test_db.py | 78 ++++++++++++++++++++++++ 3 files changed, 144 insertions(+), 6 deletions(-) create mode 100644 test/utils/test_db.py diff --git a/doris_mcp_server/utils/analysis_tools.py b/doris_mcp_server/utils/analysis_tools.py index 84715e2..b1c9283 100644 --- a/doris_mcp_server/utils/analysis_tools.py +++ b/doris_mcp_server/utils/analysis_tools.py @@ -516,7 +516,7 @@ class SQLAnalyzer: try: # Switch to specified database/catalog if provided if catalog_name: - await connection.execute(f"USE `{catalog_name}`") + await connection.execute(f"SWITCH `{catalog_name}`") if db_name: await connection.execute(f"USE `{db_name}`") @@ -1237,4 +1237,4 @@ class MemoryTracker: "tracker_names": tracker_names, "time_range": time_range, "timestamp": datetime.now().isoformat() - } \ No newline at end of file + } diff --git a/doris_mcp_server/utils/db.py b/doris_mcp_server/utils/db.py index c71928a..8cc050b 100644 --- a/doris_mcp_server/utils/db.py +++ b/doris_mcp_server/utils/db.py @@ -28,8 +28,7 @@ import time from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime -from typing import Any, Dict, List -import random +from typing import Any, Dict, Optional import aiomysql from aiomysql import Connection, Pool @@ -191,6 +190,51 @@ class DorisConnection: logging.error(f"Error occurred while closing connection: {e}") +class DorisSessionCache: + """Doris database session cache + + Save doris session in memory and get session by session id. + Provide cache_system_session/cache_user_session to specify whether to save system/user type sessions. + By default, only session_id is "query" or "system" will be saved. + """ + + def __init__(self, connection_manager=None, cache_system_session=True, cache_user_session=False): + self.logger = get_logger(__name__) + self.cached = {} + self.connection_manager = connection_manager + self.cache_system_session = cache_system_session + self.cache_user_session = cache_user_session + self.logger.info(f"Session Cache initialized, save system session: {self.cache_system_session}, save user session: {self.cache_user_session}") + + def save(self, connection: DorisConnection): + if self._should_cache(connection.session_id): + self.cached[connection.session_id] = connection + + def get(self, session_id: str) -> Optional[DorisConnection]: + self.logger.debug(f"Use cached connection: {session_id}") + return self.cached.get(session_id) + + def remove(self, session_id): + if session_id in self.cached: + del self.cached[session_id] + self.logger.debug(f"Removed session {session_id} from cache.") + else: + if self._should_cache(session_id): + self.logger.warning(f"Session {session_id} is not existed.") + + def clear(self): + if self.connection_manager: + for k, v in self.cached.items(): + self.connection_manager.release_connection(k, v) + self.cached = {} + + def _is_system_session(self, session_id) -> bool: + return session_id in ["query", "system"] + + def _should_cache(self, session_id): + return (self.cache_system_session and self._is_system_session(session_id)) or (self.cache_user_session and not self._is_system_session(session_id)) + + class DorisConnectionManager: """Doris database connection manager - Enhanced Strategy @@ -198,11 +242,13 @@ class DorisConnectionManager: Implements connection pool health monitoring and proactive cleanup """ + def __init__(self, config, security_manager=None): self.config = config self.pool: Pool | None = None self.logger = get_logger(__name__) self.security_manager = security_manager + self.session_cache = DorisSessionCache(self) # Connection pool state management self.pool_recovering = False @@ -521,8 +567,13 @@ class DorisConnectionManager: async def get_connection(self, session_id: str) -> DorisConnection: """🔧 FIX: Simplified connection acquisition without double locking - Uses only semaphore to prevent too many concurrent acquisitions + Uses only semaphore to prevent too many concurrent acquisitions. + If the connection is successfully obtained, it will be added to the connection pool cache. """ + cached_conn = self.session_cache.get(session_id) + if cached_conn: + return cached_conn + # 🔧 FIX: Use only semaphore to limit concurrent acquisitions (remove double locking) async with self._connection_semaphore: try: @@ -582,6 +633,8 @@ class DorisConnectionManager: raise RuntimeError("Acquired connection is already closed") self.logger.debug(f"✅ Acquired fresh connection for session {session_id}") + + self.session_cache.save(doris_conn) return doris_conn except Exception as e: @@ -590,6 +643,13 @@ class DorisConnectionManager: async def release_connection(self, session_id: str, connection: DorisConnection): """🔧 FIX: Release connection back to pool with proper error handling""" + cached_conn = self.session_cache.get(session_id) + if cached_conn: + self.session_cache.remove(session_id) + if not (cached_conn is connection): + self.logger.warning("Invalid connection") + connection = cached_conn + if not connection or not connection.connection: self.logger.debug(f"No connection to release for session {session_id}") return @@ -811,4 +871,4 @@ class ConnectionPoolMonitor: if pool_status["free_connections"] == 0: report["recommendations"].append("No free connections available, consider increasing pool size") - return report \ No newline at end of file + return report diff --git a/test/utils/test_db.py b/test/utils/test_db.py new file mode 100644 index 0000000..4ad721a --- /dev/null +++ b/test/utils/test_db.py @@ -0,0 +1,78 @@ +from unittest.mock import MagicMock +import pytest + +from doris_mcp_server.utils.db import DorisConnection, DorisSessionCache + + +@pytest.fixture +def session_cache(): + """Provides a DorisSessionCache instance with a mock connection manager.""" + connection_manager = MagicMock() + cache = DorisSessionCache(connection_manager=connection_manager) + yield cache, connection_manager + + +class TestDorisSessionCache: + + def test_initialization(self, session_cache): + cache, _ = session_cache + assert cache.cache_system_session is True + assert cache.cache_user_session is False + assert not cache.cached + + def test_should_cache(self, session_cache): + cache, _ = session_cache + assert cache._should_cache("query") is True + assert cache._should_cache("system") is True + assert cache._should_cache("user-test-session-id") is False + + cache.cache_user_session = True + assert cache._should_cache("user-test-session-id") is True + + def test_save_and_get_session(self, session_cache): + cache, _ = session_cache + mock_connection = MagicMock(spec=DorisConnection) + mock_connection.session_id = "query" + + cache.save(mock_connection) + retrieved_conn = cache.get("query") + assert retrieved_conn is mock_connection + + mock_user_connection = MagicMock(spec=DorisConnection) + mock_user_connection.session_id = "user-test-session-id" + cache.save(mock_user_connection) + assert cache.get("user-test-session-id") is None + + cache.cache_user_session = True + cache.save(mock_user_connection) + retrieved_user_conn = cache.get("user-test-session-id") + assert retrieved_user_conn is mock_user_connection + + def test_remove_session(self, session_cache): + cache, _ = session_cache + mock_connection = MagicMock(spec=DorisConnection) + mock_connection.session_id = "system" + + cache.save(mock_connection) + assert cache.get("system") is not None + + cache.remove("system") + assert cache.get("system") is None + + def test_clear_cache(self, session_cache): + cache, connection_manager = session_cache + mock_conn1 = MagicMock(spec=DorisConnection) + mock_conn1.session_id = "query" + mock_conn2 = MagicMock(spec=DorisConnection) + mock_conn2.session_id = "system" + + cache.save(mock_conn1) + cache.save(mock_conn2) + assert len(cache.cached) == 2 + + cache.clear() + + assert not cache.cached + connection_manager.release_connection.assert_any_call("query", mock_conn1) + connection_manager.release_connection.assert_any_call("system", mock_conn2) + assert connection_manager.release_connection.call_count == 2