@@ -45,10 +45,9 @@ def __init__(self, config: MultiDbConfig):
4545 if not config .health_checks
4646 else config .health_checks
4747 )
48-
4948 self ._health_check_interval = config .health_check_interval
50- self ._health_check_policy : HealthCheckPolicy = config . health_check_policy . value (
51- config .health_check_probes , config . health_check_delay
49+ self ._health_check_policy : HealthCheckPolicy = (
50+ config .health_check_policy . value ()
5251 )
5352 self ._failure_detectors = (
5453 config .default_failure_detectors ()
@@ -89,14 +88,25 @@ async def __aenter__(self: "MultiDBClient") -> "MultiDBClient":
8988 await self .initialize ()
9089 return self
9190
92- async def __aexit__ (self , exc_type , exc_value , traceback ):
91+ async def aclose (self ):
92+ # Cancel background tasks
9393 if self ._recurring_hc_task :
9494 self ._recurring_hc_task .cancel ()
9595 if self ._half_open_state_task :
9696 self ._half_open_state_task .cancel ()
9797 for hc_task in self ._hc_tasks :
9898 hc_task .cancel ()
9999
100+ # Close health check connection pools
101+ await self ._health_check_policy .close ()
102+
103+ # Close database client
104+ if self .command_executor .active_database :
105+ await self .command_executor .active_database .client .aclose ()
106+
107+ async def __aexit__ (self , exc_type , exc_value , traceback ):
108+ await self .aclose ()
109+
100110 async def initialize (self ):
101111 """
102112 Perform initialization of databases to define their initial state.
@@ -326,23 +336,15 @@ async def _check_databases_health(self) -> dict[Database, bool]:
326336 Runs health checks as a recurring task.
327337 Runs health checks against all databases.
328338 """
329- try :
330- task_to_db : dict [asyncio .Task , Database ] = {}
339+ task_to_db : dict [asyncio .Task , Database ] = {}
331340
332- self ._hc_tasks = []
333- for database , _ in self ._databases :
334- task = asyncio .create_task (self ._check_db_health (database ))
335- task_to_db [task ] = database
336- self ._hc_tasks .append (task )
341+ self ._hc_tasks = []
342+ for database , _ in self ._databases :
343+ task = asyncio .create_task (self ._check_db_health (database ))
344+ task_to_db [task ] = database
345+ self ._hc_tasks .append (task )
337346
338- results = await asyncio .wait_for (
339- asyncio .gather (* self ._hc_tasks , return_exceptions = True ),
340- timeout = self ._health_check_interval ,
341- )
342- except asyncio .TimeoutError :
343- raise asyncio .TimeoutError (
344- "Health check execution exceeds health_check_interval"
345- )
347+ results = await asyncio .gather (* self ._hc_tasks , return_exceptions = True )
346348
347349 # Map end results to databases
348350 db_results = {
@@ -360,8 +362,6 @@ async def _check_databases_health(self) -> dict[Database, bool]:
360362 )
361363
362364 db_results [unhealthy_db ] = False
363- elif isinstance (result , Exception ):
364- db_results [database ] = False
365365
366366 return db_results
367367
@@ -427,10 +427,6 @@ def _on_circuit_state_change_callback(
427427 if old_state != CBState .CLOSED and new_state == CBState .CLOSED :
428428 logger .info (f"Database { circuit .database } is reachable again." )
429429
430- async def aclose (self ):
431- if self .command_executor .active_database :
432- await self .command_executor .active_database .client .aclose ()
433-
434430
435431def _half_open_circuit (circuit : CircuitBreaker ):
436432 circuit .state = CBState .HALF_OPEN
0 commit comments