efro.rpc
Remote procedure call related functionality.
1# Released under the MIT License. See LICENSE for details. 2# 3"""Remote procedure call related functionality.""" 4 5from __future__ import annotations 6 7import time 8import asyncio 9import logging 10import weakref 11from enum import Enum 12from functools import partial 13from collections import deque 14from dataclasses import dataclass 15from threading import current_thread 16from typing import TYPE_CHECKING, Annotated, assert_never 17 18from efro.error import ( 19 CommunicationError, 20 is_asyncio_streams_communication_error, 21) 22from efro.dataclassio import ( 23 dataclass_to_json, 24 dataclass_from_json, 25 ioprepped, 26 IOAttrs, 27) 28 29if TYPE_CHECKING: 30 from typing import Literal, Awaitable, Callable 31 32# Terminology: 33# Packet: A chunk of data consisting of a type and some type-dependent 34# payload. Even though we use streams we organize our transmission 35# into 'packets'. 36# Message: User data which we transmit using one or more packets. 37 38 39class _PacketType(Enum): 40 HANDSHAKE = 0 41 KEEPALIVE = 1 42 MESSAGE = 2 43 RESPONSE = 3 44 MESSAGE_BIG = 4 45 RESPONSE_BIG = 5 46 47 48_BYTE_ORDER: Literal['big'] = 'big' 49 50 51@ioprepped 52@dataclass 53class _PeerInfo: 54 # So we can gracefully evolve how we communicate in the future. 55 protocol: Annotated[int, IOAttrs('p')] 56 57 # How often we'll be sending out keepalives (in seconds). 58 keepalive_interval: Annotated[float, IOAttrs('k')] 59 60 61# Note: we are expected to be forward and backward compatible; we can 62# increment protocol freely and expect everyone else to still talk to us. 63# Likewise we should retain logic to communicate with older protocols. 64# Protocol history: 65# 1 - initial release 66# 2 - gained big (32-bit len val) package/response packets 67OUR_PROTOCOL = 2 68 69 70def ssl_stream_writer_underlying_transport_info( 71 writer: asyncio.StreamWriter, 72) -> str: 73 """For debugging SSL Stream connections; returns raw transport info.""" 74 # Note: accessing internals here so just returning info and not 75 # actual objs to reduce potential for breakage. 76 transport = getattr(writer, '_transport', None) 77 if transport is not None: 78 sslproto = getattr(transport, '_ssl_protocol', None) 79 if sslproto is not None: 80 raw_transport = getattr(sslproto, '_transport', None) 81 if raw_transport is not None: 82 return str(raw_transport) 83 return '(not found)' 84 85 86def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None: 87 """Ensure a writer is closed; hacky workaround for odd hang.""" 88 from threading import Thread 89 90 # Disabling for now.. 91 if bool(True): 92 return 93 94 # Hopefully can remove this in Python 3.11?... 95 # see issue with is_closing() below for more details. 96 transport = getattr(writer, '_transport', None) 97 if transport is not None: 98 sslproto = getattr(transport, '_ssl_protocol', None) 99 if sslproto is not None: 100 raw_transport = getattr(sslproto, '_transport', None) 101 if raw_transport is not None: 102 Thread( 103 target=partial( 104 _do_writer_force_close_check, weakref.ref(raw_transport) 105 ), 106 daemon=True, 107 ).start() 108 109 110def _do_writer_force_close_check(transport_weak: weakref.ref) -> None: 111 try: 112 # Attempt to bail as soon as the obj dies. 113 # If it hasn't done so by our timeout, force-kill it. 114 starttime = time.monotonic() 115 while time.monotonic() - starttime < 10.0: 116 time.sleep(0.1) 117 if transport_weak() is None: 118 return 119 transport = transport_weak() 120 if transport is not None: 121 logging.info('Forcing abort on stuck transport %s.', transport) 122 transport.abort() 123 except Exception: 124 logging.warning('Error in writer-force-close-check', exc_info=True) 125 126 127class _InFlightMessage: 128 """Represents a message that is out on the wire.""" 129 130 def __init__(self) -> None: 131 self._response: bytes | None = None 132 self._got_response = asyncio.Event() 133 self.wait_task = asyncio.create_task( 134 self._wait(), name='rpc in flight msg wait' 135 ) 136 137 async def _wait(self) -> bytes: 138 await self._got_response.wait() 139 assert self._response is not None 140 return self._response 141 142 def set_response(self, data: bytes) -> None: 143 """Set response data.""" 144 assert self._response is None 145 self._response = data 146 self._got_response.set() 147 148 149class _KeepaliveTimeoutError(Exception): 150 """Raised if we time out due to not receiving keepalives.""" 151 152 153class RPCEndpoint: 154 """Facilitates asynchronous multiplexed remote procedure calls. 155 156 Be aware that, while multiple calls can be in flight in either direction 157 simultaneously, packets are still sent serially in a single 158 stream. So excessively long messages/responses will delay all other 159 communication. If/when this becomes an issue we can look into breaking up 160 long messages into multiple packets. 161 """ 162 163 # Set to True on an instance to test keepalive failures. 164 test_suppress_keepalives: bool = False 165 166 # How long we should wait before giving up on a message by default. 167 # Note this includes processing time on the other end. 168 DEFAULT_MESSAGE_TIMEOUT = 60.0 169 170 # How often we send out keepalive packets by default. 171 DEFAULT_KEEPALIVE_INTERVAL = 10.73 # (avoid too regular of values) 172 173 # How long we can go without receiving a keepalive packet before we 174 # disconnect. 175 DEFAULT_KEEPALIVE_TIMEOUT = 30.0 176 177 def __init__( 178 self, 179 handle_raw_message_call: Callable[[bytes], Awaitable[bytes]], 180 reader: asyncio.StreamReader, 181 writer: asyncio.StreamWriter, 182 label: str, 183 *, 184 debug_print: bool = False, 185 debug_print_io: bool = False, 186 debug_print_call: Callable[[str], None] | None = None, 187 keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL, 188 keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT, 189 ) -> None: 190 self._handle_raw_message_call = handle_raw_message_call 191 self._reader = reader 192 self._writer = writer 193 self.debug_print = debug_print 194 self.debug_print_io = debug_print_io 195 if debug_print_call is None: 196 debug_print_call = print 197 self.debug_print_call: Callable[[str], None] = debug_print_call 198 self._label = label 199 self._thread = current_thread() 200 self._closing = False 201 self._did_wait_closed = False 202 self._event_loop = asyncio.get_running_loop() 203 self._out_packets = deque[bytes]() 204 self._have_out_packets = asyncio.Event() 205 self._run_called = False 206 self._peer_info: _PeerInfo | None = None 207 self._keepalive_interval = keepalive_interval 208 self._keepalive_timeout = keepalive_timeout 209 self._did_close_writer = False 210 self._did_wait_closed_writer = False 211 self._did_out_packets_buildup_warning = False 212 self._total_bytes_read = 0 213 self._create_time = time.monotonic() 214 215 # Need to hold weak-refs to these otherwise it creates dep-loops 216 # which keeps us alive. 217 self._tasks: list[asyncio.Task] = [] 218 219 # When we last got a keepalive or equivalent (time.monotonic value) 220 self._last_keepalive_receive_time: float | None = None 221 222 # (Start near the end to make sure our looping logic is sound). 223 self._next_message_id = 65530 224 225 self._in_flight_messages: dict[int, _InFlightMessage] = {} 226 227 if self.debug_print: 228 peername = self._writer.get_extra_info('peername') 229 self.debug_print_call( 230 f'{self._label}: connected to {peername} at {self._tm()}.' 231 ) 232 233 def __del__(self) -> None: 234 if self._run_called: 235 if not self._did_close_writer: 236 logging.warning( 237 'RPCEndpoint %d dying with run' 238 ' called but writer not closed (transport=%s).', 239 id(self), 240 ssl_stream_writer_underlying_transport_info(self._writer), 241 ) 242 elif not self._did_wait_closed_writer: 243 logging.warning( 244 'RPCEndpoint %d dying with run called' 245 ' but writer not wait-closed (transport=%s).', 246 id(self), 247 ssl_stream_writer_underlying_transport_info(self._writer), 248 ) 249 250 # Currently seeing rare issue where sockets don't go down; 251 # let's add a timer to force the issue until we can figure it out. 252 ssl_stream_writer_force_close_check(self._writer) 253 254 async def run(self) -> None: 255 """Run the endpoint until the connection is lost or closed. 256 257 Handles closing the provided reader/writer on close. 258 """ 259 try: 260 await self._do_run() 261 except asyncio.CancelledError: 262 # We aren't really designed to be cancelled so let's warn 263 # if it happens. 264 logging.warning( 265 'RPCEndpoint.run got CancelledError;' 266 ' want to try and avoid this.' 267 ) 268 raise 269 270 async def _do_run(self) -> None: 271 self._check_env() 272 273 if self._run_called: 274 raise RuntimeError('Run can be called only once per endpoint.') 275 self._run_called = True 276 277 core_tasks = [ 278 asyncio.create_task( 279 self._run_core_task('keepalive', self._run_keepalive_task()), 280 name='rpc keepalive', 281 ), 282 asyncio.create_task( 283 self._run_core_task('read', self._run_read_task()), 284 name='rpc read', 285 ), 286 asyncio.create_task( 287 self._run_core_task('write', self._run_write_task()), 288 name='rpc write', 289 ), 290 ] 291 self._tasks += core_tasks 292 293 # Run our core tasks until they all complete. 294 results = await asyncio.gather(*core_tasks, return_exceptions=True) 295 296 # Core tasks should handle their own errors; the only ones 297 # we expect to bubble up are CancelledError. 298 for result in results: 299 # We want to know if any errors happened aside from CancelledError 300 # (which are BaseExceptions, not Exception). 301 if isinstance(result, Exception): 302 logging.warning( 303 'Got unexpected error from %s core task: %s', 304 self._label, 305 result, 306 ) 307 308 if not all(task.done() for task in core_tasks): 309 logging.warning( 310 'RPCEndpoint %d: not all core tasks marked done after gather.', 311 id(self), 312 ) 313 314 # Shut ourself down. 315 try: 316 self.close() 317 await self.wait_closed() 318 except Exception: 319 logging.exception('Error closing %s.', self._label) 320 321 if self.debug_print: 322 self.debug_print_call(f'{self._label}: finished.') 323 324 def send_message( 325 self, 326 message: bytes, 327 timeout: float | None = None, 328 close_on_error: bool = True, 329 ) -> Awaitable[bytes]: 330 """Send a message to the peer and return a response. 331 332 If timeout is not provided, the default will be used. 333 Raises a CommunicationError if the round trip is not completed 334 for any reason. 335 336 By default, the entire endpoint will go down in the case of 337 errors. This allows messages to be treated as 'reliable' with 338 respect to a given endpoint. Pass close_on_error=False to 339 override this for a particular message. 340 """ 341 # Note: This call is synchronous so that the first part of it 342 # (enqueueing outgoing messages) happens synchronously. If it were 343 # a pure async call it could be possible for send order to vary 344 # based on how the async tasks get processed. 345 346 if self.debug_print_io: 347 self.debug_print_call( 348 f'{self._label}: sending message of size {len(message)}' 349 f' at {self._tm()}.' 350 ) 351 352 self._check_env() 353 354 if self._closing: 355 raise CommunicationError('Endpoint is closed.') 356 357 if self.debug_print_io: 358 self.debug_print_call( 359 f'{self._label}: have peerinfo? {self._peer_info is not None}.' 360 ) 361 362 # message_id is a 16 bit looping value. 363 message_id = self._next_message_id 364 self._next_message_id = (self._next_message_id + 1) % 65536 365 366 if self.debug_print_io: 367 self.debug_print_call( 368 f'{self._label}: will enqueue at {self._tm()}.' 369 ) 370 371 # FIXME - should handle backpressure (waiting here if there are 372 # enough packets already enqueued). 373 374 if len(message) > 65535: 375 # Payload consists of type (1b), message_id (2b), 376 # len (4b), and data. 377 self._enqueue_outgoing_packet( 378 _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER) 379 + message_id.to_bytes(2, _BYTE_ORDER) 380 + len(message).to_bytes(4, _BYTE_ORDER) 381 + message 382 ) 383 else: 384 # Payload consists of type (1b), message_id (2b), 385 # len (2b), and data. 386 self._enqueue_outgoing_packet( 387 _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER) 388 + message_id.to_bytes(2, _BYTE_ORDER) 389 + len(message).to_bytes(2, _BYTE_ORDER) 390 + message 391 ) 392 393 if self.debug_print_io: 394 self.debug_print_call( 395 f'{self._label}: enqueued message of size {len(message)}' 396 f' at {self._tm()}.' 397 ) 398 399 # Make an entry so we know this message is out there. 400 assert message_id not in self._in_flight_messages 401 msgobj = self._in_flight_messages[message_id] = _InFlightMessage() 402 403 # Also add its task to our list so we properly cancel it if we die. 404 self._prune_tasks() # Keep our list from filling with dead tasks. 405 self._tasks.append(msgobj.wait_task) 406 407 # Note: we always want to incorporate a timeout. Individual 408 # messages may hang or error on the other end and this ensures 409 # we won't build up lots of zombie tasks waiting around for 410 # responses that will never arrive. 411 if timeout is None: 412 timeout = self.DEFAULT_MESSAGE_TIMEOUT 413 assert timeout is not None 414 415 bytes_awaitable = msgobj.wait_task 416 417 # Now complete the send asynchronously. 418 return self._send_message( 419 message, timeout, close_on_error, bytes_awaitable, message_id 420 ) 421 422 async def _send_message( 423 self, 424 message: bytes, 425 timeout: float | None, 426 close_on_error: bool, 427 bytes_awaitable: asyncio.Task[bytes], 428 message_id: int, 429 ) -> bytes: 430 # pylint: disable=too-many-positional-arguments 431 # We need to know their protocol, so if we haven't gotten a handshake 432 # from them yet, just wait. 433 while self._peer_info is None: 434 await asyncio.sleep(0.01) 435 assert self._peer_info is not None 436 437 if self._peer_info.protocol == 1: 438 if len(message) > 65535: 439 raise RuntimeError('Message cannot be larger than 65535 bytes') 440 441 try: 442 return await asyncio.wait_for(bytes_awaitable, timeout=timeout) 443 except asyncio.CancelledError as exc: 444 # Question: we assume this means the above wait_for() was 445 # cancelled; how do we distinguish between this and *us* being 446 # cancelled though? 447 if self.debug_print: 448 self.debug_print_call( 449 f'{self._label}: message {message_id} was cancelled.' 450 ) 451 if close_on_error: 452 self.close() 453 454 raise CommunicationError() from exc 455 except Exception as exc: 456 # If our timer timed-out or anything else went wrong with 457 # the stream, lump it in as a communication error. 458 if isinstance( 459 exc, asyncio.TimeoutError 460 ) or is_asyncio_streams_communication_error(exc): 461 if self.debug_print: 462 self.debug_print_call( 463 f'{self._label}: got {type(exc)} sending message' 464 f' {message_id}; raising CommunicationError.' 465 ) 466 467 # Stop waiting on the response. 468 bytes_awaitable.cancel() 469 470 # Remove the record of this message. 471 del self._in_flight_messages[message_id] 472 473 if close_on_error: 474 self.close() 475 476 # Let the user know something went wrong. 477 raise CommunicationError() from exc 478 479 # Some unexpected error; let it bubble up. 480 raise 481 482 def close(self) -> None: 483 """I said seagulls; mmmm; stop it now.""" 484 self._check_env() 485 486 if self._closing: 487 return 488 489 if self.debug_print: 490 self.debug_print_call(f'{self._label}: closing...') 491 492 self._closing = True 493 494 # Kill all of our in-flight tasks. 495 if self.debug_print: 496 self.debug_print_call(f'{self._label}: cancelling tasks...') 497 for task in self._get_live_tasks(): 498 task.cancel() 499 500 # Close our writer. 501 assert not self._did_close_writer 502 if self.debug_print: 503 self.debug_print_call(f'{self._label}: closing writer...') 504 self._writer.close() 505 self._did_close_writer = True 506 507 # We don't need this anymore and it is likely to be creating a 508 # dependency loop. 509 del self._handle_raw_message_call 510 511 def is_closing(self) -> bool: 512 """Have we begun the process of closing?""" 513 return self._closing 514 515 async def wait_closed(self) -> None: 516 """I said seagulls; mmmm; stop it now. 517 518 Wait for the endpoint to finish closing. This is called by run() 519 so generally does not need to be explicitly called. 520 """ 521 # pylint: disable=too-many-branches 522 self._check_env() 523 524 # Make sure we only *enter* this call once. 525 if self._did_wait_closed: 526 return 527 self._did_wait_closed = True 528 529 if not self._closing: 530 raise RuntimeError('Must be called after close()') 531 532 if not self._did_close_writer: 533 logging.warning( 534 'RPCEndpoint wait_closed() called but never' 535 ' explicitly closed writer.' 536 ) 537 538 live_tasks = self._get_live_tasks() 539 540 # Don't need our task list anymore; this should 541 # break any cyclical refs from tasks referring to us. 542 self._tasks = [] 543 544 if self.debug_print: 545 self.debug_print_call( 546 f'{self._label}: waiting for tasks to finish: ' 547 f' ({live_tasks=})...' 548 ) 549 550 # Wait for all of our in-flight tasks to wrap up. 551 results = await asyncio.gather(*live_tasks, return_exceptions=True) 552 for result in results: 553 # We want to know if any errors happened aside from CancelledError 554 # (which are BaseExceptions, not Exception). 555 if isinstance(result, Exception): 556 logging.warning( 557 'Got unexpected error cleaning up %s task: %s', 558 self._label, 559 result, 560 ) 561 562 if not all(task.done() for task in live_tasks): 563 logging.warning( 564 'RPCEndpoint %d: not all live tasks marked done after gather.', 565 id(self), 566 ) 567 568 if self.debug_print: 569 self.debug_print_call( 570 f'{self._label}: tasks finished; waiting for writer close...' 571 ) 572 573 # Now wait for our writer to finish going down. 574 # When we close our writer it generally triggers errors 575 # in our current blocked read/writes. However that same 576 # error is also sometimes returned from _writer.wait_closed(). 577 # See connection_lost() in asyncio/streams.py to see why. 578 # So let's silently ignore it when that happens. 579 assert self._writer.is_closing() 580 try: 581 # It seems that as of Python 3.9.x it is possible for this to hang 582 # indefinitely. See https://github.com/python/cpython/issues/83939 583 # It sounds like this should be fixed in 3.11 but for now just 584 # forcing the issue with a timeout here. 585 await asyncio.wait_for( 586 self._writer.wait_closed(), 587 # timeout=60.0 * 6.0, 588 timeout=30.0, 589 ) 590 except asyncio.TimeoutError: 591 logging.info( 592 'Timeout on _writer.wait_closed() for %s rpc (transport=%s).', 593 self._label, 594 ssl_stream_writer_underlying_transport_info(self._writer), 595 ) 596 if self.debug_print: 597 self.debug_print_call( 598 f'{self._label}: got timeout in _writer.wait_closed();' 599 ' This should be fixed in future Python versions.' 600 ) 601 except Exception as exc: 602 if not self._is_expected_connection_error(exc): 603 logging.exception('Error closing _writer for %s.', self._label) 604 else: 605 if self.debug_print: 606 self.debug_print_call( 607 f'{self._label}: silently ignoring error in' 608 f' _writer.wait_closed(): {exc}.' 609 ) 610 except asyncio.CancelledError: 611 logging.warning( 612 'RPCEndpoint.wait_closed()' 613 ' got asyncio.CancelledError; not expected.' 614 ) 615 raise 616 assert not self._did_wait_closed_writer 617 self._did_wait_closed_writer = True 618 619 def _tm(self) -> str: 620 """Simple readable time value for debugging.""" 621 tval = time.monotonic() % 100.0 622 return f'{tval:.2f}' 623 624 async def _run_read_task(self) -> None: 625 """Read from the peer.""" 626 self._check_env() 627 assert self._peer_info is None 628 629 # Bug fix: if we don't have this set we will never time out 630 # if we never receive any data from the other end. 631 self._last_keepalive_receive_time = time.monotonic() 632 633 # The first thing they should send us is their handshake; then 634 # we'll know if/how we can talk to them. 635 mlen = await self._read_int_32() 636 message = await self._reader.readexactly(mlen) 637 self._total_bytes_read += mlen 638 self._peer_info = dataclass_from_json(_PeerInfo, message.decode()) 639 self._last_keepalive_receive_time = time.monotonic() 640 if self.debug_print: 641 self.debug_print_call( 642 f'{self._label}: received handshake at {self._tm()}.' 643 ) 644 645 # Now just sit and handle stuff as it comes in. 646 while True: 647 if self._closing: 648 return 649 650 # Read message type. 651 mtype = _PacketType(await self._read_int_8()) 652 if mtype is _PacketType.HANDSHAKE: 653 raise RuntimeError('Got multiple handshakes') 654 655 if mtype is _PacketType.KEEPALIVE: 656 if self.debug_print_io: 657 self.debug_print_call( 658 f'{self._label}: received keepalive' 659 f' at {self._tm()}.' 660 ) 661 self._last_keepalive_receive_time = time.monotonic() 662 663 elif mtype is _PacketType.MESSAGE: 664 await self._handle_message_packet(big=False) 665 666 elif mtype is _PacketType.MESSAGE_BIG: 667 await self._handle_message_packet(big=True) 668 669 elif mtype is _PacketType.RESPONSE: 670 await self._handle_response_packet(big=False) 671 672 elif mtype is _PacketType.RESPONSE_BIG: 673 await self._handle_response_packet(big=True) 674 675 else: 676 assert_never(mtype) 677 678 async def _handle_message_packet(self, big: bool) -> None: 679 assert self._peer_info is not None 680 msgid = await self._read_int_16() 681 if big: 682 msglen = await self._read_int_32() 683 else: 684 msglen = await self._read_int_16() 685 msg = await self._reader.readexactly(msglen) 686 self._total_bytes_read += msglen 687 if self.debug_print_io: 688 self.debug_print_call( 689 f'{self._label}: received message {msgid}' 690 f' of size {msglen} at {self._tm()}.' 691 ) 692 693 # Create a message-task to handle this message and return 694 # a response (we don't want to block while that happens). 695 assert not self._closing 696 self._prune_tasks() # Keep from filling with dead tasks. 697 self._tasks.append( 698 asyncio.create_task( 699 self._handle_raw_message(message_id=msgid, message=msg), 700 name='efro rpc message handle', 701 ) 702 ) 703 if self.debug_print: 704 self.debug_print_call( 705 f'{self._label}: done handling message at {self._tm()}.' 706 ) 707 708 async def _handle_response_packet(self, big: bool) -> None: 709 assert self._peer_info is not None 710 msgid = await self._read_int_16() 711 # Protocol 2 gained 32 bit data lengths. 712 if big: 713 rsplen = await self._read_int_32() 714 else: 715 rsplen = await self._read_int_16() 716 if self.debug_print_io: 717 self.debug_print_call( 718 f'{self._label}: received response {msgid}' 719 f' of size {rsplen} at {self._tm()}.' 720 ) 721 rsp = await self._reader.readexactly(rsplen) 722 self._total_bytes_read += rsplen 723 msgobj = self._in_flight_messages.get(msgid) 724 if msgobj is None: 725 # It's possible for us to get a response to a message 726 # that has timed out. In this case we will have no local 727 # record of it. 728 if self.debug_print: 729 self.debug_print_call( 730 f'{self._label}: got response for nonexistent' 731 f' message id {msgid}; perhaps it timed out?' 732 ) 733 else: 734 msgobj.set_response(rsp) 735 736 async def _run_write_task(self) -> None: 737 """Write to the peer.""" 738 739 self._check_env() 740 741 # Introduce ourself so our peer knows how it can talk to us. 742 data = dataclass_to_json( 743 _PeerInfo( 744 protocol=OUR_PROTOCOL, 745 keepalive_interval=self._keepalive_interval, 746 ) 747 ).encode() 748 self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data) 749 750 # Now just write out-messages as they come in. 751 while True: 752 # Wait until some data comes in. 753 await self._have_out_packets.wait() 754 755 assert self._out_packets 756 data = self._out_packets.popleft() 757 758 # Important: only clear this once all packets are sent. 759 if not self._out_packets: 760 self._have_out_packets.clear() 761 762 self._writer.write(data) 763 764 # This should keep our writer from buffering huge amounts 765 # of outgoing data. We must remember though that we also 766 # need to prevent _out_packets from growing too large and 767 # that part's on us. 768 await self._writer.drain() 769 770 # For now we're not applying backpressure, but let's make 771 # noise if this gets out of hand. 772 if len(self._out_packets) > 200: 773 if not self._did_out_packets_buildup_warning: 774 logging.warning( 775 '_out_packets building up too' 776 ' much on RPCEndpoint %s.', 777 id(self), 778 ) 779 self._did_out_packets_buildup_warning = True 780 781 async def _run_keepalive_task(self) -> None: 782 """Send periodic keepalive packets.""" 783 self._check_env() 784 785 # We explicitly send our own keepalive packets so we can stay 786 # more on top of the connection state and possibly decide to 787 # kill it when contact is lost more quickly than the OS would 788 # do itself (or at least keep the user informed that the 789 # connection is lagging). It sounds like we could have the TCP 790 # layer do this sort of thing itself but that might be 791 # OS-specific so gonna go this way for now. 792 while True: 793 assert not self._closing 794 await asyncio.sleep(self._keepalive_interval) 795 if not self.test_suppress_keepalives: 796 self._enqueue_outgoing_packet( 797 _PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER) 798 ) 799 800 # Also go ahead and handle dropping the connection if we 801 # haven't heard from the peer in a while. 802 # NOTE: perhaps we want to do something more exact than 803 # this which only checks once per keepalive-interval?.. 804 now = time.monotonic() 805 if ( 806 self._last_keepalive_receive_time is not None 807 and now - self._last_keepalive_receive_time 808 > self._keepalive_timeout 809 ): 810 if self.debug_print: 811 since = now - self._last_keepalive_receive_time 812 self.debug_print_call( 813 f'{self._label}: reached keepalive time-out' 814 f' ({since:.1f}s).' 815 ) 816 raise _KeepaliveTimeoutError() 817 818 async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None: 819 try: 820 await call 821 except Exception as exc: 822 # We expect connection errors to put us here, but make noise 823 # if something else does. 824 if not self._is_expected_connection_error(exc): 825 logging.exception( 826 'Unexpected error in rpc %s %s task' 827 ' (age=%.1f, total_bytes_read=%d).', 828 self._label, 829 tasklabel, 830 time.monotonic() - self._create_time, 831 self._total_bytes_read, 832 ) 833 else: 834 if self.debug_print: 835 self.debug_print_call( 836 f'{self._label}: {tasklabel} task will exit cleanly' 837 f' due to {exc!r}.' 838 ) 839 finally: 840 # Any core task exiting triggers shutdown. 841 if self.debug_print: 842 self.debug_print_call( 843 f'{self._label}: {tasklabel} task exiting...' 844 ) 845 self.close() 846 847 async def _handle_raw_message( 848 self, message_id: int, message: bytes 849 ) -> None: 850 try: 851 response = await self._handle_raw_message_call(message) 852 except Exception: 853 # We expect local message handler to always succeed. 854 # If that doesn't happen, make a fuss so we know to fix it. 855 # The other end will simply never get a response to this 856 # message. 857 logging.exception('Error handling raw rpc message') 858 return 859 860 assert self._peer_info is not None 861 862 if self._peer_info.protocol == 1: 863 if len(response) > 65535: 864 raise RuntimeError('Response cannot be larger than 65535 bytes') 865 866 # Now send back our response. 867 # Payload consists of type (1b), msgid (2b), len (2b), and data. 868 if len(response) > 65535: 869 self._enqueue_outgoing_packet( 870 _PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER) 871 + message_id.to_bytes(2, _BYTE_ORDER) 872 + len(response).to_bytes(4, _BYTE_ORDER) 873 + response 874 ) 875 else: 876 self._enqueue_outgoing_packet( 877 _PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER) 878 + message_id.to_bytes(2, _BYTE_ORDER) 879 + len(response).to_bytes(2, _BYTE_ORDER) 880 + response 881 ) 882 883 async def _read_int_8(self) -> int: 884 out = int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER) 885 self._total_bytes_read += 1 886 return out 887 888 async def _read_int_16(self) -> int: 889 out = int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER) 890 self._total_bytes_read += 2 891 return out 892 893 async def _read_int_32(self) -> int: 894 out = int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER) 895 self._total_bytes_read += 4 896 return out 897 898 @classmethod 899 def _is_expected_connection_error(cls, exc: Exception) -> bool: 900 """Stuff we expect to end our connection in normal circumstances.""" 901 902 if isinstance(exc, _KeepaliveTimeoutError): 903 return True 904 905 return is_asyncio_streams_communication_error(exc) 906 907 def _check_env(self) -> None: 908 # I was seeing that asyncio stuff wasn't working as expected if 909 # created in one thread and used in another (and have verified 910 # that this is part of the design), so let's enforce a single 911 # thread for all use of an instance. 912 if current_thread() is not self._thread: 913 raise RuntimeError( 914 'This must be called from the same thread' 915 ' that the endpoint was created in.' 916 ) 917 918 # This should always be the case if thread is the same. 919 assert asyncio.get_running_loop() is self._event_loop 920 921 def _enqueue_outgoing_packet(self, data: bytes) -> None: 922 """Enqueue a raw packet to be sent. Must be called from our loop.""" 923 self._check_env() 924 925 if self.debug_print_io: 926 self.debug_print_call( 927 f'{self._label}: enqueueing outgoing packet' 928 f' {data[:50]!r} at {self._tm()}.' 929 ) 930 931 # Add the data and let our write task know about it. 932 self._out_packets.append(data) 933 self._have_out_packets.set() 934 935 def _prune_tasks(self) -> None: 936 self._tasks = self._get_live_tasks() 937 938 def _get_live_tasks(self) -> list[asyncio.Task]: 939 return [t for t in self._tasks if not t.done()]
71def ssl_stream_writer_underlying_transport_info( 72 writer: asyncio.StreamWriter, 73) -> str: 74 """For debugging SSL Stream connections; returns raw transport info.""" 75 # Note: accessing internals here so just returning info and not 76 # actual objs to reduce potential for breakage. 77 transport = getattr(writer, '_transport', None) 78 if transport is not None: 79 sslproto = getattr(transport, '_ssl_protocol', None) 80 if sslproto is not None: 81 raw_transport = getattr(sslproto, '_transport', None) 82 if raw_transport is not None: 83 return str(raw_transport) 84 return '(not found)'
For debugging SSL Stream connections; returns raw transport info.
87def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None: 88 """Ensure a writer is closed; hacky workaround for odd hang.""" 89 from threading import Thread 90 91 # Disabling for now.. 92 if bool(True): 93 return 94 95 # Hopefully can remove this in Python 3.11?... 96 # see issue with is_closing() below for more details. 97 transport = getattr(writer, '_transport', None) 98 if transport is not None: 99 sslproto = getattr(transport, '_ssl_protocol', None) 100 if sslproto is not None: 101 raw_transport = getattr(sslproto, '_transport', None) 102 if raw_transport is not None: 103 Thread( 104 target=partial( 105 _do_writer_force_close_check, weakref.ref(raw_transport) 106 ), 107 daemon=True, 108 ).start()
Ensure a writer is closed; hacky workaround for odd hang.
154class RPCEndpoint: 155 """Facilitates asynchronous multiplexed remote procedure calls. 156 157 Be aware that, while multiple calls can be in flight in either direction 158 simultaneously, packets are still sent serially in a single 159 stream. So excessively long messages/responses will delay all other 160 communication. If/when this becomes an issue we can look into breaking up 161 long messages into multiple packets. 162 """ 163 164 # Set to True on an instance to test keepalive failures. 165 test_suppress_keepalives: bool = False 166 167 # How long we should wait before giving up on a message by default. 168 # Note this includes processing time on the other end. 169 DEFAULT_MESSAGE_TIMEOUT = 60.0 170 171 # How often we send out keepalive packets by default. 172 DEFAULT_KEEPALIVE_INTERVAL = 10.73 # (avoid too regular of values) 173 174 # How long we can go without receiving a keepalive packet before we 175 # disconnect. 176 DEFAULT_KEEPALIVE_TIMEOUT = 30.0 177 178 def __init__( 179 self, 180 handle_raw_message_call: Callable[[bytes], Awaitable[bytes]], 181 reader: asyncio.StreamReader, 182 writer: asyncio.StreamWriter, 183 label: str, 184 *, 185 debug_print: bool = False, 186 debug_print_io: bool = False, 187 debug_print_call: Callable[[str], None] | None = None, 188 keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL, 189 keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT, 190 ) -> None: 191 self._handle_raw_message_call = handle_raw_message_call 192 self._reader = reader 193 self._writer = writer 194 self.debug_print = debug_print 195 self.debug_print_io = debug_print_io 196 if debug_print_call is None: 197 debug_print_call = print 198 self.debug_print_call: Callable[[str], None] = debug_print_call 199 self._label = label 200 self._thread = current_thread() 201 self._closing = False 202 self._did_wait_closed = False 203 self._event_loop = asyncio.get_running_loop() 204 self._out_packets = deque[bytes]() 205 self._have_out_packets = asyncio.Event() 206 self._run_called = False 207 self._peer_info: _PeerInfo | None = None 208 self._keepalive_interval = keepalive_interval 209 self._keepalive_timeout = keepalive_timeout 210 self._did_close_writer = False 211 self._did_wait_closed_writer = False 212 self._did_out_packets_buildup_warning = False 213 self._total_bytes_read = 0 214 self._create_time = time.monotonic() 215 216 # Need to hold weak-refs to these otherwise it creates dep-loops 217 # which keeps us alive. 218 self._tasks: list[asyncio.Task] = [] 219 220 # When we last got a keepalive or equivalent (time.monotonic value) 221 self._last_keepalive_receive_time: float | None = None 222 223 # (Start near the end to make sure our looping logic is sound). 224 self._next_message_id = 65530 225 226 self._in_flight_messages: dict[int, _InFlightMessage] = {} 227 228 if self.debug_print: 229 peername = self._writer.get_extra_info('peername') 230 self.debug_print_call( 231 f'{self._label}: connected to {peername} at {self._tm()}.' 232 ) 233 234 def __del__(self) -> None: 235 if self._run_called: 236 if not self._did_close_writer: 237 logging.warning( 238 'RPCEndpoint %d dying with run' 239 ' called but writer not closed (transport=%s).', 240 id(self), 241 ssl_stream_writer_underlying_transport_info(self._writer), 242 ) 243 elif not self._did_wait_closed_writer: 244 logging.warning( 245 'RPCEndpoint %d dying with run called' 246 ' but writer not wait-closed (transport=%s).', 247 id(self), 248 ssl_stream_writer_underlying_transport_info(self._writer), 249 ) 250 251 # Currently seeing rare issue where sockets don't go down; 252 # let's add a timer to force the issue until we can figure it out. 253 ssl_stream_writer_force_close_check(self._writer) 254 255 async def run(self) -> None: 256 """Run the endpoint until the connection is lost or closed. 257 258 Handles closing the provided reader/writer on close. 259 """ 260 try: 261 await self._do_run() 262 except asyncio.CancelledError: 263 # We aren't really designed to be cancelled so let's warn 264 # if it happens. 265 logging.warning( 266 'RPCEndpoint.run got CancelledError;' 267 ' want to try and avoid this.' 268 ) 269 raise 270 271 async def _do_run(self) -> None: 272 self._check_env() 273 274 if self._run_called: 275 raise RuntimeError('Run can be called only once per endpoint.') 276 self._run_called = True 277 278 core_tasks = [ 279 asyncio.create_task( 280 self._run_core_task('keepalive', self._run_keepalive_task()), 281 name='rpc keepalive', 282 ), 283 asyncio.create_task( 284 self._run_core_task('read', self._run_read_task()), 285 name='rpc read', 286 ), 287 asyncio.create_task( 288 self._run_core_task('write', self._run_write_task()), 289 name='rpc write', 290 ), 291 ] 292 self._tasks += core_tasks 293 294 # Run our core tasks until they all complete. 295 results = await asyncio.gather(*core_tasks, return_exceptions=True) 296 297 # Core tasks should handle their own errors; the only ones 298 # we expect to bubble up are CancelledError. 299 for result in results: 300 # We want to know if any errors happened aside from CancelledError 301 # (which are BaseExceptions, not Exception). 302 if isinstance(result, Exception): 303 logging.warning( 304 'Got unexpected error from %s core task: %s', 305 self._label, 306 result, 307 ) 308 309 if not all(task.done() for task in core_tasks): 310 logging.warning( 311 'RPCEndpoint %d: not all core tasks marked done after gather.', 312 id(self), 313 ) 314 315 # Shut ourself down. 316 try: 317 self.close() 318 await self.wait_closed() 319 except Exception: 320 logging.exception('Error closing %s.', self._label) 321 322 if self.debug_print: 323 self.debug_print_call(f'{self._label}: finished.') 324 325 def send_message( 326 self, 327 message: bytes, 328 timeout: float | None = None, 329 close_on_error: bool = True, 330 ) -> Awaitable[bytes]: 331 """Send a message to the peer and return a response. 332 333 If timeout is not provided, the default will be used. 334 Raises a CommunicationError if the round trip is not completed 335 for any reason. 336 337 By default, the entire endpoint will go down in the case of 338 errors. This allows messages to be treated as 'reliable' with 339 respect to a given endpoint. Pass close_on_error=False to 340 override this for a particular message. 341 """ 342 # Note: This call is synchronous so that the first part of it 343 # (enqueueing outgoing messages) happens synchronously. If it were 344 # a pure async call it could be possible for send order to vary 345 # based on how the async tasks get processed. 346 347 if self.debug_print_io: 348 self.debug_print_call( 349 f'{self._label}: sending message of size {len(message)}' 350 f' at {self._tm()}.' 351 ) 352 353 self._check_env() 354 355 if self._closing: 356 raise CommunicationError('Endpoint is closed.') 357 358 if self.debug_print_io: 359 self.debug_print_call( 360 f'{self._label}: have peerinfo? {self._peer_info is not None}.' 361 ) 362 363 # message_id is a 16 bit looping value. 364 message_id = self._next_message_id 365 self._next_message_id = (self._next_message_id + 1) % 65536 366 367 if self.debug_print_io: 368 self.debug_print_call( 369 f'{self._label}: will enqueue at {self._tm()}.' 370 ) 371 372 # FIXME - should handle backpressure (waiting here if there are 373 # enough packets already enqueued). 374 375 if len(message) > 65535: 376 # Payload consists of type (1b), message_id (2b), 377 # len (4b), and data. 378 self._enqueue_outgoing_packet( 379 _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER) 380 + message_id.to_bytes(2, _BYTE_ORDER) 381 + len(message).to_bytes(4, _BYTE_ORDER) 382 + message 383 ) 384 else: 385 # Payload consists of type (1b), message_id (2b), 386 # len (2b), and data. 387 self._enqueue_outgoing_packet( 388 _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER) 389 + message_id.to_bytes(2, _BYTE_ORDER) 390 + len(message).to_bytes(2, _BYTE_ORDER) 391 + message 392 ) 393 394 if self.debug_print_io: 395 self.debug_print_call( 396 f'{self._label}: enqueued message of size {len(message)}' 397 f' at {self._tm()}.' 398 ) 399 400 # Make an entry so we know this message is out there. 401 assert message_id not in self._in_flight_messages 402 msgobj = self._in_flight_messages[message_id] = _InFlightMessage() 403 404 # Also add its task to our list so we properly cancel it if we die. 405 self._prune_tasks() # Keep our list from filling with dead tasks. 406 self._tasks.append(msgobj.wait_task) 407 408 # Note: we always want to incorporate a timeout. Individual 409 # messages may hang or error on the other end and this ensures 410 # we won't build up lots of zombie tasks waiting around for 411 # responses that will never arrive. 412 if timeout is None: 413 timeout = self.DEFAULT_MESSAGE_TIMEOUT 414 assert timeout is not None 415 416 bytes_awaitable = msgobj.wait_task 417 418 # Now complete the send asynchronously. 419 return self._send_message( 420 message, timeout, close_on_error, bytes_awaitable, message_id 421 ) 422 423 async def _send_message( 424 self, 425 message: bytes, 426 timeout: float | None, 427 close_on_error: bool, 428 bytes_awaitable: asyncio.Task[bytes], 429 message_id: int, 430 ) -> bytes: 431 # pylint: disable=too-many-positional-arguments 432 # We need to know their protocol, so if we haven't gotten a handshake 433 # from them yet, just wait. 434 while self._peer_info is None: 435 await asyncio.sleep(0.01) 436 assert self._peer_info is not None 437 438 if self._peer_info.protocol == 1: 439 if len(message) > 65535: 440 raise RuntimeError('Message cannot be larger than 65535 bytes') 441 442 try: 443 return await asyncio.wait_for(bytes_awaitable, timeout=timeout) 444 except asyncio.CancelledError as exc: 445 # Question: we assume this means the above wait_for() was 446 # cancelled; how do we distinguish between this and *us* being 447 # cancelled though? 448 if self.debug_print: 449 self.debug_print_call( 450 f'{self._label}: message {message_id} was cancelled.' 451 ) 452 if close_on_error: 453 self.close() 454 455 raise CommunicationError() from exc 456 except Exception as exc: 457 # If our timer timed-out or anything else went wrong with 458 # the stream, lump it in as a communication error. 459 if isinstance( 460 exc, asyncio.TimeoutError 461 ) or is_asyncio_streams_communication_error(exc): 462 if self.debug_print: 463 self.debug_print_call( 464 f'{self._label}: got {type(exc)} sending message' 465 f' {message_id}; raising CommunicationError.' 466 ) 467 468 # Stop waiting on the response. 469 bytes_awaitable.cancel() 470 471 # Remove the record of this message. 472 del self._in_flight_messages[message_id] 473 474 if close_on_error: 475 self.close() 476 477 # Let the user know something went wrong. 478 raise CommunicationError() from exc 479 480 # Some unexpected error; let it bubble up. 481 raise 482 483 def close(self) -> None: 484 """I said seagulls; mmmm; stop it now.""" 485 self._check_env() 486 487 if self._closing: 488 return 489 490 if self.debug_print: 491 self.debug_print_call(f'{self._label}: closing...') 492 493 self._closing = True 494 495 # Kill all of our in-flight tasks. 496 if self.debug_print: 497 self.debug_print_call(f'{self._label}: cancelling tasks...') 498 for task in self._get_live_tasks(): 499 task.cancel() 500 501 # Close our writer. 502 assert not self._did_close_writer 503 if self.debug_print: 504 self.debug_print_call(f'{self._label}: closing writer...') 505 self._writer.close() 506 self._did_close_writer = True 507 508 # We don't need this anymore and it is likely to be creating a 509 # dependency loop. 510 del self._handle_raw_message_call 511 512 def is_closing(self) -> bool: 513 """Have we begun the process of closing?""" 514 return self._closing 515 516 async def wait_closed(self) -> None: 517 """I said seagulls; mmmm; stop it now. 518 519 Wait for the endpoint to finish closing. This is called by run() 520 so generally does not need to be explicitly called. 521 """ 522 # pylint: disable=too-many-branches 523 self._check_env() 524 525 # Make sure we only *enter* this call once. 526 if self._did_wait_closed: 527 return 528 self._did_wait_closed = True 529 530 if not self._closing: 531 raise RuntimeError('Must be called after close()') 532 533 if not self._did_close_writer: 534 logging.warning( 535 'RPCEndpoint wait_closed() called but never' 536 ' explicitly closed writer.' 537 ) 538 539 live_tasks = self._get_live_tasks() 540 541 # Don't need our task list anymore; this should 542 # break any cyclical refs from tasks referring to us. 543 self._tasks = [] 544 545 if self.debug_print: 546 self.debug_print_call( 547 f'{self._label}: waiting for tasks to finish: ' 548 f' ({live_tasks=})...' 549 ) 550 551 # Wait for all of our in-flight tasks to wrap up. 552 results = await asyncio.gather(*live_tasks, return_exceptions=True) 553 for result in results: 554 # We want to know if any errors happened aside from CancelledError 555 # (which are BaseExceptions, not Exception). 556 if isinstance(result, Exception): 557 logging.warning( 558 'Got unexpected error cleaning up %s task: %s', 559 self._label, 560 result, 561 ) 562 563 if not all(task.done() for task in live_tasks): 564 logging.warning( 565 'RPCEndpoint %d: not all live tasks marked done after gather.', 566 id(self), 567 ) 568 569 if self.debug_print: 570 self.debug_print_call( 571 f'{self._label}: tasks finished; waiting for writer close...' 572 ) 573 574 # Now wait for our writer to finish going down. 575 # When we close our writer it generally triggers errors 576 # in our current blocked read/writes. However that same 577 # error is also sometimes returned from _writer.wait_closed(). 578 # See connection_lost() in asyncio/streams.py to see why. 579 # So let's silently ignore it when that happens. 580 assert self._writer.is_closing() 581 try: 582 # It seems that as of Python 3.9.x it is possible for this to hang 583 # indefinitely. See https://github.com/python/cpython/issues/83939 584 # It sounds like this should be fixed in 3.11 but for now just 585 # forcing the issue with a timeout here. 586 await asyncio.wait_for( 587 self._writer.wait_closed(), 588 # timeout=60.0 * 6.0, 589 timeout=30.0, 590 ) 591 except asyncio.TimeoutError: 592 logging.info( 593 'Timeout on _writer.wait_closed() for %s rpc (transport=%s).', 594 self._label, 595 ssl_stream_writer_underlying_transport_info(self._writer), 596 ) 597 if self.debug_print: 598 self.debug_print_call( 599 f'{self._label}: got timeout in _writer.wait_closed();' 600 ' This should be fixed in future Python versions.' 601 ) 602 except Exception as exc: 603 if not self._is_expected_connection_error(exc): 604 logging.exception('Error closing _writer for %s.', self._label) 605 else: 606 if self.debug_print: 607 self.debug_print_call( 608 f'{self._label}: silently ignoring error in' 609 f' _writer.wait_closed(): {exc}.' 610 ) 611 except asyncio.CancelledError: 612 logging.warning( 613 'RPCEndpoint.wait_closed()' 614 ' got asyncio.CancelledError; not expected.' 615 ) 616 raise 617 assert not self._did_wait_closed_writer 618 self._did_wait_closed_writer = True 619 620 def _tm(self) -> str: 621 """Simple readable time value for debugging.""" 622 tval = time.monotonic() % 100.0 623 return f'{tval:.2f}' 624 625 async def _run_read_task(self) -> None: 626 """Read from the peer.""" 627 self._check_env() 628 assert self._peer_info is None 629 630 # Bug fix: if we don't have this set we will never time out 631 # if we never receive any data from the other end. 632 self._last_keepalive_receive_time = time.monotonic() 633 634 # The first thing they should send us is their handshake; then 635 # we'll know if/how we can talk to them. 636 mlen = await self._read_int_32() 637 message = await self._reader.readexactly(mlen) 638 self._total_bytes_read += mlen 639 self._peer_info = dataclass_from_json(_PeerInfo, message.decode()) 640 self._last_keepalive_receive_time = time.monotonic() 641 if self.debug_print: 642 self.debug_print_call( 643 f'{self._label}: received handshake at {self._tm()}.' 644 ) 645 646 # Now just sit and handle stuff as it comes in. 647 while True: 648 if self._closing: 649 return 650 651 # Read message type. 652 mtype = _PacketType(await self._read_int_8()) 653 if mtype is _PacketType.HANDSHAKE: 654 raise RuntimeError('Got multiple handshakes') 655 656 if mtype is _PacketType.KEEPALIVE: 657 if self.debug_print_io: 658 self.debug_print_call( 659 f'{self._label}: received keepalive' 660 f' at {self._tm()}.' 661 ) 662 self._last_keepalive_receive_time = time.monotonic() 663 664 elif mtype is _PacketType.MESSAGE: 665 await self._handle_message_packet(big=False) 666 667 elif mtype is _PacketType.MESSAGE_BIG: 668 await self._handle_message_packet(big=True) 669 670 elif mtype is _PacketType.RESPONSE: 671 await self._handle_response_packet(big=False) 672 673 elif mtype is _PacketType.RESPONSE_BIG: 674 await self._handle_response_packet(big=True) 675 676 else: 677 assert_never(mtype) 678 679 async def _handle_message_packet(self, big: bool) -> None: 680 assert self._peer_info is not None 681 msgid = await self._read_int_16() 682 if big: 683 msglen = await self._read_int_32() 684 else: 685 msglen = await self._read_int_16() 686 msg = await self._reader.readexactly(msglen) 687 self._total_bytes_read += msglen 688 if self.debug_print_io: 689 self.debug_print_call( 690 f'{self._label}: received message {msgid}' 691 f' of size {msglen} at {self._tm()}.' 692 ) 693 694 # Create a message-task to handle this message and return 695 # a response (we don't want to block while that happens). 696 assert not self._closing 697 self._prune_tasks() # Keep from filling with dead tasks. 698 self._tasks.append( 699 asyncio.create_task( 700 self._handle_raw_message(message_id=msgid, message=msg), 701 name='efro rpc message handle', 702 ) 703 ) 704 if self.debug_print: 705 self.debug_print_call( 706 f'{self._label}: done handling message at {self._tm()}.' 707 ) 708 709 async def _handle_response_packet(self, big: bool) -> None: 710 assert self._peer_info is not None 711 msgid = await self._read_int_16() 712 # Protocol 2 gained 32 bit data lengths. 713 if big: 714 rsplen = await self._read_int_32() 715 else: 716 rsplen = await self._read_int_16() 717 if self.debug_print_io: 718 self.debug_print_call( 719 f'{self._label}: received response {msgid}' 720 f' of size {rsplen} at {self._tm()}.' 721 ) 722 rsp = await self._reader.readexactly(rsplen) 723 self._total_bytes_read += rsplen 724 msgobj = self._in_flight_messages.get(msgid) 725 if msgobj is None: 726 # It's possible for us to get a response to a message 727 # that has timed out. In this case we will have no local 728 # record of it. 729 if self.debug_print: 730 self.debug_print_call( 731 f'{self._label}: got response for nonexistent' 732 f' message id {msgid}; perhaps it timed out?' 733 ) 734 else: 735 msgobj.set_response(rsp) 736 737 async def _run_write_task(self) -> None: 738 """Write to the peer.""" 739 740 self._check_env() 741 742 # Introduce ourself so our peer knows how it can talk to us. 743 data = dataclass_to_json( 744 _PeerInfo( 745 protocol=OUR_PROTOCOL, 746 keepalive_interval=self._keepalive_interval, 747 ) 748 ).encode() 749 self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data) 750 751 # Now just write out-messages as they come in. 752 while True: 753 # Wait until some data comes in. 754 await self._have_out_packets.wait() 755 756 assert self._out_packets 757 data = self._out_packets.popleft() 758 759 # Important: only clear this once all packets are sent. 760 if not self._out_packets: 761 self._have_out_packets.clear() 762 763 self._writer.write(data) 764 765 # This should keep our writer from buffering huge amounts 766 # of outgoing data. We must remember though that we also 767 # need to prevent _out_packets from growing too large and 768 # that part's on us. 769 await self._writer.drain() 770 771 # For now we're not applying backpressure, but let's make 772 # noise if this gets out of hand. 773 if len(self._out_packets) > 200: 774 if not self._did_out_packets_buildup_warning: 775 logging.warning( 776 '_out_packets building up too' 777 ' much on RPCEndpoint %s.', 778 id(self), 779 ) 780 self._did_out_packets_buildup_warning = True 781 782 async def _run_keepalive_task(self) -> None: 783 """Send periodic keepalive packets.""" 784 self._check_env() 785 786 # We explicitly send our own keepalive packets so we can stay 787 # more on top of the connection state and possibly decide to 788 # kill it when contact is lost more quickly than the OS would 789 # do itself (or at least keep the user informed that the 790 # connection is lagging). It sounds like we could have the TCP 791 # layer do this sort of thing itself but that might be 792 # OS-specific so gonna go this way for now. 793 while True: 794 assert not self._closing 795 await asyncio.sleep(self._keepalive_interval) 796 if not self.test_suppress_keepalives: 797 self._enqueue_outgoing_packet( 798 _PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER) 799 ) 800 801 # Also go ahead and handle dropping the connection if we 802 # haven't heard from the peer in a while. 803 # NOTE: perhaps we want to do something more exact than 804 # this which only checks once per keepalive-interval?.. 805 now = time.monotonic() 806 if ( 807 self._last_keepalive_receive_time is not None 808 and now - self._last_keepalive_receive_time 809 > self._keepalive_timeout 810 ): 811 if self.debug_print: 812 since = now - self._last_keepalive_receive_time 813 self.debug_print_call( 814 f'{self._label}: reached keepalive time-out' 815 f' ({since:.1f}s).' 816 ) 817 raise _KeepaliveTimeoutError() 818 819 async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None: 820 try: 821 await call 822 except Exception as exc: 823 # We expect connection errors to put us here, but make noise 824 # if something else does. 825 if not self._is_expected_connection_error(exc): 826 logging.exception( 827 'Unexpected error in rpc %s %s task' 828 ' (age=%.1f, total_bytes_read=%d).', 829 self._label, 830 tasklabel, 831 time.monotonic() - self._create_time, 832 self._total_bytes_read, 833 ) 834 else: 835 if self.debug_print: 836 self.debug_print_call( 837 f'{self._label}: {tasklabel} task will exit cleanly' 838 f' due to {exc!r}.' 839 ) 840 finally: 841 # Any core task exiting triggers shutdown. 842 if self.debug_print: 843 self.debug_print_call( 844 f'{self._label}: {tasklabel} task exiting...' 845 ) 846 self.close() 847 848 async def _handle_raw_message( 849 self, message_id: int, message: bytes 850 ) -> None: 851 try: 852 response = await self._handle_raw_message_call(message) 853 except Exception: 854 # We expect local message handler to always succeed. 855 # If that doesn't happen, make a fuss so we know to fix it. 856 # The other end will simply never get a response to this 857 # message. 858 logging.exception('Error handling raw rpc message') 859 return 860 861 assert self._peer_info is not None 862 863 if self._peer_info.protocol == 1: 864 if len(response) > 65535: 865 raise RuntimeError('Response cannot be larger than 65535 bytes') 866 867 # Now send back our response. 868 # Payload consists of type (1b), msgid (2b), len (2b), and data. 869 if len(response) > 65535: 870 self._enqueue_outgoing_packet( 871 _PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER) 872 + message_id.to_bytes(2, _BYTE_ORDER) 873 + len(response).to_bytes(4, _BYTE_ORDER) 874 + response 875 ) 876 else: 877 self._enqueue_outgoing_packet( 878 _PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER) 879 + message_id.to_bytes(2, _BYTE_ORDER) 880 + len(response).to_bytes(2, _BYTE_ORDER) 881 + response 882 ) 883 884 async def _read_int_8(self) -> int: 885 out = int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER) 886 self._total_bytes_read += 1 887 return out 888 889 async def _read_int_16(self) -> int: 890 out = int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER) 891 self._total_bytes_read += 2 892 return out 893 894 async def _read_int_32(self) -> int: 895 out = int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER) 896 self._total_bytes_read += 4 897 return out 898 899 @classmethod 900 def _is_expected_connection_error(cls, exc: Exception) -> bool: 901 """Stuff we expect to end our connection in normal circumstances.""" 902 903 if isinstance(exc, _KeepaliveTimeoutError): 904 return True 905 906 return is_asyncio_streams_communication_error(exc) 907 908 def _check_env(self) -> None: 909 # I was seeing that asyncio stuff wasn't working as expected if 910 # created in one thread and used in another (and have verified 911 # that this is part of the design), so let's enforce a single 912 # thread for all use of an instance. 913 if current_thread() is not self._thread: 914 raise RuntimeError( 915 'This must be called from the same thread' 916 ' that the endpoint was created in.' 917 ) 918 919 # This should always be the case if thread is the same. 920 assert asyncio.get_running_loop() is self._event_loop 921 922 def _enqueue_outgoing_packet(self, data: bytes) -> None: 923 """Enqueue a raw packet to be sent. Must be called from our loop.""" 924 self._check_env() 925 926 if self.debug_print_io: 927 self.debug_print_call( 928 f'{self._label}: enqueueing outgoing packet' 929 f' {data[:50]!r} at {self._tm()}.' 930 ) 931 932 # Add the data and let our write task know about it. 933 self._out_packets.append(data) 934 self._have_out_packets.set() 935 936 def _prune_tasks(self) -> None: 937 self._tasks = self._get_live_tasks() 938 939 def _get_live_tasks(self) -> list[asyncio.Task]: 940 return [t for t in self._tasks if not t.done()]
Facilitates asynchronous multiplexed remote procedure calls.
Be aware that, while multiple calls can be in flight in either direction simultaneously, packets are still sent serially in a single stream. So excessively long messages/responses will delay all other communication. If/when this becomes an issue we can look into breaking up long messages into multiple packets.
178 def __init__( 179 self, 180 handle_raw_message_call: Callable[[bytes], Awaitable[bytes]], 181 reader: asyncio.StreamReader, 182 writer: asyncio.StreamWriter, 183 label: str, 184 *, 185 debug_print: bool = False, 186 debug_print_io: bool = False, 187 debug_print_call: Callable[[str], None] | None = None, 188 keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL, 189 keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT, 190 ) -> None: 191 self._handle_raw_message_call = handle_raw_message_call 192 self._reader = reader 193 self._writer = writer 194 self.debug_print = debug_print 195 self.debug_print_io = debug_print_io 196 if debug_print_call is None: 197 debug_print_call = print 198 self.debug_print_call: Callable[[str], None] = debug_print_call 199 self._label = label 200 self._thread = current_thread() 201 self._closing = False 202 self._did_wait_closed = False 203 self._event_loop = asyncio.get_running_loop() 204 self._out_packets = deque[bytes]() 205 self._have_out_packets = asyncio.Event() 206 self._run_called = False 207 self._peer_info: _PeerInfo | None = None 208 self._keepalive_interval = keepalive_interval 209 self._keepalive_timeout = keepalive_timeout 210 self._did_close_writer = False 211 self._did_wait_closed_writer = False 212 self._did_out_packets_buildup_warning = False 213 self._total_bytes_read = 0 214 self._create_time = time.monotonic() 215 216 # Need to hold weak-refs to these otherwise it creates dep-loops 217 # which keeps us alive. 218 self._tasks: list[asyncio.Task] = [] 219 220 # When we last got a keepalive or equivalent (time.monotonic value) 221 self._last_keepalive_receive_time: float | None = None 222 223 # (Start near the end to make sure our looping logic is sound). 224 self._next_message_id = 65530 225 226 self._in_flight_messages: dict[int, _InFlightMessage] = {} 227 228 if self.debug_print: 229 peername = self._writer.get_extra_info('peername') 230 self.debug_print_call( 231 f'{self._label}: connected to {peername} at {self._tm()}.' 232 )
255 async def run(self) -> None: 256 """Run the endpoint until the connection is lost or closed. 257 258 Handles closing the provided reader/writer on close. 259 """ 260 try: 261 await self._do_run() 262 except asyncio.CancelledError: 263 # We aren't really designed to be cancelled so let's warn 264 # if it happens. 265 logging.warning( 266 'RPCEndpoint.run got CancelledError;' 267 ' want to try and avoid this.' 268 ) 269 raise
Run the endpoint until the connection is lost or closed.
Handles closing the provided reader/writer on close.
325 def send_message( 326 self, 327 message: bytes, 328 timeout: float | None = None, 329 close_on_error: bool = True, 330 ) -> Awaitable[bytes]: 331 """Send a message to the peer and return a response. 332 333 If timeout is not provided, the default will be used. 334 Raises a CommunicationError if the round trip is not completed 335 for any reason. 336 337 By default, the entire endpoint will go down in the case of 338 errors. This allows messages to be treated as 'reliable' with 339 respect to a given endpoint. Pass close_on_error=False to 340 override this for a particular message. 341 """ 342 # Note: This call is synchronous so that the first part of it 343 # (enqueueing outgoing messages) happens synchronously. If it were 344 # a pure async call it could be possible for send order to vary 345 # based on how the async tasks get processed. 346 347 if self.debug_print_io: 348 self.debug_print_call( 349 f'{self._label}: sending message of size {len(message)}' 350 f' at {self._tm()}.' 351 ) 352 353 self._check_env() 354 355 if self._closing: 356 raise CommunicationError('Endpoint is closed.') 357 358 if self.debug_print_io: 359 self.debug_print_call( 360 f'{self._label}: have peerinfo? {self._peer_info is not None}.' 361 ) 362 363 # message_id is a 16 bit looping value. 364 message_id = self._next_message_id 365 self._next_message_id = (self._next_message_id + 1) % 65536 366 367 if self.debug_print_io: 368 self.debug_print_call( 369 f'{self._label}: will enqueue at {self._tm()}.' 370 ) 371 372 # FIXME - should handle backpressure (waiting here if there are 373 # enough packets already enqueued). 374 375 if len(message) > 65535: 376 # Payload consists of type (1b), message_id (2b), 377 # len (4b), and data. 378 self._enqueue_outgoing_packet( 379 _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER) 380 + message_id.to_bytes(2, _BYTE_ORDER) 381 + len(message).to_bytes(4, _BYTE_ORDER) 382 + message 383 ) 384 else: 385 # Payload consists of type (1b), message_id (2b), 386 # len (2b), and data. 387 self._enqueue_outgoing_packet( 388 _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER) 389 + message_id.to_bytes(2, _BYTE_ORDER) 390 + len(message).to_bytes(2, _BYTE_ORDER) 391 + message 392 ) 393 394 if self.debug_print_io: 395 self.debug_print_call( 396 f'{self._label}: enqueued message of size {len(message)}' 397 f' at {self._tm()}.' 398 ) 399 400 # Make an entry so we know this message is out there. 401 assert message_id not in self._in_flight_messages 402 msgobj = self._in_flight_messages[message_id] = _InFlightMessage() 403 404 # Also add its task to our list so we properly cancel it if we die. 405 self._prune_tasks() # Keep our list from filling with dead tasks. 406 self._tasks.append(msgobj.wait_task) 407 408 # Note: we always want to incorporate a timeout. Individual 409 # messages may hang or error on the other end and this ensures 410 # we won't build up lots of zombie tasks waiting around for 411 # responses that will never arrive. 412 if timeout is None: 413 timeout = self.DEFAULT_MESSAGE_TIMEOUT 414 assert timeout is not None 415 416 bytes_awaitable = msgobj.wait_task 417 418 # Now complete the send asynchronously. 419 return self._send_message( 420 message, timeout, close_on_error, bytes_awaitable, message_id 421 )
Send a message to the peer and return a response.
If timeout is not provided, the default will be used. Raises a CommunicationError if the round trip is not completed for any reason.
By default, the entire endpoint will go down in the case of errors. This allows messages to be treated as 'reliable' with respect to a given endpoint. Pass close_on_error=False to override this for a particular message.
483 def close(self) -> None: 484 """I said seagulls; mmmm; stop it now.""" 485 self._check_env() 486 487 if self._closing: 488 return 489 490 if self.debug_print: 491 self.debug_print_call(f'{self._label}: closing...') 492 493 self._closing = True 494 495 # Kill all of our in-flight tasks. 496 if self.debug_print: 497 self.debug_print_call(f'{self._label}: cancelling tasks...') 498 for task in self._get_live_tasks(): 499 task.cancel() 500 501 # Close our writer. 502 assert not self._did_close_writer 503 if self.debug_print: 504 self.debug_print_call(f'{self._label}: closing writer...') 505 self._writer.close() 506 self._did_close_writer = True 507 508 # We don't need this anymore and it is likely to be creating a 509 # dependency loop. 510 del self._handle_raw_message_call
I said seagulls; mmmm; stop it now.
512 def is_closing(self) -> bool: 513 """Have we begun the process of closing?""" 514 return self._closing
Have we begun the process of closing?
516 async def wait_closed(self) -> None: 517 """I said seagulls; mmmm; stop it now. 518 519 Wait for the endpoint to finish closing. This is called by run() 520 so generally does not need to be explicitly called. 521 """ 522 # pylint: disable=too-many-branches 523 self._check_env() 524 525 # Make sure we only *enter* this call once. 526 if self._did_wait_closed: 527 return 528 self._did_wait_closed = True 529 530 if not self._closing: 531 raise RuntimeError('Must be called after close()') 532 533 if not self._did_close_writer: 534 logging.warning( 535 'RPCEndpoint wait_closed() called but never' 536 ' explicitly closed writer.' 537 ) 538 539 live_tasks = self._get_live_tasks() 540 541 # Don't need our task list anymore; this should 542 # break any cyclical refs from tasks referring to us. 543 self._tasks = [] 544 545 if self.debug_print: 546 self.debug_print_call( 547 f'{self._label}: waiting for tasks to finish: ' 548 f' ({live_tasks=})...' 549 ) 550 551 # Wait for all of our in-flight tasks to wrap up. 552 results = await asyncio.gather(*live_tasks, return_exceptions=True) 553 for result in results: 554 # We want to know if any errors happened aside from CancelledError 555 # (which are BaseExceptions, not Exception). 556 if isinstance(result, Exception): 557 logging.warning( 558 'Got unexpected error cleaning up %s task: %s', 559 self._label, 560 result, 561 ) 562 563 if not all(task.done() for task in live_tasks): 564 logging.warning( 565 'RPCEndpoint %d: not all live tasks marked done after gather.', 566 id(self), 567 ) 568 569 if self.debug_print: 570 self.debug_print_call( 571 f'{self._label}: tasks finished; waiting for writer close...' 572 ) 573 574 # Now wait for our writer to finish going down. 575 # When we close our writer it generally triggers errors 576 # in our current blocked read/writes. However that same 577 # error is also sometimes returned from _writer.wait_closed(). 578 # See connection_lost() in asyncio/streams.py to see why. 579 # So let's silently ignore it when that happens. 580 assert self._writer.is_closing() 581 try: 582 # It seems that as of Python 3.9.x it is possible for this to hang 583 # indefinitely. See https://github.com/python/cpython/issues/83939 584 # It sounds like this should be fixed in 3.11 but for now just 585 # forcing the issue with a timeout here. 586 await asyncio.wait_for( 587 self._writer.wait_closed(), 588 # timeout=60.0 * 6.0, 589 timeout=30.0, 590 ) 591 except asyncio.TimeoutError: 592 logging.info( 593 'Timeout on _writer.wait_closed() for %s rpc (transport=%s).', 594 self._label, 595 ssl_stream_writer_underlying_transport_info(self._writer), 596 ) 597 if self.debug_print: 598 self.debug_print_call( 599 f'{self._label}: got timeout in _writer.wait_closed();' 600 ' This should be fixed in future Python versions.' 601 ) 602 except Exception as exc: 603 if not self._is_expected_connection_error(exc): 604 logging.exception('Error closing _writer for %s.', self._label) 605 else: 606 if self.debug_print: 607 self.debug_print_call( 608 f'{self._label}: silently ignoring error in' 609 f' _writer.wait_closed(): {exc}.' 610 ) 611 except asyncio.CancelledError: 612 logging.warning( 613 'RPCEndpoint.wait_closed()' 614 ' got asyncio.CancelledError; not expected.' 615 ) 616 raise 617 assert not self._did_wait_closed_writer 618 self._did_wait_closed_writer = True
I said seagulls; mmmm; stop it now.
Wait for the endpoint to finish closing. This is called by run() so generally does not need to be explicitly called.