efro.rpc
Remote procedure call related functionality.
1# Released under the MIT License. See LICENSE for details. 2# 3"""Remote procedure call related functionality.""" 4 5from __future__ import annotations 6 7import time 8import asyncio 9import logging 10import weakref 11from enum import Enum 12from collections import deque 13from dataclasses import dataclass 14from threading import current_thread 15from typing import TYPE_CHECKING, Annotated, assert_never 16 17from efro.error import ( 18 CommunicationError, 19 is_asyncio_streams_communication_error, 20) 21from efro.dataclassio import ( 22 dataclass_to_json, 23 dataclass_from_json, 24 ioprepped, 25 IOAttrs, 26) 27 28if TYPE_CHECKING: 29 from typing import Literal, Awaitable, Callable 30 31# Terminology: 32# Packet: A chunk of data consisting of a type and some type-dependent 33# payload. Even though we use streams we organize our transmission 34# into 'packets'. 35# Message: User data which we transmit using one or more packets. 36 37 38class _PacketType(Enum): 39 HANDSHAKE = 0 40 KEEPALIVE = 1 41 MESSAGE = 2 42 RESPONSE = 3 43 MESSAGE_BIG = 4 44 RESPONSE_BIG = 5 45 46 47_BYTE_ORDER: Literal['big'] = 'big' 48 49 50@ioprepped 51@dataclass 52class _PeerInfo: 53 # So we can gracefully evolve how we communicate in the future. 54 protocol: Annotated[int, IOAttrs('p')] 55 56 # How often we'll be sending out keepalives (in seconds). 57 keepalive_interval: Annotated[float, IOAttrs('k')] 58 59 60# Note: we are expected to be forward and backward compatible; we can 61# increment protocol freely and expect everyone else to still talk to us. 62# Likewise we should retain logic to communicate with older protocols. 63# Protocol history: 64# 1 - initial release 65# 2 - gained big (32-bit len val) package/response packets 66OUR_PROTOCOL = 2 67 68 69def ssl_stream_writer_underlying_transport_info( 70 writer: asyncio.StreamWriter, 71) -> str: 72 """For debugging SSL Stream connections; returns raw transport info.""" 73 # Note: accessing internals here so just returning info and not 74 # actual objs to reduce potential for breakage. 75 transport = getattr(writer, '_transport', None) 76 if transport is not None: 77 sslproto = getattr(transport, '_ssl_protocol', None) 78 if sslproto is not None: 79 raw_transport = getattr(sslproto, '_transport', None) 80 if raw_transport is not None: 81 return str(raw_transport) 82 return '(not found)' 83 84 85def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None: 86 """Ensure a writer is closed; hacky workaround for odd hang.""" 87 from efro.call import tpartial 88 from threading import Thread 89 90 # Disabling for now.. 91 if bool(True): 92 return 93 94 # Hopefully can remove this in Python 3.11?... 95 # see issue with is_closing() below for more details. 96 transport = getattr(writer, '_transport', None) 97 if transport is not None: 98 sslproto = getattr(transport, '_ssl_protocol', None) 99 if sslproto is not None: 100 raw_transport = getattr(sslproto, '_transport', None) 101 if raw_transport is not None: 102 Thread( 103 target=tpartial( 104 _do_writer_force_close_check, 105 weakref.ref(raw_transport), 106 ), 107 daemon=True, 108 ).start() 109 110 111def _do_writer_force_close_check(transport_weak: weakref.ref) -> None: 112 try: 113 # Attempt to bail as soon as the obj dies. 114 # If it hasn't done so by our timeout, force-kill it. 115 starttime = time.monotonic() 116 while time.monotonic() - starttime < 10.0: 117 time.sleep(0.1) 118 if transport_weak() is None: 119 return 120 transport = transport_weak() 121 if transport is not None: 122 logging.info('Forcing abort on stuck transport %s.', transport) 123 transport.abort() 124 except Exception: 125 logging.warning('Error in writer-force-close-check', exc_info=True) 126 127 128class _InFlightMessage: 129 """Represents a message that is out on the wire.""" 130 131 def __init__(self) -> None: 132 self._response: bytes | None = None 133 self._got_response = asyncio.Event() 134 self.wait_task = asyncio.create_task( 135 self._wait(), name='rpc in flight msg wait' 136 ) 137 138 async def _wait(self) -> bytes: 139 await self._got_response.wait() 140 assert self._response is not None 141 return self._response 142 143 def set_response(self, data: bytes) -> None: 144 """Set response data.""" 145 assert self._response is None 146 self._response = data 147 self._got_response.set() 148 149 150class _KeepaliveTimeoutError(Exception): 151 """Raised if we time out due to not receiving keepalives.""" 152 153 154class RPCEndpoint: 155 """Facilitates asynchronous multiplexed remote procedure calls. 156 157 Be aware that, while multiple calls can be in flight in either direction 158 simultaneously, packets are still sent serially in a single 159 stream. So excessively long messages/responses will delay all other 160 communication. If/when this becomes an issue we can look into breaking up 161 long messages into multiple packets. 162 """ 163 164 # Set to True on an instance to test keepalive failures. 165 test_suppress_keepalives: bool = False 166 167 # How long we should wait before giving up on a message by default. 168 # Note this includes processing time on the other end. 169 DEFAULT_MESSAGE_TIMEOUT = 60.0 170 171 # How often we send out keepalive packets by default. 172 DEFAULT_KEEPALIVE_INTERVAL = 10.73 # (avoid too regular of values) 173 174 # How long we can go without receiving a keepalive packet before we 175 # disconnect. 176 DEFAULT_KEEPALIVE_TIMEOUT = 30.0 177 178 def __init__( 179 self, 180 handle_raw_message_call: Callable[[bytes], Awaitable[bytes]], 181 reader: asyncio.StreamReader, 182 writer: asyncio.StreamWriter, 183 label: str, 184 debug_print: bool = False, 185 debug_print_io: bool = False, 186 debug_print_call: Callable[[str], None] | None = None, 187 keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL, 188 keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT, 189 ) -> None: 190 self._handle_raw_message_call = handle_raw_message_call 191 self._reader = reader 192 self._writer = writer 193 self.debug_print = debug_print 194 self.debug_print_io = debug_print_io 195 if debug_print_call is None: 196 debug_print_call = print 197 self.debug_print_call: Callable[[str], None] = debug_print_call 198 self._label = label 199 self._thread = current_thread() 200 self._closing = False 201 self._did_wait_closed = False 202 self._event_loop = asyncio.get_running_loop() 203 self._out_packets = deque[bytes]() 204 self._have_out_packets = asyncio.Event() 205 self._run_called = False 206 self._peer_info: _PeerInfo | None = None 207 self._keepalive_interval = keepalive_interval 208 self._keepalive_timeout = keepalive_timeout 209 self._did_close_writer = False 210 self._did_wait_closed_writer = False 211 self._did_out_packets_buildup_warning = False 212 self._total_bytes_read = 0 213 self._create_time = time.monotonic() 214 215 # Need to hold weak-refs to these otherwise it creates dep-loops 216 # which keeps us alive. 217 self._tasks: list[asyncio.Task] = [] 218 219 # When we last got a keepalive or equivalent (time.monotonic value) 220 self._last_keepalive_receive_time: float | None = None 221 222 # (Start near the end to make sure our looping logic is sound). 223 self._next_message_id = 65530 224 225 self._in_flight_messages: dict[int, _InFlightMessage] = {} 226 227 if self.debug_print: 228 peername = self._writer.get_extra_info('peername') 229 self.debug_print_call( 230 f'{self._label}: connected to {peername} at {self._tm()}.' 231 ) 232 233 def __del__(self) -> None: 234 if self._run_called: 235 if not self._did_close_writer: 236 logging.warning( 237 'RPCEndpoint %d dying with run' 238 ' called but writer not closed (transport=%s).', 239 id(self), 240 ssl_stream_writer_underlying_transport_info(self._writer), 241 ) 242 elif not self._did_wait_closed_writer: 243 logging.warning( 244 'RPCEndpoint %d dying with run called' 245 ' but writer not wait-closed (transport=%s).', 246 id(self), 247 ssl_stream_writer_underlying_transport_info(self._writer), 248 ) 249 250 # Currently seeing rare issue where sockets don't go down; 251 # let's add a timer to force the issue until we can figure it out. 252 ssl_stream_writer_force_close_check(self._writer) 253 254 async def run(self) -> None: 255 """Run the endpoint until the connection is lost or closed. 256 257 Handles closing the provided reader/writer on close. 258 """ 259 try: 260 await self._do_run() 261 except asyncio.CancelledError: 262 # We aren't really designed to be cancelled so let's warn 263 # if it happens. 264 logging.warning( 265 'RPCEndpoint.run got CancelledError;' 266 ' want to try and avoid this.' 267 ) 268 raise 269 270 async def _do_run(self) -> None: 271 self._check_env() 272 273 if self._run_called: 274 raise RuntimeError('Run can be called only once per endpoint.') 275 self._run_called = True 276 277 core_tasks = [ 278 asyncio.create_task( 279 self._run_core_task('keepalive', self._run_keepalive_task()), 280 name='rpc keepalive', 281 ), 282 asyncio.create_task( 283 self._run_core_task('read', self._run_read_task()), 284 name='rpc read', 285 ), 286 asyncio.create_task( 287 self._run_core_task('write', self._run_write_task()), 288 name='rpc write', 289 ), 290 ] 291 self._tasks += core_tasks 292 293 # Run our core tasks until they all complete. 294 results = await asyncio.gather(*core_tasks, return_exceptions=True) 295 296 # Core tasks should handle their own errors; the only ones 297 # we expect to bubble up are CancelledError. 298 for result in results: 299 # We want to know if any errors happened aside from CancelledError 300 # (which are BaseExceptions, not Exception). 301 if isinstance(result, Exception): 302 logging.warning( 303 'Got unexpected error from %s core task: %s', 304 self._label, 305 result, 306 ) 307 308 if not all(task.done() for task in core_tasks): 309 logging.warning( 310 'RPCEndpoint %d: not all core tasks marked done after gather.', 311 id(self), 312 ) 313 314 # Shut ourself down. 315 try: 316 self.close() 317 await self.wait_closed() 318 except Exception: 319 logging.exception('Error closing %s.', self._label) 320 321 if self.debug_print: 322 self.debug_print_call(f'{self._label}: finished.') 323 324 def send_message( 325 self, 326 message: bytes, 327 timeout: float | None = None, 328 close_on_error: bool = True, 329 ) -> Awaitable[bytes]: 330 """Send a message to the peer and return a response. 331 332 If timeout is not provided, the default will be used. 333 Raises a CommunicationError if the round trip is not completed 334 for any reason. 335 336 By default, the entire endpoint will go down in the case of 337 errors. This allows messages to be treated as 'reliable' with 338 respect to a given endpoint. Pass close_on_error=False to 339 override this for a particular message. 340 """ 341 # Note: This call is synchronous so that the first part of it 342 # (enqueueing outgoing messages) happens synchronously. If it were 343 # a pure async call it could be possible for send order to vary 344 # based on how the async tasks get processed. 345 346 if self.debug_print_io: 347 self.debug_print_call( 348 f'{self._label}: sending message of size {len(message)}' 349 f' at {self._tm()}.' 350 ) 351 352 self._check_env() 353 354 if self._closing: 355 raise CommunicationError('Endpoint is closed.') 356 357 if self.debug_print_io: 358 self.debug_print_call( 359 f'{self._label}: have peerinfo? {self._peer_info is not None}.' 360 ) 361 362 # message_id is a 16 bit looping value. 363 message_id = self._next_message_id 364 self._next_message_id = (self._next_message_id + 1) % 65536 365 366 if self.debug_print_io: 367 self.debug_print_call( 368 f'{self._label}: will enqueue at {self._tm()}.' 369 ) 370 371 # FIXME - should handle backpressure (waiting here if there are 372 # enough packets already enqueued). 373 374 if len(message) > 65535: 375 # Payload consists of type (1b), message_id (2b), 376 # len (4b), and data. 377 self._enqueue_outgoing_packet( 378 _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER) 379 + message_id.to_bytes(2, _BYTE_ORDER) 380 + len(message).to_bytes(4, _BYTE_ORDER) 381 + message 382 ) 383 else: 384 # Payload consists of type (1b), message_id (2b), 385 # len (2b), and data. 386 self._enqueue_outgoing_packet( 387 _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER) 388 + message_id.to_bytes(2, _BYTE_ORDER) 389 + len(message).to_bytes(2, _BYTE_ORDER) 390 + message 391 ) 392 393 if self.debug_print_io: 394 self.debug_print_call( 395 f'{self._label}: enqueued message of size {len(message)}' 396 f' at {self._tm()}.' 397 ) 398 399 # Make an entry so we know this message is out there. 400 assert message_id not in self._in_flight_messages 401 msgobj = self._in_flight_messages[message_id] = _InFlightMessage() 402 403 # Also add its task to our list so we properly cancel it if we die. 404 self._prune_tasks() # Keep our list from filling with dead tasks. 405 self._tasks.append(msgobj.wait_task) 406 407 # Note: we always want to incorporate a timeout. Individual 408 # messages may hang or error on the other end and this ensures 409 # we won't build up lots of zombie tasks waiting around for 410 # responses that will never arrive. 411 if timeout is None: 412 timeout = self.DEFAULT_MESSAGE_TIMEOUT 413 assert timeout is not None 414 415 bytes_awaitable = msgobj.wait_task 416 417 # Now complete the send asynchronously. 418 return self._send_message( 419 message, timeout, close_on_error, bytes_awaitable, message_id 420 ) 421 422 async def _send_message( 423 self, 424 message: bytes, 425 timeout: float | None, 426 close_on_error: bool, 427 bytes_awaitable: asyncio.Task[bytes], 428 message_id: int, 429 ) -> bytes: 430 # We need to know their protocol, so if we haven't gotten a handshake 431 # from them yet, just wait. 432 while self._peer_info is None: 433 await asyncio.sleep(0.01) 434 assert self._peer_info is not None 435 436 if self._peer_info.protocol == 1: 437 if len(message) > 65535: 438 raise RuntimeError('Message cannot be larger than 65535 bytes') 439 440 try: 441 return await asyncio.wait_for(bytes_awaitable, timeout=timeout) 442 except asyncio.CancelledError as exc: 443 # Question: we assume this means the above wait_for() was 444 # cancelled; how do we distinguish between this and *us* being 445 # cancelled though? 446 if self.debug_print: 447 self.debug_print_call( 448 f'{self._label}: message {message_id} was cancelled.' 449 ) 450 if close_on_error: 451 self.close() 452 453 raise CommunicationError() from exc 454 except Exception as exc: 455 # If our timer timed-out or anything else went wrong with 456 # the stream, lump it in as a communication error. 457 if isinstance( 458 exc, asyncio.TimeoutError 459 ) or is_asyncio_streams_communication_error(exc): 460 if self.debug_print: 461 self.debug_print_call( 462 f'{self._label}: got {type(exc)} sending message' 463 f' {message_id}; raising CommunicationError.' 464 ) 465 466 # Stop waiting on the response. 467 bytes_awaitable.cancel() 468 469 # Remove the record of this message. 470 del self._in_flight_messages[message_id] 471 472 if close_on_error: 473 self.close() 474 475 # Let the user know something went wrong. 476 raise CommunicationError() from exc 477 478 # Some unexpected error; let it bubble up. 479 raise 480 481 def close(self) -> None: 482 """I said seagulls; mmmm; stop it now.""" 483 self._check_env() 484 485 if self._closing: 486 return 487 488 if self.debug_print: 489 self.debug_print_call(f'{self._label}: closing...') 490 491 self._closing = True 492 493 # Kill all of our in-flight tasks. 494 if self.debug_print: 495 self.debug_print_call(f'{self._label}: cancelling tasks...') 496 for task in self._get_live_tasks(): 497 task.cancel() 498 499 # Close our writer. 500 assert not self._did_close_writer 501 if self.debug_print: 502 self.debug_print_call(f'{self._label}: closing writer...') 503 self._writer.close() 504 self._did_close_writer = True 505 506 # We don't need this anymore and it is likely to be creating a 507 # dependency loop. 508 del self._handle_raw_message_call 509 510 def is_closing(self) -> bool: 511 """Have we begun the process of closing?""" 512 return self._closing 513 514 async def wait_closed(self) -> None: 515 """I said seagulls; mmmm; stop it now. 516 517 Wait for the endpoint to finish closing. This is called by run() 518 so generally does not need to be explicitly called. 519 """ 520 # pylint: disable=too-many-branches 521 self._check_env() 522 523 # Make sure we only *enter* this call once. 524 if self._did_wait_closed: 525 return 526 self._did_wait_closed = True 527 528 if not self._closing: 529 raise RuntimeError('Must be called after close()') 530 531 if not self._did_close_writer: 532 logging.warning( 533 'RPCEndpoint wait_closed() called but never' 534 ' explicitly closed writer.' 535 ) 536 537 live_tasks = self._get_live_tasks() 538 539 # Don't need our task list anymore; this should 540 # break any cyclical refs from tasks referring to us. 541 self._tasks = [] 542 543 if self.debug_print: 544 self.debug_print_call( 545 f'{self._label}: waiting for tasks to finish: ' 546 f' ({live_tasks=})...' 547 ) 548 549 # Wait for all of our in-flight tasks to wrap up. 550 results = await asyncio.gather(*live_tasks, return_exceptions=True) 551 for result in results: 552 # We want to know if any errors happened aside from CancelledError 553 # (which are BaseExceptions, not Exception). 554 if isinstance(result, Exception): 555 logging.warning( 556 'Got unexpected error cleaning up %s task: %s', 557 self._label, 558 result, 559 ) 560 561 if not all(task.done() for task in live_tasks): 562 logging.warning( 563 'RPCEndpoint %d: not all live tasks marked done after gather.', 564 id(self), 565 ) 566 567 if self.debug_print: 568 self.debug_print_call( 569 f'{self._label}: tasks finished; waiting for writer close...' 570 ) 571 572 # Now wait for our writer to finish going down. 573 # When we close our writer it generally triggers errors 574 # in our current blocked read/writes. However that same 575 # error is also sometimes returned from _writer.wait_closed(). 576 # See connection_lost() in asyncio/streams.py to see why. 577 # So let's silently ignore it when that happens. 578 assert self._writer.is_closing() 579 try: 580 # It seems that as of Python 3.9.x it is possible for this to hang 581 # indefinitely. See https://github.com/python/cpython/issues/83939 582 # It sounds like this should be fixed in 3.11 but for now just 583 # forcing the issue with a timeout here. 584 await asyncio.wait_for( 585 self._writer.wait_closed(), 586 # timeout=60.0 * 6.0, 587 timeout=30.0, 588 ) 589 except asyncio.TimeoutError: 590 logging.info( 591 'Timeout on _writer.wait_closed() for %s rpc (transport=%s).', 592 self._label, 593 ssl_stream_writer_underlying_transport_info(self._writer), 594 ) 595 if self.debug_print: 596 self.debug_print_call( 597 f'{self._label}: got timeout in _writer.wait_closed();' 598 ' This should be fixed in future Python versions.' 599 ) 600 except Exception as exc: 601 if not self._is_expected_connection_error(exc): 602 logging.exception('Error closing _writer for %s.', self._label) 603 else: 604 if self.debug_print: 605 self.debug_print_call( 606 f'{self._label}: silently ignoring error in' 607 f' _writer.wait_closed(): {exc}.' 608 ) 609 except asyncio.CancelledError: 610 logging.warning( 611 'RPCEndpoint.wait_closed()' 612 ' got asyncio.CancelledError; not expected.' 613 ) 614 raise 615 assert not self._did_wait_closed_writer 616 self._did_wait_closed_writer = True 617 618 def _tm(self) -> str: 619 """Simple readable time value for debugging.""" 620 tval = time.monotonic() % 100.0 621 return f'{tval:.2f}' 622 623 async def _run_read_task(self) -> None: 624 """Read from the peer.""" 625 self._check_env() 626 assert self._peer_info is None 627 628 # Bug fix: if we don't have this set we will never time out 629 # if we never receive any data from the other end. 630 self._last_keepalive_receive_time = time.monotonic() 631 632 # The first thing they should send us is their handshake; then 633 # we'll know if/how we can talk to them. 634 mlen = await self._read_int_32() 635 message = await self._reader.readexactly(mlen) 636 self._total_bytes_read += mlen 637 self._peer_info = dataclass_from_json(_PeerInfo, message.decode()) 638 self._last_keepalive_receive_time = time.monotonic() 639 if self.debug_print: 640 self.debug_print_call( 641 f'{self._label}: received handshake at {self._tm()}.' 642 ) 643 644 # Now just sit and handle stuff as it comes in. 645 while True: 646 if self._closing: 647 return 648 649 # Read message type. 650 mtype = _PacketType(await self._read_int_8()) 651 if mtype is _PacketType.HANDSHAKE: 652 raise RuntimeError('Got multiple handshakes') 653 654 if mtype is _PacketType.KEEPALIVE: 655 if self.debug_print_io: 656 self.debug_print_call( 657 f'{self._label}: received keepalive' 658 f' at {self._tm()}.' 659 ) 660 self._last_keepalive_receive_time = time.monotonic() 661 662 elif mtype is _PacketType.MESSAGE: 663 await self._handle_message_packet(big=False) 664 665 elif mtype is _PacketType.MESSAGE_BIG: 666 await self._handle_message_packet(big=True) 667 668 elif mtype is _PacketType.RESPONSE: 669 await self._handle_response_packet(big=False) 670 671 elif mtype is _PacketType.RESPONSE_BIG: 672 await self._handle_response_packet(big=True) 673 674 else: 675 assert_never(mtype) 676 677 async def _handle_message_packet(self, big: bool) -> None: 678 assert self._peer_info is not None 679 msgid = await self._read_int_16() 680 if big: 681 msglen = await self._read_int_32() 682 else: 683 msglen = await self._read_int_16() 684 msg = await self._reader.readexactly(msglen) 685 self._total_bytes_read += msglen 686 if self.debug_print_io: 687 self.debug_print_call( 688 f'{self._label}: received message {msgid}' 689 f' of size {msglen} at {self._tm()}.' 690 ) 691 692 # Create a message-task to handle this message and return 693 # a response (we don't want to block while that happens). 694 assert not self._closing 695 self._prune_tasks() # Keep from filling with dead tasks. 696 self._tasks.append( 697 asyncio.create_task( 698 self._handle_raw_message(message_id=msgid, message=msg), 699 name='efro rpc message handle', 700 ) 701 ) 702 if self.debug_print: 703 self.debug_print_call( 704 f'{self._label}: done handling message at {self._tm()}.' 705 ) 706 707 async def _handle_response_packet(self, big: bool) -> None: 708 assert self._peer_info is not None 709 msgid = await self._read_int_16() 710 # Protocol 2 gained 32 bit data lengths. 711 if big: 712 rsplen = await self._read_int_32() 713 else: 714 rsplen = await self._read_int_16() 715 if self.debug_print_io: 716 self.debug_print_call( 717 f'{self._label}: received response {msgid}' 718 f' of size {rsplen} at {self._tm()}.' 719 ) 720 rsp = await self._reader.readexactly(rsplen) 721 self._total_bytes_read += rsplen 722 msgobj = self._in_flight_messages.get(msgid) 723 if msgobj is None: 724 # It's possible for us to get a response to a message 725 # that has timed out. In this case we will have no local 726 # record of it. 727 if self.debug_print: 728 self.debug_print_call( 729 f'{self._label}: got response for nonexistent' 730 f' message id {msgid}; perhaps it timed out?' 731 ) 732 else: 733 msgobj.set_response(rsp) 734 735 async def _run_write_task(self) -> None: 736 """Write to the peer.""" 737 738 self._check_env() 739 740 # Introduce ourself so our peer knows how it can talk to us. 741 data = dataclass_to_json( 742 _PeerInfo( 743 protocol=OUR_PROTOCOL, 744 keepalive_interval=self._keepalive_interval, 745 ) 746 ).encode() 747 self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data) 748 749 # Now just write out-messages as they come in. 750 while True: 751 # Wait until some data comes in. 752 await self._have_out_packets.wait() 753 754 assert self._out_packets 755 data = self._out_packets.popleft() 756 757 # Important: only clear this once all packets are sent. 758 if not self._out_packets: 759 self._have_out_packets.clear() 760 761 self._writer.write(data) 762 763 # This should keep our writer from buffering huge amounts 764 # of outgoing data. We must remember though that we also 765 # need to prevent _out_packets from growing too large and 766 # that part's on us. 767 await self._writer.drain() 768 769 # For now we're not applying backpressure, but let's make 770 # noise if this gets out of hand. 771 if len(self._out_packets) > 200: 772 if not self._did_out_packets_buildup_warning: 773 logging.warning( 774 '_out_packets building up too' 775 ' much on RPCEndpoint %s.', 776 id(self), 777 ) 778 self._did_out_packets_buildup_warning = True 779 780 async def _run_keepalive_task(self) -> None: 781 """Send periodic keepalive packets.""" 782 self._check_env() 783 784 # We explicitly send our own keepalive packets so we can stay 785 # more on top of the connection state and possibly decide to 786 # kill it when contact is lost more quickly than the OS would 787 # do itself (or at least keep the user informed that the 788 # connection is lagging). It sounds like we could have the TCP 789 # layer do this sort of thing itself but that might be 790 # OS-specific so gonna go this way for now. 791 while True: 792 assert not self._closing 793 await asyncio.sleep(self._keepalive_interval) 794 if not self.test_suppress_keepalives: 795 self._enqueue_outgoing_packet( 796 _PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER) 797 ) 798 799 # Also go ahead and handle dropping the connection if we 800 # haven't heard from the peer in a while. 801 # NOTE: perhaps we want to do something more exact than 802 # this which only checks once per keepalive-interval?.. 803 now = time.monotonic() 804 if ( 805 self._last_keepalive_receive_time is not None 806 and now - self._last_keepalive_receive_time 807 > self._keepalive_timeout 808 ): 809 if self.debug_print: 810 since = now - self._last_keepalive_receive_time 811 self.debug_print_call( 812 f'{self._label}: reached keepalive time-out' 813 f' ({since:.1f}s).' 814 ) 815 raise _KeepaliveTimeoutError() 816 817 async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None: 818 try: 819 await call 820 except Exception as exc: 821 # We expect connection errors to put us here, but make noise 822 # if something else does. 823 if not self._is_expected_connection_error(exc): 824 logging.exception( 825 'Unexpected error in rpc %s %s task' 826 ' (age=%.1f, total_bytes_read=%d).', 827 self._label, 828 tasklabel, 829 time.monotonic() - self._create_time, 830 self._total_bytes_read, 831 ) 832 else: 833 if self.debug_print: 834 self.debug_print_call( 835 f'{self._label}: {tasklabel} task will exit cleanly' 836 f' due to {exc!r}.' 837 ) 838 finally: 839 # Any core task exiting triggers shutdown. 840 if self.debug_print: 841 self.debug_print_call( 842 f'{self._label}: {tasklabel} task exiting...' 843 ) 844 self.close() 845 846 async def _handle_raw_message( 847 self, message_id: int, message: bytes 848 ) -> None: 849 try: 850 response = await self._handle_raw_message_call(message) 851 except Exception: 852 # We expect local message handler to always succeed. 853 # If that doesn't happen, make a fuss so we know to fix it. 854 # The other end will simply never get a response to this 855 # message. 856 logging.exception('Error handling raw rpc message') 857 return 858 859 assert self._peer_info is not None 860 861 if self._peer_info.protocol == 1: 862 if len(response) > 65535: 863 raise RuntimeError('Response cannot be larger than 65535 bytes') 864 865 # Now send back our response. 866 # Payload consists of type (1b), msgid (2b), len (2b), and data. 867 if len(response) > 65535: 868 self._enqueue_outgoing_packet( 869 _PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER) 870 + message_id.to_bytes(2, _BYTE_ORDER) 871 + len(response).to_bytes(4, _BYTE_ORDER) 872 + response 873 ) 874 else: 875 self._enqueue_outgoing_packet( 876 _PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER) 877 + message_id.to_bytes(2, _BYTE_ORDER) 878 + len(response).to_bytes(2, _BYTE_ORDER) 879 + response 880 ) 881 882 async def _read_int_8(self) -> int: 883 out = int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER) 884 self._total_bytes_read += 1 885 return out 886 887 async def _read_int_16(self) -> int: 888 out = int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER) 889 self._total_bytes_read += 2 890 return out 891 892 async def _read_int_32(self) -> int: 893 out = int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER) 894 self._total_bytes_read += 4 895 return out 896 897 @classmethod 898 def _is_expected_connection_error(cls, exc: Exception) -> bool: 899 """Stuff we expect to end our connection in normal circumstances.""" 900 901 if isinstance(exc, _KeepaliveTimeoutError): 902 return True 903 904 return is_asyncio_streams_communication_error(exc) 905 906 def _check_env(self) -> None: 907 # I was seeing that asyncio stuff wasn't working as expected if 908 # created in one thread and used in another (and have verified 909 # that this is part of the design), so let's enforce a single 910 # thread for all use of an instance. 911 if current_thread() is not self._thread: 912 raise RuntimeError( 913 'This must be called from the same thread' 914 ' that the endpoint was created in.' 915 ) 916 917 # This should always be the case if thread is the same. 918 assert asyncio.get_running_loop() is self._event_loop 919 920 def _enqueue_outgoing_packet(self, data: bytes) -> None: 921 """Enqueue a raw packet to be sent. Must be called from our loop.""" 922 self._check_env() 923 924 if self.debug_print_io: 925 self.debug_print_call( 926 f'{self._label}: enqueueing outgoing packet' 927 f' {data[:50]!r} at {self._tm()}.' 928 ) 929 930 # Add the data and let our write task know about it. 931 self._out_packets.append(data) 932 self._have_out_packets.set() 933 934 def _prune_tasks(self) -> None: 935 self._tasks = self._get_live_tasks() 936 937 def _get_live_tasks(self) -> list[asyncio.Task]: 938 return [t for t in self._tasks if not t.done()]
70def ssl_stream_writer_underlying_transport_info( 71 writer: asyncio.StreamWriter, 72) -> str: 73 """For debugging SSL Stream connections; returns raw transport info.""" 74 # Note: accessing internals here so just returning info and not 75 # actual objs to reduce potential for breakage. 76 transport = getattr(writer, '_transport', None) 77 if transport is not None: 78 sslproto = getattr(transport, '_ssl_protocol', None) 79 if sslproto is not None: 80 raw_transport = getattr(sslproto, '_transport', None) 81 if raw_transport is not None: 82 return str(raw_transport) 83 return '(not found)'
For debugging SSL Stream connections; returns raw transport info.
86def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None: 87 """Ensure a writer is closed; hacky workaround for odd hang.""" 88 from efro.call import tpartial 89 from threading import Thread 90 91 # Disabling for now.. 92 if bool(True): 93 return 94 95 # Hopefully can remove this in Python 3.11?... 96 # see issue with is_closing() below for more details. 97 transport = getattr(writer, '_transport', None) 98 if transport is not None: 99 sslproto = getattr(transport, '_ssl_protocol', None) 100 if sslproto is not None: 101 raw_transport = getattr(sslproto, '_transport', None) 102 if raw_transport is not None: 103 Thread( 104 target=tpartial( 105 _do_writer_force_close_check, 106 weakref.ref(raw_transport), 107 ), 108 daemon=True, 109 ).start()
Ensure a writer is closed; hacky workaround for odd hang.
155class RPCEndpoint: 156 """Facilitates asynchronous multiplexed remote procedure calls. 157 158 Be aware that, while multiple calls can be in flight in either direction 159 simultaneously, packets are still sent serially in a single 160 stream. So excessively long messages/responses will delay all other 161 communication. If/when this becomes an issue we can look into breaking up 162 long messages into multiple packets. 163 """ 164 165 # Set to True on an instance to test keepalive failures. 166 test_suppress_keepalives: bool = False 167 168 # How long we should wait before giving up on a message by default. 169 # Note this includes processing time on the other end. 170 DEFAULT_MESSAGE_TIMEOUT = 60.0 171 172 # How often we send out keepalive packets by default. 173 DEFAULT_KEEPALIVE_INTERVAL = 10.73 # (avoid too regular of values) 174 175 # How long we can go without receiving a keepalive packet before we 176 # disconnect. 177 DEFAULT_KEEPALIVE_TIMEOUT = 30.0 178 179 def __init__( 180 self, 181 handle_raw_message_call: Callable[[bytes], Awaitable[bytes]], 182 reader: asyncio.StreamReader, 183 writer: asyncio.StreamWriter, 184 label: str, 185 debug_print: bool = False, 186 debug_print_io: bool = False, 187 debug_print_call: Callable[[str], None] | None = None, 188 keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL, 189 keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT, 190 ) -> None: 191 self._handle_raw_message_call = handle_raw_message_call 192 self._reader = reader 193 self._writer = writer 194 self.debug_print = debug_print 195 self.debug_print_io = debug_print_io 196 if debug_print_call is None: 197 debug_print_call = print 198 self.debug_print_call: Callable[[str], None] = debug_print_call 199 self._label = label 200 self._thread = current_thread() 201 self._closing = False 202 self._did_wait_closed = False 203 self._event_loop = asyncio.get_running_loop() 204 self._out_packets = deque[bytes]() 205 self._have_out_packets = asyncio.Event() 206 self._run_called = False 207 self._peer_info: _PeerInfo | None = None 208 self._keepalive_interval = keepalive_interval 209 self._keepalive_timeout = keepalive_timeout 210 self._did_close_writer = False 211 self._did_wait_closed_writer = False 212 self._did_out_packets_buildup_warning = False 213 self._total_bytes_read = 0 214 self._create_time = time.monotonic() 215 216 # Need to hold weak-refs to these otherwise it creates dep-loops 217 # which keeps us alive. 218 self._tasks: list[asyncio.Task] = [] 219 220 # When we last got a keepalive or equivalent (time.monotonic value) 221 self._last_keepalive_receive_time: float | None = None 222 223 # (Start near the end to make sure our looping logic is sound). 224 self._next_message_id = 65530 225 226 self._in_flight_messages: dict[int, _InFlightMessage] = {} 227 228 if self.debug_print: 229 peername = self._writer.get_extra_info('peername') 230 self.debug_print_call( 231 f'{self._label}: connected to {peername} at {self._tm()}.' 232 ) 233 234 def __del__(self) -> None: 235 if self._run_called: 236 if not self._did_close_writer: 237 logging.warning( 238 'RPCEndpoint %d dying with run' 239 ' called but writer not closed (transport=%s).', 240 id(self), 241 ssl_stream_writer_underlying_transport_info(self._writer), 242 ) 243 elif not self._did_wait_closed_writer: 244 logging.warning( 245 'RPCEndpoint %d dying with run called' 246 ' but writer not wait-closed (transport=%s).', 247 id(self), 248 ssl_stream_writer_underlying_transport_info(self._writer), 249 ) 250 251 # Currently seeing rare issue where sockets don't go down; 252 # let's add a timer to force the issue until we can figure it out. 253 ssl_stream_writer_force_close_check(self._writer) 254 255 async def run(self) -> None: 256 """Run the endpoint until the connection is lost or closed. 257 258 Handles closing the provided reader/writer on close. 259 """ 260 try: 261 await self._do_run() 262 except asyncio.CancelledError: 263 # We aren't really designed to be cancelled so let's warn 264 # if it happens. 265 logging.warning( 266 'RPCEndpoint.run got CancelledError;' 267 ' want to try and avoid this.' 268 ) 269 raise 270 271 async def _do_run(self) -> None: 272 self._check_env() 273 274 if self._run_called: 275 raise RuntimeError('Run can be called only once per endpoint.') 276 self._run_called = True 277 278 core_tasks = [ 279 asyncio.create_task( 280 self._run_core_task('keepalive', self._run_keepalive_task()), 281 name='rpc keepalive', 282 ), 283 asyncio.create_task( 284 self._run_core_task('read', self._run_read_task()), 285 name='rpc read', 286 ), 287 asyncio.create_task( 288 self._run_core_task('write', self._run_write_task()), 289 name='rpc write', 290 ), 291 ] 292 self._tasks += core_tasks 293 294 # Run our core tasks until they all complete. 295 results = await asyncio.gather(*core_tasks, return_exceptions=True) 296 297 # Core tasks should handle their own errors; the only ones 298 # we expect to bubble up are CancelledError. 299 for result in results: 300 # We want to know if any errors happened aside from CancelledError 301 # (which are BaseExceptions, not Exception). 302 if isinstance(result, Exception): 303 logging.warning( 304 'Got unexpected error from %s core task: %s', 305 self._label, 306 result, 307 ) 308 309 if not all(task.done() for task in core_tasks): 310 logging.warning( 311 'RPCEndpoint %d: not all core tasks marked done after gather.', 312 id(self), 313 ) 314 315 # Shut ourself down. 316 try: 317 self.close() 318 await self.wait_closed() 319 except Exception: 320 logging.exception('Error closing %s.', self._label) 321 322 if self.debug_print: 323 self.debug_print_call(f'{self._label}: finished.') 324 325 def send_message( 326 self, 327 message: bytes, 328 timeout: float | None = None, 329 close_on_error: bool = True, 330 ) -> Awaitable[bytes]: 331 """Send a message to the peer and return a response. 332 333 If timeout is not provided, the default will be used. 334 Raises a CommunicationError if the round trip is not completed 335 for any reason. 336 337 By default, the entire endpoint will go down in the case of 338 errors. This allows messages to be treated as 'reliable' with 339 respect to a given endpoint. Pass close_on_error=False to 340 override this for a particular message. 341 """ 342 # Note: This call is synchronous so that the first part of it 343 # (enqueueing outgoing messages) happens synchronously. If it were 344 # a pure async call it could be possible for send order to vary 345 # based on how the async tasks get processed. 346 347 if self.debug_print_io: 348 self.debug_print_call( 349 f'{self._label}: sending message of size {len(message)}' 350 f' at {self._tm()}.' 351 ) 352 353 self._check_env() 354 355 if self._closing: 356 raise CommunicationError('Endpoint is closed.') 357 358 if self.debug_print_io: 359 self.debug_print_call( 360 f'{self._label}: have peerinfo? {self._peer_info is not None}.' 361 ) 362 363 # message_id is a 16 bit looping value. 364 message_id = self._next_message_id 365 self._next_message_id = (self._next_message_id + 1) % 65536 366 367 if self.debug_print_io: 368 self.debug_print_call( 369 f'{self._label}: will enqueue at {self._tm()}.' 370 ) 371 372 # FIXME - should handle backpressure (waiting here if there are 373 # enough packets already enqueued). 374 375 if len(message) > 65535: 376 # Payload consists of type (1b), message_id (2b), 377 # len (4b), and data. 378 self._enqueue_outgoing_packet( 379 _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER) 380 + message_id.to_bytes(2, _BYTE_ORDER) 381 + len(message).to_bytes(4, _BYTE_ORDER) 382 + message 383 ) 384 else: 385 # Payload consists of type (1b), message_id (2b), 386 # len (2b), and data. 387 self._enqueue_outgoing_packet( 388 _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER) 389 + message_id.to_bytes(2, _BYTE_ORDER) 390 + len(message).to_bytes(2, _BYTE_ORDER) 391 + message 392 ) 393 394 if self.debug_print_io: 395 self.debug_print_call( 396 f'{self._label}: enqueued message of size {len(message)}' 397 f' at {self._tm()}.' 398 ) 399 400 # Make an entry so we know this message is out there. 401 assert message_id not in self._in_flight_messages 402 msgobj = self._in_flight_messages[message_id] = _InFlightMessage() 403 404 # Also add its task to our list so we properly cancel it if we die. 405 self._prune_tasks() # Keep our list from filling with dead tasks. 406 self._tasks.append(msgobj.wait_task) 407 408 # Note: we always want to incorporate a timeout. Individual 409 # messages may hang or error on the other end and this ensures 410 # we won't build up lots of zombie tasks waiting around for 411 # responses that will never arrive. 412 if timeout is None: 413 timeout = self.DEFAULT_MESSAGE_TIMEOUT 414 assert timeout is not None 415 416 bytes_awaitable = msgobj.wait_task 417 418 # Now complete the send asynchronously. 419 return self._send_message( 420 message, timeout, close_on_error, bytes_awaitable, message_id 421 ) 422 423 async def _send_message( 424 self, 425 message: bytes, 426 timeout: float | None, 427 close_on_error: bool, 428 bytes_awaitable: asyncio.Task[bytes], 429 message_id: int, 430 ) -> bytes: 431 # We need to know their protocol, so if we haven't gotten a handshake 432 # from them yet, just wait. 433 while self._peer_info is None: 434 await asyncio.sleep(0.01) 435 assert self._peer_info is not None 436 437 if self._peer_info.protocol == 1: 438 if len(message) > 65535: 439 raise RuntimeError('Message cannot be larger than 65535 bytes') 440 441 try: 442 return await asyncio.wait_for(bytes_awaitable, timeout=timeout) 443 except asyncio.CancelledError as exc: 444 # Question: we assume this means the above wait_for() was 445 # cancelled; how do we distinguish between this and *us* being 446 # cancelled though? 447 if self.debug_print: 448 self.debug_print_call( 449 f'{self._label}: message {message_id} was cancelled.' 450 ) 451 if close_on_error: 452 self.close() 453 454 raise CommunicationError() from exc 455 except Exception as exc: 456 # If our timer timed-out or anything else went wrong with 457 # the stream, lump it in as a communication error. 458 if isinstance( 459 exc, asyncio.TimeoutError 460 ) or is_asyncio_streams_communication_error(exc): 461 if self.debug_print: 462 self.debug_print_call( 463 f'{self._label}: got {type(exc)} sending message' 464 f' {message_id}; raising CommunicationError.' 465 ) 466 467 # Stop waiting on the response. 468 bytes_awaitable.cancel() 469 470 # Remove the record of this message. 471 del self._in_flight_messages[message_id] 472 473 if close_on_error: 474 self.close() 475 476 # Let the user know something went wrong. 477 raise CommunicationError() from exc 478 479 # Some unexpected error; let it bubble up. 480 raise 481 482 def close(self) -> None: 483 """I said seagulls; mmmm; stop it now.""" 484 self._check_env() 485 486 if self._closing: 487 return 488 489 if self.debug_print: 490 self.debug_print_call(f'{self._label}: closing...') 491 492 self._closing = True 493 494 # Kill all of our in-flight tasks. 495 if self.debug_print: 496 self.debug_print_call(f'{self._label}: cancelling tasks...') 497 for task in self._get_live_tasks(): 498 task.cancel() 499 500 # Close our writer. 501 assert not self._did_close_writer 502 if self.debug_print: 503 self.debug_print_call(f'{self._label}: closing writer...') 504 self._writer.close() 505 self._did_close_writer = True 506 507 # We don't need this anymore and it is likely to be creating a 508 # dependency loop. 509 del self._handle_raw_message_call 510 511 def is_closing(self) -> bool: 512 """Have we begun the process of closing?""" 513 return self._closing 514 515 async def wait_closed(self) -> None: 516 """I said seagulls; mmmm; stop it now. 517 518 Wait for the endpoint to finish closing. This is called by run() 519 so generally does not need to be explicitly called. 520 """ 521 # pylint: disable=too-many-branches 522 self._check_env() 523 524 # Make sure we only *enter* this call once. 525 if self._did_wait_closed: 526 return 527 self._did_wait_closed = True 528 529 if not self._closing: 530 raise RuntimeError('Must be called after close()') 531 532 if not self._did_close_writer: 533 logging.warning( 534 'RPCEndpoint wait_closed() called but never' 535 ' explicitly closed writer.' 536 ) 537 538 live_tasks = self._get_live_tasks() 539 540 # Don't need our task list anymore; this should 541 # break any cyclical refs from tasks referring to us. 542 self._tasks = [] 543 544 if self.debug_print: 545 self.debug_print_call( 546 f'{self._label}: waiting for tasks to finish: ' 547 f' ({live_tasks=})...' 548 ) 549 550 # Wait for all of our in-flight tasks to wrap up. 551 results = await asyncio.gather(*live_tasks, return_exceptions=True) 552 for result in results: 553 # We want to know if any errors happened aside from CancelledError 554 # (which are BaseExceptions, not Exception). 555 if isinstance(result, Exception): 556 logging.warning( 557 'Got unexpected error cleaning up %s task: %s', 558 self._label, 559 result, 560 ) 561 562 if not all(task.done() for task in live_tasks): 563 logging.warning( 564 'RPCEndpoint %d: not all live tasks marked done after gather.', 565 id(self), 566 ) 567 568 if self.debug_print: 569 self.debug_print_call( 570 f'{self._label}: tasks finished; waiting for writer close...' 571 ) 572 573 # Now wait for our writer to finish going down. 574 # When we close our writer it generally triggers errors 575 # in our current blocked read/writes. However that same 576 # error is also sometimes returned from _writer.wait_closed(). 577 # See connection_lost() in asyncio/streams.py to see why. 578 # So let's silently ignore it when that happens. 579 assert self._writer.is_closing() 580 try: 581 # It seems that as of Python 3.9.x it is possible for this to hang 582 # indefinitely. See https://github.com/python/cpython/issues/83939 583 # It sounds like this should be fixed in 3.11 but for now just 584 # forcing the issue with a timeout here. 585 await asyncio.wait_for( 586 self._writer.wait_closed(), 587 # timeout=60.0 * 6.0, 588 timeout=30.0, 589 ) 590 except asyncio.TimeoutError: 591 logging.info( 592 'Timeout on _writer.wait_closed() for %s rpc (transport=%s).', 593 self._label, 594 ssl_stream_writer_underlying_transport_info(self._writer), 595 ) 596 if self.debug_print: 597 self.debug_print_call( 598 f'{self._label}: got timeout in _writer.wait_closed();' 599 ' This should be fixed in future Python versions.' 600 ) 601 except Exception as exc: 602 if not self._is_expected_connection_error(exc): 603 logging.exception('Error closing _writer for %s.', self._label) 604 else: 605 if self.debug_print: 606 self.debug_print_call( 607 f'{self._label}: silently ignoring error in' 608 f' _writer.wait_closed(): {exc}.' 609 ) 610 except asyncio.CancelledError: 611 logging.warning( 612 'RPCEndpoint.wait_closed()' 613 ' got asyncio.CancelledError; not expected.' 614 ) 615 raise 616 assert not self._did_wait_closed_writer 617 self._did_wait_closed_writer = True 618 619 def _tm(self) -> str: 620 """Simple readable time value for debugging.""" 621 tval = time.monotonic() % 100.0 622 return f'{tval:.2f}' 623 624 async def _run_read_task(self) -> None: 625 """Read from the peer.""" 626 self._check_env() 627 assert self._peer_info is None 628 629 # Bug fix: if we don't have this set we will never time out 630 # if we never receive any data from the other end. 631 self._last_keepalive_receive_time = time.monotonic() 632 633 # The first thing they should send us is their handshake; then 634 # we'll know if/how we can talk to them. 635 mlen = await self._read_int_32() 636 message = await self._reader.readexactly(mlen) 637 self._total_bytes_read += mlen 638 self._peer_info = dataclass_from_json(_PeerInfo, message.decode()) 639 self._last_keepalive_receive_time = time.monotonic() 640 if self.debug_print: 641 self.debug_print_call( 642 f'{self._label}: received handshake at {self._tm()}.' 643 ) 644 645 # Now just sit and handle stuff as it comes in. 646 while True: 647 if self._closing: 648 return 649 650 # Read message type. 651 mtype = _PacketType(await self._read_int_8()) 652 if mtype is _PacketType.HANDSHAKE: 653 raise RuntimeError('Got multiple handshakes') 654 655 if mtype is _PacketType.KEEPALIVE: 656 if self.debug_print_io: 657 self.debug_print_call( 658 f'{self._label}: received keepalive' 659 f' at {self._tm()}.' 660 ) 661 self._last_keepalive_receive_time = time.monotonic() 662 663 elif mtype is _PacketType.MESSAGE: 664 await self._handle_message_packet(big=False) 665 666 elif mtype is _PacketType.MESSAGE_BIG: 667 await self._handle_message_packet(big=True) 668 669 elif mtype is _PacketType.RESPONSE: 670 await self._handle_response_packet(big=False) 671 672 elif mtype is _PacketType.RESPONSE_BIG: 673 await self._handle_response_packet(big=True) 674 675 else: 676 assert_never(mtype) 677 678 async def _handle_message_packet(self, big: bool) -> None: 679 assert self._peer_info is not None 680 msgid = await self._read_int_16() 681 if big: 682 msglen = await self._read_int_32() 683 else: 684 msglen = await self._read_int_16() 685 msg = await self._reader.readexactly(msglen) 686 self._total_bytes_read += msglen 687 if self.debug_print_io: 688 self.debug_print_call( 689 f'{self._label}: received message {msgid}' 690 f' of size {msglen} at {self._tm()}.' 691 ) 692 693 # Create a message-task to handle this message and return 694 # a response (we don't want to block while that happens). 695 assert not self._closing 696 self._prune_tasks() # Keep from filling with dead tasks. 697 self._tasks.append( 698 asyncio.create_task( 699 self._handle_raw_message(message_id=msgid, message=msg), 700 name='efro rpc message handle', 701 ) 702 ) 703 if self.debug_print: 704 self.debug_print_call( 705 f'{self._label}: done handling message at {self._tm()}.' 706 ) 707 708 async def _handle_response_packet(self, big: bool) -> None: 709 assert self._peer_info is not None 710 msgid = await self._read_int_16() 711 # Protocol 2 gained 32 bit data lengths. 712 if big: 713 rsplen = await self._read_int_32() 714 else: 715 rsplen = await self._read_int_16() 716 if self.debug_print_io: 717 self.debug_print_call( 718 f'{self._label}: received response {msgid}' 719 f' of size {rsplen} at {self._tm()}.' 720 ) 721 rsp = await self._reader.readexactly(rsplen) 722 self._total_bytes_read += rsplen 723 msgobj = self._in_flight_messages.get(msgid) 724 if msgobj is None: 725 # It's possible for us to get a response to a message 726 # that has timed out. In this case we will have no local 727 # record of it. 728 if self.debug_print: 729 self.debug_print_call( 730 f'{self._label}: got response for nonexistent' 731 f' message id {msgid}; perhaps it timed out?' 732 ) 733 else: 734 msgobj.set_response(rsp) 735 736 async def _run_write_task(self) -> None: 737 """Write to the peer.""" 738 739 self._check_env() 740 741 # Introduce ourself so our peer knows how it can talk to us. 742 data = dataclass_to_json( 743 _PeerInfo( 744 protocol=OUR_PROTOCOL, 745 keepalive_interval=self._keepalive_interval, 746 ) 747 ).encode() 748 self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data) 749 750 # Now just write out-messages as they come in. 751 while True: 752 # Wait until some data comes in. 753 await self._have_out_packets.wait() 754 755 assert self._out_packets 756 data = self._out_packets.popleft() 757 758 # Important: only clear this once all packets are sent. 759 if not self._out_packets: 760 self._have_out_packets.clear() 761 762 self._writer.write(data) 763 764 # This should keep our writer from buffering huge amounts 765 # of outgoing data. We must remember though that we also 766 # need to prevent _out_packets from growing too large and 767 # that part's on us. 768 await self._writer.drain() 769 770 # For now we're not applying backpressure, but let's make 771 # noise if this gets out of hand. 772 if len(self._out_packets) > 200: 773 if not self._did_out_packets_buildup_warning: 774 logging.warning( 775 '_out_packets building up too' 776 ' much on RPCEndpoint %s.', 777 id(self), 778 ) 779 self._did_out_packets_buildup_warning = True 780 781 async def _run_keepalive_task(self) -> None: 782 """Send periodic keepalive packets.""" 783 self._check_env() 784 785 # We explicitly send our own keepalive packets so we can stay 786 # more on top of the connection state and possibly decide to 787 # kill it when contact is lost more quickly than the OS would 788 # do itself (or at least keep the user informed that the 789 # connection is lagging). It sounds like we could have the TCP 790 # layer do this sort of thing itself but that might be 791 # OS-specific so gonna go this way for now. 792 while True: 793 assert not self._closing 794 await asyncio.sleep(self._keepalive_interval) 795 if not self.test_suppress_keepalives: 796 self._enqueue_outgoing_packet( 797 _PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER) 798 ) 799 800 # Also go ahead and handle dropping the connection if we 801 # haven't heard from the peer in a while. 802 # NOTE: perhaps we want to do something more exact than 803 # this which only checks once per keepalive-interval?.. 804 now = time.monotonic() 805 if ( 806 self._last_keepalive_receive_time is not None 807 and now - self._last_keepalive_receive_time 808 > self._keepalive_timeout 809 ): 810 if self.debug_print: 811 since = now - self._last_keepalive_receive_time 812 self.debug_print_call( 813 f'{self._label}: reached keepalive time-out' 814 f' ({since:.1f}s).' 815 ) 816 raise _KeepaliveTimeoutError() 817 818 async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None: 819 try: 820 await call 821 except Exception as exc: 822 # We expect connection errors to put us here, but make noise 823 # if something else does. 824 if not self._is_expected_connection_error(exc): 825 logging.exception( 826 'Unexpected error in rpc %s %s task' 827 ' (age=%.1f, total_bytes_read=%d).', 828 self._label, 829 tasklabel, 830 time.monotonic() - self._create_time, 831 self._total_bytes_read, 832 ) 833 else: 834 if self.debug_print: 835 self.debug_print_call( 836 f'{self._label}: {tasklabel} task will exit cleanly' 837 f' due to {exc!r}.' 838 ) 839 finally: 840 # Any core task exiting triggers shutdown. 841 if self.debug_print: 842 self.debug_print_call( 843 f'{self._label}: {tasklabel} task exiting...' 844 ) 845 self.close() 846 847 async def _handle_raw_message( 848 self, message_id: int, message: bytes 849 ) -> None: 850 try: 851 response = await self._handle_raw_message_call(message) 852 except Exception: 853 # We expect local message handler to always succeed. 854 # If that doesn't happen, make a fuss so we know to fix it. 855 # The other end will simply never get a response to this 856 # message. 857 logging.exception('Error handling raw rpc message') 858 return 859 860 assert self._peer_info is not None 861 862 if self._peer_info.protocol == 1: 863 if len(response) > 65535: 864 raise RuntimeError('Response cannot be larger than 65535 bytes') 865 866 # Now send back our response. 867 # Payload consists of type (1b), msgid (2b), len (2b), and data. 868 if len(response) > 65535: 869 self._enqueue_outgoing_packet( 870 _PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER) 871 + message_id.to_bytes(2, _BYTE_ORDER) 872 + len(response).to_bytes(4, _BYTE_ORDER) 873 + response 874 ) 875 else: 876 self._enqueue_outgoing_packet( 877 _PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER) 878 + message_id.to_bytes(2, _BYTE_ORDER) 879 + len(response).to_bytes(2, _BYTE_ORDER) 880 + response 881 ) 882 883 async def _read_int_8(self) -> int: 884 out = int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER) 885 self._total_bytes_read += 1 886 return out 887 888 async def _read_int_16(self) -> int: 889 out = int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER) 890 self._total_bytes_read += 2 891 return out 892 893 async def _read_int_32(self) -> int: 894 out = int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER) 895 self._total_bytes_read += 4 896 return out 897 898 @classmethod 899 def _is_expected_connection_error(cls, exc: Exception) -> bool: 900 """Stuff we expect to end our connection in normal circumstances.""" 901 902 if isinstance(exc, _KeepaliveTimeoutError): 903 return True 904 905 return is_asyncio_streams_communication_error(exc) 906 907 def _check_env(self) -> None: 908 # I was seeing that asyncio stuff wasn't working as expected if 909 # created in one thread and used in another (and have verified 910 # that this is part of the design), so let's enforce a single 911 # thread for all use of an instance. 912 if current_thread() is not self._thread: 913 raise RuntimeError( 914 'This must be called from the same thread' 915 ' that the endpoint was created in.' 916 ) 917 918 # This should always be the case if thread is the same. 919 assert asyncio.get_running_loop() is self._event_loop 920 921 def _enqueue_outgoing_packet(self, data: bytes) -> None: 922 """Enqueue a raw packet to be sent. Must be called from our loop.""" 923 self._check_env() 924 925 if self.debug_print_io: 926 self.debug_print_call( 927 f'{self._label}: enqueueing outgoing packet' 928 f' {data[:50]!r} at {self._tm()}.' 929 ) 930 931 # Add the data and let our write task know about it. 932 self._out_packets.append(data) 933 self._have_out_packets.set() 934 935 def _prune_tasks(self) -> None: 936 self._tasks = self._get_live_tasks() 937 938 def _get_live_tasks(self) -> list[asyncio.Task]: 939 return [t for t in self._tasks if not t.done()]
Facilitates asynchronous multiplexed remote procedure calls.
Be aware that, while multiple calls can be in flight in either direction simultaneously, packets are still sent serially in a single stream. So excessively long messages/responses will delay all other communication. If/when this becomes an issue we can look into breaking up long messages into multiple packets.
179 def __init__( 180 self, 181 handle_raw_message_call: Callable[[bytes], Awaitable[bytes]], 182 reader: asyncio.StreamReader, 183 writer: asyncio.StreamWriter, 184 label: str, 185 debug_print: bool = False, 186 debug_print_io: bool = False, 187 debug_print_call: Callable[[str], None] | None = None, 188 keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL, 189 keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT, 190 ) -> None: 191 self._handle_raw_message_call = handle_raw_message_call 192 self._reader = reader 193 self._writer = writer 194 self.debug_print = debug_print 195 self.debug_print_io = debug_print_io 196 if debug_print_call is None: 197 debug_print_call = print 198 self.debug_print_call: Callable[[str], None] = debug_print_call 199 self._label = label 200 self._thread = current_thread() 201 self._closing = False 202 self._did_wait_closed = False 203 self._event_loop = asyncio.get_running_loop() 204 self._out_packets = deque[bytes]() 205 self._have_out_packets = asyncio.Event() 206 self._run_called = False 207 self._peer_info: _PeerInfo | None = None 208 self._keepalive_interval = keepalive_interval 209 self._keepalive_timeout = keepalive_timeout 210 self._did_close_writer = False 211 self._did_wait_closed_writer = False 212 self._did_out_packets_buildup_warning = False 213 self._total_bytes_read = 0 214 self._create_time = time.monotonic() 215 216 # Need to hold weak-refs to these otherwise it creates dep-loops 217 # which keeps us alive. 218 self._tasks: list[asyncio.Task] = [] 219 220 # When we last got a keepalive or equivalent (time.monotonic value) 221 self._last_keepalive_receive_time: float | None = None 222 223 # (Start near the end to make sure our looping logic is sound). 224 self._next_message_id = 65530 225 226 self._in_flight_messages: dict[int, _InFlightMessage] = {} 227 228 if self.debug_print: 229 peername = self._writer.get_extra_info('peername') 230 self.debug_print_call( 231 f'{self._label}: connected to {peername} at {self._tm()}.' 232 )
255 async def run(self) -> None: 256 """Run the endpoint until the connection is lost or closed. 257 258 Handles closing the provided reader/writer on close. 259 """ 260 try: 261 await self._do_run() 262 except asyncio.CancelledError: 263 # We aren't really designed to be cancelled so let's warn 264 # if it happens. 265 logging.warning( 266 'RPCEndpoint.run got CancelledError;' 267 ' want to try and avoid this.' 268 ) 269 raise
Run the endpoint until the connection is lost or closed.
Handles closing the provided reader/writer on close.
325 def send_message( 326 self, 327 message: bytes, 328 timeout: float | None = None, 329 close_on_error: bool = True, 330 ) -> Awaitable[bytes]: 331 """Send a message to the peer and return a response. 332 333 If timeout is not provided, the default will be used. 334 Raises a CommunicationError if the round trip is not completed 335 for any reason. 336 337 By default, the entire endpoint will go down in the case of 338 errors. This allows messages to be treated as 'reliable' with 339 respect to a given endpoint. Pass close_on_error=False to 340 override this for a particular message. 341 """ 342 # Note: This call is synchronous so that the first part of it 343 # (enqueueing outgoing messages) happens synchronously. If it were 344 # a pure async call it could be possible for send order to vary 345 # based on how the async tasks get processed. 346 347 if self.debug_print_io: 348 self.debug_print_call( 349 f'{self._label}: sending message of size {len(message)}' 350 f' at {self._tm()}.' 351 ) 352 353 self._check_env() 354 355 if self._closing: 356 raise CommunicationError('Endpoint is closed.') 357 358 if self.debug_print_io: 359 self.debug_print_call( 360 f'{self._label}: have peerinfo? {self._peer_info is not None}.' 361 ) 362 363 # message_id is a 16 bit looping value. 364 message_id = self._next_message_id 365 self._next_message_id = (self._next_message_id + 1) % 65536 366 367 if self.debug_print_io: 368 self.debug_print_call( 369 f'{self._label}: will enqueue at {self._tm()}.' 370 ) 371 372 # FIXME - should handle backpressure (waiting here if there are 373 # enough packets already enqueued). 374 375 if len(message) > 65535: 376 # Payload consists of type (1b), message_id (2b), 377 # len (4b), and data. 378 self._enqueue_outgoing_packet( 379 _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER) 380 + message_id.to_bytes(2, _BYTE_ORDER) 381 + len(message).to_bytes(4, _BYTE_ORDER) 382 + message 383 ) 384 else: 385 # Payload consists of type (1b), message_id (2b), 386 # len (2b), and data. 387 self._enqueue_outgoing_packet( 388 _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER) 389 + message_id.to_bytes(2, _BYTE_ORDER) 390 + len(message).to_bytes(2, _BYTE_ORDER) 391 + message 392 ) 393 394 if self.debug_print_io: 395 self.debug_print_call( 396 f'{self._label}: enqueued message of size {len(message)}' 397 f' at {self._tm()}.' 398 ) 399 400 # Make an entry so we know this message is out there. 401 assert message_id not in self._in_flight_messages 402 msgobj = self._in_flight_messages[message_id] = _InFlightMessage() 403 404 # Also add its task to our list so we properly cancel it if we die. 405 self._prune_tasks() # Keep our list from filling with dead tasks. 406 self._tasks.append(msgobj.wait_task) 407 408 # Note: we always want to incorporate a timeout. Individual 409 # messages may hang or error on the other end and this ensures 410 # we won't build up lots of zombie tasks waiting around for 411 # responses that will never arrive. 412 if timeout is None: 413 timeout = self.DEFAULT_MESSAGE_TIMEOUT 414 assert timeout is not None 415 416 bytes_awaitable = msgobj.wait_task 417 418 # Now complete the send asynchronously. 419 return self._send_message( 420 message, timeout, close_on_error, bytes_awaitable, message_id 421 )
Send a message to the peer and return a response.
If timeout is not provided, the default will be used. Raises a CommunicationError if the round trip is not completed for any reason.
By default, the entire endpoint will go down in the case of errors. This allows messages to be treated as 'reliable' with respect to a given endpoint. Pass close_on_error=False to override this for a particular message.
482 def close(self) -> None: 483 """I said seagulls; mmmm; stop it now.""" 484 self._check_env() 485 486 if self._closing: 487 return 488 489 if self.debug_print: 490 self.debug_print_call(f'{self._label}: closing...') 491 492 self._closing = True 493 494 # Kill all of our in-flight tasks. 495 if self.debug_print: 496 self.debug_print_call(f'{self._label}: cancelling tasks...') 497 for task in self._get_live_tasks(): 498 task.cancel() 499 500 # Close our writer. 501 assert not self._did_close_writer 502 if self.debug_print: 503 self.debug_print_call(f'{self._label}: closing writer...') 504 self._writer.close() 505 self._did_close_writer = True 506 507 # We don't need this anymore and it is likely to be creating a 508 # dependency loop. 509 del self._handle_raw_message_call
I said seagulls; mmmm; stop it now.
511 def is_closing(self) -> bool: 512 """Have we begun the process of closing?""" 513 return self._closing
Have we begun the process of closing?
515 async def wait_closed(self) -> None: 516 """I said seagulls; mmmm; stop it now. 517 518 Wait for the endpoint to finish closing. This is called by run() 519 so generally does not need to be explicitly called. 520 """ 521 # pylint: disable=too-many-branches 522 self._check_env() 523 524 # Make sure we only *enter* this call once. 525 if self._did_wait_closed: 526 return 527 self._did_wait_closed = True 528 529 if not self._closing: 530 raise RuntimeError('Must be called after close()') 531 532 if not self._did_close_writer: 533 logging.warning( 534 'RPCEndpoint wait_closed() called but never' 535 ' explicitly closed writer.' 536 ) 537 538 live_tasks = self._get_live_tasks() 539 540 # Don't need our task list anymore; this should 541 # break any cyclical refs from tasks referring to us. 542 self._tasks = [] 543 544 if self.debug_print: 545 self.debug_print_call( 546 f'{self._label}: waiting for tasks to finish: ' 547 f' ({live_tasks=})...' 548 ) 549 550 # Wait for all of our in-flight tasks to wrap up. 551 results = await asyncio.gather(*live_tasks, return_exceptions=True) 552 for result in results: 553 # We want to know if any errors happened aside from CancelledError 554 # (which are BaseExceptions, not Exception). 555 if isinstance(result, Exception): 556 logging.warning( 557 'Got unexpected error cleaning up %s task: %s', 558 self._label, 559 result, 560 ) 561 562 if not all(task.done() for task in live_tasks): 563 logging.warning( 564 'RPCEndpoint %d: not all live tasks marked done after gather.', 565 id(self), 566 ) 567 568 if self.debug_print: 569 self.debug_print_call( 570 f'{self._label}: tasks finished; waiting for writer close...' 571 ) 572 573 # Now wait for our writer to finish going down. 574 # When we close our writer it generally triggers errors 575 # in our current blocked read/writes. However that same 576 # error is also sometimes returned from _writer.wait_closed(). 577 # See connection_lost() in asyncio/streams.py to see why. 578 # So let's silently ignore it when that happens. 579 assert self._writer.is_closing() 580 try: 581 # It seems that as of Python 3.9.x it is possible for this to hang 582 # indefinitely. See https://github.com/python/cpython/issues/83939 583 # It sounds like this should be fixed in 3.11 but for now just 584 # forcing the issue with a timeout here. 585 await asyncio.wait_for( 586 self._writer.wait_closed(), 587 # timeout=60.0 * 6.0, 588 timeout=30.0, 589 ) 590 except asyncio.TimeoutError: 591 logging.info( 592 'Timeout on _writer.wait_closed() for %s rpc (transport=%s).', 593 self._label, 594 ssl_stream_writer_underlying_transport_info(self._writer), 595 ) 596 if self.debug_print: 597 self.debug_print_call( 598 f'{self._label}: got timeout in _writer.wait_closed();' 599 ' This should be fixed in future Python versions.' 600 ) 601 except Exception as exc: 602 if not self._is_expected_connection_error(exc): 603 logging.exception('Error closing _writer for %s.', self._label) 604 else: 605 if self.debug_print: 606 self.debug_print_call( 607 f'{self._label}: silently ignoring error in' 608 f' _writer.wait_closed(): {exc}.' 609 ) 610 except asyncio.CancelledError: 611 logging.warning( 612 'RPCEndpoint.wait_closed()' 613 ' got asyncio.CancelledError; not expected.' 614 ) 615 raise 616 assert not self._did_wait_closed_writer 617 self._did_wait_closed_writer = True
I said seagulls; mmmm; stop it now.
Wait for the endpoint to finish closing. This is called by run() so generally does not need to be explicitly called.