@@ -334,6 +334,7 @@ def __init__(
334334 encoding : str = "utf-8" ,
335335 encoding_errors : str = "strict" ,
336336 decode_responses : bool = False ,
337+ check_server_ready : bool = False ,
337338 parser_class = DefaultParser ,
338339 socket_read_size : int = 65536 ,
339340 health_check_interval : int = 0 ,
@@ -412,6 +413,7 @@ def __init__(
412413 self .redis_connect_func = redis_connect_func
413414 self .encoder = Encoder (encoding , encoding_errors , decode_responses )
414415 self .handshake_metadata = None
416+ self .check_server_ready = check_server_ready
415417 self ._sock = None
416418 self ._socket_read_size = socket_read_size
417419 self ._connect_callbacks = []
@@ -575,17 +577,17 @@ def connect_check_health(
575577 return
576578 try :
577579 if retry_socket_connect :
578- sock = self .retry .call_with_retry (
579- lambda : self ._connect (), lambda error : self .disconnect (error )
580+ self .retry .call_with_retry (
581+ lambda : self ._connect_check_server_ready (),
582+ lambda error : self .disconnect (error ),
580583 )
581584 else :
582- sock = self ._connect ()
585+ self ._connect_check_server_ready ()
583586 except socket .timeout :
584587 raise TimeoutError ("Timeout connecting to server" )
585588 except OSError as e :
586589 raise ConnectionError (self ._error_message (e ))
587590
588- self ._sock = sock
589591 try :
590592 if self .redis_connect_func is None :
591593 # Use the default on_connect function
@@ -607,8 +609,27 @@ def connect_check_health(
607609 if callback :
608610 callback (self )
609611
612+ def _connect_check_server_ready (self ):
613+ self ._connect ()
614+
615+ # Doing handshake since connect and send operations work even when Redis is not ready
616+ if self .check_server_ready :
617+ try :
618+ self .send_command ("PING" , check_health = False )
619+
620+ response = str_if_bytes (self ._sock .recv (1024 ))
621+ if not (response .startswith ("+PONG" ) or response .startswith ("-NOAUTH" )):
622+ raise ResponseError (f"Invalid PING response: { response } " )
623+ except (ConnectionResetError , ResponseError ) as err :
624+ try :
625+ self ._sock .shutdown (socket .SHUT_RDWR ) # ensure a clean close
626+ except OSError :
627+ pass
628+ self ._sock .close ()
629+ raise ConnectionError (self ._error_message (err ))
630+
610631 @abstractmethod
611- def _connect (self ):
632+ def _connect (self ) -> None :
612633 pass
613634
614635 @abstractmethod
@@ -1097,7 +1118,7 @@ def repr_pieces(self):
10971118 pieces .append (("client_name" , self .client_name ))
10981119 return pieces
10991120
1100- def _connect (self ):
1121+ def _connect (self ) -> None :
11011122 "Create a TCP socket connection"
11021123 # we want to mimic what socket.create_connection does to support
11031124 # ipv4/ipv6, but we want to set options prior to calling
@@ -1128,7 +1149,8 @@ def _connect(self):
11281149
11291150 # set the socket_timeout now that we're connected
11301151 sock .settimeout (self .socket_timeout )
1131- return sock
1152+ self ._sock = sock
1153+ return
11321154
11331155 except OSError as _ :
11341156 err = _
@@ -1448,15 +1470,15 @@ def __init__(
14481470 self .ssl_ciphers = ssl_ciphers
14491471 super ().__init__ (** kwargs )
14501472
1451- def _connect (self ):
1473+ def _connect (self ) -> None :
14521474 """
14531475 Wrap the socket with SSL support, handling potential errors.
14541476 """
1455- sock = super ()._connect ()
1477+ super ()._connect ()
14561478 try :
1457- return self ._wrap_socket_with_ssl (sock )
1479+ self . _sock = self ._wrap_socket_with_ssl (self . _sock )
14581480 except (OSError , RedisError ):
1459- sock .close ()
1481+ self . _sock .close ()
14601482 raise
14611483
14621484 def _wrap_socket_with_ssl (self , sock ):
@@ -1559,7 +1581,7 @@ def repr_pieces(self):
15591581 pieces .append (("client_name" , self .client_name ))
15601582 return pieces
15611583
1562- def _connect (self ):
1584+ def _connect (self ) -> None :
15631585 "Create a Unix domain socket connection"
15641586 sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
15651587 sock .settimeout (self .socket_connect_timeout )
@@ -1574,7 +1596,7 @@ def _connect(self):
15741596 sock .close ()
15751597 raise
15761598 sock .settimeout (self .socket_timeout )
1577- return sock
1599+ self . _sock = sock
15781600
15791601 def _host_error (self ):
15801602 return self .path
0 commit comments