@@ -42,11 +42,12 @@ class Connection(metaclass=ConnectionMeta):
4242 '_stmt_cache' , '_stmts_to_close' ,
4343 '_addr' , '_opts' , '_command_timeout' , '_listeners' ,
4444 '_server_version' , '_server_caps' , '_intro_query' ,
45- '_reset_query' , '_proxy' , '_stmt_exclusive_section' )
45+ '_reset_query' , '_proxy' , '_stmt_exclusive_section' ,
46+ '_ssl_context' )
4647
4748 def __init__ (self , protocol , transport , loop , addr , opts , * ,
4849 statement_cache_size , command_timeout ,
49- max_cached_statement_lifetime ):
50+ max_cached_statement_lifetime , ssl_context ):
5051 self ._protocol = protocol
5152 self ._transport = transport
5253 self ._loop = loop
@@ -58,6 +59,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
5859
5960 self ._addr = addr
6061 self ._opts = opts
62+ self ._ssl_context = ssl_context
6163
6264 self ._stmt_cache = _StatementCache (
6365 loop = loop ,
@@ -521,12 +523,24 @@ async def cancel():
521523 r , w = await asyncio .open_unix_connection (
522524 self ._addr , loop = self ._loop )
523525 else :
524- r , w = await asyncio .open_connection (
525- * self ._addr , loop = self ._loop )
526-
527- sock = w .transport .get_extra_info ('socket' )
528- sock .setsockopt (socket .IPPROTO_TCP ,
529- socket .TCP_NODELAY , 1 )
526+ if self ._ssl_context :
527+ sock = await _get_ssl_ready_socket (
528+ * self ._addr , loop = self ._loop )
529+
530+ try :
531+ r , w = await asyncio .open_connection (
532+ sock = sock ,
533+ loop = self ._loop ,
534+ ssl = self ._ssl_context ,
535+ server_hostname = self ._addr [0 ])
536+ except Exception :
537+ sock .close ()
538+ raise
539+
540+ else :
541+ r , w = await asyncio .open_connection (
542+ * self ._addr , loop = self ._loop )
543+ _set_nodelay (_get_socket (w .transport ))
530544
531545 # Pack CancelRequest message
532546 msg = struct .pack ('!llll' , 16 , 80877102 ,
@@ -708,9 +722,10 @@ async def connect(dsn=None, *,
708722 statement_cache_size = 100 ,
709723 max_cached_statement_lifetime = 300 ,
710724 command_timeout = None ,
725+ ssl = None ,
711726 __connection_class__ = Connection ,
712727 ** opts ):
713- """A coroutine to establish a connection to a PostgreSQL server.
728+ r """A coroutine to establish a connection to a PostgreSQL server.
714729
715730 Returns a new :class:`~asyncpg.connection.Connection` object.
716731
@@ -761,6 +776,12 @@ async def connect(dsn=None, *,
761776 the default timeout for operations on this connection
762777 (the default is no timeout).
763778
779+ :param ssl:
780+ pass ``True`` or an `ssl.SSLContext <SSLContext_>`_ instance to
781+ require an SSL connection. If ``True``, a default SSL context
782+ returned by `ssl.create_default_context() <create_default_context_>`_
783+ will be used.
784+
764785 :return: A :class:`~asyncpg.connection.Connection` instance.
765786
766787 Example:
@@ -778,42 +799,51 @@ async def connect(dsn=None, *,
778799
779800 .. versionchanged:: 0.10.0
780801 Added ``max_cached_statement_use_count`` parameter.
802+
803+ .. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
804+ .. _create_default_context: https://docs.python.org/3/library/ssl.html#\
805+ ssl.create_default_context
781806 """
782807 if loop is None :
783808 loop = asyncio .get_event_loop ()
784809
785- host , port , opts = _parse_connect_params (
810+ addrs , opts = _parse_connect_params (
786811 dsn = dsn , host = host , port = port , user = user , password = password ,
787812 database = database , opts = opts )
788813
789- last_ex = None
814+ if ssl :
815+ for addr in addrs :
816+ if isinstance (addr , str ):
817+ # UNIX socket
818+ raise exceptions .InterfaceError (
819+ '`ssl` parameter can only be enabled for TCP addresses, '
820+ 'got a UNIX socket path: {!r}' .format (addr ))
821+
822+ last_error = None
790823 addr = None
791- for h in host :
824+ for addr in addrs :
792825 connected = _create_future (loop )
793- unix = h .startswith ('/' )
794-
795- if unix :
796- # UNIX socket name
797- addr = h
798- if '.s.PGSQL.' not in addr :
799- addr = os .path .join (addr , '.s.PGSQL.{}' .format (port ))
800- conn = loop .create_unix_connection (
801- lambda : protocol .Protocol (addr , connected , opts , loop ),
802- addr )
826+ proto_factory = lambda : protocol .Protocol (addr , connected , opts , loop )
827+
828+ if isinstance (addr , str ):
829+ # UNIX socket
830+ assert ssl is None
831+ connector = loop .create_unix_connection (proto_factory , addr )
832+ elif ssl :
833+ connector = _create_ssl_connection (
834+ proto_factory , * addr , loop = loop , ssl_context = ssl )
803835 else :
804- addr = (h , port )
805- conn = loop .create_connection (
806- lambda : protocol .Protocol (addr , connected , opts , loop ),
807- h , port )
836+ connector = loop .create_connection (proto_factory , * addr )
808837
809838 try :
810- tr , pr = await asyncio .wait_for (conn , timeout = timeout , loop = loop )
811- except (OSError , asyncio .TimeoutError ) as ex :
812- last_ex = ex
839+ tr , pr = await asyncio .wait_for (
840+ connector , timeout = timeout , loop = loop )
841+ except (OSError , asyncio .TimeoutError , ConnectionError ) as ex :
842+ last_error = ex
813843 else :
814844 break
815845 else :
816- raise last_ex
846+ raise last_error
817847
818848 try :
819849 await connected
@@ -825,12 +855,60 @@ async def connect(dsn=None, *,
825855 pr , tr , loop , addr , opts ,
826856 statement_cache_size = statement_cache_size ,
827857 max_cached_statement_lifetime = max_cached_statement_lifetime ,
828- command_timeout = command_timeout )
858+ command_timeout = command_timeout , ssl_context = ssl )
829859
830860 pr .set_connection (con )
831861 return con
832862
833863
864+ async def _get_ssl_ready_socket (host , port , * , loop ):
865+ reader , writer = await asyncio .open_connection (host , port , loop = loop )
866+
867+ tr = writer .transport
868+ try :
869+ sock = _get_socket (tr )
870+ _set_nodelay (sock )
871+
872+ writer .write (struct .pack ('!ll' , 8 , 80877103 )) # SSLRequest message.
873+ await writer .drain ()
874+ resp = await reader .readexactly (1 )
875+
876+ if resp == b'S' :
877+ return sock .dup ()
878+ else :
879+ raise ConnectionError (
880+ 'PostgreSQL server at "{}:{}" rejected SSL upgrade' .format (
881+ host , port ))
882+ finally :
883+ tr .close ()
884+
885+
886+ async def _create_ssl_connection (protocol_factory , host , port , * ,
887+ loop , ssl_context ):
888+ sock = await _get_ssl_ready_socket (host , port , loop = loop )
889+ try :
890+ return await loop .create_connection (
891+ protocol_factory , sock = sock , ssl = ssl_context ,
892+ server_hostname = host )
893+ except Exception :
894+ sock .close ()
895+ raise
896+
897+
898+ def _get_socket (transport ):
899+ sock = transport .get_extra_info ('socket' )
900+ if sock is None :
901+ # Shouldn't happen with any asyncio-complaint event loop.
902+ raise ConnectionError (
903+ 'could not get the socket for transport {!r}' .format (transport ))
904+ return sock
905+
906+
907+ def _set_nodelay (sock ):
908+ if not hasattr (socket , 'AF_UNIX' ) or sock .family != socket .AF_UNIX :
909+ sock .setsockopt (socket .IPPROTO_TCP , socket .TCP_NODELAY , 1 )
910+
911+
834912class _StatementCacheEntry :
835913
836914 __slots__ = ('_query' , '_statement' , '_cache' , '_cleanup_cb' )
@@ -1116,7 +1194,18 @@ def _parse_connect_params(*, dsn, host, port, user,
11161194 'invalid connection parameter {!r}: {!r} (str expected)'
11171195 .format (param , opts [param ]))
11181196
1119- return host , port , opts
1197+ addrs = []
1198+ for h in host :
1199+ if h .startswith ('/' ):
1200+ # UNIX socket name
1201+ if '.s.PGSQL.' not in h :
1202+ h = os .path .join (h , '.s.PGSQL.{}' .format (port ))
1203+ addrs .append (h )
1204+ else :
1205+ # TCP host/port
1206+ addrs .append ((h , port ))
1207+
1208+ return addrs , opts
11201209
11211210
11221211def _create_future (loop ):
0 commit comments