efro.rpc
Remote procedure call related functionality.
1# Released under the MIT License. See LICENSE for details. 2# 3"""Remote procedure call related functionality.""" 4 5from __future__ import annotations 6 7import time 8import asyncio 9import logging 10import weakref 11from enum import Enum 12from functools import partial 13from collections import deque 14from dataclasses import dataclass 15from threading import current_thread 16from typing import TYPE_CHECKING, Annotated, assert_never 17 18from efro.error import ( 19 CommunicationError, 20 is_asyncio_streams_communication_error, 21) 22from efro.dataclassio import ( 23 dataclass_to_json, 24 dataclass_from_json, 25 ioprepped, 26 IOAttrs, 27) 28 29if TYPE_CHECKING: 30 from typing import Literal, Awaitable, Callable 31 32# Terminology: 33# Packet: A chunk of data consisting of a type and some type-dependent 34# payload. Even though we use streams we organize our transmission 35# into 'packets'. 36# Message: User data which we transmit using one or more packets. 37 38 39class _PacketType(Enum): 40 HANDSHAKE = 0 41 KEEPALIVE = 1 42 MESSAGE = 2 43 RESPONSE = 3 44 MESSAGE_BIG = 4 45 RESPONSE_BIG = 5 46 47 48_BYTE_ORDER: Literal['big'] = 'big' 49 50 51@ioprepped 52@dataclass 53class _PeerInfo: 54 # So we can gracefully evolve how we communicate in the future. 55 protocol: Annotated[int, IOAttrs('p')] 56 57 # How often we'll be sending out keepalives (in seconds). 58 keepalive_interval: Annotated[float, IOAttrs('k')] 59 60 61# Note: we are expected to be forward and backward compatible; we can 62# increment protocol freely and expect everyone else to still talk to us. 63# Likewise we should retain logic to communicate with older protocols. 64# Protocol history: 65# 1 - initial release 66# 2 - gained big (32-bit len val) package/response packets 67OUR_PROTOCOL = 2 68 69 70def ssl_stream_writer_underlying_transport_info( 71 writer: asyncio.StreamWriter, 72) -> str: 73 """For debugging SSL Stream connections; returns raw transport info.""" 74 # Note: accessing internals here so just returning info and not 75 # actual objs to reduce potential for breakage. 76 transport = getattr(writer, '_transport', None) 77 if transport is not None: 78 sslproto = getattr(transport, '_ssl_protocol', None) 79 if sslproto is not None: 80 raw_transport = getattr(sslproto, '_transport', None) 81 if raw_transport is not None: 82 return str(raw_transport) 83 return '(not found)' 84 85 86def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None: 87 """Ensure a writer is closed; hacky workaround for odd hang.""" 88 from threading import Thread 89 90 # Disabling for now.. 91 if bool(True): 92 return 93 94 # Hopefully can remove this in Python 3.11?... 95 # see issue with is_closing() below for more details. 96 transport = getattr(writer, '_transport', None) 97 if transport is not None: 98 sslproto = getattr(transport, '_ssl_protocol', None) 99 if sslproto is not None: 100 raw_transport = getattr(sslproto, '_transport', None) 101 if raw_transport is not None: 102 Thread( 103 target=partial( 104 _do_writer_force_close_check, weakref.ref(raw_transport) 105 ), 106 daemon=True, 107 ).start() 108 109 110def _do_writer_force_close_check(transport_weak: weakref.ref) -> None: 111 try: 112 # Attempt to bail as soon as the obj dies. 113 # If it hasn't done so by our timeout, force-kill it. 114 starttime = time.monotonic() 115 while time.monotonic() - starttime < 10.0: 116 time.sleep(0.1) 117 if transport_weak() is None: 118 return 119 transport = transport_weak() 120 if transport is not None: 121 logging.info('Forcing abort on stuck transport %s.', transport) 122 transport.abort() 123 except Exception: 124 logging.warning('Error in writer-force-close-check', exc_info=True) 125 126 127class _InFlightMessage: 128 """Represents a message that is out on the wire.""" 129 130 def __init__(self) -> None: 131 self._response: bytes | None = None 132 self._got_response = asyncio.Event() 133 self.wait_task = asyncio.create_task( 134 self._wait(), name='rpc in flight msg wait' 135 ) 136 137 async def _wait(self) -> bytes: 138 await self._got_response.wait() 139 assert self._response is not None 140 return self._response 141 142 def set_response(self, data: bytes) -> None: 143 """Set response data.""" 144 assert self._response is None 145 self._response = data 146 self._got_response.set() 147 148 149class _KeepaliveTimeoutError(Exception): 150 """Raised if we time out due to not receiving keepalives.""" 151 152 153class RPCEndpoint: 154 """Facilitates asynchronous multiplexed remote procedure calls. 155 156 Be aware that, while multiple calls can be in flight in either direction 157 simultaneously, packets are still sent serially in a single 158 stream. So excessively long messages/responses will delay all other 159 communication. If/when this becomes an issue we can look into breaking up 160 long messages into multiple packets. 161 """ 162 163 # Set to True on an instance to test keepalive failures. 164 test_suppress_keepalives: bool = False 165 166 # How long we should wait before giving up on a message by default. 167 # Note this includes processing time on the other end. 168 DEFAULT_MESSAGE_TIMEOUT = 60.0 169 170 # How often we send out keepalive packets by default. 171 DEFAULT_KEEPALIVE_INTERVAL = 10.73 # (avoid too regular of values) 172 173 # How long we can go without receiving a keepalive packet before we 174 # disconnect. 175 DEFAULT_KEEPALIVE_TIMEOUT = 30.0 176 177 def __init__( 178 self, 179 handle_raw_message_call: Callable[[bytes], Awaitable[bytes]], 180 reader: asyncio.StreamReader, 181 writer: asyncio.StreamWriter, 182 label: str, 183 debug_print: bool = False, 184 debug_print_io: bool = False, 185 debug_print_call: Callable[[str], None] | None = None, 186 keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL, 187 keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT, 188 ) -> None: 189 self._handle_raw_message_call = handle_raw_message_call 190 self._reader = reader 191 self._writer = writer 192 self.debug_print = debug_print 193 self.debug_print_io = debug_print_io 194 if debug_print_call is None: 195 debug_print_call = print 196 self.debug_print_call: Callable[[str], None] = debug_print_call 197 self._label = label 198 self._thread = current_thread() 199 self._closing = False 200 self._did_wait_closed = False 201 self._event_loop = asyncio.get_running_loop() 202 self._out_packets = deque[bytes]() 203 self._have_out_packets = asyncio.Event() 204 self._run_called = False 205 self._peer_info: _PeerInfo | None = None 206 self._keepalive_interval = keepalive_interval 207 self._keepalive_timeout = keepalive_timeout 208 self._did_close_writer = False 209 self._did_wait_closed_writer = False 210 self._did_out_packets_buildup_warning = False 211 self._total_bytes_read = 0 212 self._create_time = time.monotonic() 213 214 # Need to hold weak-refs to these otherwise it creates dep-loops 215 # which keeps us alive. 216 self._tasks: list[asyncio.Task] = [] 217 218 # When we last got a keepalive or equivalent (time.monotonic value) 219 self._last_keepalive_receive_time: float | None = None 220 221 # (Start near the end to make sure our looping logic is sound). 222 self._next_message_id = 65530 223 224 self._in_flight_messages: dict[int, _InFlightMessage] = {} 225 226 if self.debug_print: 227 peername = self._writer.get_extra_info('peername') 228 self.debug_print_call( 229 f'{self._label}: connected to {peername} at {self._tm()}.' 230 ) 231 232 def __del__(self) -> None: 233 if self._run_called: 234 if not self._did_close_writer: 235 logging.warning( 236 'RPCEndpoint %d dying with run' 237 ' called but writer not closed (transport=%s).', 238 id(self), 239 ssl_stream_writer_underlying_transport_info(self._writer), 240 ) 241 elif not self._did_wait_closed_writer: 242 logging.warning( 243 'RPCEndpoint %d dying with run called' 244 ' but writer not wait-closed (transport=%s).', 245 id(self), 246 ssl_stream_writer_underlying_transport_info(self._writer), 247 ) 248 249 # Currently seeing rare issue where sockets don't go down; 250 # let's add a timer to force the issue until we can figure it out. 251 ssl_stream_writer_force_close_check(self._writer) 252 253 async def run(self) -> None: 254 """Run the endpoint until the connection is lost or closed. 255 256 Handles closing the provided reader/writer on close. 257 """ 258 try: 259 await self._do_run() 260 except asyncio.CancelledError: 261 # We aren't really designed to be cancelled so let's warn 262 # if it happens. 263 logging.warning( 264 'RPCEndpoint.run got CancelledError;' 265 ' want to try and avoid this.' 266 ) 267 raise 268 269 async def _do_run(self) -> None: 270 self._check_env() 271 272 if self._run_called: 273 raise RuntimeError('Run can be called only once per endpoint.') 274 self._run_called = True 275 276 core_tasks = [ 277 asyncio.create_task( 278 self._run_core_task('keepalive', self._run_keepalive_task()), 279 name='rpc keepalive', 280 ), 281 asyncio.create_task( 282 self._run_core_task('read', self._run_read_task()), 283 name='rpc read', 284 ), 285 asyncio.create_task( 286 self._run_core_task('write', self._run_write_task()), 287 name='rpc write', 288 ), 289 ] 290 self._tasks += core_tasks 291 292 # Run our core tasks until they all complete. 293 results = await asyncio.gather(*core_tasks, return_exceptions=True) 294 295 # Core tasks should handle their own errors; the only ones 296 # we expect to bubble up are CancelledError. 297 for result in results: 298 # We want to know if any errors happened aside from CancelledError 299 # (which are BaseExceptions, not Exception). 300 if isinstance(result, Exception): 301 logging.warning( 302 'Got unexpected error from %s core task: %s', 303 self._label, 304 result, 305 ) 306 307 if not all(task.done() for task in core_tasks): 308 logging.warning( 309 'RPCEndpoint %d: not all core tasks marked done after gather.', 310 id(self), 311 ) 312 313 # Shut ourself down. 314 try: 315 self.close() 316 await self.wait_closed() 317 except Exception: 318 logging.exception('Error closing %s.', self._label) 319 320 if self.debug_print: 321 self.debug_print_call(f'{self._label}: finished.') 322 323 def send_message( 324 self, 325 message: bytes, 326 timeout: float | None = None, 327 close_on_error: bool = True, 328 ) -> Awaitable[bytes]: 329 """Send a message to the peer and return a response. 330 331 If timeout is not provided, the default will be used. 332 Raises a CommunicationError if the round trip is not completed 333 for any reason. 334 335 By default, the entire endpoint will go down in the case of 336 errors. This allows messages to be treated as 'reliable' with 337 respect to a given endpoint. Pass close_on_error=False to 338 override this for a particular message. 339 """ 340 # Note: This call is synchronous so that the first part of it 341 # (enqueueing outgoing messages) happens synchronously. If it were 342 # a pure async call it could be possible for send order to vary 343 # based on how the async tasks get processed. 344 345 if self.debug_print_io: 346 self.debug_print_call( 347 f'{self._label}: sending message of size {len(message)}' 348 f' at {self._tm()}.' 349 ) 350 351 self._check_env() 352 353 if self._closing: 354 raise CommunicationError('Endpoint is closed.') 355 356 if self.debug_print_io: 357 self.debug_print_call( 358 f'{self._label}: have peerinfo? {self._peer_info is not None}.' 359 ) 360 361 # message_id is a 16 bit looping value. 362 message_id = self._next_message_id 363 self._next_message_id = (self._next_message_id + 1) % 65536 364 365 if self.debug_print_io: 366 self.debug_print_call( 367 f'{self._label}: will enqueue at {self._tm()}.' 368 ) 369 370 # FIXME - should handle backpressure (waiting here if there are 371 # enough packets already enqueued). 372 373 if len(message) > 65535: 374 # Payload consists of type (1b), message_id (2b), 375 # len (4b), and data. 376 self._enqueue_outgoing_packet( 377 _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER) 378 + message_id.to_bytes(2, _BYTE_ORDER) 379 + len(message).to_bytes(4, _BYTE_ORDER) 380 + message 381 ) 382 else: 383 # Payload consists of type (1b), message_id (2b), 384 # len (2b), and data. 385 self._enqueue_outgoing_packet( 386 _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER) 387 + message_id.to_bytes(2, _BYTE_ORDER) 388 + len(message).to_bytes(2, _BYTE_ORDER) 389 + message 390 ) 391 392 if self.debug_print_io: 393 self.debug_print_call( 394 f'{self._label}: enqueued message of size {len(message)}' 395 f' at {self._tm()}.' 396 ) 397 398 # Make an entry so we know this message is out there. 399 assert message_id not in self._in_flight_messages 400 msgobj = self._in_flight_messages[message_id] = _InFlightMessage() 401 402 # Also add its task to our list so we properly cancel it if we die. 403 self._prune_tasks() # Keep our list from filling with dead tasks. 404 self._tasks.append(msgobj.wait_task) 405 406 # Note: we always want to incorporate a timeout. Individual 407 # messages may hang or error on the other end and this ensures 408 # we won't build up lots of zombie tasks waiting around for 409 # responses that will never arrive. 410 if timeout is None: 411 timeout = self.DEFAULT_MESSAGE_TIMEOUT 412 assert timeout is not None 413 414 bytes_awaitable = msgobj.wait_task 415 416 # Now complete the send asynchronously. 417 return self._send_message( 418 message, timeout, close_on_error, bytes_awaitable, message_id 419 ) 420 421 async def _send_message( 422 self, 423 message: bytes, 424 timeout: float | None, 425 close_on_error: bool, 426 bytes_awaitable: asyncio.Task[bytes], 427 message_id: int, 428 ) -> bytes: 429 # We need to know their protocol, so if we haven't gotten a handshake 430 # from them yet, just wait. 431 while self._peer_info is None: 432 await asyncio.sleep(0.01) 433 assert self._peer_info is not None 434 435 if self._peer_info.protocol == 1: 436 if len(message) > 65535: 437 raise RuntimeError('Message cannot be larger than 65535 bytes') 438 439 try: 440 return await asyncio.wait_for(bytes_awaitable, timeout=timeout) 441 except asyncio.CancelledError as exc: 442 # Question: we assume this means the above wait_for() was 443 # cancelled; how do we distinguish between this and *us* being 444 # cancelled though? 445 if self.debug_print: 446 self.debug_print_call( 447 f'{self._label}: message {message_id} was cancelled.' 448 ) 449 if close_on_error: 450 self.close() 451 452 raise CommunicationError() from exc 453 except Exception as exc: 454 # If our timer timed-out or anything else went wrong with 455 # the stream, lump it in as a communication error. 456 if isinstance( 457 exc, asyncio.TimeoutError 458 ) or is_asyncio_streams_communication_error(exc): 459 if self.debug_print: 460 self.debug_print_call( 461 f'{self._label}: got {type(exc)} sending message' 462 f' {message_id}; raising CommunicationError.' 463 ) 464 465 # Stop waiting on the response. 466 bytes_awaitable.cancel() 467 468 # Remove the record of this message. 469 del self._in_flight_messages[message_id] 470 471 if close_on_error: 472 self.close() 473 474 # Let the user know something went wrong. 475 raise CommunicationError() from exc 476 477 # Some unexpected error; let it bubble up. 478 raise 479 480 def close(self) -> None: 481 """I said seagulls; mmmm; stop it now.""" 482 self._check_env() 483 484 if self._closing: 485 return 486 487 if self.debug_print: 488 self.debug_print_call(f'{self._label}: closing...') 489 490 self._closing = True 491 492 # Kill all of our in-flight tasks. 493 if self.debug_print: 494 self.debug_print_call(f'{self._label}: cancelling tasks...') 495 for task in self._get_live_tasks(): 496 task.cancel() 497 498 # Close our writer. 499 assert not self._did_close_writer 500 if self.debug_print: 501 self.debug_print_call(f'{self._label}: closing writer...') 502 self._writer.close() 503 self._did_close_writer = True 504 505 # We don't need this anymore and it is likely to be creating a 506 # dependency loop. 507 del self._handle_raw_message_call 508 509 def is_closing(self) -> bool: 510 """Have we begun the process of closing?""" 511 return self._closing 512 513 async def wait_closed(self) -> None: 514 """I said seagulls; mmmm; stop it now. 515 516 Wait for the endpoint to finish closing. This is called by run() 517 so generally does not need to be explicitly called. 518 """ 519 # pylint: disable=too-many-branches 520 self._check_env() 521 522 # Make sure we only *enter* this call once. 523 if self._did_wait_closed: 524 return 525 self._did_wait_closed = True 526 527 if not self._closing: 528 raise RuntimeError('Must be called after close()') 529 530 if not self._did_close_writer: 531 logging.warning( 532 'RPCEndpoint wait_closed() called but never' 533 ' explicitly closed writer.' 534 ) 535 536 live_tasks = self._get_live_tasks() 537 538 # Don't need our task list anymore; this should 539 # break any cyclical refs from tasks referring to us. 540 self._tasks = [] 541 542 if self.debug_print: 543 self.debug_print_call( 544 f'{self._label}: waiting for tasks to finish: ' 545 f' ({live_tasks=})...' 546 ) 547 548 # Wait for all of our in-flight tasks to wrap up. 549 results = await asyncio.gather(*live_tasks, return_exceptions=True) 550 for result in results: 551 # We want to know if any errors happened aside from CancelledError 552 # (which are BaseExceptions, not Exception). 553 if isinstance(result, Exception): 554 logging.warning( 555 'Got unexpected error cleaning up %s task: %s', 556 self._label, 557 result, 558 ) 559 560 if not all(task.done() for task in live_tasks): 561 logging.warning( 562 'RPCEndpoint %d: not all live tasks marked done after gather.', 563 id(self), 564 ) 565 566 if self.debug_print: 567 self.debug_print_call( 568 f'{self._label}: tasks finished; waiting for writer close...' 569 ) 570 571 # Now wait for our writer to finish going down. 572 # When we close our writer it generally triggers errors 573 # in our current blocked read/writes. However that same 574 # error is also sometimes returned from _writer.wait_closed(). 575 # See connection_lost() in asyncio/streams.py to see why. 576 # So let's silently ignore it when that happens. 577 assert self._writer.is_closing() 578 try: 579 # It seems that as of Python 3.9.x it is possible for this to hang 580 # indefinitely. See https://github.com/python/cpython/issues/83939 581 # It sounds like this should be fixed in 3.11 but for now just 582 # forcing the issue with a timeout here. 583 await asyncio.wait_for( 584 self._writer.wait_closed(), 585 # timeout=60.0 * 6.0, 586 timeout=30.0, 587 ) 588 except asyncio.TimeoutError: 589 logging.info( 590 'Timeout on _writer.wait_closed() for %s rpc (transport=%s).', 591 self._label, 592 ssl_stream_writer_underlying_transport_info(self._writer), 593 ) 594 if self.debug_print: 595 self.debug_print_call( 596 f'{self._label}: got timeout in _writer.wait_closed();' 597 ' This should be fixed in future Python versions.' 598 ) 599 except Exception as exc: 600 if not self._is_expected_connection_error(exc): 601 logging.exception('Error closing _writer for %s.', self._label) 602 else: 603 if self.debug_print: 604 self.debug_print_call( 605 f'{self._label}: silently ignoring error in' 606 f' _writer.wait_closed(): {exc}.' 607 ) 608 except asyncio.CancelledError: 609 logging.warning( 610 'RPCEndpoint.wait_closed()' 611 ' got asyncio.CancelledError; not expected.' 612 ) 613 raise 614 assert not self._did_wait_closed_writer 615 self._did_wait_closed_writer = True 616 617 def _tm(self) -> str: 618 """Simple readable time value for debugging.""" 619 tval = time.monotonic() % 100.0 620 return f'{tval:.2f}' 621 622 async def _run_read_task(self) -> None: 623 """Read from the peer.""" 624 self._check_env() 625 assert self._peer_info is None 626 627 # Bug fix: if we don't have this set we will never time out 628 # if we never receive any data from the other end. 629 self._last_keepalive_receive_time = time.monotonic() 630 631 # The first thing they should send us is their handshake; then 632 # we'll know if/how we can talk to them. 633 mlen = await self._read_int_32() 634 message = await self._reader.readexactly(mlen) 635 self._total_bytes_read += mlen 636 self._peer_info = dataclass_from_json(_PeerInfo, message.decode()) 637 self._last_keepalive_receive_time = time.monotonic() 638 if self.debug_print: 639 self.debug_print_call( 640 f'{self._label}: received handshake at {self._tm()}.' 641 ) 642 643 # Now just sit and handle stuff as it comes in. 644 while True: 645 if self._closing: 646 return 647 648 # Read message type. 649 mtype = _PacketType(await self._read_int_8()) 650 if mtype is _PacketType.HANDSHAKE: 651 raise RuntimeError('Got multiple handshakes') 652 653 if mtype is _PacketType.KEEPALIVE: 654 if self.debug_print_io: 655 self.debug_print_call( 656 f'{self._label}: received keepalive' 657 f' at {self._tm()}.' 658 ) 659 self._last_keepalive_receive_time = time.monotonic() 660 661 elif mtype is _PacketType.MESSAGE: 662 await self._handle_message_packet(big=False) 663 664 elif mtype is _PacketType.MESSAGE_BIG: 665 await self._handle_message_packet(big=True) 666 667 elif mtype is _PacketType.RESPONSE: 668 await self._handle_response_packet(big=False) 669 670 elif mtype is _PacketType.RESPONSE_BIG: 671 await self._handle_response_packet(big=True) 672 673 else: 674 assert_never(mtype) 675 676 async def _handle_message_packet(self, big: bool) -> None: 677 assert self._peer_info is not None 678 msgid = await self._read_int_16() 679 if big: 680 msglen = await self._read_int_32() 681 else: 682 msglen = await self._read_int_16() 683 msg = await self._reader.readexactly(msglen) 684 self._total_bytes_read += msglen 685 if self.debug_print_io: 686 self.debug_print_call( 687 f'{self._label}: received message {msgid}' 688 f' of size {msglen} at {self._tm()}.' 689 ) 690 691 # Create a message-task to handle this message and return 692 # a response (we don't want to block while that happens). 693 assert not self._closing 694 self._prune_tasks() # Keep from filling with dead tasks. 695 self._tasks.append( 696 asyncio.create_task( 697 self._handle_raw_message(message_id=msgid, message=msg), 698 name='efro rpc message handle', 699 ) 700 ) 701 if self.debug_print: 702 self.debug_print_call( 703 f'{self._label}: done handling message at {self._tm()}.' 704 ) 705 706 async def _handle_response_packet(self, big: bool) -> None: 707 assert self._peer_info is not None 708 msgid = await self._read_int_16() 709 # Protocol 2 gained 32 bit data lengths. 710 if big: 711 rsplen = await self._read_int_32() 712 else: 713 rsplen = await self._read_int_16() 714 if self.debug_print_io: 715 self.debug_print_call( 716 f'{self._label}: received response {msgid}' 717 f' of size {rsplen} at {self._tm()}.' 718 ) 719 rsp = await self._reader.readexactly(rsplen) 720 self._total_bytes_read += rsplen 721 msgobj = self._in_flight_messages.get(msgid) 722 if msgobj is None: 723 # It's possible for us to get a response to a message 724 # that has timed out. In this case we will have no local 725 # record of it. 726 if self.debug_print: 727 self.debug_print_call( 728 f'{self._label}: got response for nonexistent' 729 f' message id {msgid}; perhaps it timed out?' 730 ) 731 else: 732 msgobj.set_response(rsp) 733 734 async def _run_write_task(self) -> None: 735 """Write to the peer.""" 736 737 self._check_env() 738 739 # Introduce ourself so our peer knows how it can talk to us. 740 data = dataclass_to_json( 741 _PeerInfo( 742 protocol=OUR_PROTOCOL, 743 keepalive_interval=self._keepalive_interval, 744 ) 745 ).encode() 746 self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data) 747 748 # Now just write out-messages as they come in. 749 while True: 750 # Wait until some data comes in. 751 await self._have_out_packets.wait() 752 753 assert self._out_packets 754 data = self._out_packets.popleft() 755 756 # Important: only clear this once all packets are sent. 757 if not self._out_packets: 758 self._have_out_packets.clear() 759 760 self._writer.write(data) 761 762 # This should keep our writer from buffering huge amounts 763 # of outgoing data. We must remember though that we also 764 # need to prevent _out_packets from growing too large and 765 # that part's on us. 766 await self._writer.drain() 767 768 # For now we're not applying backpressure, but let's make 769 # noise if this gets out of hand. 770 if len(self._out_packets) > 200: 771 if not self._did_out_packets_buildup_warning: 772 logging.warning( 773 '_out_packets building up too' 774 ' much on RPCEndpoint %s.', 775 id(self), 776 ) 777 self._did_out_packets_buildup_warning = True 778 779 async def _run_keepalive_task(self) -> None: 780 """Send periodic keepalive packets.""" 781 self._check_env() 782 783 # We explicitly send our own keepalive packets so we can stay 784 # more on top of the connection state and possibly decide to 785 # kill it when contact is lost more quickly than the OS would 786 # do itself (or at least keep the user informed that the 787 # connection is lagging). It sounds like we could have the TCP 788 # layer do this sort of thing itself but that might be 789 # OS-specific so gonna go this way for now. 790 while True: 791 assert not self._closing 792 await asyncio.sleep(self._keepalive_interval) 793 if not self.test_suppress_keepalives: 794 self._enqueue_outgoing_packet( 795 _PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER) 796 ) 797 798 # Also go ahead and handle dropping the connection if we 799 # haven't heard from the peer in a while. 800 # NOTE: perhaps we want to do something more exact than 801 # this which only checks once per keepalive-interval?.. 802 now = time.monotonic() 803 if ( 804 self._last_keepalive_receive_time is not None 805 and now - self._last_keepalive_receive_time 806 > self._keepalive_timeout 807 ): 808 if self.debug_print: 809 since = now - self._last_keepalive_receive_time 810 self.debug_print_call( 811 f'{self._label}: reached keepalive time-out' 812 f' ({since:.1f}s).' 813 ) 814 raise _KeepaliveTimeoutError() 815 816 async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None: 817 try: 818 await call 819 except Exception as exc: 820 # We expect connection errors to put us here, but make noise 821 # if something else does. 822 if not self._is_expected_connection_error(exc): 823 logging.exception( 824 'Unexpected error in rpc %s %s task' 825 ' (age=%.1f, total_bytes_read=%d).', 826 self._label, 827 tasklabel, 828 time.monotonic() - self._create_time, 829 self._total_bytes_read, 830 ) 831 else: 832 if self.debug_print: 833 self.debug_print_call( 834 f'{self._label}: {tasklabel} task will exit cleanly' 835 f' due to {exc!r}.' 836 ) 837 finally: 838 # Any core task exiting triggers shutdown. 839 if self.debug_print: 840 self.debug_print_call( 841 f'{self._label}: {tasklabel} task exiting...' 842 ) 843 self.close() 844 845 async def _handle_raw_message( 846 self, message_id: int, message: bytes 847 ) -> None: 848 try: 849 response = await self._handle_raw_message_call(message) 850 except Exception: 851 # We expect local message handler to always succeed. 852 # If that doesn't happen, make a fuss so we know to fix it. 853 # The other end will simply never get a response to this 854 # message. 855 logging.exception('Error handling raw rpc message') 856 return 857 858 assert self._peer_info is not None 859 860 if self._peer_info.protocol == 1: 861 if len(response) > 65535: 862 raise RuntimeError('Response cannot be larger than 65535 bytes') 863 864 # Now send back our response. 865 # Payload consists of type (1b), msgid (2b), len (2b), and data. 866 if len(response) > 65535: 867 self._enqueue_outgoing_packet( 868 _PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER) 869 + message_id.to_bytes(2, _BYTE_ORDER) 870 + len(response).to_bytes(4, _BYTE_ORDER) 871 + response 872 ) 873 else: 874 self._enqueue_outgoing_packet( 875 _PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER) 876 + message_id.to_bytes(2, _BYTE_ORDER) 877 + len(response).to_bytes(2, _BYTE_ORDER) 878 + response 879 ) 880 881 async def _read_int_8(self) -> int: 882 out = int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER) 883 self._total_bytes_read += 1 884 return out 885 886 async def _read_int_16(self) -> int: 887 out = int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER) 888 self._total_bytes_read += 2 889 return out 890 891 async def _read_int_32(self) -> int: 892 out = int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER) 893 self._total_bytes_read += 4 894 return out 895 896 @classmethod 897 def _is_expected_connection_error(cls, exc: Exception) -> bool: 898 """Stuff we expect to end our connection in normal circumstances.""" 899 900 if isinstance(exc, _KeepaliveTimeoutError): 901 return True 902 903 return is_asyncio_streams_communication_error(exc) 904 905 def _check_env(self) -> None: 906 # I was seeing that asyncio stuff wasn't working as expected if 907 # created in one thread and used in another (and have verified 908 # that this is part of the design), so let's enforce a single 909 # thread for all use of an instance. 910 if current_thread() is not self._thread: 911 raise RuntimeError( 912 'This must be called from the same thread' 913 ' that the endpoint was created in.' 914 ) 915 916 # This should always be the case if thread is the same. 917 assert asyncio.get_running_loop() is self._event_loop 918 919 def _enqueue_outgoing_packet(self, data: bytes) -> None: 920 """Enqueue a raw packet to be sent. Must be called from our loop.""" 921 self._check_env() 922 923 if self.debug_print_io: 924 self.debug_print_call( 925 f'{self._label}: enqueueing outgoing packet' 926 f' {data[:50]!r} at {self._tm()}.' 927 ) 928 929 # Add the data and let our write task know about it. 930 self._out_packets.append(data) 931 self._have_out_packets.set() 932 933 def _prune_tasks(self) -> None: 934 self._tasks = self._get_live_tasks() 935 936 def _get_live_tasks(self) -> list[asyncio.Task]: 937 return [t for t in self._tasks if not t.done()]
71def ssl_stream_writer_underlying_transport_info( 72 writer: asyncio.StreamWriter, 73) -> str: 74 """For debugging SSL Stream connections; returns raw transport info.""" 75 # Note: accessing internals here so just returning info and not 76 # actual objs to reduce potential for breakage. 77 transport = getattr(writer, '_transport', None) 78 if transport is not None: 79 sslproto = getattr(transport, '_ssl_protocol', None) 80 if sslproto is not None: 81 raw_transport = getattr(sslproto, '_transport', None) 82 if raw_transport is not None: 83 return str(raw_transport) 84 return '(not found)'
For debugging SSL Stream connections; returns raw transport info.
87def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None: 88 """Ensure a writer is closed; hacky workaround for odd hang.""" 89 from threading import Thread 90 91 # Disabling for now.. 92 if bool(True): 93 return 94 95 # Hopefully can remove this in Python 3.11?... 96 # see issue with is_closing() below for more details. 97 transport = getattr(writer, '_transport', None) 98 if transport is not None: 99 sslproto = getattr(transport, '_ssl_protocol', None) 100 if sslproto is not None: 101 raw_transport = getattr(sslproto, '_transport', None) 102 if raw_transport is not None: 103 Thread( 104 target=partial( 105 _do_writer_force_close_check, weakref.ref(raw_transport) 106 ), 107 daemon=True, 108 ).start()
Ensure a writer is closed; hacky workaround for odd hang.
154class RPCEndpoint: 155 """Facilitates asynchronous multiplexed remote procedure calls. 156 157 Be aware that, while multiple calls can be in flight in either direction 158 simultaneously, packets are still sent serially in a single 159 stream. So excessively long messages/responses will delay all other 160 communication. If/when this becomes an issue we can look into breaking up 161 long messages into multiple packets. 162 """ 163 164 # Set to True on an instance to test keepalive failures. 165 test_suppress_keepalives: bool = False 166 167 # How long we should wait before giving up on a message by default. 168 # Note this includes processing time on the other end. 169 DEFAULT_MESSAGE_TIMEOUT = 60.0 170 171 # How often we send out keepalive packets by default. 172 DEFAULT_KEEPALIVE_INTERVAL = 10.73 # (avoid too regular of values) 173 174 # How long we can go without receiving a keepalive packet before we 175 # disconnect. 176 DEFAULT_KEEPALIVE_TIMEOUT = 30.0 177 178 def __init__( 179 self, 180 handle_raw_message_call: Callable[[bytes], Awaitable[bytes]], 181 reader: asyncio.StreamReader, 182 writer: asyncio.StreamWriter, 183 label: str, 184 debug_print: bool = False, 185 debug_print_io: bool = False, 186 debug_print_call: Callable[[str], None] | None = None, 187 keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL, 188 keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT, 189 ) -> None: 190 self._handle_raw_message_call = handle_raw_message_call 191 self._reader = reader 192 self._writer = writer 193 self.debug_print = debug_print 194 self.debug_print_io = debug_print_io 195 if debug_print_call is None: 196 debug_print_call = print 197 self.debug_print_call: Callable[[str], None] = debug_print_call 198 self._label = label 199 self._thread = current_thread() 200 self._closing = False 201 self._did_wait_closed = False 202 self._event_loop = asyncio.get_running_loop() 203 self._out_packets = deque[bytes]() 204 self._have_out_packets = asyncio.Event() 205 self._run_called = False 206 self._peer_info: _PeerInfo | None = None 207 self._keepalive_interval = keepalive_interval 208 self._keepalive_timeout = keepalive_timeout 209 self._did_close_writer = False 210 self._did_wait_closed_writer = False 211 self._did_out_packets_buildup_warning = False 212 self._total_bytes_read = 0 213 self._create_time = time.monotonic() 214 215 # Need to hold weak-refs to these otherwise it creates dep-loops 216 # which keeps us alive. 217 self._tasks: list[asyncio.Task] = [] 218 219 # When we last got a keepalive or equivalent (time.monotonic value) 220 self._last_keepalive_receive_time: float | None = None 221 222 # (Start near the end to make sure our looping logic is sound). 223 self._next_message_id = 65530 224 225 self._in_flight_messages: dict[int, _InFlightMessage] = {} 226 227 if self.debug_print: 228 peername = self._writer.get_extra_info('peername') 229 self.debug_print_call( 230 f'{self._label}: connected to {peername} at {self._tm()}.' 231 ) 232 233 def __del__(self) -> None: 234 if self._run_called: 235 if not self._did_close_writer: 236 logging.warning( 237 'RPCEndpoint %d dying with run' 238 ' called but writer not closed (transport=%s).', 239 id(self), 240 ssl_stream_writer_underlying_transport_info(self._writer), 241 ) 242 elif not self._did_wait_closed_writer: 243 logging.warning( 244 'RPCEndpoint %d dying with run called' 245 ' but writer not wait-closed (transport=%s).', 246 id(self), 247 ssl_stream_writer_underlying_transport_info(self._writer), 248 ) 249 250 # Currently seeing rare issue where sockets don't go down; 251 # let's add a timer to force the issue until we can figure it out. 252 ssl_stream_writer_force_close_check(self._writer) 253 254 async def run(self) -> None: 255 """Run the endpoint until the connection is lost or closed. 256 257 Handles closing the provided reader/writer on close. 258 """ 259 try: 260 await self._do_run() 261 except asyncio.CancelledError: 262 # We aren't really designed to be cancelled so let's warn 263 # if it happens. 264 logging.warning( 265 'RPCEndpoint.run got CancelledError;' 266 ' want to try and avoid this.' 267 ) 268 raise 269 270 async def _do_run(self) -> None: 271 self._check_env() 272 273 if self._run_called: 274 raise RuntimeError('Run can be called only once per endpoint.') 275 self._run_called = True 276 277 core_tasks = [ 278 asyncio.create_task( 279 self._run_core_task('keepalive', self._run_keepalive_task()), 280 name='rpc keepalive', 281 ), 282 asyncio.create_task( 283 self._run_core_task('read', self._run_read_task()), 284 name='rpc read', 285 ), 286 asyncio.create_task( 287 self._run_core_task('write', self._run_write_task()), 288 name='rpc write', 289 ), 290 ] 291 self._tasks += core_tasks 292 293 # Run our core tasks until they all complete. 294 results = await asyncio.gather(*core_tasks, return_exceptions=True) 295 296 # Core tasks should handle their own errors; the only ones 297 # we expect to bubble up are CancelledError. 298 for result in results: 299 # We want to know if any errors happened aside from CancelledError 300 # (which are BaseExceptions, not Exception). 301 if isinstance(result, Exception): 302 logging.warning( 303 'Got unexpected error from %s core task: %s', 304 self._label, 305 result, 306 ) 307 308 if not all(task.done() for task in core_tasks): 309 logging.warning( 310 'RPCEndpoint %d: not all core tasks marked done after gather.', 311 id(self), 312 ) 313 314 # Shut ourself down. 315 try: 316 self.close() 317 await self.wait_closed() 318 except Exception: 319 logging.exception('Error closing %s.', self._label) 320 321 if self.debug_print: 322 self.debug_print_call(f'{self._label}: finished.') 323 324 def send_message( 325 self, 326 message: bytes, 327 timeout: float | None = None, 328 close_on_error: bool = True, 329 ) -> Awaitable[bytes]: 330 """Send a message to the peer and return a response. 331 332 If timeout is not provided, the default will be used. 333 Raises a CommunicationError if the round trip is not completed 334 for any reason. 335 336 By default, the entire endpoint will go down in the case of 337 errors. This allows messages to be treated as 'reliable' with 338 respect to a given endpoint. Pass close_on_error=False to 339 override this for a particular message. 340 """ 341 # Note: This call is synchronous so that the first part of it 342 # (enqueueing outgoing messages) happens synchronously. If it were 343 # a pure async call it could be possible for send order to vary 344 # based on how the async tasks get processed. 345 346 if self.debug_print_io: 347 self.debug_print_call( 348 f'{self._label}: sending message of size {len(message)}' 349 f' at {self._tm()}.' 350 ) 351 352 self._check_env() 353 354 if self._closing: 355 raise CommunicationError('Endpoint is closed.') 356 357 if self.debug_print_io: 358 self.debug_print_call( 359 f'{self._label}: have peerinfo? {self._peer_info is not None}.' 360 ) 361 362 # message_id is a 16 bit looping value. 363 message_id = self._next_message_id 364 self._next_message_id = (self._next_message_id + 1) % 65536 365 366 if self.debug_print_io: 367 self.debug_print_call( 368 f'{self._label}: will enqueue at {self._tm()}.' 369 ) 370 371 # FIXME - should handle backpressure (waiting here if there are 372 # enough packets already enqueued). 373 374 if len(message) > 65535: 375 # Payload consists of type (1b), message_id (2b), 376 # len (4b), and data. 377 self._enqueue_outgoing_packet( 378 _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER) 379 + message_id.to_bytes(2, _BYTE_ORDER) 380 + len(message).to_bytes(4, _BYTE_ORDER) 381 + message 382 ) 383 else: 384 # Payload consists of type (1b), message_id (2b), 385 # len (2b), and data. 386 self._enqueue_outgoing_packet( 387 _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER) 388 + message_id.to_bytes(2, _BYTE_ORDER) 389 + len(message).to_bytes(2, _BYTE_ORDER) 390 + message 391 ) 392 393 if self.debug_print_io: 394 self.debug_print_call( 395 f'{self._label}: enqueued message of size {len(message)}' 396 f' at {self._tm()}.' 397 ) 398 399 # Make an entry so we know this message is out there. 400 assert message_id not in self._in_flight_messages 401 msgobj = self._in_flight_messages[message_id] = _InFlightMessage() 402 403 # Also add its task to our list so we properly cancel it if we die. 404 self._prune_tasks() # Keep our list from filling with dead tasks. 405 self._tasks.append(msgobj.wait_task) 406 407 # Note: we always want to incorporate a timeout. Individual 408 # messages may hang or error on the other end and this ensures 409 # we won't build up lots of zombie tasks waiting around for 410 # responses that will never arrive. 411 if timeout is None: 412 timeout = self.DEFAULT_MESSAGE_TIMEOUT 413 assert timeout is not None 414 415 bytes_awaitable = msgobj.wait_task 416 417 # Now complete the send asynchronously. 418 return self._send_message( 419 message, timeout, close_on_error, bytes_awaitable, message_id 420 ) 421 422 async def _send_message( 423 self, 424 message: bytes, 425 timeout: float | None, 426 close_on_error: bool, 427 bytes_awaitable: asyncio.Task[bytes], 428 message_id: int, 429 ) -> bytes: 430 # We need to know their protocol, so if we haven't gotten a handshake 431 # from them yet, just wait. 432 while self._peer_info is None: 433 await asyncio.sleep(0.01) 434 assert self._peer_info is not None 435 436 if self._peer_info.protocol == 1: 437 if len(message) > 65535: 438 raise RuntimeError('Message cannot be larger than 65535 bytes') 439 440 try: 441 return await asyncio.wait_for(bytes_awaitable, timeout=timeout) 442 except asyncio.CancelledError as exc: 443 # Question: we assume this means the above wait_for() was 444 # cancelled; how do we distinguish between this and *us* being 445 # cancelled though? 446 if self.debug_print: 447 self.debug_print_call( 448 f'{self._label}: message {message_id} was cancelled.' 449 ) 450 if close_on_error: 451 self.close() 452 453 raise CommunicationError() from exc 454 except Exception as exc: 455 # If our timer timed-out or anything else went wrong with 456 # the stream, lump it in as a communication error. 457 if isinstance( 458 exc, asyncio.TimeoutError 459 ) or is_asyncio_streams_communication_error(exc): 460 if self.debug_print: 461 self.debug_print_call( 462 f'{self._label}: got {type(exc)} sending message' 463 f' {message_id}; raising CommunicationError.' 464 ) 465 466 # Stop waiting on the response. 467 bytes_awaitable.cancel() 468 469 # Remove the record of this message. 470 del self._in_flight_messages[message_id] 471 472 if close_on_error: 473 self.close() 474 475 # Let the user know something went wrong. 476 raise CommunicationError() from exc 477 478 # Some unexpected error; let it bubble up. 479 raise 480 481 def close(self) -> None: 482 """I said seagulls; mmmm; stop it now.""" 483 self._check_env() 484 485 if self._closing: 486 return 487 488 if self.debug_print: 489 self.debug_print_call(f'{self._label}: closing...') 490 491 self._closing = True 492 493 # Kill all of our in-flight tasks. 494 if self.debug_print: 495 self.debug_print_call(f'{self._label}: cancelling tasks...') 496 for task in self._get_live_tasks(): 497 task.cancel() 498 499 # Close our writer. 500 assert not self._did_close_writer 501 if self.debug_print: 502 self.debug_print_call(f'{self._label}: closing writer...') 503 self._writer.close() 504 self._did_close_writer = True 505 506 # We don't need this anymore and it is likely to be creating a 507 # dependency loop. 508 del self._handle_raw_message_call 509 510 def is_closing(self) -> bool: 511 """Have we begun the process of closing?""" 512 return self._closing 513 514 async def wait_closed(self) -> None: 515 """I said seagulls; mmmm; stop it now. 516 517 Wait for the endpoint to finish closing. This is called by run() 518 so generally does not need to be explicitly called. 519 """ 520 # pylint: disable=too-many-branches 521 self._check_env() 522 523 # Make sure we only *enter* this call once. 524 if self._did_wait_closed: 525 return 526 self._did_wait_closed = True 527 528 if not self._closing: 529 raise RuntimeError('Must be called after close()') 530 531 if not self._did_close_writer: 532 logging.warning( 533 'RPCEndpoint wait_closed() called but never' 534 ' explicitly closed writer.' 535 ) 536 537 live_tasks = self._get_live_tasks() 538 539 # Don't need our task list anymore; this should 540 # break any cyclical refs from tasks referring to us. 541 self._tasks = [] 542 543 if self.debug_print: 544 self.debug_print_call( 545 f'{self._label}: waiting for tasks to finish: ' 546 f' ({live_tasks=})...' 547 ) 548 549 # Wait for all of our in-flight tasks to wrap up. 550 results = await asyncio.gather(*live_tasks, return_exceptions=True) 551 for result in results: 552 # We want to know if any errors happened aside from CancelledError 553 # (which are BaseExceptions, not Exception). 554 if isinstance(result, Exception): 555 logging.warning( 556 'Got unexpected error cleaning up %s task: %s', 557 self._label, 558 result, 559 ) 560 561 if not all(task.done() for task in live_tasks): 562 logging.warning( 563 'RPCEndpoint %d: not all live tasks marked done after gather.', 564 id(self), 565 ) 566 567 if self.debug_print: 568 self.debug_print_call( 569 f'{self._label}: tasks finished; waiting for writer close...' 570 ) 571 572 # Now wait for our writer to finish going down. 573 # When we close our writer it generally triggers errors 574 # in our current blocked read/writes. However that same 575 # error is also sometimes returned from _writer.wait_closed(). 576 # See connection_lost() in asyncio/streams.py to see why. 577 # So let's silently ignore it when that happens. 578 assert self._writer.is_closing() 579 try: 580 # It seems that as of Python 3.9.x it is possible for this to hang 581 # indefinitely. See https://github.com/python/cpython/issues/83939 582 # It sounds like this should be fixed in 3.11 but for now just 583 # forcing the issue with a timeout here. 584 await asyncio.wait_for( 585 self._writer.wait_closed(), 586 # timeout=60.0 * 6.0, 587 timeout=30.0, 588 ) 589 except asyncio.TimeoutError: 590 logging.info( 591 'Timeout on _writer.wait_closed() for %s rpc (transport=%s).', 592 self._label, 593 ssl_stream_writer_underlying_transport_info(self._writer), 594 ) 595 if self.debug_print: 596 self.debug_print_call( 597 f'{self._label}: got timeout in _writer.wait_closed();' 598 ' This should be fixed in future Python versions.' 599 ) 600 except Exception as exc: 601 if not self._is_expected_connection_error(exc): 602 logging.exception('Error closing _writer for %s.', self._label) 603 else: 604 if self.debug_print: 605 self.debug_print_call( 606 f'{self._label}: silently ignoring error in' 607 f' _writer.wait_closed(): {exc}.' 608 ) 609 except asyncio.CancelledError: 610 logging.warning( 611 'RPCEndpoint.wait_closed()' 612 ' got asyncio.CancelledError; not expected.' 613 ) 614 raise 615 assert not self._did_wait_closed_writer 616 self._did_wait_closed_writer = True 617 618 def _tm(self) -> str: 619 """Simple readable time value for debugging.""" 620 tval = time.monotonic() % 100.0 621 return f'{tval:.2f}' 622 623 async def _run_read_task(self) -> None: 624 """Read from the peer.""" 625 self._check_env() 626 assert self._peer_info is None 627 628 # Bug fix: if we don't have this set we will never time out 629 # if we never receive any data from the other end. 630 self._last_keepalive_receive_time = time.monotonic() 631 632 # The first thing they should send us is their handshake; then 633 # we'll know if/how we can talk to them. 634 mlen = await self._read_int_32() 635 message = await self._reader.readexactly(mlen) 636 self._total_bytes_read += mlen 637 self._peer_info = dataclass_from_json(_PeerInfo, message.decode()) 638 self._last_keepalive_receive_time = time.monotonic() 639 if self.debug_print: 640 self.debug_print_call( 641 f'{self._label}: received handshake at {self._tm()}.' 642 ) 643 644 # Now just sit and handle stuff as it comes in. 645 while True: 646 if self._closing: 647 return 648 649 # Read message type. 650 mtype = _PacketType(await self._read_int_8()) 651 if mtype is _PacketType.HANDSHAKE: 652 raise RuntimeError('Got multiple handshakes') 653 654 if mtype is _PacketType.KEEPALIVE: 655 if self.debug_print_io: 656 self.debug_print_call( 657 f'{self._label}: received keepalive' 658 f' at {self._tm()}.' 659 ) 660 self._last_keepalive_receive_time = time.monotonic() 661 662 elif mtype is _PacketType.MESSAGE: 663 await self._handle_message_packet(big=False) 664 665 elif mtype is _PacketType.MESSAGE_BIG: 666 await self._handle_message_packet(big=True) 667 668 elif mtype is _PacketType.RESPONSE: 669 await self._handle_response_packet(big=False) 670 671 elif mtype is _PacketType.RESPONSE_BIG: 672 await self._handle_response_packet(big=True) 673 674 else: 675 assert_never(mtype) 676 677 async def _handle_message_packet(self, big: bool) -> None: 678 assert self._peer_info is not None 679 msgid = await self._read_int_16() 680 if big: 681 msglen = await self._read_int_32() 682 else: 683 msglen = await self._read_int_16() 684 msg = await self._reader.readexactly(msglen) 685 self._total_bytes_read += msglen 686 if self.debug_print_io: 687 self.debug_print_call( 688 f'{self._label}: received message {msgid}' 689 f' of size {msglen} at {self._tm()}.' 690 ) 691 692 # Create a message-task to handle this message and return 693 # a response (we don't want to block while that happens). 694 assert not self._closing 695 self._prune_tasks() # Keep from filling with dead tasks. 696 self._tasks.append( 697 asyncio.create_task( 698 self._handle_raw_message(message_id=msgid, message=msg), 699 name='efro rpc message handle', 700 ) 701 ) 702 if self.debug_print: 703 self.debug_print_call( 704 f'{self._label}: done handling message at {self._tm()}.' 705 ) 706 707 async def _handle_response_packet(self, big: bool) -> None: 708 assert self._peer_info is not None 709 msgid = await self._read_int_16() 710 # Protocol 2 gained 32 bit data lengths. 711 if big: 712 rsplen = await self._read_int_32() 713 else: 714 rsplen = await self._read_int_16() 715 if self.debug_print_io: 716 self.debug_print_call( 717 f'{self._label}: received response {msgid}' 718 f' of size {rsplen} at {self._tm()}.' 719 ) 720 rsp = await self._reader.readexactly(rsplen) 721 self._total_bytes_read += rsplen 722 msgobj = self._in_flight_messages.get(msgid) 723 if msgobj is None: 724 # It's possible for us to get a response to a message 725 # that has timed out. In this case we will have no local 726 # record of it. 727 if self.debug_print: 728 self.debug_print_call( 729 f'{self._label}: got response for nonexistent' 730 f' message id {msgid}; perhaps it timed out?' 731 ) 732 else: 733 msgobj.set_response(rsp) 734 735 async def _run_write_task(self) -> None: 736 """Write to the peer.""" 737 738 self._check_env() 739 740 # Introduce ourself so our peer knows how it can talk to us. 741 data = dataclass_to_json( 742 _PeerInfo( 743 protocol=OUR_PROTOCOL, 744 keepalive_interval=self._keepalive_interval, 745 ) 746 ).encode() 747 self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data) 748 749 # Now just write out-messages as they come in. 750 while True: 751 # Wait until some data comes in. 752 await self._have_out_packets.wait() 753 754 assert self._out_packets 755 data = self._out_packets.popleft() 756 757 # Important: only clear this once all packets are sent. 758 if not self._out_packets: 759 self._have_out_packets.clear() 760 761 self._writer.write(data) 762 763 # This should keep our writer from buffering huge amounts 764 # of outgoing data. We must remember though that we also 765 # need to prevent _out_packets from growing too large and 766 # that part's on us. 767 await self._writer.drain() 768 769 # For now we're not applying backpressure, but let's make 770 # noise if this gets out of hand. 771 if len(self._out_packets) > 200: 772 if not self._did_out_packets_buildup_warning: 773 logging.warning( 774 '_out_packets building up too' 775 ' much on RPCEndpoint %s.', 776 id(self), 777 ) 778 self._did_out_packets_buildup_warning = True 779 780 async def _run_keepalive_task(self) -> None: 781 """Send periodic keepalive packets.""" 782 self._check_env() 783 784 # We explicitly send our own keepalive packets so we can stay 785 # more on top of the connection state and possibly decide to 786 # kill it when contact is lost more quickly than the OS would 787 # do itself (or at least keep the user informed that the 788 # connection is lagging). It sounds like we could have the TCP 789 # layer do this sort of thing itself but that might be 790 # OS-specific so gonna go this way for now. 791 while True: 792 assert not self._closing 793 await asyncio.sleep(self._keepalive_interval) 794 if not self.test_suppress_keepalives: 795 self._enqueue_outgoing_packet( 796 _PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER) 797 ) 798 799 # Also go ahead and handle dropping the connection if we 800 # haven't heard from the peer in a while. 801 # NOTE: perhaps we want to do something more exact than 802 # this which only checks once per keepalive-interval?.. 803 now = time.monotonic() 804 if ( 805 self._last_keepalive_receive_time is not None 806 and now - self._last_keepalive_receive_time 807 > self._keepalive_timeout 808 ): 809 if self.debug_print: 810 since = now - self._last_keepalive_receive_time 811 self.debug_print_call( 812 f'{self._label}: reached keepalive time-out' 813 f' ({since:.1f}s).' 814 ) 815 raise _KeepaliveTimeoutError() 816 817 async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None: 818 try: 819 await call 820 except Exception as exc: 821 # We expect connection errors to put us here, but make noise 822 # if something else does. 823 if not self._is_expected_connection_error(exc): 824 logging.exception( 825 'Unexpected error in rpc %s %s task' 826 ' (age=%.1f, total_bytes_read=%d).', 827 self._label, 828 tasklabel, 829 time.monotonic() - self._create_time, 830 self._total_bytes_read, 831 ) 832 else: 833 if self.debug_print: 834 self.debug_print_call( 835 f'{self._label}: {tasklabel} task will exit cleanly' 836 f' due to {exc!r}.' 837 ) 838 finally: 839 # Any core task exiting triggers shutdown. 840 if self.debug_print: 841 self.debug_print_call( 842 f'{self._label}: {tasklabel} task exiting...' 843 ) 844 self.close() 845 846 async def _handle_raw_message( 847 self, message_id: int, message: bytes 848 ) -> None: 849 try: 850 response = await self._handle_raw_message_call(message) 851 except Exception: 852 # We expect local message handler to always succeed. 853 # If that doesn't happen, make a fuss so we know to fix it. 854 # The other end will simply never get a response to this 855 # message. 856 logging.exception('Error handling raw rpc message') 857 return 858 859 assert self._peer_info is not None 860 861 if self._peer_info.protocol == 1: 862 if len(response) > 65535: 863 raise RuntimeError('Response cannot be larger than 65535 bytes') 864 865 # Now send back our response. 866 # Payload consists of type (1b), msgid (2b), len (2b), and data. 867 if len(response) > 65535: 868 self._enqueue_outgoing_packet( 869 _PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER) 870 + message_id.to_bytes(2, _BYTE_ORDER) 871 + len(response).to_bytes(4, _BYTE_ORDER) 872 + response 873 ) 874 else: 875 self._enqueue_outgoing_packet( 876 _PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER) 877 + message_id.to_bytes(2, _BYTE_ORDER) 878 + len(response).to_bytes(2, _BYTE_ORDER) 879 + response 880 ) 881 882 async def _read_int_8(self) -> int: 883 out = int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER) 884 self._total_bytes_read += 1 885 return out 886 887 async def _read_int_16(self) -> int: 888 out = int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER) 889 self._total_bytes_read += 2 890 return out 891 892 async def _read_int_32(self) -> int: 893 out = int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER) 894 self._total_bytes_read += 4 895 return out 896 897 @classmethod 898 def _is_expected_connection_error(cls, exc: Exception) -> bool: 899 """Stuff we expect to end our connection in normal circumstances.""" 900 901 if isinstance(exc, _KeepaliveTimeoutError): 902 return True 903 904 return is_asyncio_streams_communication_error(exc) 905 906 def _check_env(self) -> None: 907 # I was seeing that asyncio stuff wasn't working as expected if 908 # created in one thread and used in another (and have verified 909 # that this is part of the design), so let's enforce a single 910 # thread for all use of an instance. 911 if current_thread() is not self._thread: 912 raise RuntimeError( 913 'This must be called from the same thread' 914 ' that the endpoint was created in.' 915 ) 916 917 # This should always be the case if thread is the same. 918 assert asyncio.get_running_loop() is self._event_loop 919 920 def _enqueue_outgoing_packet(self, data: bytes) -> None: 921 """Enqueue a raw packet to be sent. Must be called from our loop.""" 922 self._check_env() 923 924 if self.debug_print_io: 925 self.debug_print_call( 926 f'{self._label}: enqueueing outgoing packet' 927 f' {data[:50]!r} at {self._tm()}.' 928 ) 929 930 # Add the data and let our write task know about it. 931 self._out_packets.append(data) 932 self._have_out_packets.set() 933 934 def _prune_tasks(self) -> None: 935 self._tasks = self._get_live_tasks() 936 937 def _get_live_tasks(self) -> list[asyncio.Task]: 938 return [t for t in self._tasks if not t.done()]
Facilitates asynchronous multiplexed remote procedure calls.
Be aware that, while multiple calls can be in flight in either direction simultaneously, packets are still sent serially in a single stream. So excessively long messages/responses will delay all other communication. If/when this becomes an issue we can look into breaking up long messages into multiple packets.
178 def __init__( 179 self, 180 handle_raw_message_call: Callable[[bytes], Awaitable[bytes]], 181 reader: asyncio.StreamReader, 182 writer: asyncio.StreamWriter, 183 label: str, 184 debug_print: bool = False, 185 debug_print_io: bool = False, 186 debug_print_call: Callable[[str], None] | None = None, 187 keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL, 188 keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT, 189 ) -> None: 190 self._handle_raw_message_call = handle_raw_message_call 191 self._reader = reader 192 self._writer = writer 193 self.debug_print = debug_print 194 self.debug_print_io = debug_print_io 195 if debug_print_call is None: 196 debug_print_call = print 197 self.debug_print_call: Callable[[str], None] = debug_print_call 198 self._label = label 199 self._thread = current_thread() 200 self._closing = False 201 self._did_wait_closed = False 202 self._event_loop = asyncio.get_running_loop() 203 self._out_packets = deque[bytes]() 204 self._have_out_packets = asyncio.Event() 205 self._run_called = False 206 self._peer_info: _PeerInfo | None = None 207 self._keepalive_interval = keepalive_interval 208 self._keepalive_timeout = keepalive_timeout 209 self._did_close_writer = False 210 self._did_wait_closed_writer = False 211 self._did_out_packets_buildup_warning = False 212 self._total_bytes_read = 0 213 self._create_time = time.monotonic() 214 215 # Need to hold weak-refs to these otherwise it creates dep-loops 216 # which keeps us alive. 217 self._tasks: list[asyncio.Task] = [] 218 219 # When we last got a keepalive or equivalent (time.monotonic value) 220 self._last_keepalive_receive_time: float | None = None 221 222 # (Start near the end to make sure our looping logic is sound). 223 self._next_message_id = 65530 224 225 self._in_flight_messages: dict[int, _InFlightMessage] = {} 226 227 if self.debug_print: 228 peername = self._writer.get_extra_info('peername') 229 self.debug_print_call( 230 f'{self._label}: connected to {peername} at {self._tm()}.' 231 )
254 async def run(self) -> None: 255 """Run the endpoint until the connection is lost or closed. 256 257 Handles closing the provided reader/writer on close. 258 """ 259 try: 260 await self._do_run() 261 except asyncio.CancelledError: 262 # We aren't really designed to be cancelled so let's warn 263 # if it happens. 264 logging.warning( 265 'RPCEndpoint.run got CancelledError;' 266 ' want to try and avoid this.' 267 ) 268 raise
Run the endpoint until the connection is lost or closed.
Handles closing the provided reader/writer on close.
324 def send_message( 325 self, 326 message: bytes, 327 timeout: float | None = None, 328 close_on_error: bool = True, 329 ) -> Awaitable[bytes]: 330 """Send a message to the peer and return a response. 331 332 If timeout is not provided, the default will be used. 333 Raises a CommunicationError if the round trip is not completed 334 for any reason. 335 336 By default, the entire endpoint will go down in the case of 337 errors. This allows messages to be treated as 'reliable' with 338 respect to a given endpoint. Pass close_on_error=False to 339 override this for a particular message. 340 """ 341 # Note: This call is synchronous so that the first part of it 342 # (enqueueing outgoing messages) happens synchronously. If it were 343 # a pure async call it could be possible for send order to vary 344 # based on how the async tasks get processed. 345 346 if self.debug_print_io: 347 self.debug_print_call( 348 f'{self._label}: sending message of size {len(message)}' 349 f' at {self._tm()}.' 350 ) 351 352 self._check_env() 353 354 if self._closing: 355 raise CommunicationError('Endpoint is closed.') 356 357 if self.debug_print_io: 358 self.debug_print_call( 359 f'{self._label}: have peerinfo? {self._peer_info is not None}.' 360 ) 361 362 # message_id is a 16 bit looping value. 363 message_id = self._next_message_id 364 self._next_message_id = (self._next_message_id + 1) % 65536 365 366 if self.debug_print_io: 367 self.debug_print_call( 368 f'{self._label}: will enqueue at {self._tm()}.' 369 ) 370 371 # FIXME - should handle backpressure (waiting here if there are 372 # enough packets already enqueued). 373 374 if len(message) > 65535: 375 # Payload consists of type (1b), message_id (2b), 376 # len (4b), and data. 377 self._enqueue_outgoing_packet( 378 _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER) 379 + message_id.to_bytes(2, _BYTE_ORDER) 380 + len(message).to_bytes(4, _BYTE_ORDER) 381 + message 382 ) 383 else: 384 # Payload consists of type (1b), message_id (2b), 385 # len (2b), and data. 386 self._enqueue_outgoing_packet( 387 _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER) 388 + message_id.to_bytes(2, _BYTE_ORDER) 389 + len(message).to_bytes(2, _BYTE_ORDER) 390 + message 391 ) 392 393 if self.debug_print_io: 394 self.debug_print_call( 395 f'{self._label}: enqueued message of size {len(message)}' 396 f' at {self._tm()}.' 397 ) 398 399 # Make an entry so we know this message is out there. 400 assert message_id not in self._in_flight_messages 401 msgobj = self._in_flight_messages[message_id] = _InFlightMessage() 402 403 # Also add its task to our list so we properly cancel it if we die. 404 self._prune_tasks() # Keep our list from filling with dead tasks. 405 self._tasks.append(msgobj.wait_task) 406 407 # Note: we always want to incorporate a timeout. Individual 408 # messages may hang or error on the other end and this ensures 409 # we won't build up lots of zombie tasks waiting around for 410 # responses that will never arrive. 411 if timeout is None: 412 timeout = self.DEFAULT_MESSAGE_TIMEOUT 413 assert timeout is not None 414 415 bytes_awaitable = msgobj.wait_task 416 417 # Now complete the send asynchronously. 418 return self._send_message( 419 message, timeout, close_on_error, bytes_awaitable, message_id 420 )
Send a message to the peer and return a response.
If timeout is not provided, the default will be used. Raises a CommunicationError if the round trip is not completed for any reason.
By default, the entire endpoint will go down in the case of errors. This allows messages to be treated as 'reliable' with respect to a given endpoint. Pass close_on_error=False to override this for a particular message.
481 def close(self) -> None: 482 """I said seagulls; mmmm; stop it now.""" 483 self._check_env() 484 485 if self._closing: 486 return 487 488 if self.debug_print: 489 self.debug_print_call(f'{self._label}: closing...') 490 491 self._closing = True 492 493 # Kill all of our in-flight tasks. 494 if self.debug_print: 495 self.debug_print_call(f'{self._label}: cancelling tasks...') 496 for task in self._get_live_tasks(): 497 task.cancel() 498 499 # Close our writer. 500 assert not self._did_close_writer 501 if self.debug_print: 502 self.debug_print_call(f'{self._label}: closing writer...') 503 self._writer.close() 504 self._did_close_writer = True 505 506 # We don't need this anymore and it is likely to be creating a 507 # dependency loop. 508 del self._handle_raw_message_call
I said seagulls; mmmm; stop it now.
510 def is_closing(self) -> bool: 511 """Have we begun the process of closing?""" 512 return self._closing
Have we begun the process of closing?
514 async def wait_closed(self) -> None: 515 """I said seagulls; mmmm; stop it now. 516 517 Wait for the endpoint to finish closing. This is called by run() 518 so generally does not need to be explicitly called. 519 """ 520 # pylint: disable=too-many-branches 521 self._check_env() 522 523 # Make sure we only *enter* this call once. 524 if self._did_wait_closed: 525 return 526 self._did_wait_closed = True 527 528 if not self._closing: 529 raise RuntimeError('Must be called after close()') 530 531 if not self._did_close_writer: 532 logging.warning( 533 'RPCEndpoint wait_closed() called but never' 534 ' explicitly closed writer.' 535 ) 536 537 live_tasks = self._get_live_tasks() 538 539 # Don't need our task list anymore; this should 540 # break any cyclical refs from tasks referring to us. 541 self._tasks = [] 542 543 if self.debug_print: 544 self.debug_print_call( 545 f'{self._label}: waiting for tasks to finish: ' 546 f' ({live_tasks=})...' 547 ) 548 549 # Wait for all of our in-flight tasks to wrap up. 550 results = await asyncio.gather(*live_tasks, return_exceptions=True) 551 for result in results: 552 # We want to know if any errors happened aside from CancelledError 553 # (which are BaseExceptions, not Exception). 554 if isinstance(result, Exception): 555 logging.warning( 556 'Got unexpected error cleaning up %s task: %s', 557 self._label, 558 result, 559 ) 560 561 if not all(task.done() for task in live_tasks): 562 logging.warning( 563 'RPCEndpoint %d: not all live tasks marked done after gather.', 564 id(self), 565 ) 566 567 if self.debug_print: 568 self.debug_print_call( 569 f'{self._label}: tasks finished; waiting for writer close...' 570 ) 571 572 # Now wait for our writer to finish going down. 573 # When we close our writer it generally triggers errors 574 # in our current blocked read/writes. However that same 575 # error is also sometimes returned from _writer.wait_closed(). 576 # See connection_lost() in asyncio/streams.py to see why. 577 # So let's silently ignore it when that happens. 578 assert self._writer.is_closing() 579 try: 580 # It seems that as of Python 3.9.x it is possible for this to hang 581 # indefinitely. See https://github.com/python/cpython/issues/83939 582 # It sounds like this should be fixed in 3.11 but for now just 583 # forcing the issue with a timeout here. 584 await asyncio.wait_for( 585 self._writer.wait_closed(), 586 # timeout=60.0 * 6.0, 587 timeout=30.0, 588 ) 589 except asyncio.TimeoutError: 590 logging.info( 591 'Timeout on _writer.wait_closed() for %s rpc (transport=%s).', 592 self._label, 593 ssl_stream_writer_underlying_transport_info(self._writer), 594 ) 595 if self.debug_print: 596 self.debug_print_call( 597 f'{self._label}: got timeout in _writer.wait_closed();' 598 ' This should be fixed in future Python versions.' 599 ) 600 except Exception as exc: 601 if not self._is_expected_connection_error(exc): 602 logging.exception('Error closing _writer for %s.', self._label) 603 else: 604 if self.debug_print: 605 self.debug_print_call( 606 f'{self._label}: silently ignoring error in' 607 f' _writer.wait_closed(): {exc}.' 608 ) 609 except asyncio.CancelledError: 610 logging.warning( 611 'RPCEndpoint.wait_closed()' 612 ' got asyncio.CancelledError; not expected.' 613 ) 614 raise 615 assert not self._did_wait_closed_writer 616 self._did_wait_closed_writer = True
I said seagulls; mmmm; stop it now.
Wait for the endpoint to finish closing. This is called by run() so generally does not need to be explicitly called.