|
34 | 34 | ) |
35 | 35 |
|
36 | 36 | from bson import DEFAULT_CODEC_OPTIONS |
37 | | -from pymongo import _csot, helpers_shared |
| 37 | +from pymongo import _csot, helpers_shared, network_layer |
38 | 38 | from pymongo.asynchronous.client_session import _validate_session_write_concern |
39 | 39 | from pymongo.asynchronous.helpers import _handle_reauth |
40 | 40 | from pymongo.asynchronous.network import command |
@@ -188,6 +188,41 @@ def __init__( |
188 | 188 | self.creation_time = time.monotonic() |
189 | 189 | # For gossiping $clusterTime from the connection handshake to the client. |
190 | 190 | self._cluster_time = None |
| 191 | + self.pending_response = False |
| 192 | + self.pending_bytes = 0 |
| 193 | + self.pending_deadline = 0.0 |
| 194 | + |
| 195 | + def mark_pending(self, nbytes: int) -> None: |
| 196 | + """Mark this connection as having a pending response.""" |
| 197 | + # TODO: add "if self.enable_pending:" |
| 198 | + self.pending_response = True |
| 199 | + self.pending_bytes = nbytes |
| 200 | + self.pending_deadline = time.monotonic() + 3 # 3 seconds timeout for pending response |
| 201 | + |
| 202 | + async def complete_pending(self) -> None: |
| 203 | + """Complete a pending response.""" |
| 204 | + if not self.pending_response: |
| 205 | + return |
| 206 | + |
| 207 | + timeout: Optional[Union[float, int]] |
| 208 | + timeout = self.conn.gettimeout |
| 209 | + if _csot.get_timeout(): |
| 210 | + deadline = min(_csot.get_deadline(), self.pending_deadline) |
| 211 | + elif timeout: |
| 212 | + deadline = min(time.monotonic() + timeout, self.pending_deadline) |
| 213 | + else: |
| 214 | + deadline = self.pending_deadline |
| 215 | + |
| 216 | + if not _IS_SYNC: |
| 217 | + # In async the reader task reads the whole message at once. |
| 218 | + # TODO: respect deadline |
| 219 | + await self.receive_message(None, True) |
| 220 | + else: |
| 221 | + # In sync we need to track the bytes left for the message. |
| 222 | + network_layer.receive_data(self.conn.get_conn, self.pending_byte, deadline) |
| 223 | + self.pending_response = False |
| 224 | + self.pending_bytes = 0 |
| 225 | + self.pending_deadline = 0.0 |
191 | 226 |
|
192 | 227 | def set_conn_timeout(self, timeout: Optional[float]) -> None: |
193 | 228 | """Cache last timeout to avoid duplicate calls to conn.settimeout.""" |
@@ -454,13 +489,17 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None: |
454 | 489 | except BaseException as error: |
455 | 490 | await self._raise_connection_failure(error) |
456 | 491 |
|
457 | | - async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]: |
| 492 | + async def receive_message( |
| 493 | + self, request_id: Optional[int], enable_pending: bool = False |
| 494 | + ) -> Union[_OpReply, _OpMsg]: |
458 | 495 | """Receive a raw BSON message or raise ConnectionFailure. |
459 | 496 |
|
460 | 497 | If any exception is raised, the socket is closed. |
461 | 498 | """ |
462 | 499 | try: |
463 | | - return await async_receive_message(self, request_id, self.max_message_size) |
| 500 | + return await async_receive_message( |
| 501 | + self, request_id, self.max_message_size, enable_pending |
| 502 | + ) |
464 | 503 | # Catch KeyboardInterrupt, CancelledError, etc. and cleanup. |
465 | 504 | except BaseException as error: |
466 | 505 | await self._raise_connection_failure(error) |
@@ -495,7 +534,9 @@ async def write_command( |
495 | 534 | :param msg: bytes, the command message. |
496 | 535 | """ |
497 | 536 | await self.send_message(msg, 0) |
498 | | - reply = await self.receive_message(request_id) |
| 537 | + reply = await self.receive_message( |
| 538 | + request_id, enable_pending=(_csot.get_timeout() is not None) |
| 539 | + ) |
499 | 540 | result = reply.command_response(codec_options) |
500 | 541 |
|
501 | 542 | # Raises NotPrimaryError or OperationFailure. |
@@ -635,7 +676,10 @@ async def _raise_connection_failure(self, error: BaseException) -> NoReturn: |
635 | 676 | reason = None |
636 | 677 | else: |
637 | 678 | reason = ConnectionClosedReason.ERROR |
638 | | - await self.close_conn(reason) |
| 679 | + |
| 680 | + # Pending connections should be placed back in the pool. |
| 681 | + if not self.pending_response: |
| 682 | + await self.close_conn(reason) |
639 | 683 | # SSLError from PyOpenSSL inherits directly from Exception. |
640 | 684 | if isinstance(error, (IOError, OSError, SSLError)): |
641 | 685 | details = _get_timeout_details(self.opts) |
@@ -1076,7 +1120,7 @@ async def checkout( |
1076 | 1120 |
|
1077 | 1121 | This method should always be used in a with-statement:: |
1078 | 1122 |
|
1079 | | - with pool.get_conn() as connection: |
| 1123 | + with pool.checkout() as connection: |
1080 | 1124 | connection.send_message(msg) |
1081 | 1125 | data = connection.receive_message(op_code, request_id) |
1082 | 1126 |
|
@@ -1388,6 +1432,7 @@ async def _perished(self, conn: AsyncConnection) -> bool: |
1388 | 1432 | pool, to keep performance reasonable - we can't avoid AutoReconnects |
1389 | 1433 | completely anyway. |
1390 | 1434 | """ |
| 1435 | + await conn.complete_pending() |
1391 | 1436 | idle_time_seconds = conn.idle_time_seconds() |
1392 | 1437 | # If socket is idle, open a new one. |
1393 | 1438 | if ( |
|
0 commit comments