efro.rpc

Remote procedure call related functionality.

  1# Released under the MIT License. See LICENSE for details.
  2#
  3"""Remote procedure call related functionality."""
  4
  5from __future__ import annotations
  6
  7import time
  8import asyncio
  9import logging
 10import weakref
 11from enum import Enum
 12from functools import partial
 13from collections import deque
 14from dataclasses import dataclass
 15from threading import current_thread
 16from typing import TYPE_CHECKING, Annotated, assert_never
 17
 18from efro.error import (
 19    CommunicationError,
 20    is_asyncio_streams_communication_error,
 21)
 22from efro.dataclassio import (
 23    dataclass_to_json,
 24    dataclass_from_json,
 25    ioprepped,
 26    IOAttrs,
 27)
 28
 29if TYPE_CHECKING:
 30    from typing import Literal, Awaitable, Callable
 31
 32# Terminology:
 33# Packet: A chunk of data consisting of a type and some type-dependent
 34#         payload. Even though we use streams we organize our transmission
 35#         into 'packets'.
 36# Message: User data which we transmit using one or more packets.
 37
 38
 39class _PacketType(Enum):
 40    HANDSHAKE = 0
 41    KEEPALIVE = 1
 42    MESSAGE = 2
 43    RESPONSE = 3
 44    MESSAGE_BIG = 4
 45    RESPONSE_BIG = 5
 46
 47
 48_BYTE_ORDER: Literal['big'] = 'big'
 49
 50
 51@ioprepped
 52@dataclass
 53class _PeerInfo:
 54    # So we can gracefully evolve how we communicate in the future.
 55    protocol: Annotated[int, IOAttrs('p')]
 56
 57    # How often we'll be sending out keepalives (in seconds).
 58    keepalive_interval: Annotated[float, IOAttrs('k')]
 59
 60
 61# Note: we are expected to be forward and backward compatible; we can
 62# increment protocol freely and expect everyone else to still talk to us.
 63# Likewise we should retain logic to communicate with older protocols.
 64# Protocol history:
 65# 1 - initial release
 66# 2 - gained big (32-bit len val) package/response packets
 67OUR_PROTOCOL = 2
 68
 69
 70def ssl_stream_writer_underlying_transport_info(
 71    writer: asyncio.StreamWriter,
 72) -> str:
 73    """For debugging SSL Stream connections; returns raw transport info."""
 74    # Note: accessing internals here so just returning info and not
 75    # actual objs to reduce potential for breakage.
 76    transport = getattr(writer, '_transport', None)
 77    if transport is not None:
 78        sslproto = getattr(transport, '_ssl_protocol', None)
 79        if sslproto is not None:
 80            raw_transport = getattr(sslproto, '_transport', None)
 81            if raw_transport is not None:
 82                return str(raw_transport)
 83    return '(not found)'
 84
 85
 86def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None:
 87    """Ensure a writer is closed; hacky workaround for odd hang."""
 88    from threading import Thread
 89
 90    # Disabling for now..
 91    if bool(True):
 92        return
 93
 94    # Hopefully can remove this in Python 3.11?...
 95    # see issue with is_closing() below for more details.
 96    transport = getattr(writer, '_transport', None)
 97    if transport is not None:
 98        sslproto = getattr(transport, '_ssl_protocol', None)
 99        if sslproto is not None:
100            raw_transport = getattr(sslproto, '_transport', None)
101            if raw_transport is not None:
102                Thread(
103                    target=partial(
104                        _do_writer_force_close_check, weakref.ref(raw_transport)
105                    ),
106                    daemon=True,
107                ).start()
108
109
110def _do_writer_force_close_check(transport_weak: weakref.ref) -> None:
111    try:
112        # Attempt to bail as soon as the obj dies.
113        # If it hasn't done so by our timeout, force-kill it.
114        starttime = time.monotonic()
115        while time.monotonic() - starttime < 10.0:
116            time.sleep(0.1)
117            if transport_weak() is None:
118                return
119        transport = transport_weak()
120        if transport is not None:
121            logging.info('Forcing abort on stuck transport %s.', transport)
122            transport.abort()
123    except Exception:
124        logging.warning('Error in writer-force-close-check', exc_info=True)
125
126
127class _InFlightMessage:
128    """Represents a message that is out on the wire."""
129
130    def __init__(self) -> None:
131        self._response: bytes | None = None
132        self._got_response = asyncio.Event()
133        self.wait_task = asyncio.create_task(
134            self._wait(), name='rpc in flight msg wait'
135        )
136
137    async def _wait(self) -> bytes:
138        await self._got_response.wait()
139        assert self._response is not None
140        return self._response
141
142    def set_response(self, data: bytes) -> None:
143        """Set response data."""
144        assert self._response is None
145        self._response = data
146        self._got_response.set()
147
148
149class _KeepaliveTimeoutError(Exception):
150    """Raised if we time out due to not receiving keepalives."""
151
152
153class RPCEndpoint:
154    """Facilitates asynchronous multiplexed remote procedure calls.
155
156    Be aware that, while multiple calls can be in flight in either direction
157    simultaneously, packets are still sent serially in a single
158    stream. So excessively long messages/responses will delay all other
159    communication. If/when this becomes an issue we can look into breaking up
160    long messages into multiple packets.
161    """
162
163    # Set to True on an instance to test keepalive failures.
164    test_suppress_keepalives: bool = False
165
166    # How long we should wait before giving up on a message by default.
167    # Note this includes processing time on the other end.
168    DEFAULT_MESSAGE_TIMEOUT = 60.0
169
170    # How often we send out keepalive packets by default.
171    DEFAULT_KEEPALIVE_INTERVAL = 10.73  # (avoid too regular of values)
172
173    # How long we can go without receiving a keepalive packet before we
174    # disconnect.
175    DEFAULT_KEEPALIVE_TIMEOUT = 30.0
176
177    def __init__(
178        self,
179        handle_raw_message_call: Callable[[bytes], Awaitable[bytes]],
180        reader: asyncio.StreamReader,
181        writer: asyncio.StreamWriter,
182        label: str,
183        *,
184        debug_print: bool = False,
185        debug_print_io: bool = False,
186        debug_print_call: Callable[[str], None] | None = None,
187        keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL,
188        keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT,
189    ) -> None:
190        self._handle_raw_message_call = handle_raw_message_call
191        self._reader = reader
192        self._writer = writer
193        self.debug_print = debug_print
194        self.debug_print_io = debug_print_io
195        if debug_print_call is None:
196            debug_print_call = print
197        self.debug_print_call: Callable[[str], None] = debug_print_call
198        self._label = label
199        self._thread = current_thread()
200        self._closing = False
201        self._did_wait_closed = False
202        self._event_loop = asyncio.get_running_loop()
203        self._out_packets = deque[bytes]()
204        self._have_out_packets = asyncio.Event()
205        self._run_called = False
206        self._peer_info: _PeerInfo | None = None
207        self._keepalive_interval = keepalive_interval
208        self._keepalive_timeout = keepalive_timeout
209        self._did_close_writer = False
210        self._did_wait_closed_writer = False
211        self._did_out_packets_buildup_warning = False
212        self._total_bytes_read = 0
213        self._create_time = time.monotonic()
214
215        # Need to hold weak-refs to these otherwise it creates dep-loops
216        # which keeps us alive.
217        self._tasks: list[asyncio.Task] = []
218
219        # When we last got a keepalive or equivalent (time.monotonic value)
220        self._last_keepalive_receive_time: float | None = None
221
222        # (Start near the end to make sure our looping logic is sound).
223        self._next_message_id = 65530
224
225        self._in_flight_messages: dict[int, _InFlightMessage] = {}
226
227        if self.debug_print:
228            peername = self._writer.get_extra_info('peername')
229            self.debug_print_call(
230                f'{self._label}: connected to {peername} at {self._tm()}.'
231            )
232
233    def __del__(self) -> None:
234        if self._run_called:
235            if not self._did_close_writer:
236                logging.warning(
237                    'RPCEndpoint %d dying with run'
238                    ' called but writer not closed (transport=%s).',
239                    id(self),
240                    ssl_stream_writer_underlying_transport_info(self._writer),
241                )
242            elif not self._did_wait_closed_writer:
243                logging.warning(
244                    'RPCEndpoint %d dying with run called'
245                    ' but writer not wait-closed (transport=%s).',
246                    id(self),
247                    ssl_stream_writer_underlying_transport_info(self._writer),
248                )
249
250        # Currently seeing rare issue where sockets don't go down;
251        # let's add a timer to force the issue until we can figure it out.
252        ssl_stream_writer_force_close_check(self._writer)
253
254    async def run(self) -> None:
255        """Run the endpoint until the connection is lost or closed.
256
257        Handles closing the provided reader/writer on close.
258        """
259        try:
260            await self._do_run()
261        except asyncio.CancelledError:
262            # We aren't really designed to be cancelled so let's warn
263            # if it happens.
264            logging.warning(
265                'RPCEndpoint.run got CancelledError;'
266                ' want to try and avoid this.'
267            )
268            raise
269
270    async def _do_run(self) -> None:
271        self._check_env()
272
273        if self._run_called:
274            raise RuntimeError('Run can be called only once per endpoint.')
275        self._run_called = True
276
277        core_tasks = [
278            asyncio.create_task(
279                self._run_core_task('keepalive', self._run_keepalive_task()),
280                name='rpc keepalive',
281            ),
282            asyncio.create_task(
283                self._run_core_task('read', self._run_read_task()),
284                name='rpc read',
285            ),
286            asyncio.create_task(
287                self._run_core_task('write', self._run_write_task()),
288                name='rpc write',
289            ),
290        ]
291        self._tasks += core_tasks
292
293        # Run our core tasks until they all complete.
294        results = await asyncio.gather(*core_tasks, return_exceptions=True)
295
296        # Core tasks should handle their own errors; the only ones
297        # we expect to bubble up are CancelledError.
298        for result in results:
299            # We want to know if any errors happened aside from CancelledError
300            # (which are BaseExceptions, not Exception).
301            if isinstance(result, Exception):
302                logging.warning(
303                    'Got unexpected error from %s core task: %s',
304                    self._label,
305                    result,
306                )
307
308        if not all(task.done() for task in core_tasks):
309            logging.warning(
310                'RPCEndpoint %d: not all core tasks marked done after gather.',
311                id(self),
312            )
313
314        # Shut ourself down.
315        try:
316            self.close()
317            await self.wait_closed()
318        except Exception:
319            logging.exception('Error closing %s.', self._label)
320
321        if self.debug_print:
322            self.debug_print_call(f'{self._label}: finished.')
323
324    def send_message(
325        self,
326        message: bytes,
327        timeout: float | None = None,
328        close_on_error: bool = True,
329    ) -> Awaitable[bytes]:
330        """Send a message to the peer and return a response.
331
332        If timeout is not provided, the default will be used.
333        Raises a CommunicationError if the round trip is not completed
334        for any reason.
335
336        By default, the entire endpoint will go down in the case of
337        errors. This allows messages to be treated as 'reliable' with
338        respect to a given endpoint. Pass close_on_error=False to
339        override this for a particular message.
340        """
341        # Note: This call is synchronous so that the first part of it
342        # (enqueueing outgoing messages) happens synchronously. If it were
343        # a pure async call it could be possible for send order to vary
344        # based on how the async tasks get processed.
345
346        if self.debug_print_io:
347            self.debug_print_call(
348                f'{self._label}: sending message of size {len(message)}'
349                f' at {self._tm()}.'
350            )
351
352        self._check_env()
353
354        if self._closing:
355            raise CommunicationError('Endpoint is closed.')
356
357        if self.debug_print_io:
358            self.debug_print_call(
359                f'{self._label}: have peerinfo? {self._peer_info is not None}.'
360            )
361
362        # message_id is a 16 bit looping value.
363        message_id = self._next_message_id
364        self._next_message_id = (self._next_message_id + 1) % 65536
365
366        if self.debug_print_io:
367            self.debug_print_call(
368                f'{self._label}: will enqueue at {self._tm()}.'
369            )
370
371        # FIXME - should handle backpressure (waiting here if there are
372        # enough packets already enqueued).
373
374        if len(message) > 65535:
375            # Payload consists of type (1b), message_id (2b),
376            # len (4b), and data.
377            self._enqueue_outgoing_packet(
378                _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER)
379                + message_id.to_bytes(2, _BYTE_ORDER)
380                + len(message).to_bytes(4, _BYTE_ORDER)
381                + message
382            )
383        else:
384            # Payload consists of type (1b), message_id (2b),
385            # len (2b), and data.
386            self._enqueue_outgoing_packet(
387                _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER)
388                + message_id.to_bytes(2, _BYTE_ORDER)
389                + len(message).to_bytes(2, _BYTE_ORDER)
390                + message
391            )
392
393        if self.debug_print_io:
394            self.debug_print_call(
395                f'{self._label}: enqueued message of size {len(message)}'
396                f' at {self._tm()}.'
397            )
398
399        # Make an entry so we know this message is out there.
400        assert message_id not in self._in_flight_messages
401        msgobj = self._in_flight_messages[message_id] = _InFlightMessage()
402
403        # Also add its task to our list so we properly cancel it if we die.
404        self._prune_tasks()  # Keep our list from filling with dead tasks.
405        self._tasks.append(msgobj.wait_task)
406
407        # Note: we always want to incorporate a timeout. Individual
408        # messages may hang or error on the other end and this ensures
409        # we won't build up lots of zombie tasks waiting around for
410        # responses that will never arrive.
411        if timeout is None:
412            timeout = self.DEFAULT_MESSAGE_TIMEOUT
413        assert timeout is not None
414
415        bytes_awaitable = msgobj.wait_task
416
417        # Now complete the send asynchronously.
418        return self._send_message(
419            message, timeout, close_on_error, bytes_awaitable, message_id
420        )
421
422    async def _send_message(
423        self,
424        message: bytes,
425        timeout: float | None,
426        close_on_error: bool,
427        bytes_awaitable: asyncio.Task[bytes],
428        message_id: int,
429    ) -> bytes:
430        # pylint: disable=too-many-positional-arguments
431        # We need to know their protocol, so if we haven't gotten a handshake
432        # from them yet, just wait.
433        while self._peer_info is None:
434            await asyncio.sleep(0.01)
435        assert self._peer_info is not None
436
437        if self._peer_info.protocol == 1:
438            if len(message) > 65535:
439                raise RuntimeError('Message cannot be larger than 65535 bytes')
440
441        try:
442            return await asyncio.wait_for(bytes_awaitable, timeout=timeout)
443        except asyncio.CancelledError as exc:
444            # Question: we assume this means the above wait_for() was
445            # cancelled; how do we distinguish between this and *us* being
446            # cancelled though?
447            if self.debug_print:
448                self.debug_print_call(
449                    f'{self._label}: message {message_id} was cancelled.'
450                )
451            if close_on_error:
452                self.close()
453
454            raise CommunicationError() from exc
455        except Exception as exc:
456            # If our timer timed-out or anything else went wrong with
457            # the stream, lump it in as a communication error.
458            if isinstance(
459                exc, asyncio.TimeoutError
460            ) or is_asyncio_streams_communication_error(exc):
461                if self.debug_print:
462                    self.debug_print_call(
463                        f'{self._label}: got {type(exc)} sending message'
464                        f' {message_id}; raising CommunicationError.'
465                    )
466
467                # Stop waiting on the response.
468                bytes_awaitable.cancel()
469
470                # Remove the record of this message.
471                del self._in_flight_messages[message_id]
472
473                if close_on_error:
474                    self.close()
475
476                # Let the user know something went wrong.
477                raise CommunicationError() from exc
478
479            # Some unexpected error; let it bubble up.
480            raise
481
482    def close(self) -> None:
483        """I said seagulls; mmmm; stop it now."""
484        self._check_env()
485
486        if self._closing:
487            return
488
489        if self.debug_print:
490            self.debug_print_call(f'{self._label}: closing...')
491
492        self._closing = True
493
494        # Kill all of our in-flight tasks.
495        if self.debug_print:
496            self.debug_print_call(f'{self._label}: cancelling tasks...')
497        for task in self._get_live_tasks():
498            task.cancel()
499
500        # Close our writer.
501        assert not self._did_close_writer
502        if self.debug_print:
503            self.debug_print_call(f'{self._label}: closing writer...')
504        self._writer.close()
505        self._did_close_writer = True
506
507        # We don't need this anymore and it is likely to be creating a
508        # dependency loop.
509        del self._handle_raw_message_call
510
511    def is_closing(self) -> bool:
512        """Have we begun the process of closing?"""
513        return self._closing
514
515    async def wait_closed(self) -> None:
516        """I said seagulls; mmmm; stop it now.
517
518        Wait for the endpoint to finish closing. This is called by run()
519        so generally does not need to be explicitly called.
520        """
521        # pylint: disable=too-many-branches
522        self._check_env()
523
524        # Make sure we only *enter* this call once.
525        if self._did_wait_closed:
526            return
527        self._did_wait_closed = True
528
529        if not self._closing:
530            raise RuntimeError('Must be called after close()')
531
532        if not self._did_close_writer:
533            logging.warning(
534                'RPCEndpoint wait_closed() called but never'
535                ' explicitly closed writer.'
536            )
537
538        live_tasks = self._get_live_tasks()
539
540        # Don't need our task list anymore; this should
541        # break any cyclical refs from tasks referring to us.
542        self._tasks = []
543
544        if self.debug_print:
545            self.debug_print_call(
546                f'{self._label}: waiting for tasks to finish: '
547                f' ({live_tasks=})...'
548            )
549
550        # Wait for all of our in-flight tasks to wrap up.
551        results = await asyncio.gather(*live_tasks, return_exceptions=True)
552        for result in results:
553            # We want to know if any errors happened aside from CancelledError
554            # (which are BaseExceptions, not Exception).
555            if isinstance(result, Exception):
556                logging.warning(
557                    'Got unexpected error cleaning up %s task: %s',
558                    self._label,
559                    result,
560                )
561
562        if not all(task.done() for task in live_tasks):
563            logging.warning(
564                'RPCEndpoint %d: not all live tasks marked done after gather.',
565                id(self),
566            )
567
568        if self.debug_print:
569            self.debug_print_call(
570                f'{self._label}: tasks finished; waiting for writer close...'
571            )
572
573        # Now wait for our writer to finish going down.
574        # When we close our writer it generally triggers errors
575        # in our current blocked read/writes. However that same
576        # error is also sometimes returned from _writer.wait_closed().
577        # See connection_lost() in asyncio/streams.py to see why.
578        # So let's silently ignore it when that happens.
579        assert self._writer.is_closing()
580        try:
581            # It seems that as of Python 3.9.x it is possible for this to hang
582            # indefinitely. See https://github.com/python/cpython/issues/83939
583            # It sounds like this should be fixed in 3.11 but for now just
584            # forcing the issue with a timeout here.
585            await asyncio.wait_for(
586                self._writer.wait_closed(),
587                # timeout=60.0 * 6.0,
588                timeout=30.0,
589            )
590        except asyncio.TimeoutError:
591            logging.info(
592                'Timeout on _writer.wait_closed() for %s rpc (transport=%s).',
593                self._label,
594                ssl_stream_writer_underlying_transport_info(self._writer),
595            )
596            if self.debug_print:
597                self.debug_print_call(
598                    f'{self._label}: got timeout in _writer.wait_closed();'
599                    ' This should be fixed in future Python versions.'
600                )
601        except Exception as exc:
602            if not self._is_expected_connection_error(exc):
603                logging.exception('Error closing _writer for %s.', self._label)
604            else:
605                if self.debug_print:
606                    self.debug_print_call(
607                        f'{self._label}: silently ignoring error in'
608                        f' _writer.wait_closed(): {exc}.'
609                    )
610        except asyncio.CancelledError:
611            logging.warning(
612                'RPCEndpoint.wait_closed()'
613                ' got asyncio.CancelledError; not expected.'
614            )
615            raise
616        assert not self._did_wait_closed_writer
617        self._did_wait_closed_writer = True
618
619    def _tm(self) -> str:
620        """Simple readable time value for debugging."""
621        tval = time.monotonic() % 100.0
622        return f'{tval:.2f}'
623
624    async def _run_read_task(self) -> None:
625        """Read from the peer."""
626        self._check_env()
627        assert self._peer_info is None
628
629        # Bug fix: if we don't have this set we will never time out
630        # if we never receive any data from the other end.
631        self._last_keepalive_receive_time = time.monotonic()
632
633        # The first thing they should send us is their handshake; then
634        # we'll know if/how we can talk to them.
635        mlen = await self._read_int_32()
636        message = await self._reader.readexactly(mlen)
637        self._total_bytes_read += mlen
638        self._peer_info = dataclass_from_json(_PeerInfo, message.decode())
639        self._last_keepalive_receive_time = time.monotonic()
640        if self.debug_print:
641            self.debug_print_call(
642                f'{self._label}: received handshake at {self._tm()}.'
643            )
644
645        # Now just sit and handle stuff as it comes in.
646        while True:
647            if self._closing:
648                return
649
650            # Read message type.
651            mtype = _PacketType(await self._read_int_8())
652            if mtype is _PacketType.HANDSHAKE:
653                raise RuntimeError('Got multiple handshakes')
654
655            if mtype is _PacketType.KEEPALIVE:
656                if self.debug_print_io:
657                    self.debug_print_call(
658                        f'{self._label}: received keepalive'
659                        f' at {self._tm()}.'
660                    )
661                self._last_keepalive_receive_time = time.monotonic()
662
663            elif mtype is _PacketType.MESSAGE:
664                await self._handle_message_packet(big=False)
665
666            elif mtype is _PacketType.MESSAGE_BIG:
667                await self._handle_message_packet(big=True)
668
669            elif mtype is _PacketType.RESPONSE:
670                await self._handle_response_packet(big=False)
671
672            elif mtype is _PacketType.RESPONSE_BIG:
673                await self._handle_response_packet(big=True)
674
675            else:
676                assert_never(mtype)
677
678    async def _handle_message_packet(self, big: bool) -> None:
679        assert self._peer_info is not None
680        msgid = await self._read_int_16()
681        if big:
682            msglen = await self._read_int_32()
683        else:
684            msglen = await self._read_int_16()
685        msg = await self._reader.readexactly(msglen)
686        self._total_bytes_read += msglen
687        if self.debug_print_io:
688            self.debug_print_call(
689                f'{self._label}: received message {msgid}'
690                f' of size {msglen} at {self._tm()}.'
691            )
692
693        # Create a message-task to handle this message and return
694        # a response (we don't want to block while that happens).
695        assert not self._closing
696        self._prune_tasks()  # Keep from filling with dead tasks.
697        self._tasks.append(
698            asyncio.create_task(
699                self._handle_raw_message(message_id=msgid, message=msg),
700                name='efro rpc message handle',
701            )
702        )
703        if self.debug_print:
704            self.debug_print_call(
705                f'{self._label}: done handling message at {self._tm()}.'
706            )
707
708    async def _handle_response_packet(self, big: bool) -> None:
709        assert self._peer_info is not None
710        msgid = await self._read_int_16()
711        # Protocol 2 gained 32 bit data lengths.
712        if big:
713            rsplen = await self._read_int_32()
714        else:
715            rsplen = await self._read_int_16()
716        if self.debug_print_io:
717            self.debug_print_call(
718                f'{self._label}: received response {msgid}'
719                f' of size {rsplen} at {self._tm()}.'
720            )
721        rsp = await self._reader.readexactly(rsplen)
722        self._total_bytes_read += rsplen
723        msgobj = self._in_flight_messages.get(msgid)
724        if msgobj is None:
725            # It's possible for us to get a response to a message
726            # that has timed out. In this case we will have no local
727            # record of it.
728            if self.debug_print:
729                self.debug_print_call(
730                    f'{self._label}: got response for nonexistent'
731                    f' message id {msgid}; perhaps it timed out?'
732                )
733        else:
734            msgobj.set_response(rsp)
735
736    async def _run_write_task(self) -> None:
737        """Write to the peer."""
738
739        self._check_env()
740
741        # Introduce ourself so our peer knows how it can talk to us.
742        data = dataclass_to_json(
743            _PeerInfo(
744                protocol=OUR_PROTOCOL,
745                keepalive_interval=self._keepalive_interval,
746            )
747        ).encode()
748        self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data)
749
750        # Now just write out-messages as they come in.
751        while True:
752            # Wait until some data comes in.
753            await self._have_out_packets.wait()
754
755            assert self._out_packets
756            data = self._out_packets.popleft()
757
758            # Important: only clear this once all packets are sent.
759            if not self._out_packets:
760                self._have_out_packets.clear()
761
762            self._writer.write(data)
763
764            # This should keep our writer from buffering huge amounts
765            # of outgoing data. We must remember though that we also
766            # need to prevent _out_packets from growing too large and
767            # that part's on us.
768            await self._writer.drain()
769
770            # For now we're not applying backpressure, but let's make
771            # noise if this gets out of hand.
772            if len(self._out_packets) > 200:
773                if not self._did_out_packets_buildup_warning:
774                    logging.warning(
775                        '_out_packets building up too'
776                        ' much on RPCEndpoint %s.',
777                        id(self),
778                    )
779                    self._did_out_packets_buildup_warning = True
780
781    async def _run_keepalive_task(self) -> None:
782        """Send periodic keepalive packets."""
783        self._check_env()
784
785        # We explicitly send our own keepalive packets so we can stay
786        # more on top of the connection state and possibly decide to
787        # kill it when contact is lost more quickly than the OS would
788        # do itself (or at least keep the user informed that the
789        # connection is lagging). It sounds like we could have the TCP
790        # layer do this sort of thing itself but that might be
791        # OS-specific so gonna go this way for now.
792        while True:
793            assert not self._closing
794            await asyncio.sleep(self._keepalive_interval)
795            if not self.test_suppress_keepalives:
796                self._enqueue_outgoing_packet(
797                    _PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER)
798                )
799
800            # Also go ahead and handle dropping the connection if we
801            # haven't heard from the peer in a while.
802            # NOTE: perhaps we want to do something more exact than
803            # this which only checks once per keepalive-interval?..
804            now = time.monotonic()
805            if (
806                self._last_keepalive_receive_time is not None
807                and now - self._last_keepalive_receive_time
808                > self._keepalive_timeout
809            ):
810                if self.debug_print:
811                    since = now - self._last_keepalive_receive_time
812                    self.debug_print_call(
813                        f'{self._label}: reached keepalive time-out'
814                        f' ({since:.1f}s).'
815                    )
816                raise _KeepaliveTimeoutError()
817
818    async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None:
819        try:
820            await call
821        except Exception as exc:
822            # We expect connection errors to put us here, but make noise
823            # if something else does.
824            if not self._is_expected_connection_error(exc):
825                logging.exception(
826                    'Unexpected error in rpc %s %s task'
827                    ' (age=%.1f, total_bytes_read=%d).',
828                    self._label,
829                    tasklabel,
830                    time.monotonic() - self._create_time,
831                    self._total_bytes_read,
832                )
833            else:
834                if self.debug_print:
835                    self.debug_print_call(
836                        f'{self._label}: {tasklabel} task will exit cleanly'
837                        f' due to {exc!r}.'
838                    )
839        finally:
840            # Any core task exiting triggers shutdown.
841            if self.debug_print:
842                self.debug_print_call(
843                    f'{self._label}: {tasklabel} task exiting...'
844                )
845            self.close()
846
847    async def _handle_raw_message(
848        self, message_id: int, message: bytes
849    ) -> None:
850        try:
851            response = await self._handle_raw_message_call(message)
852        except Exception:
853            # We expect local message handler to always succeed.
854            # If that doesn't happen, make a fuss so we know to fix it.
855            # The other end will simply never get a response to this
856            # message.
857            logging.exception('Error handling raw rpc message')
858            return
859
860        assert self._peer_info is not None
861
862        if self._peer_info.protocol == 1:
863            if len(response) > 65535:
864                raise RuntimeError('Response cannot be larger than 65535 bytes')
865
866        # Now send back our response.
867        # Payload consists of type (1b), msgid (2b), len (2b), and data.
868        if len(response) > 65535:
869            self._enqueue_outgoing_packet(
870                _PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER)
871                + message_id.to_bytes(2, _BYTE_ORDER)
872                + len(response).to_bytes(4, _BYTE_ORDER)
873                + response
874            )
875        else:
876            self._enqueue_outgoing_packet(
877                _PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER)
878                + message_id.to_bytes(2, _BYTE_ORDER)
879                + len(response).to_bytes(2, _BYTE_ORDER)
880                + response
881            )
882
883    async def _read_int_8(self) -> int:
884        out = int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER)
885        self._total_bytes_read += 1
886        return out
887
888    async def _read_int_16(self) -> int:
889        out = int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER)
890        self._total_bytes_read += 2
891        return out
892
893    async def _read_int_32(self) -> int:
894        out = int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER)
895        self._total_bytes_read += 4
896        return out
897
898    @classmethod
899    def _is_expected_connection_error(cls, exc: Exception) -> bool:
900        """Stuff we expect to end our connection in normal circumstances."""
901
902        if isinstance(exc, _KeepaliveTimeoutError):
903            return True
904
905        return is_asyncio_streams_communication_error(exc)
906
907    def _check_env(self) -> None:
908        # I was seeing that asyncio stuff wasn't working as expected if
909        # created in one thread and used in another (and have verified
910        # that this is part of the design), so let's enforce a single
911        # thread for all use of an instance.
912        if current_thread() is not self._thread:
913            raise RuntimeError(
914                'This must be called from the same thread'
915                ' that the endpoint was created in.'
916            )
917
918        # This should always be the case if thread is the same.
919        assert asyncio.get_running_loop() is self._event_loop
920
921    def _enqueue_outgoing_packet(self, data: bytes) -> None:
922        """Enqueue a raw packet to be sent. Must be called from our loop."""
923        self._check_env()
924
925        if self.debug_print_io:
926            self.debug_print_call(
927                f'{self._label}: enqueueing outgoing packet'
928                f' {data[:50]!r} at {self._tm()}.'
929            )
930
931        # Add the data and let our write task know about it.
932        self._out_packets.append(data)
933        self._have_out_packets.set()
934
935    def _prune_tasks(self) -> None:
936        self._tasks = self._get_live_tasks()
937
938    def _get_live_tasks(self) -> list[asyncio.Task]:
939        return [t for t in self._tasks if not t.done()]
OUR_PROTOCOL = 2
def ssl_stream_writer_underlying_transport_info(writer: asyncio.streams.StreamWriter) -> str:
71def ssl_stream_writer_underlying_transport_info(
72    writer: asyncio.StreamWriter,
73) -> str:
74    """For debugging SSL Stream connections; returns raw transport info."""
75    # Note: accessing internals here so just returning info and not
76    # actual objs to reduce potential for breakage.
77    transport = getattr(writer, '_transport', None)
78    if transport is not None:
79        sslproto = getattr(transport, '_ssl_protocol', None)
80        if sslproto is not None:
81            raw_transport = getattr(sslproto, '_transport', None)
82            if raw_transport is not None:
83                return str(raw_transport)
84    return '(not found)'

For debugging SSL Stream connections; returns raw transport info.

def ssl_stream_writer_force_close_check(writer: asyncio.streams.StreamWriter) -> None:
 87def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None:
 88    """Ensure a writer is closed; hacky workaround for odd hang."""
 89    from threading import Thread
 90
 91    # Disabling for now..
 92    if bool(True):
 93        return
 94
 95    # Hopefully can remove this in Python 3.11?...
 96    # see issue with is_closing() below for more details.
 97    transport = getattr(writer, '_transport', None)
 98    if transport is not None:
 99        sslproto = getattr(transport, '_ssl_protocol', None)
100        if sslproto is not None:
101            raw_transport = getattr(sslproto, '_transport', None)
102            if raw_transport is not None:
103                Thread(
104                    target=partial(
105                        _do_writer_force_close_check, weakref.ref(raw_transport)
106                    ),
107                    daemon=True,
108                ).start()

Ensure a writer is closed; hacky workaround for odd hang.

class RPCEndpoint:
154class RPCEndpoint:
155    """Facilitates asynchronous multiplexed remote procedure calls.
156
157    Be aware that, while multiple calls can be in flight in either direction
158    simultaneously, packets are still sent serially in a single
159    stream. So excessively long messages/responses will delay all other
160    communication. If/when this becomes an issue we can look into breaking up
161    long messages into multiple packets.
162    """
163
164    # Set to True on an instance to test keepalive failures.
165    test_suppress_keepalives: bool = False
166
167    # How long we should wait before giving up on a message by default.
168    # Note this includes processing time on the other end.
169    DEFAULT_MESSAGE_TIMEOUT = 60.0
170
171    # How often we send out keepalive packets by default.
172    DEFAULT_KEEPALIVE_INTERVAL = 10.73  # (avoid too regular of values)
173
174    # How long we can go without receiving a keepalive packet before we
175    # disconnect.
176    DEFAULT_KEEPALIVE_TIMEOUT = 30.0
177
178    def __init__(
179        self,
180        handle_raw_message_call: Callable[[bytes], Awaitable[bytes]],
181        reader: asyncio.StreamReader,
182        writer: asyncio.StreamWriter,
183        label: str,
184        *,
185        debug_print: bool = False,
186        debug_print_io: bool = False,
187        debug_print_call: Callable[[str], None] | None = None,
188        keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL,
189        keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT,
190    ) -> None:
191        self._handle_raw_message_call = handle_raw_message_call
192        self._reader = reader
193        self._writer = writer
194        self.debug_print = debug_print
195        self.debug_print_io = debug_print_io
196        if debug_print_call is None:
197            debug_print_call = print
198        self.debug_print_call: Callable[[str], None] = debug_print_call
199        self._label = label
200        self._thread = current_thread()
201        self._closing = False
202        self._did_wait_closed = False
203        self._event_loop = asyncio.get_running_loop()
204        self._out_packets = deque[bytes]()
205        self._have_out_packets = asyncio.Event()
206        self._run_called = False
207        self._peer_info: _PeerInfo | None = None
208        self._keepalive_interval = keepalive_interval
209        self._keepalive_timeout = keepalive_timeout
210        self._did_close_writer = False
211        self._did_wait_closed_writer = False
212        self._did_out_packets_buildup_warning = False
213        self._total_bytes_read = 0
214        self._create_time = time.monotonic()
215
216        # Need to hold weak-refs to these otherwise it creates dep-loops
217        # which keeps us alive.
218        self._tasks: list[asyncio.Task] = []
219
220        # When we last got a keepalive or equivalent (time.monotonic value)
221        self._last_keepalive_receive_time: float | None = None
222
223        # (Start near the end to make sure our looping logic is sound).
224        self._next_message_id = 65530
225
226        self._in_flight_messages: dict[int, _InFlightMessage] = {}
227
228        if self.debug_print:
229            peername = self._writer.get_extra_info('peername')
230            self.debug_print_call(
231                f'{self._label}: connected to {peername} at {self._tm()}.'
232            )
233
234    def __del__(self) -> None:
235        if self._run_called:
236            if not self._did_close_writer:
237                logging.warning(
238                    'RPCEndpoint %d dying with run'
239                    ' called but writer not closed (transport=%s).',
240                    id(self),
241                    ssl_stream_writer_underlying_transport_info(self._writer),
242                )
243            elif not self._did_wait_closed_writer:
244                logging.warning(
245                    'RPCEndpoint %d dying with run called'
246                    ' but writer not wait-closed (transport=%s).',
247                    id(self),
248                    ssl_stream_writer_underlying_transport_info(self._writer),
249                )
250
251        # Currently seeing rare issue where sockets don't go down;
252        # let's add a timer to force the issue until we can figure it out.
253        ssl_stream_writer_force_close_check(self._writer)
254
255    async def run(self) -> None:
256        """Run the endpoint until the connection is lost or closed.
257
258        Handles closing the provided reader/writer on close.
259        """
260        try:
261            await self._do_run()
262        except asyncio.CancelledError:
263            # We aren't really designed to be cancelled so let's warn
264            # if it happens.
265            logging.warning(
266                'RPCEndpoint.run got CancelledError;'
267                ' want to try and avoid this.'
268            )
269            raise
270
271    async def _do_run(self) -> None:
272        self._check_env()
273
274        if self._run_called:
275            raise RuntimeError('Run can be called only once per endpoint.')
276        self._run_called = True
277
278        core_tasks = [
279            asyncio.create_task(
280                self._run_core_task('keepalive', self._run_keepalive_task()),
281                name='rpc keepalive',
282            ),
283            asyncio.create_task(
284                self._run_core_task('read', self._run_read_task()),
285                name='rpc read',
286            ),
287            asyncio.create_task(
288                self._run_core_task('write', self._run_write_task()),
289                name='rpc write',
290            ),
291        ]
292        self._tasks += core_tasks
293
294        # Run our core tasks until they all complete.
295        results = await asyncio.gather(*core_tasks, return_exceptions=True)
296
297        # Core tasks should handle their own errors; the only ones
298        # we expect to bubble up are CancelledError.
299        for result in results:
300            # We want to know if any errors happened aside from CancelledError
301            # (which are BaseExceptions, not Exception).
302            if isinstance(result, Exception):
303                logging.warning(
304                    'Got unexpected error from %s core task: %s',
305                    self._label,
306                    result,
307                )
308
309        if not all(task.done() for task in core_tasks):
310            logging.warning(
311                'RPCEndpoint %d: not all core tasks marked done after gather.',
312                id(self),
313            )
314
315        # Shut ourself down.
316        try:
317            self.close()
318            await self.wait_closed()
319        except Exception:
320            logging.exception('Error closing %s.', self._label)
321
322        if self.debug_print:
323            self.debug_print_call(f'{self._label}: finished.')
324
325    def send_message(
326        self,
327        message: bytes,
328        timeout: float | None = None,
329        close_on_error: bool = True,
330    ) -> Awaitable[bytes]:
331        """Send a message to the peer and return a response.
332
333        If timeout is not provided, the default will be used.
334        Raises a CommunicationError if the round trip is not completed
335        for any reason.
336
337        By default, the entire endpoint will go down in the case of
338        errors. This allows messages to be treated as 'reliable' with
339        respect to a given endpoint. Pass close_on_error=False to
340        override this for a particular message.
341        """
342        # Note: This call is synchronous so that the first part of it
343        # (enqueueing outgoing messages) happens synchronously. If it were
344        # a pure async call it could be possible for send order to vary
345        # based on how the async tasks get processed.
346
347        if self.debug_print_io:
348            self.debug_print_call(
349                f'{self._label}: sending message of size {len(message)}'
350                f' at {self._tm()}.'
351            )
352
353        self._check_env()
354
355        if self._closing:
356            raise CommunicationError('Endpoint is closed.')
357
358        if self.debug_print_io:
359            self.debug_print_call(
360                f'{self._label}: have peerinfo? {self._peer_info is not None}.'
361            )
362
363        # message_id is a 16 bit looping value.
364        message_id = self._next_message_id
365        self._next_message_id = (self._next_message_id + 1) % 65536
366
367        if self.debug_print_io:
368            self.debug_print_call(
369                f'{self._label}: will enqueue at {self._tm()}.'
370            )
371
372        # FIXME - should handle backpressure (waiting here if there are
373        # enough packets already enqueued).
374
375        if len(message) > 65535:
376            # Payload consists of type (1b), message_id (2b),
377            # len (4b), and data.
378            self._enqueue_outgoing_packet(
379                _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER)
380                + message_id.to_bytes(2, _BYTE_ORDER)
381                + len(message).to_bytes(4, _BYTE_ORDER)
382                + message
383            )
384        else:
385            # Payload consists of type (1b), message_id (2b),
386            # len (2b), and data.
387            self._enqueue_outgoing_packet(
388                _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER)
389                + message_id.to_bytes(2, _BYTE_ORDER)
390                + len(message).to_bytes(2, _BYTE_ORDER)
391                + message
392            )
393
394        if self.debug_print_io:
395            self.debug_print_call(
396                f'{self._label}: enqueued message of size {len(message)}'
397                f' at {self._tm()}.'
398            )
399
400        # Make an entry so we know this message is out there.
401        assert message_id not in self._in_flight_messages
402        msgobj = self._in_flight_messages[message_id] = _InFlightMessage()
403
404        # Also add its task to our list so we properly cancel it if we die.
405        self._prune_tasks()  # Keep our list from filling with dead tasks.
406        self._tasks.append(msgobj.wait_task)
407
408        # Note: we always want to incorporate a timeout. Individual
409        # messages may hang or error on the other end and this ensures
410        # we won't build up lots of zombie tasks waiting around for
411        # responses that will never arrive.
412        if timeout is None:
413            timeout = self.DEFAULT_MESSAGE_TIMEOUT
414        assert timeout is not None
415
416        bytes_awaitable = msgobj.wait_task
417
418        # Now complete the send asynchronously.
419        return self._send_message(
420            message, timeout, close_on_error, bytes_awaitable, message_id
421        )
422
423    async def _send_message(
424        self,
425        message: bytes,
426        timeout: float | None,
427        close_on_error: bool,
428        bytes_awaitable: asyncio.Task[bytes],
429        message_id: int,
430    ) -> bytes:
431        # pylint: disable=too-many-positional-arguments
432        # We need to know their protocol, so if we haven't gotten a handshake
433        # from them yet, just wait.
434        while self._peer_info is None:
435            await asyncio.sleep(0.01)
436        assert self._peer_info is not None
437
438        if self._peer_info.protocol == 1:
439            if len(message) > 65535:
440                raise RuntimeError('Message cannot be larger than 65535 bytes')
441
442        try:
443            return await asyncio.wait_for(bytes_awaitable, timeout=timeout)
444        except asyncio.CancelledError as exc:
445            # Question: we assume this means the above wait_for() was
446            # cancelled; how do we distinguish between this and *us* being
447            # cancelled though?
448            if self.debug_print:
449                self.debug_print_call(
450                    f'{self._label}: message {message_id} was cancelled.'
451                )
452            if close_on_error:
453                self.close()
454
455            raise CommunicationError() from exc
456        except Exception as exc:
457            # If our timer timed-out or anything else went wrong with
458            # the stream, lump it in as a communication error.
459            if isinstance(
460                exc, asyncio.TimeoutError
461            ) or is_asyncio_streams_communication_error(exc):
462                if self.debug_print:
463                    self.debug_print_call(
464                        f'{self._label}: got {type(exc)} sending message'
465                        f' {message_id}; raising CommunicationError.'
466                    )
467
468                # Stop waiting on the response.
469                bytes_awaitable.cancel()
470
471                # Remove the record of this message.
472                del self._in_flight_messages[message_id]
473
474                if close_on_error:
475                    self.close()
476
477                # Let the user know something went wrong.
478                raise CommunicationError() from exc
479
480            # Some unexpected error; let it bubble up.
481            raise
482
483    def close(self) -> None:
484        """I said seagulls; mmmm; stop it now."""
485        self._check_env()
486
487        if self._closing:
488            return
489
490        if self.debug_print:
491            self.debug_print_call(f'{self._label}: closing...')
492
493        self._closing = True
494
495        # Kill all of our in-flight tasks.
496        if self.debug_print:
497            self.debug_print_call(f'{self._label}: cancelling tasks...')
498        for task in self._get_live_tasks():
499            task.cancel()
500
501        # Close our writer.
502        assert not self._did_close_writer
503        if self.debug_print:
504            self.debug_print_call(f'{self._label}: closing writer...')
505        self._writer.close()
506        self._did_close_writer = True
507
508        # We don't need this anymore and it is likely to be creating a
509        # dependency loop.
510        del self._handle_raw_message_call
511
512    def is_closing(self) -> bool:
513        """Have we begun the process of closing?"""
514        return self._closing
515
516    async def wait_closed(self) -> None:
517        """I said seagulls; mmmm; stop it now.
518
519        Wait for the endpoint to finish closing. This is called by run()
520        so generally does not need to be explicitly called.
521        """
522        # pylint: disable=too-many-branches
523        self._check_env()
524
525        # Make sure we only *enter* this call once.
526        if self._did_wait_closed:
527            return
528        self._did_wait_closed = True
529
530        if not self._closing:
531            raise RuntimeError('Must be called after close()')
532
533        if not self._did_close_writer:
534            logging.warning(
535                'RPCEndpoint wait_closed() called but never'
536                ' explicitly closed writer.'
537            )
538
539        live_tasks = self._get_live_tasks()
540
541        # Don't need our task list anymore; this should
542        # break any cyclical refs from tasks referring to us.
543        self._tasks = []
544
545        if self.debug_print:
546            self.debug_print_call(
547                f'{self._label}: waiting for tasks to finish: '
548                f' ({live_tasks=})...'
549            )
550
551        # Wait for all of our in-flight tasks to wrap up.
552        results = await asyncio.gather(*live_tasks, return_exceptions=True)
553        for result in results:
554            # We want to know if any errors happened aside from CancelledError
555            # (which are BaseExceptions, not Exception).
556            if isinstance(result, Exception):
557                logging.warning(
558                    'Got unexpected error cleaning up %s task: %s',
559                    self._label,
560                    result,
561                )
562
563        if not all(task.done() for task in live_tasks):
564            logging.warning(
565                'RPCEndpoint %d: not all live tasks marked done after gather.',
566                id(self),
567            )
568
569        if self.debug_print:
570            self.debug_print_call(
571                f'{self._label}: tasks finished; waiting for writer close...'
572            )
573
574        # Now wait for our writer to finish going down.
575        # When we close our writer it generally triggers errors
576        # in our current blocked read/writes. However that same
577        # error is also sometimes returned from _writer.wait_closed().
578        # See connection_lost() in asyncio/streams.py to see why.
579        # So let's silently ignore it when that happens.
580        assert self._writer.is_closing()
581        try:
582            # It seems that as of Python 3.9.x it is possible for this to hang
583            # indefinitely. See https://github.com/python/cpython/issues/83939
584            # It sounds like this should be fixed in 3.11 but for now just
585            # forcing the issue with a timeout here.
586            await asyncio.wait_for(
587                self._writer.wait_closed(),
588                # timeout=60.0 * 6.0,
589                timeout=30.0,
590            )
591        except asyncio.TimeoutError:
592            logging.info(
593                'Timeout on _writer.wait_closed() for %s rpc (transport=%s).',
594                self._label,
595                ssl_stream_writer_underlying_transport_info(self._writer),
596            )
597            if self.debug_print:
598                self.debug_print_call(
599                    f'{self._label}: got timeout in _writer.wait_closed();'
600                    ' This should be fixed in future Python versions.'
601                )
602        except Exception as exc:
603            if not self._is_expected_connection_error(exc):
604                logging.exception('Error closing _writer for %s.', self._label)
605            else:
606                if self.debug_print:
607                    self.debug_print_call(
608                        f'{self._label}: silently ignoring error in'
609                        f' _writer.wait_closed(): {exc}.'
610                    )
611        except asyncio.CancelledError:
612            logging.warning(
613                'RPCEndpoint.wait_closed()'
614                ' got asyncio.CancelledError; not expected.'
615            )
616            raise
617        assert not self._did_wait_closed_writer
618        self._did_wait_closed_writer = True
619
620    def _tm(self) -> str:
621        """Simple readable time value for debugging."""
622        tval = time.monotonic() % 100.0
623        return f'{tval:.2f}'
624
625    async def _run_read_task(self) -> None:
626        """Read from the peer."""
627        self._check_env()
628        assert self._peer_info is None
629
630        # Bug fix: if we don't have this set we will never time out
631        # if we never receive any data from the other end.
632        self._last_keepalive_receive_time = time.monotonic()
633
634        # The first thing they should send us is their handshake; then
635        # we'll know if/how we can talk to them.
636        mlen = await self._read_int_32()
637        message = await self._reader.readexactly(mlen)
638        self._total_bytes_read += mlen
639        self._peer_info = dataclass_from_json(_PeerInfo, message.decode())
640        self._last_keepalive_receive_time = time.monotonic()
641        if self.debug_print:
642            self.debug_print_call(
643                f'{self._label}: received handshake at {self._tm()}.'
644            )
645
646        # Now just sit and handle stuff as it comes in.
647        while True:
648            if self._closing:
649                return
650
651            # Read message type.
652            mtype = _PacketType(await self._read_int_8())
653            if mtype is _PacketType.HANDSHAKE:
654                raise RuntimeError('Got multiple handshakes')
655
656            if mtype is _PacketType.KEEPALIVE:
657                if self.debug_print_io:
658                    self.debug_print_call(
659                        f'{self._label}: received keepalive'
660                        f' at {self._tm()}.'
661                    )
662                self._last_keepalive_receive_time = time.monotonic()
663
664            elif mtype is _PacketType.MESSAGE:
665                await self._handle_message_packet(big=False)
666
667            elif mtype is _PacketType.MESSAGE_BIG:
668                await self._handle_message_packet(big=True)
669
670            elif mtype is _PacketType.RESPONSE:
671                await self._handle_response_packet(big=False)
672
673            elif mtype is _PacketType.RESPONSE_BIG:
674                await self._handle_response_packet(big=True)
675
676            else:
677                assert_never(mtype)
678
679    async def _handle_message_packet(self, big: bool) -> None:
680        assert self._peer_info is not None
681        msgid = await self._read_int_16()
682        if big:
683            msglen = await self._read_int_32()
684        else:
685            msglen = await self._read_int_16()
686        msg = await self._reader.readexactly(msglen)
687        self._total_bytes_read += msglen
688        if self.debug_print_io:
689            self.debug_print_call(
690                f'{self._label}: received message {msgid}'
691                f' of size {msglen} at {self._tm()}.'
692            )
693
694        # Create a message-task to handle this message and return
695        # a response (we don't want to block while that happens).
696        assert not self._closing
697        self._prune_tasks()  # Keep from filling with dead tasks.
698        self._tasks.append(
699            asyncio.create_task(
700                self._handle_raw_message(message_id=msgid, message=msg),
701                name='efro rpc message handle',
702            )
703        )
704        if self.debug_print:
705            self.debug_print_call(
706                f'{self._label}: done handling message at {self._tm()}.'
707            )
708
709    async def _handle_response_packet(self, big: bool) -> None:
710        assert self._peer_info is not None
711        msgid = await self._read_int_16()
712        # Protocol 2 gained 32 bit data lengths.
713        if big:
714            rsplen = await self._read_int_32()
715        else:
716            rsplen = await self._read_int_16()
717        if self.debug_print_io:
718            self.debug_print_call(
719                f'{self._label}: received response {msgid}'
720                f' of size {rsplen} at {self._tm()}.'
721            )
722        rsp = await self._reader.readexactly(rsplen)
723        self._total_bytes_read += rsplen
724        msgobj = self._in_flight_messages.get(msgid)
725        if msgobj is None:
726            # It's possible for us to get a response to a message
727            # that has timed out. In this case we will have no local
728            # record of it.
729            if self.debug_print:
730                self.debug_print_call(
731                    f'{self._label}: got response for nonexistent'
732                    f' message id {msgid}; perhaps it timed out?'
733                )
734        else:
735            msgobj.set_response(rsp)
736
737    async def _run_write_task(self) -> None:
738        """Write to the peer."""
739
740        self._check_env()
741
742        # Introduce ourself so our peer knows how it can talk to us.
743        data = dataclass_to_json(
744            _PeerInfo(
745                protocol=OUR_PROTOCOL,
746                keepalive_interval=self._keepalive_interval,
747            )
748        ).encode()
749        self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data)
750
751        # Now just write out-messages as they come in.
752        while True:
753            # Wait until some data comes in.
754            await self._have_out_packets.wait()
755
756            assert self._out_packets
757            data = self._out_packets.popleft()
758
759            # Important: only clear this once all packets are sent.
760            if not self._out_packets:
761                self._have_out_packets.clear()
762
763            self._writer.write(data)
764
765            # This should keep our writer from buffering huge amounts
766            # of outgoing data. We must remember though that we also
767            # need to prevent _out_packets from growing too large and
768            # that part's on us.
769            await self._writer.drain()
770
771            # For now we're not applying backpressure, but let's make
772            # noise if this gets out of hand.
773            if len(self._out_packets) > 200:
774                if not self._did_out_packets_buildup_warning:
775                    logging.warning(
776                        '_out_packets building up too'
777                        ' much on RPCEndpoint %s.',
778                        id(self),
779                    )
780                    self._did_out_packets_buildup_warning = True
781
782    async def _run_keepalive_task(self) -> None:
783        """Send periodic keepalive packets."""
784        self._check_env()
785
786        # We explicitly send our own keepalive packets so we can stay
787        # more on top of the connection state and possibly decide to
788        # kill it when contact is lost more quickly than the OS would
789        # do itself (or at least keep the user informed that the
790        # connection is lagging). It sounds like we could have the TCP
791        # layer do this sort of thing itself but that might be
792        # OS-specific so gonna go this way for now.
793        while True:
794            assert not self._closing
795            await asyncio.sleep(self._keepalive_interval)
796            if not self.test_suppress_keepalives:
797                self._enqueue_outgoing_packet(
798                    _PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER)
799                )
800
801            # Also go ahead and handle dropping the connection if we
802            # haven't heard from the peer in a while.
803            # NOTE: perhaps we want to do something more exact than
804            # this which only checks once per keepalive-interval?..
805            now = time.monotonic()
806            if (
807                self._last_keepalive_receive_time is not None
808                and now - self._last_keepalive_receive_time
809                > self._keepalive_timeout
810            ):
811                if self.debug_print:
812                    since = now - self._last_keepalive_receive_time
813                    self.debug_print_call(
814                        f'{self._label}: reached keepalive time-out'
815                        f' ({since:.1f}s).'
816                    )
817                raise _KeepaliveTimeoutError()
818
819    async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None:
820        try:
821            await call
822        except Exception as exc:
823            # We expect connection errors to put us here, but make noise
824            # if something else does.
825            if not self._is_expected_connection_error(exc):
826                logging.exception(
827                    'Unexpected error in rpc %s %s task'
828                    ' (age=%.1f, total_bytes_read=%d).',
829                    self._label,
830                    tasklabel,
831                    time.monotonic() - self._create_time,
832                    self._total_bytes_read,
833                )
834            else:
835                if self.debug_print:
836                    self.debug_print_call(
837                        f'{self._label}: {tasklabel} task will exit cleanly'
838                        f' due to {exc!r}.'
839                    )
840        finally:
841            # Any core task exiting triggers shutdown.
842            if self.debug_print:
843                self.debug_print_call(
844                    f'{self._label}: {tasklabel} task exiting...'
845                )
846            self.close()
847
848    async def _handle_raw_message(
849        self, message_id: int, message: bytes
850    ) -> None:
851        try:
852            response = await self._handle_raw_message_call(message)
853        except Exception:
854            # We expect local message handler to always succeed.
855            # If that doesn't happen, make a fuss so we know to fix it.
856            # The other end will simply never get a response to this
857            # message.
858            logging.exception('Error handling raw rpc message')
859            return
860
861        assert self._peer_info is not None
862
863        if self._peer_info.protocol == 1:
864            if len(response) > 65535:
865                raise RuntimeError('Response cannot be larger than 65535 bytes')
866
867        # Now send back our response.
868        # Payload consists of type (1b), msgid (2b), len (2b), and data.
869        if len(response) > 65535:
870            self._enqueue_outgoing_packet(
871                _PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER)
872                + message_id.to_bytes(2, _BYTE_ORDER)
873                + len(response).to_bytes(4, _BYTE_ORDER)
874                + response
875            )
876        else:
877            self._enqueue_outgoing_packet(
878                _PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER)
879                + message_id.to_bytes(2, _BYTE_ORDER)
880                + len(response).to_bytes(2, _BYTE_ORDER)
881                + response
882            )
883
884    async def _read_int_8(self) -> int:
885        out = int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER)
886        self._total_bytes_read += 1
887        return out
888
889    async def _read_int_16(self) -> int:
890        out = int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER)
891        self._total_bytes_read += 2
892        return out
893
894    async def _read_int_32(self) -> int:
895        out = int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER)
896        self._total_bytes_read += 4
897        return out
898
899    @classmethod
900    def _is_expected_connection_error(cls, exc: Exception) -> bool:
901        """Stuff we expect to end our connection in normal circumstances."""
902
903        if isinstance(exc, _KeepaliveTimeoutError):
904            return True
905
906        return is_asyncio_streams_communication_error(exc)
907
908    def _check_env(self) -> None:
909        # I was seeing that asyncio stuff wasn't working as expected if
910        # created in one thread and used in another (and have verified
911        # that this is part of the design), so let's enforce a single
912        # thread for all use of an instance.
913        if current_thread() is not self._thread:
914            raise RuntimeError(
915                'This must be called from the same thread'
916                ' that the endpoint was created in.'
917            )
918
919        # This should always be the case if thread is the same.
920        assert asyncio.get_running_loop() is self._event_loop
921
922    def _enqueue_outgoing_packet(self, data: bytes) -> None:
923        """Enqueue a raw packet to be sent. Must be called from our loop."""
924        self._check_env()
925
926        if self.debug_print_io:
927            self.debug_print_call(
928                f'{self._label}: enqueueing outgoing packet'
929                f' {data[:50]!r} at {self._tm()}.'
930            )
931
932        # Add the data and let our write task know about it.
933        self._out_packets.append(data)
934        self._have_out_packets.set()
935
936    def _prune_tasks(self) -> None:
937        self._tasks = self._get_live_tasks()
938
939    def _get_live_tasks(self) -> list[asyncio.Task]:
940        return [t for t in self._tasks if not t.done()]

Facilitates asynchronous multiplexed remote procedure calls.

Be aware that, while multiple calls can be in flight in either direction simultaneously, packets are still sent serially in a single stream. So excessively long messages/responses will delay all other communication. If/when this becomes an issue we can look into breaking up long messages into multiple packets.

RPCEndpoint( handle_raw_message_call: Callable[[bytes], Awaitable[bytes]], reader: asyncio.streams.StreamReader, writer: asyncio.streams.StreamWriter, label: str, *, debug_print: bool = False, debug_print_io: bool = False, debug_print_call: Optional[Callable[[str], NoneType]] = None, keepalive_interval: float = 10.73, keepalive_timeout: float = 30.0)
178    def __init__(
179        self,
180        handle_raw_message_call: Callable[[bytes], Awaitable[bytes]],
181        reader: asyncio.StreamReader,
182        writer: asyncio.StreamWriter,
183        label: str,
184        *,
185        debug_print: bool = False,
186        debug_print_io: bool = False,
187        debug_print_call: Callable[[str], None] | None = None,
188        keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL,
189        keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT,
190    ) -> None:
191        self._handle_raw_message_call = handle_raw_message_call
192        self._reader = reader
193        self._writer = writer
194        self.debug_print = debug_print
195        self.debug_print_io = debug_print_io
196        if debug_print_call is None:
197            debug_print_call = print
198        self.debug_print_call: Callable[[str], None] = debug_print_call
199        self._label = label
200        self._thread = current_thread()
201        self._closing = False
202        self._did_wait_closed = False
203        self._event_loop = asyncio.get_running_loop()
204        self._out_packets = deque[bytes]()
205        self._have_out_packets = asyncio.Event()
206        self._run_called = False
207        self._peer_info: _PeerInfo | None = None
208        self._keepalive_interval = keepalive_interval
209        self._keepalive_timeout = keepalive_timeout
210        self._did_close_writer = False
211        self._did_wait_closed_writer = False
212        self._did_out_packets_buildup_warning = False
213        self._total_bytes_read = 0
214        self._create_time = time.monotonic()
215
216        # Need to hold weak-refs to these otherwise it creates dep-loops
217        # which keeps us alive.
218        self._tasks: list[asyncio.Task] = []
219
220        # When we last got a keepalive or equivalent (time.monotonic value)
221        self._last_keepalive_receive_time: float | None = None
222
223        # (Start near the end to make sure our looping logic is sound).
224        self._next_message_id = 65530
225
226        self._in_flight_messages: dict[int, _InFlightMessage] = {}
227
228        if self.debug_print:
229            peername = self._writer.get_extra_info('peername')
230            self.debug_print_call(
231                f'{self._label}: connected to {peername} at {self._tm()}.'
232            )
test_suppress_keepalives: bool = False
DEFAULT_MESSAGE_TIMEOUT = 60.0
DEFAULT_KEEPALIVE_INTERVAL = 10.73
DEFAULT_KEEPALIVE_TIMEOUT = 30.0
debug_print
debug_print_io
debug_print_call: Callable[[str], NoneType]
async def run(self) -> None:
255    async def run(self) -> None:
256        """Run the endpoint until the connection is lost or closed.
257
258        Handles closing the provided reader/writer on close.
259        """
260        try:
261            await self._do_run()
262        except asyncio.CancelledError:
263            # We aren't really designed to be cancelled so let's warn
264            # if it happens.
265            logging.warning(
266                'RPCEndpoint.run got CancelledError;'
267                ' want to try and avoid this.'
268            )
269            raise

Run the endpoint until the connection is lost or closed.

Handles closing the provided reader/writer on close.

def send_message( self, message: bytes, timeout: float | None = None, close_on_error: bool = True) -> Awaitable[bytes]:
325    def send_message(
326        self,
327        message: bytes,
328        timeout: float | None = None,
329        close_on_error: bool = True,
330    ) -> Awaitable[bytes]:
331        """Send a message to the peer and return a response.
332
333        If timeout is not provided, the default will be used.
334        Raises a CommunicationError if the round trip is not completed
335        for any reason.
336
337        By default, the entire endpoint will go down in the case of
338        errors. This allows messages to be treated as 'reliable' with
339        respect to a given endpoint. Pass close_on_error=False to
340        override this for a particular message.
341        """
342        # Note: This call is synchronous so that the first part of it
343        # (enqueueing outgoing messages) happens synchronously. If it were
344        # a pure async call it could be possible for send order to vary
345        # based on how the async tasks get processed.
346
347        if self.debug_print_io:
348            self.debug_print_call(
349                f'{self._label}: sending message of size {len(message)}'
350                f' at {self._tm()}.'
351            )
352
353        self._check_env()
354
355        if self._closing:
356            raise CommunicationError('Endpoint is closed.')
357
358        if self.debug_print_io:
359            self.debug_print_call(
360                f'{self._label}: have peerinfo? {self._peer_info is not None}.'
361            )
362
363        # message_id is a 16 bit looping value.
364        message_id = self._next_message_id
365        self._next_message_id = (self._next_message_id + 1) % 65536
366
367        if self.debug_print_io:
368            self.debug_print_call(
369                f'{self._label}: will enqueue at {self._tm()}.'
370            )
371
372        # FIXME - should handle backpressure (waiting here if there are
373        # enough packets already enqueued).
374
375        if len(message) > 65535:
376            # Payload consists of type (1b), message_id (2b),
377            # len (4b), and data.
378            self._enqueue_outgoing_packet(
379                _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER)
380                + message_id.to_bytes(2, _BYTE_ORDER)
381                + len(message).to_bytes(4, _BYTE_ORDER)
382                + message
383            )
384        else:
385            # Payload consists of type (1b), message_id (2b),
386            # len (2b), and data.
387            self._enqueue_outgoing_packet(
388                _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER)
389                + message_id.to_bytes(2, _BYTE_ORDER)
390                + len(message).to_bytes(2, _BYTE_ORDER)
391                + message
392            )
393
394        if self.debug_print_io:
395            self.debug_print_call(
396                f'{self._label}: enqueued message of size {len(message)}'
397                f' at {self._tm()}.'
398            )
399
400        # Make an entry so we know this message is out there.
401        assert message_id not in self._in_flight_messages
402        msgobj = self._in_flight_messages[message_id] = _InFlightMessage()
403
404        # Also add its task to our list so we properly cancel it if we die.
405        self._prune_tasks()  # Keep our list from filling with dead tasks.
406        self._tasks.append(msgobj.wait_task)
407
408        # Note: we always want to incorporate a timeout. Individual
409        # messages may hang or error on the other end and this ensures
410        # we won't build up lots of zombie tasks waiting around for
411        # responses that will never arrive.
412        if timeout is None:
413            timeout = self.DEFAULT_MESSAGE_TIMEOUT
414        assert timeout is not None
415
416        bytes_awaitable = msgobj.wait_task
417
418        # Now complete the send asynchronously.
419        return self._send_message(
420            message, timeout, close_on_error, bytes_awaitable, message_id
421        )

Send a message to the peer and return a response.

If timeout is not provided, the default will be used. Raises a CommunicationError if the round trip is not completed for any reason.

By default, the entire endpoint will go down in the case of errors. This allows messages to be treated as 'reliable' with respect to a given endpoint. Pass close_on_error=False to override this for a particular message.

def close(self) -> None:
483    def close(self) -> None:
484        """I said seagulls; mmmm; stop it now."""
485        self._check_env()
486
487        if self._closing:
488            return
489
490        if self.debug_print:
491            self.debug_print_call(f'{self._label}: closing...')
492
493        self._closing = True
494
495        # Kill all of our in-flight tasks.
496        if self.debug_print:
497            self.debug_print_call(f'{self._label}: cancelling tasks...')
498        for task in self._get_live_tasks():
499            task.cancel()
500
501        # Close our writer.
502        assert not self._did_close_writer
503        if self.debug_print:
504            self.debug_print_call(f'{self._label}: closing writer...')
505        self._writer.close()
506        self._did_close_writer = True
507
508        # We don't need this anymore and it is likely to be creating a
509        # dependency loop.
510        del self._handle_raw_message_call

I said seagulls; mmmm; stop it now.

def is_closing(self) -> bool:
512    def is_closing(self) -> bool:
513        """Have we begun the process of closing?"""
514        return self._closing

Have we begun the process of closing?

async def wait_closed(self) -> None:
516    async def wait_closed(self) -> None:
517        """I said seagulls; mmmm; stop it now.
518
519        Wait for the endpoint to finish closing. This is called by run()
520        so generally does not need to be explicitly called.
521        """
522        # pylint: disable=too-many-branches
523        self._check_env()
524
525        # Make sure we only *enter* this call once.
526        if self._did_wait_closed:
527            return
528        self._did_wait_closed = True
529
530        if not self._closing:
531            raise RuntimeError('Must be called after close()')
532
533        if not self._did_close_writer:
534            logging.warning(
535                'RPCEndpoint wait_closed() called but never'
536                ' explicitly closed writer.'
537            )
538
539        live_tasks = self._get_live_tasks()
540
541        # Don't need our task list anymore; this should
542        # break any cyclical refs from tasks referring to us.
543        self._tasks = []
544
545        if self.debug_print:
546            self.debug_print_call(
547                f'{self._label}: waiting for tasks to finish: '
548                f' ({live_tasks=})...'
549            )
550
551        # Wait for all of our in-flight tasks to wrap up.
552        results = await asyncio.gather(*live_tasks, return_exceptions=True)
553        for result in results:
554            # We want to know if any errors happened aside from CancelledError
555            # (which are BaseExceptions, not Exception).
556            if isinstance(result, Exception):
557                logging.warning(
558                    'Got unexpected error cleaning up %s task: %s',
559                    self._label,
560                    result,
561                )
562
563        if not all(task.done() for task in live_tasks):
564            logging.warning(
565                'RPCEndpoint %d: not all live tasks marked done after gather.',
566                id(self),
567            )
568
569        if self.debug_print:
570            self.debug_print_call(
571                f'{self._label}: tasks finished; waiting for writer close...'
572            )
573
574        # Now wait for our writer to finish going down.
575        # When we close our writer it generally triggers errors
576        # in our current blocked read/writes. However that same
577        # error is also sometimes returned from _writer.wait_closed().
578        # See connection_lost() in asyncio/streams.py to see why.
579        # So let's silently ignore it when that happens.
580        assert self._writer.is_closing()
581        try:
582            # It seems that as of Python 3.9.x it is possible for this to hang
583            # indefinitely. See https://github.com/python/cpython/issues/83939
584            # It sounds like this should be fixed in 3.11 but for now just
585            # forcing the issue with a timeout here.
586            await asyncio.wait_for(
587                self._writer.wait_closed(),
588                # timeout=60.0 * 6.0,
589                timeout=30.0,
590            )
591        except asyncio.TimeoutError:
592            logging.info(
593                'Timeout on _writer.wait_closed() for %s rpc (transport=%s).',
594                self._label,
595                ssl_stream_writer_underlying_transport_info(self._writer),
596            )
597            if self.debug_print:
598                self.debug_print_call(
599                    f'{self._label}: got timeout in _writer.wait_closed();'
600                    ' This should be fixed in future Python versions.'
601                )
602        except Exception as exc:
603            if not self._is_expected_connection_error(exc):
604                logging.exception('Error closing _writer for %s.', self._label)
605            else:
606                if self.debug_print:
607                    self.debug_print_call(
608                        f'{self._label}: silently ignoring error in'
609                        f' _writer.wait_closed(): {exc}.'
610                    )
611        except asyncio.CancelledError:
612            logging.warning(
613                'RPCEndpoint.wait_closed()'
614                ' got asyncio.CancelledError; not expected.'
615            )
616            raise
617        assert not self._did_wait_closed_writer
618        self._did_wait_closed_writer = True

I said seagulls; mmmm; stop it now.

Wait for the endpoint to finish closing. This is called by run() so generally does not need to be explicitly called.