efro.rpc

Remote procedure call related functionality.

  1# Released under the MIT License. See LICENSE for details.
  2#
  3"""Remote procedure call related functionality."""
  4
  5from __future__ import annotations
  6
  7import time
  8import asyncio
  9import logging
 10import weakref
 11from enum import Enum
 12from collections import deque
 13from dataclasses import dataclass
 14from threading import current_thread
 15from typing import TYPE_CHECKING, Annotated, assert_never
 16
 17from efro.error import (
 18    CommunicationError,
 19    is_asyncio_streams_communication_error,
 20)
 21from efro.dataclassio import (
 22    dataclass_to_json,
 23    dataclass_from_json,
 24    ioprepped,
 25    IOAttrs,
 26)
 27
 28if TYPE_CHECKING:
 29    from typing import Literal, Awaitable, Callable
 30
 31# Terminology:
 32# Packet: A chunk of data consisting of a type and some type-dependent
 33#         payload. Even though we use streams we organize our transmission
 34#         into 'packets'.
 35# Message: User data which we transmit using one or more packets.
 36
 37
 38class _PacketType(Enum):
 39    HANDSHAKE = 0
 40    KEEPALIVE = 1
 41    MESSAGE = 2
 42    RESPONSE = 3
 43    MESSAGE_BIG = 4
 44    RESPONSE_BIG = 5
 45
 46
 47_BYTE_ORDER: Literal['big'] = 'big'
 48
 49
 50@ioprepped
 51@dataclass
 52class _PeerInfo:
 53    # So we can gracefully evolve how we communicate in the future.
 54    protocol: Annotated[int, IOAttrs('p')]
 55
 56    # How often we'll be sending out keepalives (in seconds).
 57    keepalive_interval: Annotated[float, IOAttrs('k')]
 58
 59
 60# Note: we are expected to be forward and backward compatible; we can
 61# increment protocol freely and expect everyone else to still talk to us.
 62# Likewise we should retain logic to communicate with older protocols.
 63# Protocol history:
 64# 1 - initial release
 65# 2 - gained big (32-bit len val) package/response packets
 66OUR_PROTOCOL = 2
 67
 68
 69def ssl_stream_writer_underlying_transport_info(
 70    writer: asyncio.StreamWriter,
 71) -> str:
 72    """For debugging SSL Stream connections; returns raw transport info."""
 73    # Note: accessing internals here so just returning info and not
 74    # actual objs to reduce potential for breakage.
 75    transport = getattr(writer, '_transport', None)
 76    if transport is not None:
 77        sslproto = getattr(transport, '_ssl_protocol', None)
 78        if sslproto is not None:
 79            raw_transport = getattr(sslproto, '_transport', None)
 80            if raw_transport is not None:
 81                return str(raw_transport)
 82    return '(not found)'
 83
 84
 85def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None:
 86    """Ensure a writer is closed; hacky workaround for odd hang."""
 87    from efro.call import tpartial
 88    from threading import Thread
 89
 90    # Disabling for now..
 91    if bool(True):
 92        return
 93
 94    # Hopefully can remove this in Python 3.11?...
 95    # see issue with is_closing() below for more details.
 96    transport = getattr(writer, '_transport', None)
 97    if transport is not None:
 98        sslproto = getattr(transport, '_ssl_protocol', None)
 99        if sslproto is not None:
100            raw_transport = getattr(sslproto, '_transport', None)
101            if raw_transport is not None:
102                Thread(
103                    target=tpartial(
104                        _do_writer_force_close_check,
105                        weakref.ref(raw_transport),
106                    ),
107                    daemon=True,
108                ).start()
109
110
111def _do_writer_force_close_check(transport_weak: weakref.ref) -> None:
112    try:
113        # Attempt to bail as soon as the obj dies.
114        # If it hasn't done so by our timeout, force-kill it.
115        starttime = time.monotonic()
116        while time.monotonic() - starttime < 10.0:
117            time.sleep(0.1)
118            if transport_weak() is None:
119                return
120        transport = transport_weak()
121        if transport is not None:
122            logging.info('Forcing abort on stuck transport %s.', transport)
123            transport.abort()
124    except Exception:
125        logging.warning('Error in writer-force-close-check', exc_info=True)
126
127
128class _InFlightMessage:
129    """Represents a message that is out on the wire."""
130
131    def __init__(self) -> None:
132        self._response: bytes | None = None
133        self._got_response = asyncio.Event()
134        self.wait_task = asyncio.create_task(
135            self._wait(), name='rpc in flight msg wait'
136        )
137
138    async def _wait(self) -> bytes:
139        await self._got_response.wait()
140        assert self._response is not None
141        return self._response
142
143    def set_response(self, data: bytes) -> None:
144        """Set response data."""
145        assert self._response is None
146        self._response = data
147        self._got_response.set()
148
149
150class _KeepaliveTimeoutError(Exception):
151    """Raised if we time out due to not receiving keepalives."""
152
153
154class RPCEndpoint:
155    """Facilitates asynchronous multiplexed remote procedure calls.
156
157    Be aware that, while multiple calls can be in flight in either direction
158    simultaneously, packets are still sent serially in a single
159    stream. So excessively long messages/responses will delay all other
160    communication. If/when this becomes an issue we can look into breaking up
161    long messages into multiple packets.
162    """
163
164    # Set to True on an instance to test keepalive failures.
165    test_suppress_keepalives: bool = False
166
167    # How long we should wait before giving up on a message by default.
168    # Note this includes processing time on the other end.
169    DEFAULT_MESSAGE_TIMEOUT = 60.0
170
171    # How often we send out keepalive packets by default.
172    DEFAULT_KEEPALIVE_INTERVAL = 10.73  # (avoid too regular of values)
173
174    # How long we can go without receiving a keepalive packet before we
175    # disconnect.
176    DEFAULT_KEEPALIVE_TIMEOUT = 30.0
177
178    def __init__(
179        self,
180        handle_raw_message_call: Callable[[bytes], Awaitable[bytes]],
181        reader: asyncio.StreamReader,
182        writer: asyncio.StreamWriter,
183        label: str,
184        debug_print: bool = False,
185        debug_print_io: bool = False,
186        debug_print_call: Callable[[str], None] | None = None,
187        keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL,
188        keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT,
189    ) -> None:
190        self._handle_raw_message_call = handle_raw_message_call
191        self._reader = reader
192        self._writer = writer
193        self.debug_print = debug_print
194        self.debug_print_io = debug_print_io
195        if debug_print_call is None:
196            debug_print_call = print
197        self.debug_print_call: Callable[[str], None] = debug_print_call
198        self._label = label
199        self._thread = current_thread()
200        self._closing = False
201        self._did_wait_closed = False
202        self._event_loop = asyncio.get_running_loop()
203        self._out_packets = deque[bytes]()
204        self._have_out_packets = asyncio.Event()
205        self._run_called = False
206        self._peer_info: _PeerInfo | None = None
207        self._keepalive_interval = keepalive_interval
208        self._keepalive_timeout = keepalive_timeout
209        self._did_close_writer = False
210        self._did_wait_closed_writer = False
211        self._did_out_packets_buildup_warning = False
212        self._total_bytes_read = 0
213        self._create_time = time.monotonic()
214
215        # Need to hold weak-refs to these otherwise it creates dep-loops
216        # which keeps us alive.
217        self._tasks: list[asyncio.Task] = []
218
219        # When we last got a keepalive or equivalent (time.monotonic value)
220        self._last_keepalive_receive_time: float | None = None
221
222        # (Start near the end to make sure our looping logic is sound).
223        self._next_message_id = 65530
224
225        self._in_flight_messages: dict[int, _InFlightMessage] = {}
226
227        if self.debug_print:
228            peername = self._writer.get_extra_info('peername')
229            self.debug_print_call(
230                f'{self._label}: connected to {peername} at {self._tm()}.'
231            )
232
233    def __del__(self) -> None:
234        if self._run_called:
235            if not self._did_close_writer:
236                logging.warning(
237                    'RPCEndpoint %d dying with run'
238                    ' called but writer not closed (transport=%s).',
239                    id(self),
240                    ssl_stream_writer_underlying_transport_info(self._writer),
241                )
242            elif not self._did_wait_closed_writer:
243                logging.warning(
244                    'RPCEndpoint %d dying with run called'
245                    ' but writer not wait-closed (transport=%s).',
246                    id(self),
247                    ssl_stream_writer_underlying_transport_info(self._writer),
248                )
249
250        # Currently seeing rare issue where sockets don't go down;
251        # let's add a timer to force the issue until we can figure it out.
252        ssl_stream_writer_force_close_check(self._writer)
253
254    async def run(self) -> None:
255        """Run the endpoint until the connection is lost or closed.
256
257        Handles closing the provided reader/writer on close.
258        """
259        try:
260            await self._do_run()
261        except asyncio.CancelledError:
262            # We aren't really designed to be cancelled so let's warn
263            # if it happens.
264            logging.warning(
265                'RPCEndpoint.run got CancelledError;'
266                ' want to try and avoid this.'
267            )
268            raise
269
270    async def _do_run(self) -> None:
271        self._check_env()
272
273        if self._run_called:
274            raise RuntimeError('Run can be called only once per endpoint.')
275        self._run_called = True
276
277        core_tasks = [
278            asyncio.create_task(
279                self._run_core_task('keepalive', self._run_keepalive_task()),
280                name='rpc keepalive',
281            ),
282            asyncio.create_task(
283                self._run_core_task('read', self._run_read_task()),
284                name='rpc read',
285            ),
286            asyncio.create_task(
287                self._run_core_task('write', self._run_write_task()),
288                name='rpc write',
289            ),
290        ]
291        self._tasks += core_tasks
292
293        # Run our core tasks until they all complete.
294        results = await asyncio.gather(*core_tasks, return_exceptions=True)
295
296        # Core tasks should handle their own errors; the only ones
297        # we expect to bubble up are CancelledError.
298        for result in results:
299            # We want to know if any errors happened aside from CancelledError
300            # (which are BaseExceptions, not Exception).
301            if isinstance(result, Exception):
302                logging.warning(
303                    'Got unexpected error from %s core task: %s',
304                    self._label,
305                    result,
306                )
307
308        if not all(task.done() for task in core_tasks):
309            logging.warning(
310                'RPCEndpoint %d: not all core tasks marked done after gather.',
311                id(self),
312            )
313
314        # Shut ourself down.
315        try:
316            self.close()
317            await self.wait_closed()
318        except Exception:
319            logging.exception('Error closing %s.', self._label)
320
321        if self.debug_print:
322            self.debug_print_call(f'{self._label}: finished.')
323
324    def send_message(
325        self,
326        message: bytes,
327        timeout: float | None = None,
328        close_on_error: bool = True,
329    ) -> Awaitable[bytes]:
330        """Send a message to the peer and return a response.
331
332        If timeout is not provided, the default will be used.
333        Raises a CommunicationError if the round trip is not completed
334        for any reason.
335
336        By default, the entire endpoint will go down in the case of
337        errors. This allows messages to be treated as 'reliable' with
338        respect to a given endpoint. Pass close_on_error=False to
339        override this for a particular message.
340        """
341        # Note: This call is synchronous so that the first part of it
342        # (enqueueing outgoing messages) happens synchronously. If it were
343        # a pure async call it could be possible for send order to vary
344        # based on how the async tasks get processed.
345
346        if self.debug_print_io:
347            self.debug_print_call(
348                f'{self._label}: sending message of size {len(message)}'
349                f' at {self._tm()}.'
350            )
351
352        self._check_env()
353
354        if self._closing:
355            raise CommunicationError('Endpoint is closed.')
356
357        if self.debug_print_io:
358            self.debug_print_call(
359                f'{self._label}: have peerinfo? {self._peer_info is not None}.'
360            )
361
362        # message_id is a 16 bit looping value.
363        message_id = self._next_message_id
364        self._next_message_id = (self._next_message_id + 1) % 65536
365
366        if self.debug_print_io:
367            self.debug_print_call(
368                f'{self._label}: will enqueue at {self._tm()}.'
369            )
370
371        # FIXME - should handle backpressure (waiting here if there are
372        # enough packets already enqueued).
373
374        if len(message) > 65535:
375            # Payload consists of type (1b), message_id (2b),
376            # len (4b), and data.
377            self._enqueue_outgoing_packet(
378                _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER)
379                + message_id.to_bytes(2, _BYTE_ORDER)
380                + len(message).to_bytes(4, _BYTE_ORDER)
381                + message
382            )
383        else:
384            # Payload consists of type (1b), message_id (2b),
385            # len (2b), and data.
386            self._enqueue_outgoing_packet(
387                _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER)
388                + message_id.to_bytes(2, _BYTE_ORDER)
389                + len(message).to_bytes(2, _BYTE_ORDER)
390                + message
391            )
392
393        if self.debug_print_io:
394            self.debug_print_call(
395                f'{self._label}: enqueued message of size {len(message)}'
396                f' at {self._tm()}.'
397            )
398
399        # Make an entry so we know this message is out there.
400        assert message_id not in self._in_flight_messages
401        msgobj = self._in_flight_messages[message_id] = _InFlightMessage()
402
403        # Also add its task to our list so we properly cancel it if we die.
404        self._prune_tasks()  # Keep our list from filling with dead tasks.
405        self._tasks.append(msgobj.wait_task)
406
407        # Note: we always want to incorporate a timeout. Individual
408        # messages may hang or error on the other end and this ensures
409        # we won't build up lots of zombie tasks waiting around for
410        # responses that will never arrive.
411        if timeout is None:
412            timeout = self.DEFAULT_MESSAGE_TIMEOUT
413        assert timeout is not None
414
415        bytes_awaitable = msgobj.wait_task
416
417        # Now complete the send asynchronously.
418        return self._send_message(
419            message, timeout, close_on_error, bytes_awaitable, message_id
420        )
421
422    async def _send_message(
423        self,
424        message: bytes,
425        timeout: float | None,
426        close_on_error: bool,
427        bytes_awaitable: asyncio.Task[bytes],
428        message_id: int,
429    ) -> bytes:
430        # We need to know their protocol, so if we haven't gotten a handshake
431        # from them yet, just wait.
432        while self._peer_info is None:
433            await asyncio.sleep(0.01)
434        assert self._peer_info is not None
435
436        if self._peer_info.protocol == 1:
437            if len(message) > 65535:
438                raise RuntimeError('Message cannot be larger than 65535 bytes')
439
440        try:
441            return await asyncio.wait_for(bytes_awaitable, timeout=timeout)
442        except asyncio.CancelledError as exc:
443            # Question: we assume this means the above wait_for() was
444            # cancelled; how do we distinguish between this and *us* being
445            # cancelled though?
446            if self.debug_print:
447                self.debug_print_call(
448                    f'{self._label}: message {message_id} was cancelled.'
449                )
450            if close_on_error:
451                self.close()
452
453            raise CommunicationError() from exc
454        except Exception as exc:
455            # If our timer timed-out or anything else went wrong with
456            # the stream, lump it in as a communication error.
457            if isinstance(
458                exc, asyncio.TimeoutError
459            ) or is_asyncio_streams_communication_error(exc):
460                if self.debug_print:
461                    self.debug_print_call(
462                        f'{self._label}: got {type(exc)} sending message'
463                        f' {message_id}; raising CommunicationError.'
464                    )
465
466                # Stop waiting on the response.
467                bytes_awaitable.cancel()
468
469                # Remove the record of this message.
470                del self._in_flight_messages[message_id]
471
472                if close_on_error:
473                    self.close()
474
475                # Let the user know something went wrong.
476                raise CommunicationError() from exc
477
478            # Some unexpected error; let it bubble up.
479            raise
480
481    def close(self) -> None:
482        """I said seagulls; mmmm; stop it now."""
483        self._check_env()
484
485        if self._closing:
486            return
487
488        if self.debug_print:
489            self.debug_print_call(f'{self._label}: closing...')
490
491        self._closing = True
492
493        # Kill all of our in-flight tasks.
494        if self.debug_print:
495            self.debug_print_call(f'{self._label}: cancelling tasks...')
496        for task in self._get_live_tasks():
497            task.cancel()
498
499        # Close our writer.
500        assert not self._did_close_writer
501        if self.debug_print:
502            self.debug_print_call(f'{self._label}: closing writer...')
503        self._writer.close()
504        self._did_close_writer = True
505
506        # We don't need this anymore and it is likely to be creating a
507        # dependency loop.
508        del self._handle_raw_message_call
509
510    def is_closing(self) -> bool:
511        """Have we begun the process of closing?"""
512        return self._closing
513
514    async def wait_closed(self) -> None:
515        """I said seagulls; mmmm; stop it now.
516
517        Wait for the endpoint to finish closing. This is called by run()
518        so generally does not need to be explicitly called.
519        """
520        # pylint: disable=too-many-branches
521        self._check_env()
522
523        # Make sure we only *enter* this call once.
524        if self._did_wait_closed:
525            return
526        self._did_wait_closed = True
527
528        if not self._closing:
529            raise RuntimeError('Must be called after close()')
530
531        if not self._did_close_writer:
532            logging.warning(
533                'RPCEndpoint wait_closed() called but never'
534                ' explicitly closed writer.'
535            )
536
537        live_tasks = self._get_live_tasks()
538
539        # Don't need our task list anymore; this should
540        # break any cyclical refs from tasks referring to us.
541        self._tasks = []
542
543        if self.debug_print:
544            self.debug_print_call(
545                f'{self._label}: waiting for tasks to finish: '
546                f' ({live_tasks=})...'
547            )
548
549        # Wait for all of our in-flight tasks to wrap up.
550        results = await asyncio.gather(*live_tasks, return_exceptions=True)
551        for result in results:
552            # We want to know if any errors happened aside from CancelledError
553            # (which are BaseExceptions, not Exception).
554            if isinstance(result, Exception):
555                logging.warning(
556                    'Got unexpected error cleaning up %s task: %s',
557                    self._label,
558                    result,
559                )
560
561        if not all(task.done() for task in live_tasks):
562            logging.warning(
563                'RPCEndpoint %d: not all live tasks marked done after gather.',
564                id(self),
565            )
566
567        if self.debug_print:
568            self.debug_print_call(
569                f'{self._label}: tasks finished; waiting for writer close...'
570            )
571
572        # Now wait for our writer to finish going down.
573        # When we close our writer it generally triggers errors
574        # in our current blocked read/writes. However that same
575        # error is also sometimes returned from _writer.wait_closed().
576        # See connection_lost() in asyncio/streams.py to see why.
577        # So let's silently ignore it when that happens.
578        assert self._writer.is_closing()
579        try:
580            # It seems that as of Python 3.9.x it is possible for this to hang
581            # indefinitely. See https://github.com/python/cpython/issues/83939
582            # It sounds like this should be fixed in 3.11 but for now just
583            # forcing the issue with a timeout here.
584            await asyncio.wait_for(
585                self._writer.wait_closed(),
586                # timeout=60.0 * 6.0,
587                timeout=30.0,
588            )
589        except asyncio.TimeoutError:
590            logging.info(
591                'Timeout on _writer.wait_closed() for %s rpc (transport=%s).',
592                self._label,
593                ssl_stream_writer_underlying_transport_info(self._writer),
594            )
595            if self.debug_print:
596                self.debug_print_call(
597                    f'{self._label}: got timeout in _writer.wait_closed();'
598                    ' This should be fixed in future Python versions.'
599                )
600        except Exception as exc:
601            if not self._is_expected_connection_error(exc):
602                logging.exception('Error closing _writer for %s.', self._label)
603            else:
604                if self.debug_print:
605                    self.debug_print_call(
606                        f'{self._label}: silently ignoring error in'
607                        f' _writer.wait_closed(): {exc}.'
608                    )
609        except asyncio.CancelledError:
610            logging.warning(
611                'RPCEndpoint.wait_closed()'
612                ' got asyncio.CancelledError; not expected.'
613            )
614            raise
615        assert not self._did_wait_closed_writer
616        self._did_wait_closed_writer = True
617
618    def _tm(self) -> str:
619        """Simple readable time value for debugging."""
620        tval = time.monotonic() % 100.0
621        return f'{tval:.2f}'
622
623    async def _run_read_task(self) -> None:
624        """Read from the peer."""
625        self._check_env()
626        assert self._peer_info is None
627
628        # Bug fix: if we don't have this set we will never time out
629        # if we never receive any data from the other end.
630        self._last_keepalive_receive_time = time.monotonic()
631
632        # The first thing they should send us is their handshake; then
633        # we'll know if/how we can talk to them.
634        mlen = await self._read_int_32()
635        message = await self._reader.readexactly(mlen)
636        self._total_bytes_read += mlen
637        self._peer_info = dataclass_from_json(_PeerInfo, message.decode())
638        self._last_keepalive_receive_time = time.monotonic()
639        if self.debug_print:
640            self.debug_print_call(
641                f'{self._label}: received handshake at {self._tm()}.'
642            )
643
644        # Now just sit and handle stuff as it comes in.
645        while True:
646            if self._closing:
647                return
648
649            # Read message type.
650            mtype = _PacketType(await self._read_int_8())
651            if mtype is _PacketType.HANDSHAKE:
652                raise RuntimeError('Got multiple handshakes')
653
654            if mtype is _PacketType.KEEPALIVE:
655                if self.debug_print_io:
656                    self.debug_print_call(
657                        f'{self._label}: received keepalive'
658                        f' at {self._tm()}.'
659                    )
660                self._last_keepalive_receive_time = time.monotonic()
661
662            elif mtype is _PacketType.MESSAGE:
663                await self._handle_message_packet(big=False)
664
665            elif mtype is _PacketType.MESSAGE_BIG:
666                await self._handle_message_packet(big=True)
667
668            elif mtype is _PacketType.RESPONSE:
669                await self._handle_response_packet(big=False)
670
671            elif mtype is _PacketType.RESPONSE_BIG:
672                await self._handle_response_packet(big=True)
673
674            else:
675                assert_never(mtype)
676
677    async def _handle_message_packet(self, big: bool) -> None:
678        assert self._peer_info is not None
679        msgid = await self._read_int_16()
680        if big:
681            msglen = await self._read_int_32()
682        else:
683            msglen = await self._read_int_16()
684        msg = await self._reader.readexactly(msglen)
685        self._total_bytes_read += msglen
686        if self.debug_print_io:
687            self.debug_print_call(
688                f'{self._label}: received message {msgid}'
689                f' of size {msglen} at {self._tm()}.'
690            )
691
692        # Create a message-task to handle this message and return
693        # a response (we don't want to block while that happens).
694        assert not self._closing
695        self._prune_tasks()  # Keep from filling with dead tasks.
696        self._tasks.append(
697            asyncio.create_task(
698                self._handle_raw_message(message_id=msgid, message=msg),
699                name='efro rpc message handle',
700            )
701        )
702        if self.debug_print:
703            self.debug_print_call(
704                f'{self._label}: done handling message at {self._tm()}.'
705            )
706
707    async def _handle_response_packet(self, big: bool) -> None:
708        assert self._peer_info is not None
709        msgid = await self._read_int_16()
710        # Protocol 2 gained 32 bit data lengths.
711        if big:
712            rsplen = await self._read_int_32()
713        else:
714            rsplen = await self._read_int_16()
715        if self.debug_print_io:
716            self.debug_print_call(
717                f'{self._label}: received response {msgid}'
718                f' of size {rsplen} at {self._tm()}.'
719            )
720        rsp = await self._reader.readexactly(rsplen)
721        self._total_bytes_read += rsplen
722        msgobj = self._in_flight_messages.get(msgid)
723        if msgobj is None:
724            # It's possible for us to get a response to a message
725            # that has timed out. In this case we will have no local
726            # record of it.
727            if self.debug_print:
728                self.debug_print_call(
729                    f'{self._label}: got response for nonexistent'
730                    f' message id {msgid}; perhaps it timed out?'
731                )
732        else:
733            msgobj.set_response(rsp)
734
735    async def _run_write_task(self) -> None:
736        """Write to the peer."""
737
738        self._check_env()
739
740        # Introduce ourself so our peer knows how it can talk to us.
741        data = dataclass_to_json(
742            _PeerInfo(
743                protocol=OUR_PROTOCOL,
744                keepalive_interval=self._keepalive_interval,
745            )
746        ).encode()
747        self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data)
748
749        # Now just write out-messages as they come in.
750        while True:
751            # Wait until some data comes in.
752            await self._have_out_packets.wait()
753
754            assert self._out_packets
755            data = self._out_packets.popleft()
756
757            # Important: only clear this once all packets are sent.
758            if not self._out_packets:
759                self._have_out_packets.clear()
760
761            self._writer.write(data)
762
763            # This should keep our writer from buffering huge amounts
764            # of outgoing data. We must remember though that we also
765            # need to prevent _out_packets from growing too large and
766            # that part's on us.
767            await self._writer.drain()
768
769            # For now we're not applying backpressure, but let's make
770            # noise if this gets out of hand.
771            if len(self._out_packets) > 200:
772                if not self._did_out_packets_buildup_warning:
773                    logging.warning(
774                        '_out_packets building up too'
775                        ' much on RPCEndpoint %s.',
776                        id(self),
777                    )
778                    self._did_out_packets_buildup_warning = True
779
780    async def _run_keepalive_task(self) -> None:
781        """Send periodic keepalive packets."""
782        self._check_env()
783
784        # We explicitly send our own keepalive packets so we can stay
785        # more on top of the connection state and possibly decide to
786        # kill it when contact is lost more quickly than the OS would
787        # do itself (or at least keep the user informed that the
788        # connection is lagging). It sounds like we could have the TCP
789        # layer do this sort of thing itself but that might be
790        # OS-specific so gonna go this way for now.
791        while True:
792            assert not self._closing
793            await asyncio.sleep(self._keepalive_interval)
794            if not self.test_suppress_keepalives:
795                self._enqueue_outgoing_packet(
796                    _PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER)
797                )
798
799            # Also go ahead and handle dropping the connection if we
800            # haven't heard from the peer in a while.
801            # NOTE: perhaps we want to do something more exact than
802            # this which only checks once per keepalive-interval?..
803            now = time.monotonic()
804            if (
805                self._last_keepalive_receive_time is not None
806                and now - self._last_keepalive_receive_time
807                > self._keepalive_timeout
808            ):
809                if self.debug_print:
810                    since = now - self._last_keepalive_receive_time
811                    self.debug_print_call(
812                        f'{self._label}: reached keepalive time-out'
813                        f' ({since:.1f}s).'
814                    )
815                raise _KeepaliveTimeoutError()
816
817    async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None:
818        try:
819            await call
820        except Exception as exc:
821            # We expect connection errors to put us here, but make noise
822            # if something else does.
823            if not self._is_expected_connection_error(exc):
824                logging.exception(
825                    'Unexpected error in rpc %s %s task'
826                    ' (age=%.1f, total_bytes_read=%d).',
827                    self._label,
828                    tasklabel,
829                    time.monotonic() - self._create_time,
830                    self._total_bytes_read,
831                )
832            else:
833                if self.debug_print:
834                    self.debug_print_call(
835                        f'{self._label}: {tasklabel} task will exit cleanly'
836                        f' due to {exc!r}.'
837                    )
838        finally:
839            # Any core task exiting triggers shutdown.
840            if self.debug_print:
841                self.debug_print_call(
842                    f'{self._label}: {tasklabel} task exiting...'
843                )
844            self.close()
845
846    async def _handle_raw_message(
847        self, message_id: int, message: bytes
848    ) -> None:
849        try:
850            response = await self._handle_raw_message_call(message)
851        except Exception:
852            # We expect local message handler to always succeed.
853            # If that doesn't happen, make a fuss so we know to fix it.
854            # The other end will simply never get a response to this
855            # message.
856            logging.exception('Error handling raw rpc message')
857            return
858
859        assert self._peer_info is not None
860
861        if self._peer_info.protocol == 1:
862            if len(response) > 65535:
863                raise RuntimeError('Response cannot be larger than 65535 bytes')
864
865        # Now send back our response.
866        # Payload consists of type (1b), msgid (2b), len (2b), and data.
867        if len(response) > 65535:
868            self._enqueue_outgoing_packet(
869                _PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER)
870                + message_id.to_bytes(2, _BYTE_ORDER)
871                + len(response).to_bytes(4, _BYTE_ORDER)
872                + response
873            )
874        else:
875            self._enqueue_outgoing_packet(
876                _PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER)
877                + message_id.to_bytes(2, _BYTE_ORDER)
878                + len(response).to_bytes(2, _BYTE_ORDER)
879                + response
880            )
881
882    async def _read_int_8(self) -> int:
883        out = int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER)
884        self._total_bytes_read += 1
885        return out
886
887    async def _read_int_16(self) -> int:
888        out = int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER)
889        self._total_bytes_read += 2
890        return out
891
892    async def _read_int_32(self) -> int:
893        out = int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER)
894        self._total_bytes_read += 4
895        return out
896
897    @classmethod
898    def _is_expected_connection_error(cls, exc: Exception) -> bool:
899        """Stuff we expect to end our connection in normal circumstances."""
900
901        if isinstance(exc, _KeepaliveTimeoutError):
902            return True
903
904        return is_asyncio_streams_communication_error(exc)
905
906    def _check_env(self) -> None:
907        # I was seeing that asyncio stuff wasn't working as expected if
908        # created in one thread and used in another (and have verified
909        # that this is part of the design), so let's enforce a single
910        # thread for all use of an instance.
911        if current_thread() is not self._thread:
912            raise RuntimeError(
913                'This must be called from the same thread'
914                ' that the endpoint was created in.'
915            )
916
917        # This should always be the case if thread is the same.
918        assert asyncio.get_running_loop() is self._event_loop
919
920    def _enqueue_outgoing_packet(self, data: bytes) -> None:
921        """Enqueue a raw packet to be sent. Must be called from our loop."""
922        self._check_env()
923
924        if self.debug_print_io:
925            self.debug_print_call(
926                f'{self._label}: enqueueing outgoing packet'
927                f' {data[:50]!r} at {self._tm()}.'
928            )
929
930        # Add the data and let our write task know about it.
931        self._out_packets.append(data)
932        self._have_out_packets.set()
933
934    def _prune_tasks(self) -> None:
935        self._tasks = self._get_live_tasks()
936
937    def _get_live_tasks(self) -> list[asyncio.Task]:
938        return [t for t in self._tasks if not t.done()]
OUR_PROTOCOL = 2
def ssl_stream_writer_underlying_transport_info(writer: asyncio.streams.StreamWriter) -> str:
70def ssl_stream_writer_underlying_transport_info(
71    writer: asyncio.StreamWriter,
72) -> str:
73    """For debugging SSL Stream connections; returns raw transport info."""
74    # Note: accessing internals here so just returning info and not
75    # actual objs to reduce potential for breakage.
76    transport = getattr(writer, '_transport', None)
77    if transport is not None:
78        sslproto = getattr(transport, '_ssl_protocol', None)
79        if sslproto is not None:
80            raw_transport = getattr(sslproto, '_transport', None)
81            if raw_transport is not None:
82                return str(raw_transport)
83    return '(not found)'

For debugging SSL Stream connections; returns raw transport info.

def ssl_stream_writer_force_close_check(writer: asyncio.streams.StreamWriter) -> None:
 86def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None:
 87    """Ensure a writer is closed; hacky workaround for odd hang."""
 88    from efro.call import tpartial
 89    from threading import Thread
 90
 91    # Disabling for now..
 92    if bool(True):
 93        return
 94
 95    # Hopefully can remove this in Python 3.11?...
 96    # see issue with is_closing() below for more details.
 97    transport = getattr(writer, '_transport', None)
 98    if transport is not None:
 99        sslproto = getattr(transport, '_ssl_protocol', None)
100        if sslproto is not None:
101            raw_transport = getattr(sslproto, '_transport', None)
102            if raw_transport is not None:
103                Thread(
104                    target=tpartial(
105                        _do_writer_force_close_check,
106                        weakref.ref(raw_transport),
107                    ),
108                    daemon=True,
109                ).start()

Ensure a writer is closed; hacky workaround for odd hang.

class RPCEndpoint:
155class RPCEndpoint:
156    """Facilitates asynchronous multiplexed remote procedure calls.
157
158    Be aware that, while multiple calls can be in flight in either direction
159    simultaneously, packets are still sent serially in a single
160    stream. So excessively long messages/responses will delay all other
161    communication. If/when this becomes an issue we can look into breaking up
162    long messages into multiple packets.
163    """
164
165    # Set to True on an instance to test keepalive failures.
166    test_suppress_keepalives: bool = False
167
168    # How long we should wait before giving up on a message by default.
169    # Note this includes processing time on the other end.
170    DEFAULT_MESSAGE_TIMEOUT = 60.0
171
172    # How often we send out keepalive packets by default.
173    DEFAULT_KEEPALIVE_INTERVAL = 10.73  # (avoid too regular of values)
174
175    # How long we can go without receiving a keepalive packet before we
176    # disconnect.
177    DEFAULT_KEEPALIVE_TIMEOUT = 30.0
178
179    def __init__(
180        self,
181        handle_raw_message_call: Callable[[bytes], Awaitable[bytes]],
182        reader: asyncio.StreamReader,
183        writer: asyncio.StreamWriter,
184        label: str,
185        debug_print: bool = False,
186        debug_print_io: bool = False,
187        debug_print_call: Callable[[str], None] | None = None,
188        keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL,
189        keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT,
190    ) -> None:
191        self._handle_raw_message_call = handle_raw_message_call
192        self._reader = reader
193        self._writer = writer
194        self.debug_print = debug_print
195        self.debug_print_io = debug_print_io
196        if debug_print_call is None:
197            debug_print_call = print
198        self.debug_print_call: Callable[[str], None] = debug_print_call
199        self._label = label
200        self._thread = current_thread()
201        self._closing = False
202        self._did_wait_closed = False
203        self._event_loop = asyncio.get_running_loop()
204        self._out_packets = deque[bytes]()
205        self._have_out_packets = asyncio.Event()
206        self._run_called = False
207        self._peer_info: _PeerInfo | None = None
208        self._keepalive_interval = keepalive_interval
209        self._keepalive_timeout = keepalive_timeout
210        self._did_close_writer = False
211        self._did_wait_closed_writer = False
212        self._did_out_packets_buildup_warning = False
213        self._total_bytes_read = 0
214        self._create_time = time.monotonic()
215
216        # Need to hold weak-refs to these otherwise it creates dep-loops
217        # which keeps us alive.
218        self._tasks: list[asyncio.Task] = []
219
220        # When we last got a keepalive or equivalent (time.monotonic value)
221        self._last_keepalive_receive_time: float | None = None
222
223        # (Start near the end to make sure our looping logic is sound).
224        self._next_message_id = 65530
225
226        self._in_flight_messages: dict[int, _InFlightMessage] = {}
227
228        if self.debug_print:
229            peername = self._writer.get_extra_info('peername')
230            self.debug_print_call(
231                f'{self._label}: connected to {peername} at {self._tm()}.'
232            )
233
234    def __del__(self) -> None:
235        if self._run_called:
236            if not self._did_close_writer:
237                logging.warning(
238                    'RPCEndpoint %d dying with run'
239                    ' called but writer not closed (transport=%s).',
240                    id(self),
241                    ssl_stream_writer_underlying_transport_info(self._writer),
242                )
243            elif not self._did_wait_closed_writer:
244                logging.warning(
245                    'RPCEndpoint %d dying with run called'
246                    ' but writer not wait-closed (transport=%s).',
247                    id(self),
248                    ssl_stream_writer_underlying_transport_info(self._writer),
249                )
250
251        # Currently seeing rare issue where sockets don't go down;
252        # let's add a timer to force the issue until we can figure it out.
253        ssl_stream_writer_force_close_check(self._writer)
254
255    async def run(self) -> None:
256        """Run the endpoint until the connection is lost or closed.
257
258        Handles closing the provided reader/writer on close.
259        """
260        try:
261            await self._do_run()
262        except asyncio.CancelledError:
263            # We aren't really designed to be cancelled so let's warn
264            # if it happens.
265            logging.warning(
266                'RPCEndpoint.run got CancelledError;'
267                ' want to try and avoid this.'
268            )
269            raise
270
271    async def _do_run(self) -> None:
272        self._check_env()
273
274        if self._run_called:
275            raise RuntimeError('Run can be called only once per endpoint.')
276        self._run_called = True
277
278        core_tasks = [
279            asyncio.create_task(
280                self._run_core_task('keepalive', self._run_keepalive_task()),
281                name='rpc keepalive',
282            ),
283            asyncio.create_task(
284                self._run_core_task('read', self._run_read_task()),
285                name='rpc read',
286            ),
287            asyncio.create_task(
288                self._run_core_task('write', self._run_write_task()),
289                name='rpc write',
290            ),
291        ]
292        self._tasks += core_tasks
293
294        # Run our core tasks until they all complete.
295        results = await asyncio.gather(*core_tasks, return_exceptions=True)
296
297        # Core tasks should handle their own errors; the only ones
298        # we expect to bubble up are CancelledError.
299        for result in results:
300            # We want to know if any errors happened aside from CancelledError
301            # (which are BaseExceptions, not Exception).
302            if isinstance(result, Exception):
303                logging.warning(
304                    'Got unexpected error from %s core task: %s',
305                    self._label,
306                    result,
307                )
308
309        if not all(task.done() for task in core_tasks):
310            logging.warning(
311                'RPCEndpoint %d: not all core tasks marked done after gather.',
312                id(self),
313            )
314
315        # Shut ourself down.
316        try:
317            self.close()
318            await self.wait_closed()
319        except Exception:
320            logging.exception('Error closing %s.', self._label)
321
322        if self.debug_print:
323            self.debug_print_call(f'{self._label}: finished.')
324
325    def send_message(
326        self,
327        message: bytes,
328        timeout: float | None = None,
329        close_on_error: bool = True,
330    ) -> Awaitable[bytes]:
331        """Send a message to the peer and return a response.
332
333        If timeout is not provided, the default will be used.
334        Raises a CommunicationError if the round trip is not completed
335        for any reason.
336
337        By default, the entire endpoint will go down in the case of
338        errors. This allows messages to be treated as 'reliable' with
339        respect to a given endpoint. Pass close_on_error=False to
340        override this for a particular message.
341        """
342        # Note: This call is synchronous so that the first part of it
343        # (enqueueing outgoing messages) happens synchronously. If it were
344        # a pure async call it could be possible for send order to vary
345        # based on how the async tasks get processed.
346
347        if self.debug_print_io:
348            self.debug_print_call(
349                f'{self._label}: sending message of size {len(message)}'
350                f' at {self._tm()}.'
351            )
352
353        self._check_env()
354
355        if self._closing:
356            raise CommunicationError('Endpoint is closed.')
357
358        if self.debug_print_io:
359            self.debug_print_call(
360                f'{self._label}: have peerinfo? {self._peer_info is not None}.'
361            )
362
363        # message_id is a 16 bit looping value.
364        message_id = self._next_message_id
365        self._next_message_id = (self._next_message_id + 1) % 65536
366
367        if self.debug_print_io:
368            self.debug_print_call(
369                f'{self._label}: will enqueue at {self._tm()}.'
370            )
371
372        # FIXME - should handle backpressure (waiting here if there are
373        # enough packets already enqueued).
374
375        if len(message) > 65535:
376            # Payload consists of type (1b), message_id (2b),
377            # len (4b), and data.
378            self._enqueue_outgoing_packet(
379                _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER)
380                + message_id.to_bytes(2, _BYTE_ORDER)
381                + len(message).to_bytes(4, _BYTE_ORDER)
382                + message
383            )
384        else:
385            # Payload consists of type (1b), message_id (2b),
386            # len (2b), and data.
387            self._enqueue_outgoing_packet(
388                _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER)
389                + message_id.to_bytes(2, _BYTE_ORDER)
390                + len(message).to_bytes(2, _BYTE_ORDER)
391                + message
392            )
393
394        if self.debug_print_io:
395            self.debug_print_call(
396                f'{self._label}: enqueued message of size {len(message)}'
397                f' at {self._tm()}.'
398            )
399
400        # Make an entry so we know this message is out there.
401        assert message_id not in self._in_flight_messages
402        msgobj = self._in_flight_messages[message_id] = _InFlightMessage()
403
404        # Also add its task to our list so we properly cancel it if we die.
405        self._prune_tasks()  # Keep our list from filling with dead tasks.
406        self._tasks.append(msgobj.wait_task)
407
408        # Note: we always want to incorporate a timeout. Individual
409        # messages may hang or error on the other end and this ensures
410        # we won't build up lots of zombie tasks waiting around for
411        # responses that will never arrive.
412        if timeout is None:
413            timeout = self.DEFAULT_MESSAGE_TIMEOUT
414        assert timeout is not None
415
416        bytes_awaitable = msgobj.wait_task
417
418        # Now complete the send asynchronously.
419        return self._send_message(
420            message, timeout, close_on_error, bytes_awaitable, message_id
421        )
422
423    async def _send_message(
424        self,
425        message: bytes,
426        timeout: float | None,
427        close_on_error: bool,
428        bytes_awaitable: asyncio.Task[bytes],
429        message_id: int,
430    ) -> bytes:
431        # We need to know their protocol, so if we haven't gotten a handshake
432        # from them yet, just wait.
433        while self._peer_info is None:
434            await asyncio.sleep(0.01)
435        assert self._peer_info is not None
436
437        if self._peer_info.protocol == 1:
438            if len(message) > 65535:
439                raise RuntimeError('Message cannot be larger than 65535 bytes')
440
441        try:
442            return await asyncio.wait_for(bytes_awaitable, timeout=timeout)
443        except asyncio.CancelledError as exc:
444            # Question: we assume this means the above wait_for() was
445            # cancelled; how do we distinguish between this and *us* being
446            # cancelled though?
447            if self.debug_print:
448                self.debug_print_call(
449                    f'{self._label}: message {message_id} was cancelled.'
450                )
451            if close_on_error:
452                self.close()
453
454            raise CommunicationError() from exc
455        except Exception as exc:
456            # If our timer timed-out or anything else went wrong with
457            # the stream, lump it in as a communication error.
458            if isinstance(
459                exc, asyncio.TimeoutError
460            ) or is_asyncio_streams_communication_error(exc):
461                if self.debug_print:
462                    self.debug_print_call(
463                        f'{self._label}: got {type(exc)} sending message'
464                        f' {message_id}; raising CommunicationError.'
465                    )
466
467                # Stop waiting on the response.
468                bytes_awaitable.cancel()
469
470                # Remove the record of this message.
471                del self._in_flight_messages[message_id]
472
473                if close_on_error:
474                    self.close()
475
476                # Let the user know something went wrong.
477                raise CommunicationError() from exc
478
479            # Some unexpected error; let it bubble up.
480            raise
481
482    def close(self) -> None:
483        """I said seagulls; mmmm; stop it now."""
484        self._check_env()
485
486        if self._closing:
487            return
488
489        if self.debug_print:
490            self.debug_print_call(f'{self._label}: closing...')
491
492        self._closing = True
493
494        # Kill all of our in-flight tasks.
495        if self.debug_print:
496            self.debug_print_call(f'{self._label}: cancelling tasks...')
497        for task in self._get_live_tasks():
498            task.cancel()
499
500        # Close our writer.
501        assert not self._did_close_writer
502        if self.debug_print:
503            self.debug_print_call(f'{self._label}: closing writer...')
504        self._writer.close()
505        self._did_close_writer = True
506
507        # We don't need this anymore and it is likely to be creating a
508        # dependency loop.
509        del self._handle_raw_message_call
510
511    def is_closing(self) -> bool:
512        """Have we begun the process of closing?"""
513        return self._closing
514
515    async def wait_closed(self) -> None:
516        """I said seagulls; mmmm; stop it now.
517
518        Wait for the endpoint to finish closing. This is called by run()
519        so generally does not need to be explicitly called.
520        """
521        # pylint: disable=too-many-branches
522        self._check_env()
523
524        # Make sure we only *enter* this call once.
525        if self._did_wait_closed:
526            return
527        self._did_wait_closed = True
528
529        if not self._closing:
530            raise RuntimeError('Must be called after close()')
531
532        if not self._did_close_writer:
533            logging.warning(
534                'RPCEndpoint wait_closed() called but never'
535                ' explicitly closed writer.'
536            )
537
538        live_tasks = self._get_live_tasks()
539
540        # Don't need our task list anymore; this should
541        # break any cyclical refs from tasks referring to us.
542        self._tasks = []
543
544        if self.debug_print:
545            self.debug_print_call(
546                f'{self._label}: waiting for tasks to finish: '
547                f' ({live_tasks=})...'
548            )
549
550        # Wait for all of our in-flight tasks to wrap up.
551        results = await asyncio.gather(*live_tasks, return_exceptions=True)
552        for result in results:
553            # We want to know if any errors happened aside from CancelledError
554            # (which are BaseExceptions, not Exception).
555            if isinstance(result, Exception):
556                logging.warning(
557                    'Got unexpected error cleaning up %s task: %s',
558                    self._label,
559                    result,
560                )
561
562        if not all(task.done() for task in live_tasks):
563            logging.warning(
564                'RPCEndpoint %d: not all live tasks marked done after gather.',
565                id(self),
566            )
567
568        if self.debug_print:
569            self.debug_print_call(
570                f'{self._label}: tasks finished; waiting for writer close...'
571            )
572
573        # Now wait for our writer to finish going down.
574        # When we close our writer it generally triggers errors
575        # in our current blocked read/writes. However that same
576        # error is also sometimes returned from _writer.wait_closed().
577        # See connection_lost() in asyncio/streams.py to see why.
578        # So let's silently ignore it when that happens.
579        assert self._writer.is_closing()
580        try:
581            # It seems that as of Python 3.9.x it is possible for this to hang
582            # indefinitely. See https://github.com/python/cpython/issues/83939
583            # It sounds like this should be fixed in 3.11 but for now just
584            # forcing the issue with a timeout here.
585            await asyncio.wait_for(
586                self._writer.wait_closed(),
587                # timeout=60.0 * 6.0,
588                timeout=30.0,
589            )
590        except asyncio.TimeoutError:
591            logging.info(
592                'Timeout on _writer.wait_closed() for %s rpc (transport=%s).',
593                self._label,
594                ssl_stream_writer_underlying_transport_info(self._writer),
595            )
596            if self.debug_print:
597                self.debug_print_call(
598                    f'{self._label}: got timeout in _writer.wait_closed();'
599                    ' This should be fixed in future Python versions.'
600                )
601        except Exception as exc:
602            if not self._is_expected_connection_error(exc):
603                logging.exception('Error closing _writer for %s.', self._label)
604            else:
605                if self.debug_print:
606                    self.debug_print_call(
607                        f'{self._label}: silently ignoring error in'
608                        f' _writer.wait_closed(): {exc}.'
609                    )
610        except asyncio.CancelledError:
611            logging.warning(
612                'RPCEndpoint.wait_closed()'
613                ' got asyncio.CancelledError; not expected.'
614            )
615            raise
616        assert not self._did_wait_closed_writer
617        self._did_wait_closed_writer = True
618
619    def _tm(self) -> str:
620        """Simple readable time value for debugging."""
621        tval = time.monotonic() % 100.0
622        return f'{tval:.2f}'
623
624    async def _run_read_task(self) -> None:
625        """Read from the peer."""
626        self._check_env()
627        assert self._peer_info is None
628
629        # Bug fix: if we don't have this set we will never time out
630        # if we never receive any data from the other end.
631        self._last_keepalive_receive_time = time.monotonic()
632
633        # The first thing they should send us is their handshake; then
634        # we'll know if/how we can talk to them.
635        mlen = await self._read_int_32()
636        message = await self._reader.readexactly(mlen)
637        self._total_bytes_read += mlen
638        self._peer_info = dataclass_from_json(_PeerInfo, message.decode())
639        self._last_keepalive_receive_time = time.monotonic()
640        if self.debug_print:
641            self.debug_print_call(
642                f'{self._label}: received handshake at {self._tm()}.'
643            )
644
645        # Now just sit and handle stuff as it comes in.
646        while True:
647            if self._closing:
648                return
649
650            # Read message type.
651            mtype = _PacketType(await self._read_int_8())
652            if mtype is _PacketType.HANDSHAKE:
653                raise RuntimeError('Got multiple handshakes')
654
655            if mtype is _PacketType.KEEPALIVE:
656                if self.debug_print_io:
657                    self.debug_print_call(
658                        f'{self._label}: received keepalive'
659                        f' at {self._tm()}.'
660                    )
661                self._last_keepalive_receive_time = time.monotonic()
662
663            elif mtype is _PacketType.MESSAGE:
664                await self._handle_message_packet(big=False)
665
666            elif mtype is _PacketType.MESSAGE_BIG:
667                await self._handle_message_packet(big=True)
668
669            elif mtype is _PacketType.RESPONSE:
670                await self._handle_response_packet(big=False)
671
672            elif mtype is _PacketType.RESPONSE_BIG:
673                await self._handle_response_packet(big=True)
674
675            else:
676                assert_never(mtype)
677
678    async def _handle_message_packet(self, big: bool) -> None:
679        assert self._peer_info is not None
680        msgid = await self._read_int_16()
681        if big:
682            msglen = await self._read_int_32()
683        else:
684            msglen = await self._read_int_16()
685        msg = await self._reader.readexactly(msglen)
686        self._total_bytes_read += msglen
687        if self.debug_print_io:
688            self.debug_print_call(
689                f'{self._label}: received message {msgid}'
690                f' of size {msglen} at {self._tm()}.'
691            )
692
693        # Create a message-task to handle this message and return
694        # a response (we don't want to block while that happens).
695        assert not self._closing
696        self._prune_tasks()  # Keep from filling with dead tasks.
697        self._tasks.append(
698            asyncio.create_task(
699                self._handle_raw_message(message_id=msgid, message=msg),
700                name='efro rpc message handle',
701            )
702        )
703        if self.debug_print:
704            self.debug_print_call(
705                f'{self._label}: done handling message at {self._tm()}.'
706            )
707
708    async def _handle_response_packet(self, big: bool) -> None:
709        assert self._peer_info is not None
710        msgid = await self._read_int_16()
711        # Protocol 2 gained 32 bit data lengths.
712        if big:
713            rsplen = await self._read_int_32()
714        else:
715            rsplen = await self._read_int_16()
716        if self.debug_print_io:
717            self.debug_print_call(
718                f'{self._label}: received response {msgid}'
719                f' of size {rsplen} at {self._tm()}.'
720            )
721        rsp = await self._reader.readexactly(rsplen)
722        self._total_bytes_read += rsplen
723        msgobj = self._in_flight_messages.get(msgid)
724        if msgobj is None:
725            # It's possible for us to get a response to a message
726            # that has timed out. In this case we will have no local
727            # record of it.
728            if self.debug_print:
729                self.debug_print_call(
730                    f'{self._label}: got response for nonexistent'
731                    f' message id {msgid}; perhaps it timed out?'
732                )
733        else:
734            msgobj.set_response(rsp)
735
736    async def _run_write_task(self) -> None:
737        """Write to the peer."""
738
739        self._check_env()
740
741        # Introduce ourself so our peer knows how it can talk to us.
742        data = dataclass_to_json(
743            _PeerInfo(
744                protocol=OUR_PROTOCOL,
745                keepalive_interval=self._keepalive_interval,
746            )
747        ).encode()
748        self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data)
749
750        # Now just write out-messages as they come in.
751        while True:
752            # Wait until some data comes in.
753            await self._have_out_packets.wait()
754
755            assert self._out_packets
756            data = self._out_packets.popleft()
757
758            # Important: only clear this once all packets are sent.
759            if not self._out_packets:
760                self._have_out_packets.clear()
761
762            self._writer.write(data)
763
764            # This should keep our writer from buffering huge amounts
765            # of outgoing data. We must remember though that we also
766            # need to prevent _out_packets from growing too large and
767            # that part's on us.
768            await self._writer.drain()
769
770            # For now we're not applying backpressure, but let's make
771            # noise if this gets out of hand.
772            if len(self._out_packets) > 200:
773                if not self._did_out_packets_buildup_warning:
774                    logging.warning(
775                        '_out_packets building up too'
776                        ' much on RPCEndpoint %s.',
777                        id(self),
778                    )
779                    self._did_out_packets_buildup_warning = True
780
781    async def _run_keepalive_task(self) -> None:
782        """Send periodic keepalive packets."""
783        self._check_env()
784
785        # We explicitly send our own keepalive packets so we can stay
786        # more on top of the connection state and possibly decide to
787        # kill it when contact is lost more quickly than the OS would
788        # do itself (or at least keep the user informed that the
789        # connection is lagging). It sounds like we could have the TCP
790        # layer do this sort of thing itself but that might be
791        # OS-specific so gonna go this way for now.
792        while True:
793            assert not self._closing
794            await asyncio.sleep(self._keepalive_interval)
795            if not self.test_suppress_keepalives:
796                self._enqueue_outgoing_packet(
797                    _PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER)
798                )
799
800            # Also go ahead and handle dropping the connection if we
801            # haven't heard from the peer in a while.
802            # NOTE: perhaps we want to do something more exact than
803            # this which only checks once per keepalive-interval?..
804            now = time.monotonic()
805            if (
806                self._last_keepalive_receive_time is not None
807                and now - self._last_keepalive_receive_time
808                > self._keepalive_timeout
809            ):
810                if self.debug_print:
811                    since = now - self._last_keepalive_receive_time
812                    self.debug_print_call(
813                        f'{self._label}: reached keepalive time-out'
814                        f' ({since:.1f}s).'
815                    )
816                raise _KeepaliveTimeoutError()
817
818    async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None:
819        try:
820            await call
821        except Exception as exc:
822            # We expect connection errors to put us here, but make noise
823            # if something else does.
824            if not self._is_expected_connection_error(exc):
825                logging.exception(
826                    'Unexpected error in rpc %s %s task'
827                    ' (age=%.1f, total_bytes_read=%d).',
828                    self._label,
829                    tasklabel,
830                    time.monotonic() - self._create_time,
831                    self._total_bytes_read,
832                )
833            else:
834                if self.debug_print:
835                    self.debug_print_call(
836                        f'{self._label}: {tasklabel} task will exit cleanly'
837                        f' due to {exc!r}.'
838                    )
839        finally:
840            # Any core task exiting triggers shutdown.
841            if self.debug_print:
842                self.debug_print_call(
843                    f'{self._label}: {tasklabel} task exiting...'
844                )
845            self.close()
846
847    async def _handle_raw_message(
848        self, message_id: int, message: bytes
849    ) -> None:
850        try:
851            response = await self._handle_raw_message_call(message)
852        except Exception:
853            # We expect local message handler to always succeed.
854            # If that doesn't happen, make a fuss so we know to fix it.
855            # The other end will simply never get a response to this
856            # message.
857            logging.exception('Error handling raw rpc message')
858            return
859
860        assert self._peer_info is not None
861
862        if self._peer_info.protocol == 1:
863            if len(response) > 65535:
864                raise RuntimeError('Response cannot be larger than 65535 bytes')
865
866        # Now send back our response.
867        # Payload consists of type (1b), msgid (2b), len (2b), and data.
868        if len(response) > 65535:
869            self._enqueue_outgoing_packet(
870                _PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER)
871                + message_id.to_bytes(2, _BYTE_ORDER)
872                + len(response).to_bytes(4, _BYTE_ORDER)
873                + response
874            )
875        else:
876            self._enqueue_outgoing_packet(
877                _PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER)
878                + message_id.to_bytes(2, _BYTE_ORDER)
879                + len(response).to_bytes(2, _BYTE_ORDER)
880                + response
881            )
882
883    async def _read_int_8(self) -> int:
884        out = int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER)
885        self._total_bytes_read += 1
886        return out
887
888    async def _read_int_16(self) -> int:
889        out = int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER)
890        self._total_bytes_read += 2
891        return out
892
893    async def _read_int_32(self) -> int:
894        out = int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER)
895        self._total_bytes_read += 4
896        return out
897
898    @classmethod
899    def _is_expected_connection_error(cls, exc: Exception) -> bool:
900        """Stuff we expect to end our connection in normal circumstances."""
901
902        if isinstance(exc, _KeepaliveTimeoutError):
903            return True
904
905        return is_asyncio_streams_communication_error(exc)
906
907    def _check_env(self) -> None:
908        # I was seeing that asyncio stuff wasn't working as expected if
909        # created in one thread and used in another (and have verified
910        # that this is part of the design), so let's enforce a single
911        # thread for all use of an instance.
912        if current_thread() is not self._thread:
913            raise RuntimeError(
914                'This must be called from the same thread'
915                ' that the endpoint was created in.'
916            )
917
918        # This should always be the case if thread is the same.
919        assert asyncio.get_running_loop() is self._event_loop
920
921    def _enqueue_outgoing_packet(self, data: bytes) -> None:
922        """Enqueue a raw packet to be sent. Must be called from our loop."""
923        self._check_env()
924
925        if self.debug_print_io:
926            self.debug_print_call(
927                f'{self._label}: enqueueing outgoing packet'
928                f' {data[:50]!r} at {self._tm()}.'
929            )
930
931        # Add the data and let our write task know about it.
932        self._out_packets.append(data)
933        self._have_out_packets.set()
934
935    def _prune_tasks(self) -> None:
936        self._tasks = self._get_live_tasks()
937
938    def _get_live_tasks(self) -> list[asyncio.Task]:
939        return [t for t in self._tasks if not t.done()]

Facilitates asynchronous multiplexed remote procedure calls.

Be aware that, while multiple calls can be in flight in either direction simultaneously, packets are still sent serially in a single stream. So excessively long messages/responses will delay all other communication. If/when this becomes an issue we can look into breaking up long messages into multiple packets.

RPCEndpoint( handle_raw_message_call: Callable[[bytes], Awaitable[bytes]], reader: asyncio.streams.StreamReader, writer: asyncio.streams.StreamWriter, label: str, debug_print: bool = False, debug_print_io: bool = False, debug_print_call: Optional[Callable[[str], NoneType]] = None, keepalive_interval: float = 10.73, keepalive_timeout: float = 30.0)
179    def __init__(
180        self,
181        handle_raw_message_call: Callable[[bytes], Awaitable[bytes]],
182        reader: asyncio.StreamReader,
183        writer: asyncio.StreamWriter,
184        label: str,
185        debug_print: bool = False,
186        debug_print_io: bool = False,
187        debug_print_call: Callable[[str], None] | None = None,
188        keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL,
189        keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT,
190    ) -> None:
191        self._handle_raw_message_call = handle_raw_message_call
192        self._reader = reader
193        self._writer = writer
194        self.debug_print = debug_print
195        self.debug_print_io = debug_print_io
196        if debug_print_call is None:
197            debug_print_call = print
198        self.debug_print_call: Callable[[str], None] = debug_print_call
199        self._label = label
200        self._thread = current_thread()
201        self._closing = False
202        self._did_wait_closed = False
203        self._event_loop = asyncio.get_running_loop()
204        self._out_packets = deque[bytes]()
205        self._have_out_packets = asyncio.Event()
206        self._run_called = False
207        self._peer_info: _PeerInfo | None = None
208        self._keepalive_interval = keepalive_interval
209        self._keepalive_timeout = keepalive_timeout
210        self._did_close_writer = False
211        self._did_wait_closed_writer = False
212        self._did_out_packets_buildup_warning = False
213        self._total_bytes_read = 0
214        self._create_time = time.monotonic()
215
216        # Need to hold weak-refs to these otherwise it creates dep-loops
217        # which keeps us alive.
218        self._tasks: list[asyncio.Task] = []
219
220        # When we last got a keepalive or equivalent (time.monotonic value)
221        self._last_keepalive_receive_time: float | None = None
222
223        # (Start near the end to make sure our looping logic is sound).
224        self._next_message_id = 65530
225
226        self._in_flight_messages: dict[int, _InFlightMessage] = {}
227
228        if self.debug_print:
229            peername = self._writer.get_extra_info('peername')
230            self.debug_print_call(
231                f'{self._label}: connected to {peername} at {self._tm()}.'
232            )
test_suppress_keepalives: bool = False
DEFAULT_MESSAGE_TIMEOUT = 60.0
DEFAULT_KEEPALIVE_INTERVAL = 10.73
DEFAULT_KEEPALIVE_TIMEOUT = 30.0
debug_print
debug_print_io
debug_print_call: Callable[[str], NoneType]
async def run(self) -> None:
255    async def run(self) -> None:
256        """Run the endpoint until the connection is lost or closed.
257
258        Handles closing the provided reader/writer on close.
259        """
260        try:
261            await self._do_run()
262        except asyncio.CancelledError:
263            # We aren't really designed to be cancelled so let's warn
264            # if it happens.
265            logging.warning(
266                'RPCEndpoint.run got CancelledError;'
267                ' want to try and avoid this.'
268            )
269            raise

Run the endpoint until the connection is lost or closed.

Handles closing the provided reader/writer on close.

def send_message( self, message: bytes, timeout: float | None = None, close_on_error: bool = True) -> Awaitable[bytes]:
325    def send_message(
326        self,
327        message: bytes,
328        timeout: float | None = None,
329        close_on_error: bool = True,
330    ) -> Awaitable[bytes]:
331        """Send a message to the peer and return a response.
332
333        If timeout is not provided, the default will be used.
334        Raises a CommunicationError if the round trip is not completed
335        for any reason.
336
337        By default, the entire endpoint will go down in the case of
338        errors. This allows messages to be treated as 'reliable' with
339        respect to a given endpoint. Pass close_on_error=False to
340        override this for a particular message.
341        """
342        # Note: This call is synchronous so that the first part of it
343        # (enqueueing outgoing messages) happens synchronously. If it were
344        # a pure async call it could be possible for send order to vary
345        # based on how the async tasks get processed.
346
347        if self.debug_print_io:
348            self.debug_print_call(
349                f'{self._label}: sending message of size {len(message)}'
350                f' at {self._tm()}.'
351            )
352
353        self._check_env()
354
355        if self._closing:
356            raise CommunicationError('Endpoint is closed.')
357
358        if self.debug_print_io:
359            self.debug_print_call(
360                f'{self._label}: have peerinfo? {self._peer_info is not None}.'
361            )
362
363        # message_id is a 16 bit looping value.
364        message_id = self._next_message_id
365        self._next_message_id = (self._next_message_id + 1) % 65536
366
367        if self.debug_print_io:
368            self.debug_print_call(
369                f'{self._label}: will enqueue at {self._tm()}.'
370            )
371
372        # FIXME - should handle backpressure (waiting here if there are
373        # enough packets already enqueued).
374
375        if len(message) > 65535:
376            # Payload consists of type (1b), message_id (2b),
377            # len (4b), and data.
378            self._enqueue_outgoing_packet(
379                _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER)
380                + message_id.to_bytes(2, _BYTE_ORDER)
381                + len(message).to_bytes(4, _BYTE_ORDER)
382                + message
383            )
384        else:
385            # Payload consists of type (1b), message_id (2b),
386            # len (2b), and data.
387            self._enqueue_outgoing_packet(
388                _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER)
389                + message_id.to_bytes(2, _BYTE_ORDER)
390                + len(message).to_bytes(2, _BYTE_ORDER)
391                + message
392            )
393
394        if self.debug_print_io:
395            self.debug_print_call(
396                f'{self._label}: enqueued message of size {len(message)}'
397                f' at {self._tm()}.'
398            )
399
400        # Make an entry so we know this message is out there.
401        assert message_id not in self._in_flight_messages
402        msgobj = self._in_flight_messages[message_id] = _InFlightMessage()
403
404        # Also add its task to our list so we properly cancel it if we die.
405        self._prune_tasks()  # Keep our list from filling with dead tasks.
406        self._tasks.append(msgobj.wait_task)
407
408        # Note: we always want to incorporate a timeout. Individual
409        # messages may hang or error on the other end and this ensures
410        # we won't build up lots of zombie tasks waiting around for
411        # responses that will never arrive.
412        if timeout is None:
413            timeout = self.DEFAULT_MESSAGE_TIMEOUT
414        assert timeout is not None
415
416        bytes_awaitable = msgobj.wait_task
417
418        # Now complete the send asynchronously.
419        return self._send_message(
420            message, timeout, close_on_error, bytes_awaitable, message_id
421        )

Send a message to the peer and return a response.

If timeout is not provided, the default will be used. Raises a CommunicationError if the round trip is not completed for any reason.

By default, the entire endpoint will go down in the case of errors. This allows messages to be treated as 'reliable' with respect to a given endpoint. Pass close_on_error=False to override this for a particular message.

def close(self) -> None:
482    def close(self) -> None:
483        """I said seagulls; mmmm; stop it now."""
484        self._check_env()
485
486        if self._closing:
487            return
488
489        if self.debug_print:
490            self.debug_print_call(f'{self._label}: closing...')
491
492        self._closing = True
493
494        # Kill all of our in-flight tasks.
495        if self.debug_print:
496            self.debug_print_call(f'{self._label}: cancelling tasks...')
497        for task in self._get_live_tasks():
498            task.cancel()
499
500        # Close our writer.
501        assert not self._did_close_writer
502        if self.debug_print:
503            self.debug_print_call(f'{self._label}: closing writer...')
504        self._writer.close()
505        self._did_close_writer = True
506
507        # We don't need this anymore and it is likely to be creating a
508        # dependency loop.
509        del self._handle_raw_message_call

I said seagulls; mmmm; stop it now.

def is_closing(self) -> bool:
511    def is_closing(self) -> bool:
512        """Have we begun the process of closing?"""
513        return self._closing

Have we begun the process of closing?

async def wait_closed(self) -> None:
515    async def wait_closed(self) -> None:
516        """I said seagulls; mmmm; stop it now.
517
518        Wait for the endpoint to finish closing. This is called by run()
519        so generally does not need to be explicitly called.
520        """
521        # pylint: disable=too-many-branches
522        self._check_env()
523
524        # Make sure we only *enter* this call once.
525        if self._did_wait_closed:
526            return
527        self._did_wait_closed = True
528
529        if not self._closing:
530            raise RuntimeError('Must be called after close()')
531
532        if not self._did_close_writer:
533            logging.warning(
534                'RPCEndpoint wait_closed() called but never'
535                ' explicitly closed writer.'
536            )
537
538        live_tasks = self._get_live_tasks()
539
540        # Don't need our task list anymore; this should
541        # break any cyclical refs from tasks referring to us.
542        self._tasks = []
543
544        if self.debug_print:
545            self.debug_print_call(
546                f'{self._label}: waiting for tasks to finish: '
547                f' ({live_tasks=})...'
548            )
549
550        # Wait for all of our in-flight tasks to wrap up.
551        results = await asyncio.gather(*live_tasks, return_exceptions=True)
552        for result in results:
553            # We want to know if any errors happened aside from CancelledError
554            # (which are BaseExceptions, not Exception).
555            if isinstance(result, Exception):
556                logging.warning(
557                    'Got unexpected error cleaning up %s task: %s',
558                    self._label,
559                    result,
560                )
561
562        if not all(task.done() for task in live_tasks):
563            logging.warning(
564                'RPCEndpoint %d: not all live tasks marked done after gather.',
565                id(self),
566            )
567
568        if self.debug_print:
569            self.debug_print_call(
570                f'{self._label}: tasks finished; waiting for writer close...'
571            )
572
573        # Now wait for our writer to finish going down.
574        # When we close our writer it generally triggers errors
575        # in our current blocked read/writes. However that same
576        # error is also sometimes returned from _writer.wait_closed().
577        # See connection_lost() in asyncio/streams.py to see why.
578        # So let's silently ignore it when that happens.
579        assert self._writer.is_closing()
580        try:
581            # It seems that as of Python 3.9.x it is possible for this to hang
582            # indefinitely. See https://github.com/python/cpython/issues/83939
583            # It sounds like this should be fixed in 3.11 but for now just
584            # forcing the issue with a timeout here.
585            await asyncio.wait_for(
586                self._writer.wait_closed(),
587                # timeout=60.0 * 6.0,
588                timeout=30.0,
589            )
590        except asyncio.TimeoutError:
591            logging.info(
592                'Timeout on _writer.wait_closed() for %s rpc (transport=%s).',
593                self._label,
594                ssl_stream_writer_underlying_transport_info(self._writer),
595            )
596            if self.debug_print:
597                self.debug_print_call(
598                    f'{self._label}: got timeout in _writer.wait_closed();'
599                    ' This should be fixed in future Python versions.'
600                )
601        except Exception as exc:
602            if not self._is_expected_connection_error(exc):
603                logging.exception('Error closing _writer for %s.', self._label)
604            else:
605                if self.debug_print:
606                    self.debug_print_call(
607                        f'{self._label}: silently ignoring error in'
608                        f' _writer.wait_closed(): {exc}.'
609                    )
610        except asyncio.CancelledError:
611            logging.warning(
612                'RPCEndpoint.wait_closed()'
613                ' got asyncio.CancelledError; not expected.'
614            )
615            raise
616        assert not self._did_wait_closed_writer
617        self._did_wait_closed_writer = True

I said seagulls; mmmm; stop it now.

Wait for the endpoint to finish closing. This is called by run() so generally does not need to be explicitly called.