@@ -26,12 +26,13 @@ class Pool:
2626 '_connect_args' , '_connect_kwargs' ,
2727 '_working_addr' , '_working_opts' ,
2828 '_con_count' , '_max_queries' , '_connections' ,
29- '_initialized' , '_closed' )
29+ '_initialized' , '_closed' , '_setup' )
3030
3131 def __init__ (self , * connect_args ,
3232 min_size ,
3333 max_size ,
3434 max_queries ,
35+ setup ,
3536 loop ,
3637 ** connect_kwargs ):
3738
@@ -55,6 +56,8 @@ def __init__(self, *connect_args,
5556 self ._maxsize = max_size
5657 self ._max_queries = max_queries
5758
59+ self ._setup = setup
60+
5861 self ._connect_args = connect_args
5962 self ._connect_kwargs = connect_kwargs
6063
@@ -65,10 +68,9 @@ def __init__(self, *connect_args,
6568
6669 self ._closed = False
6770
68- async def _new_connection (self , timeout = None ):
71+ async def _new_connection (self ):
6972 if self ._working_addr is None :
7073 con = await connection .connect (* self ._connect_args ,
71- timeout = timeout ,
7274 loop = self ._loop ,
7375 ** self ._connect_kwargs )
7476
@@ -83,7 +85,6 @@ async def _new_connection(self, timeout=None):
8385 host , port = self ._working_addr
8486
8587 con = await connection .connect (host = host , port = port ,
86- timeout = timeout ,
8788 loop = self ._loop ,
8889 ** self ._working_opts )
8990
@@ -134,27 +135,40 @@ def acquire(self, *, timeout=None):
134135 return PoolAcquireContext (self , timeout )
135136
136137 async def _acquire (self , timeout ):
138+ if timeout is None :
139+ return await self ._acquire_impl ()
140+ else :
141+ return await asyncio .wait_for (self ._acquire_impl (),
142+ timeout = timeout ,
143+ loop = self ._loop )
144+
145+ async def _acquire_impl (self ):
137146 self ._check_init ()
138147
139148 try :
140- return self ._queue .get_nowait ()
149+ con = self ._queue .get_nowait ()
141150 except asyncio .QueueEmpty :
142- pass
151+ con = None
152+
153+ if con is None :
154+ if self ._con_count < self ._maxsize :
155+ self ._con_count += 1
156+ try :
157+ con = await self ._new_connection ()
158+ except :
159+ self ._con_count -= 1
160+ raise
161+ else :
162+ con = await self ._queue .get ()
143163
144- if self ._con_count < self ._maxsize :
145- self ._con_count += 1
164+ if self ._setup is not None :
146165 try :
147- con = await self ._new_connection ( timeout = timeout )
166+ await self ._setup ( con )
148167 except :
149- self ._con_count -= 1
168+ await self .release ( con )
150169 raise
151- return con
152170
153- if timeout is None :
154- return await self ._queue .get ()
155- else :
156- return await asyncio .wait_for (self ._queue .get (), timeout = timeout ,
157- loop = self ._loop )
171+ return con
158172
159173 async def release (self , connection ):
160174 """Release a database connection back to the pool."""
@@ -246,6 +260,7 @@ def create_pool(dsn=None, *,
246260 min_size = 10 ,
247261 max_size = 10 ,
248262 max_queries = 50000 ,
263+ setup = None ,
249264 loop = None ,
250265 ** connect_kwargs ):
251266 r"""Create a connection pool.
@@ -281,11 +296,16 @@ def create_pool(dsn=None, *,
281296 :param int max_size: Max number of connections in the pool.
282297 :param int max_queries: Number of queries after a connection is closed
283298 and replaced with a new connection.
299+ :param coroutine setup: A coroutine to initialize a connection right before
300+ it is returned from :meth:`~pool.Pool.acquire`.
301+ An example use case would be to automatically
302+ set up notifications listeners for all connections
303+ of a pool.
284304 :param loop: An asyncio event loop instance. If ``None``, the default
285305 event loop will be used.
286306 :return: An instance of :class:`~asyncpg.pool.Pool`.
287307 """
288308 return Pool (dsn ,
289309 min_size = min_size , max_size = max_size ,
290- max_queries = max_queries , loop = loop ,
310+ max_queries = max_queries , loop = loop , setup = setup ,
291311 ** connect_kwargs )
0 commit comments