99import asyncpg
1010import collections
1111import collections .abc
12+ import contextlib
1213import functools
1314import itertools
1415import inspect
@@ -53,7 +54,7 @@ class Connection(metaclass=ConnectionMeta):
5354 '_intro_query' , '_reset_query' , '_proxy' ,
5455 '_stmt_exclusive_section' , '_config' , '_params' , '_addr' ,
5556 '_log_listeners' , '_termination_listeners' , '_cancellations' ,
56- '_source_traceback' , '__weakref__' )
57+ '_source_traceback' , '_query_loggers' , ' __weakref__' )
5758
5859 def __init__ (self , protocol , transport , loop ,
5960 addr ,
@@ -87,6 +88,7 @@ def __init__(self, protocol, transport, loop,
8788 self ._log_listeners = set ()
8889 self ._cancellations = set ()
8990 self ._termination_listeners = set ()
91+ self ._query_loggers = set ()
9092
9193 settings = self ._protocol .get_settings ()
9294 ver_string = settings .server_version
@@ -224,6 +226,30 @@ def remove_termination_listener(self, callback):
224226 """
225227 self ._termination_listeners .discard (_Callback .from_callable (callback ))
226228
229+ def add_query_logger (self , callback ):
230+ """Add a logger that will be called when queries are executed.
231+
232+ :param callable callback:
233+ A callable or a coroutine function receiving one argument:
234+ **record**: a LoggedQuery containing `query`, `args`, `timeout`,
235+ `elapsed`, `exception`, `conn_addr`, and
236+ `conn_params`.
237+
238+ .. versionadded:: 0.29.0
239+ """
240+ self ._query_loggers .add (_Callback .from_callable (callback ))
241+
242+ def remove_query_logger (self , callback ):
243+ """Remove a query logger callback.
244+
245+ :param callable callback:
246+ The callable or coroutine function that was passed to
247+ :meth:`Connection.add_query_logger`.
248+
249+ .. versionadded:: 0.29.0
250+ """
251+ self ._query_loggers .discard (_Callback .from_callable (callback ))
252+
227253 def get_server_pid (self ):
228254 """Return the PID of the Postgres server the connection is bound to."""
229255 return self ._protocol .get_server_pid ()
@@ -317,7 +343,12 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
317343 self ._check_open ()
318344
319345 if not args :
320- return await self ._protocol .query (query , timeout )
346+ if self ._query_loggers :
347+ with self ._time_and_log (query , args , timeout ):
348+ result = await self ._protocol .query (query , timeout )
349+ else :
350+ result = await self ._protocol .query (query , timeout )
351+ return result
321352
322353 _ , status , _ = await self ._execute (
323354 query ,
@@ -1487,6 +1518,7 @@ def _cleanup(self):
14871518 self ._mark_stmts_as_closed ()
14881519 self ._listeners .clear ()
14891520 self ._log_listeners .clear ()
1521+ self ._query_loggers .clear ()
14901522 self ._clean_tasks ()
14911523
14921524 def _clean_tasks (self ):
@@ -1770,6 +1802,63 @@ async def _execute(
17701802 )
17711803 return result
17721804
1805+ @contextlib .contextmanager
1806+ def query_logger (self , callback ):
1807+ """Context manager that adds `callback` to the list of query loggers,
1808+ and removes it upon exit.
1809+
1810+ :param callable callback:
1811+ A callable or a coroutine function receiving one argument:
1812+ **record**: a LoggedQuery containing `query`, `args`, `timeout`,
1813+ `elapsed`, `exception`, `conn_addr`, and
1814+ `conn_params`.
1815+
1816+ Example:
1817+
1818+ .. code-block:: pycon
1819+
1820+ >>> class QuerySaver:
1821+ def __init__(self):
1822+ self.queries = []
1823+ def __call__(self, record):
1824+ self.queries.append(record.query)
1825+ >>> with con.query_logger(QuerySaver()):
1826+ >>> await con.execute("SELECT 1")
1827+ >>> print(log.queries)
1828+ ['SELECT 1']
1829+
1830+ .. versionadded:: 0.29.0
1831+ """
1832+ self .add_query_logger (callback )
1833+ yield
1834+ self .remove_query_logger (callback )
1835+
1836+ @contextlib .contextmanager
1837+ def _time_and_log (self , query , args , timeout ):
1838+ start = time .monotonic ()
1839+ exception = None
1840+ try :
1841+ yield
1842+ except BaseException as ex :
1843+ exception = ex
1844+ raise
1845+ finally :
1846+ elapsed = time .monotonic () - start
1847+ record = LoggedQuery (
1848+ query = query ,
1849+ args = args ,
1850+ timeout = timeout ,
1851+ elapsed = elapsed ,
1852+ exception = exception ,
1853+ conn_addr = self ._addr ,
1854+ conn_params = self ._params ,
1855+ )
1856+ for cb in self ._query_loggers :
1857+ if cb .is_async :
1858+ self ._loop .create_task (cb .cb (record ))
1859+ else :
1860+ self ._loop .call_soon (cb .cb , record )
1861+
17731862 async def __execute (
17741863 self ,
17751864 query ,
@@ -1790,13 +1879,24 @@ async def __execute(
17901879 timeout = timeout ,
17911880 )
17921881 timeout = self ._protocol ._get_timeout (timeout )
1793- return await self ._do_execute (
1794- query ,
1795- executor ,
1796- timeout ,
1797- record_class = record_class ,
1798- ignore_custom_codec = ignore_custom_codec ,
1799- )
1882+ if self ._query_loggers :
1883+ with self ._time_and_log (query , args , timeout ):
1884+ result , stmt = await self ._do_execute (
1885+ query ,
1886+ executor ,
1887+ timeout ,
1888+ record_class = record_class ,
1889+ ignore_custom_codec = ignore_custom_codec ,
1890+ )
1891+ else :
1892+ result , stmt = await self ._do_execute (
1893+ query ,
1894+ executor ,
1895+ timeout ,
1896+ record_class = record_class ,
1897+ ignore_custom_codec = ignore_custom_codec ,
1898+ )
1899+ return result , stmt
18001900
18011901 async def _executemany (self , query , args , timeout ):
18021902 executor = lambda stmt , timeout : self ._protocol .bind_execute_many (
@@ -1807,7 +1907,8 @@ async def _executemany(self, query, args, timeout):
18071907 )
18081908 timeout = self ._protocol ._get_timeout (timeout )
18091909 with self ._stmt_exclusive_section :
1810- result , _ = await self ._do_execute (query , executor , timeout )
1910+ with self ._time_and_log (query , args , timeout ):
1911+ result , _ = await self ._do_execute (query , executor , timeout )
18111912 return result
18121913
18131914 async def _do_execute (
@@ -2440,6 +2541,13 @@ class _ConnectionProxy:
24402541 __slots__ = ()
24412542
24422543
2544+ LoggedQuery = collections .namedtuple (
2545+ 'LoggedQuery' ,
2546+ ['query' , 'args' , 'timeout' , 'elapsed' , 'exception' , 'conn_addr' ,
2547+ 'conn_params' ])
2548+ LoggedQuery .__doc__ = 'Log record of an executed query.'
2549+
2550+
24432551ServerCapabilities = collections .namedtuple (
24442552 'ServerCapabilities' ,
24452553 ['advisory_locks' , 'notifications' , 'plpgsql' , 'sql_reset' ,
0 commit comments