efro.rpc

Remote procedure call related functionality.

  1# Released under the MIT License. See LICENSE for details.
  2#
  3"""Remote procedure call related functionality."""
  4
  5from __future__ import annotations
  6
  7import time
  8import asyncio
  9import logging
 10import weakref
 11from enum import Enum
 12from functools import partial
 13from collections import deque
 14from dataclasses import dataclass
 15from threading import current_thread
 16from typing import TYPE_CHECKING, Annotated, assert_never
 17
 18from efro.error import (
 19    CommunicationError,
 20    is_asyncio_streams_communication_error,
 21)
 22from efro.dataclassio import (
 23    dataclass_to_json,
 24    dataclass_from_json,
 25    ioprepped,
 26    IOAttrs,
 27)
 28
 29if TYPE_CHECKING:
 30    from typing import Literal, Awaitable, Callable
 31
 32# Terminology:
 33# Packet: A chunk of data consisting of a type and some type-dependent
 34#         payload. Even though we use streams we organize our transmission
 35#         into 'packets'.
 36# Message: User data which we transmit using one or more packets.
 37
 38
 39class _PacketType(Enum):
 40    HANDSHAKE = 0
 41    KEEPALIVE = 1
 42    MESSAGE = 2
 43    RESPONSE = 3
 44    MESSAGE_BIG = 4
 45    RESPONSE_BIG = 5
 46
 47
 48_BYTE_ORDER: Literal['big'] = 'big'
 49
 50
 51@ioprepped
 52@dataclass
 53class _PeerInfo:
 54    # So we can gracefully evolve how we communicate in the future.
 55    protocol: Annotated[int, IOAttrs('p')]
 56
 57    # How often we'll be sending out keepalives (in seconds).
 58    keepalive_interval: Annotated[float, IOAttrs('k')]
 59
 60
 61# Note: we are expected to be forward and backward compatible; we can
 62# increment protocol freely and expect everyone else to still talk to us.
 63# Likewise we should retain logic to communicate with older protocols.
 64# Protocol history:
 65# 1 - initial release
 66# 2 - gained big (32-bit len val) package/response packets
 67OUR_PROTOCOL = 2
 68
 69
 70def ssl_stream_writer_underlying_transport_info(
 71    writer: asyncio.StreamWriter,
 72) -> str:
 73    """For debugging SSL Stream connections; returns raw transport info."""
 74    # Note: accessing internals here so just returning info and not
 75    # actual objs to reduce potential for breakage.
 76    transport = getattr(writer, '_transport', None)
 77    if transport is not None:
 78        sslproto = getattr(transport, '_ssl_protocol', None)
 79        if sslproto is not None:
 80            raw_transport = getattr(sslproto, '_transport', None)
 81            if raw_transport is not None:
 82                return str(raw_transport)
 83    return '(not found)'
 84
 85
 86def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None:
 87    """Ensure a writer is closed; hacky workaround for odd hang."""
 88    from threading import Thread
 89
 90    # Disabling for now..
 91    if bool(True):
 92        return
 93
 94    # Hopefully can remove this in Python 3.11?...
 95    # see issue with is_closing() below for more details.
 96    transport = getattr(writer, '_transport', None)
 97    if transport is not None:
 98        sslproto = getattr(transport, '_ssl_protocol', None)
 99        if sslproto is not None:
100            raw_transport = getattr(sslproto, '_transport', None)
101            if raw_transport is not None:
102                Thread(
103                    target=partial(
104                        _do_writer_force_close_check, weakref.ref(raw_transport)
105                    ),
106                    daemon=True,
107                ).start()
108
109
110def _do_writer_force_close_check(transport_weak: weakref.ref) -> None:
111    try:
112        # Attempt to bail as soon as the obj dies.
113        # If it hasn't done so by our timeout, force-kill it.
114        starttime = time.monotonic()
115        while time.monotonic() - starttime < 10.0:
116            time.sleep(0.1)
117            if transport_weak() is None:
118                return
119        transport = transport_weak()
120        if transport is not None:
121            logging.info('Forcing abort on stuck transport %s.', transport)
122            transport.abort()
123    except Exception:
124        logging.warning('Error in writer-force-close-check', exc_info=True)
125
126
127class _InFlightMessage:
128    """Represents a message that is out on the wire."""
129
130    def __init__(self) -> None:
131        self._response: bytes | None = None
132        self._got_response = asyncio.Event()
133        self.wait_task = asyncio.create_task(
134            self._wait(), name='rpc in flight msg wait'
135        )
136
137    async def _wait(self) -> bytes:
138        await self._got_response.wait()
139        assert self._response is not None
140        return self._response
141
142    def set_response(self, data: bytes) -> None:
143        """Set response data."""
144        assert self._response is None
145        self._response = data
146        self._got_response.set()
147
148
149class _KeepaliveTimeoutError(Exception):
150    """Raised if we time out due to not receiving keepalives."""
151
152
153class RPCEndpoint:
154    """Facilitates asynchronous multiplexed remote procedure calls.
155
156    Be aware that, while multiple calls can be in flight in either direction
157    simultaneously, packets are still sent serially in a single
158    stream. So excessively long messages/responses will delay all other
159    communication. If/when this becomes an issue we can look into breaking up
160    long messages into multiple packets.
161    """
162
163    # Set to True on an instance to test keepalive failures.
164    test_suppress_keepalives: bool = False
165
166    # How long we should wait before giving up on a message by default.
167    # Note this includes processing time on the other end.
168    DEFAULT_MESSAGE_TIMEOUT = 60.0
169
170    # How often we send out keepalive packets by default.
171    DEFAULT_KEEPALIVE_INTERVAL = 10.73  # (avoid too regular of values)
172
173    # How long we can go without receiving a keepalive packet before we
174    # disconnect.
175    DEFAULT_KEEPALIVE_TIMEOUT = 30.0
176
177    def __init__(
178        self,
179        handle_raw_message_call: Callable[[bytes], Awaitable[bytes]],
180        reader: asyncio.StreamReader,
181        writer: asyncio.StreamWriter,
182        label: str,
183        debug_print: bool = False,
184        debug_print_io: bool = False,
185        debug_print_call: Callable[[str], None] | None = None,
186        keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL,
187        keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT,
188    ) -> None:
189        self._handle_raw_message_call = handle_raw_message_call
190        self._reader = reader
191        self._writer = writer
192        self.debug_print = debug_print
193        self.debug_print_io = debug_print_io
194        if debug_print_call is None:
195            debug_print_call = print
196        self.debug_print_call: Callable[[str], None] = debug_print_call
197        self._label = label
198        self._thread = current_thread()
199        self._closing = False
200        self._did_wait_closed = False
201        self._event_loop = asyncio.get_running_loop()
202        self._out_packets = deque[bytes]()
203        self._have_out_packets = asyncio.Event()
204        self._run_called = False
205        self._peer_info: _PeerInfo | None = None
206        self._keepalive_interval = keepalive_interval
207        self._keepalive_timeout = keepalive_timeout
208        self._did_close_writer = False
209        self._did_wait_closed_writer = False
210        self._did_out_packets_buildup_warning = False
211        self._total_bytes_read = 0
212        self._create_time = time.monotonic()
213
214        # Need to hold weak-refs to these otherwise it creates dep-loops
215        # which keeps us alive.
216        self._tasks: list[asyncio.Task] = []
217
218        # When we last got a keepalive or equivalent (time.monotonic value)
219        self._last_keepalive_receive_time: float | None = None
220
221        # (Start near the end to make sure our looping logic is sound).
222        self._next_message_id = 65530
223
224        self._in_flight_messages: dict[int, _InFlightMessage] = {}
225
226        if self.debug_print:
227            peername = self._writer.get_extra_info('peername')
228            self.debug_print_call(
229                f'{self._label}: connected to {peername} at {self._tm()}.'
230            )
231
232    def __del__(self) -> None:
233        if self._run_called:
234            if not self._did_close_writer:
235                logging.warning(
236                    'RPCEndpoint %d dying with run'
237                    ' called but writer not closed (transport=%s).',
238                    id(self),
239                    ssl_stream_writer_underlying_transport_info(self._writer),
240                )
241            elif not self._did_wait_closed_writer:
242                logging.warning(
243                    'RPCEndpoint %d dying with run called'
244                    ' but writer not wait-closed (transport=%s).',
245                    id(self),
246                    ssl_stream_writer_underlying_transport_info(self._writer),
247                )
248
249        # Currently seeing rare issue where sockets don't go down;
250        # let's add a timer to force the issue until we can figure it out.
251        ssl_stream_writer_force_close_check(self._writer)
252
253    async def run(self) -> None:
254        """Run the endpoint until the connection is lost or closed.
255
256        Handles closing the provided reader/writer on close.
257        """
258        try:
259            await self._do_run()
260        except asyncio.CancelledError:
261            # We aren't really designed to be cancelled so let's warn
262            # if it happens.
263            logging.warning(
264                'RPCEndpoint.run got CancelledError;'
265                ' want to try and avoid this.'
266            )
267            raise
268
269    async def _do_run(self) -> None:
270        self._check_env()
271
272        if self._run_called:
273            raise RuntimeError('Run can be called only once per endpoint.')
274        self._run_called = True
275
276        core_tasks = [
277            asyncio.create_task(
278                self._run_core_task('keepalive', self._run_keepalive_task()),
279                name='rpc keepalive',
280            ),
281            asyncio.create_task(
282                self._run_core_task('read', self._run_read_task()),
283                name='rpc read',
284            ),
285            asyncio.create_task(
286                self._run_core_task('write', self._run_write_task()),
287                name='rpc write',
288            ),
289        ]
290        self._tasks += core_tasks
291
292        # Run our core tasks until they all complete.
293        results = await asyncio.gather(*core_tasks, return_exceptions=True)
294
295        # Core tasks should handle their own errors; the only ones
296        # we expect to bubble up are CancelledError.
297        for result in results:
298            # We want to know if any errors happened aside from CancelledError
299            # (which are BaseExceptions, not Exception).
300            if isinstance(result, Exception):
301                logging.warning(
302                    'Got unexpected error from %s core task: %s',
303                    self._label,
304                    result,
305                )
306
307        if not all(task.done() for task in core_tasks):
308            logging.warning(
309                'RPCEndpoint %d: not all core tasks marked done after gather.',
310                id(self),
311            )
312
313        # Shut ourself down.
314        try:
315            self.close()
316            await self.wait_closed()
317        except Exception:
318            logging.exception('Error closing %s.', self._label)
319
320        if self.debug_print:
321            self.debug_print_call(f'{self._label}: finished.')
322
323    def send_message(
324        self,
325        message: bytes,
326        timeout: float | None = None,
327        close_on_error: bool = True,
328    ) -> Awaitable[bytes]:
329        """Send a message to the peer and return a response.
330
331        If timeout is not provided, the default will be used.
332        Raises a CommunicationError if the round trip is not completed
333        for any reason.
334
335        By default, the entire endpoint will go down in the case of
336        errors. This allows messages to be treated as 'reliable' with
337        respect to a given endpoint. Pass close_on_error=False to
338        override this for a particular message.
339        """
340        # Note: This call is synchronous so that the first part of it
341        # (enqueueing outgoing messages) happens synchronously. If it were
342        # a pure async call it could be possible for send order to vary
343        # based on how the async tasks get processed.
344
345        if self.debug_print_io:
346            self.debug_print_call(
347                f'{self._label}: sending message of size {len(message)}'
348                f' at {self._tm()}.'
349            )
350
351        self._check_env()
352
353        if self._closing:
354            raise CommunicationError('Endpoint is closed.')
355
356        if self.debug_print_io:
357            self.debug_print_call(
358                f'{self._label}: have peerinfo? {self._peer_info is not None}.'
359            )
360
361        # message_id is a 16 bit looping value.
362        message_id = self._next_message_id
363        self._next_message_id = (self._next_message_id + 1) % 65536
364
365        if self.debug_print_io:
366            self.debug_print_call(
367                f'{self._label}: will enqueue at {self._tm()}.'
368            )
369
370        # FIXME - should handle backpressure (waiting here if there are
371        # enough packets already enqueued).
372
373        if len(message) > 65535:
374            # Payload consists of type (1b), message_id (2b),
375            # len (4b), and data.
376            self._enqueue_outgoing_packet(
377                _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER)
378                + message_id.to_bytes(2, _BYTE_ORDER)
379                + len(message).to_bytes(4, _BYTE_ORDER)
380                + message
381            )
382        else:
383            # Payload consists of type (1b), message_id (2b),
384            # len (2b), and data.
385            self._enqueue_outgoing_packet(
386                _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER)
387                + message_id.to_bytes(2, _BYTE_ORDER)
388                + len(message).to_bytes(2, _BYTE_ORDER)
389                + message
390            )
391
392        if self.debug_print_io:
393            self.debug_print_call(
394                f'{self._label}: enqueued message of size {len(message)}'
395                f' at {self._tm()}.'
396            )
397
398        # Make an entry so we know this message is out there.
399        assert message_id not in self._in_flight_messages
400        msgobj = self._in_flight_messages[message_id] = _InFlightMessage()
401
402        # Also add its task to our list so we properly cancel it if we die.
403        self._prune_tasks()  # Keep our list from filling with dead tasks.
404        self._tasks.append(msgobj.wait_task)
405
406        # Note: we always want to incorporate a timeout. Individual
407        # messages may hang or error on the other end and this ensures
408        # we won't build up lots of zombie tasks waiting around for
409        # responses that will never arrive.
410        if timeout is None:
411            timeout = self.DEFAULT_MESSAGE_TIMEOUT
412        assert timeout is not None
413
414        bytes_awaitable = msgobj.wait_task
415
416        # Now complete the send asynchronously.
417        return self._send_message(
418            message, timeout, close_on_error, bytes_awaitable, message_id
419        )
420
421    async def _send_message(
422        self,
423        message: bytes,
424        timeout: float | None,
425        close_on_error: bool,
426        bytes_awaitable: asyncio.Task[bytes],
427        message_id: int,
428    ) -> bytes:
429        # We need to know their protocol, so if we haven't gotten a handshake
430        # from them yet, just wait.
431        while self._peer_info is None:
432            await asyncio.sleep(0.01)
433        assert self._peer_info is not None
434
435        if self._peer_info.protocol == 1:
436            if len(message) > 65535:
437                raise RuntimeError('Message cannot be larger than 65535 bytes')
438
439        try:
440            return await asyncio.wait_for(bytes_awaitable, timeout=timeout)
441        except asyncio.CancelledError as exc:
442            # Question: we assume this means the above wait_for() was
443            # cancelled; how do we distinguish between this and *us* being
444            # cancelled though?
445            if self.debug_print:
446                self.debug_print_call(
447                    f'{self._label}: message {message_id} was cancelled.'
448                )
449            if close_on_error:
450                self.close()
451
452            raise CommunicationError() from exc
453        except Exception as exc:
454            # If our timer timed-out or anything else went wrong with
455            # the stream, lump it in as a communication error.
456            if isinstance(
457                exc, asyncio.TimeoutError
458            ) or is_asyncio_streams_communication_error(exc):
459                if self.debug_print:
460                    self.debug_print_call(
461                        f'{self._label}: got {type(exc)} sending message'
462                        f' {message_id}; raising CommunicationError.'
463                    )
464
465                # Stop waiting on the response.
466                bytes_awaitable.cancel()
467
468                # Remove the record of this message.
469                del self._in_flight_messages[message_id]
470
471                if close_on_error:
472                    self.close()
473
474                # Let the user know something went wrong.
475                raise CommunicationError() from exc
476
477            # Some unexpected error; let it bubble up.
478            raise
479
480    def close(self) -> None:
481        """I said seagulls; mmmm; stop it now."""
482        self._check_env()
483
484        if self._closing:
485            return
486
487        if self.debug_print:
488            self.debug_print_call(f'{self._label}: closing...')
489
490        self._closing = True
491
492        # Kill all of our in-flight tasks.
493        if self.debug_print:
494            self.debug_print_call(f'{self._label}: cancelling tasks...')
495        for task in self._get_live_tasks():
496            task.cancel()
497
498        # Close our writer.
499        assert not self._did_close_writer
500        if self.debug_print:
501            self.debug_print_call(f'{self._label}: closing writer...')
502        self._writer.close()
503        self._did_close_writer = True
504
505        # We don't need this anymore and it is likely to be creating a
506        # dependency loop.
507        del self._handle_raw_message_call
508
509    def is_closing(self) -> bool:
510        """Have we begun the process of closing?"""
511        return self._closing
512
513    async def wait_closed(self) -> None:
514        """I said seagulls; mmmm; stop it now.
515
516        Wait for the endpoint to finish closing. This is called by run()
517        so generally does not need to be explicitly called.
518        """
519        # pylint: disable=too-many-branches
520        self._check_env()
521
522        # Make sure we only *enter* this call once.
523        if self._did_wait_closed:
524            return
525        self._did_wait_closed = True
526
527        if not self._closing:
528            raise RuntimeError('Must be called after close()')
529
530        if not self._did_close_writer:
531            logging.warning(
532                'RPCEndpoint wait_closed() called but never'
533                ' explicitly closed writer.'
534            )
535
536        live_tasks = self._get_live_tasks()
537
538        # Don't need our task list anymore; this should
539        # break any cyclical refs from tasks referring to us.
540        self._tasks = []
541
542        if self.debug_print:
543            self.debug_print_call(
544                f'{self._label}: waiting for tasks to finish: '
545                f' ({live_tasks=})...'
546            )
547
548        # Wait for all of our in-flight tasks to wrap up.
549        results = await asyncio.gather(*live_tasks, return_exceptions=True)
550        for result in results:
551            # We want to know if any errors happened aside from CancelledError
552            # (which are BaseExceptions, not Exception).
553            if isinstance(result, Exception):
554                logging.warning(
555                    'Got unexpected error cleaning up %s task: %s',
556                    self._label,
557                    result,
558                )
559
560        if not all(task.done() for task in live_tasks):
561            logging.warning(
562                'RPCEndpoint %d: not all live tasks marked done after gather.',
563                id(self),
564            )
565
566        if self.debug_print:
567            self.debug_print_call(
568                f'{self._label}: tasks finished; waiting for writer close...'
569            )
570
571        # Now wait for our writer to finish going down.
572        # When we close our writer it generally triggers errors
573        # in our current blocked read/writes. However that same
574        # error is also sometimes returned from _writer.wait_closed().
575        # See connection_lost() in asyncio/streams.py to see why.
576        # So let's silently ignore it when that happens.
577        assert self._writer.is_closing()
578        try:
579            # It seems that as of Python 3.9.x it is possible for this to hang
580            # indefinitely. See https://github.com/python/cpython/issues/83939
581            # It sounds like this should be fixed in 3.11 but for now just
582            # forcing the issue with a timeout here.
583            await asyncio.wait_for(
584                self._writer.wait_closed(),
585                # timeout=60.0 * 6.0,
586                timeout=30.0,
587            )
588        except asyncio.TimeoutError:
589            logging.info(
590                'Timeout on _writer.wait_closed() for %s rpc (transport=%s).',
591                self._label,
592                ssl_stream_writer_underlying_transport_info(self._writer),
593            )
594            if self.debug_print:
595                self.debug_print_call(
596                    f'{self._label}: got timeout in _writer.wait_closed();'
597                    ' This should be fixed in future Python versions.'
598                )
599        except Exception as exc:
600            if not self._is_expected_connection_error(exc):
601                logging.exception('Error closing _writer for %s.', self._label)
602            else:
603                if self.debug_print:
604                    self.debug_print_call(
605                        f'{self._label}: silently ignoring error in'
606                        f' _writer.wait_closed(): {exc}.'
607                    )
608        except asyncio.CancelledError:
609            logging.warning(
610                'RPCEndpoint.wait_closed()'
611                ' got asyncio.CancelledError; not expected.'
612            )
613            raise
614        assert not self._did_wait_closed_writer
615        self._did_wait_closed_writer = True
616
617    def _tm(self) -> str:
618        """Simple readable time value for debugging."""
619        tval = time.monotonic() % 100.0
620        return f'{tval:.2f}'
621
622    async def _run_read_task(self) -> None:
623        """Read from the peer."""
624        self._check_env()
625        assert self._peer_info is None
626
627        # Bug fix: if we don't have this set we will never time out
628        # if we never receive any data from the other end.
629        self._last_keepalive_receive_time = time.monotonic()
630
631        # The first thing they should send us is their handshake; then
632        # we'll know if/how we can talk to them.
633        mlen = await self._read_int_32()
634        message = await self._reader.readexactly(mlen)
635        self._total_bytes_read += mlen
636        self._peer_info = dataclass_from_json(_PeerInfo, message.decode())
637        self._last_keepalive_receive_time = time.monotonic()
638        if self.debug_print:
639            self.debug_print_call(
640                f'{self._label}: received handshake at {self._tm()}.'
641            )
642
643        # Now just sit and handle stuff as it comes in.
644        while True:
645            if self._closing:
646                return
647
648            # Read message type.
649            mtype = _PacketType(await self._read_int_8())
650            if mtype is _PacketType.HANDSHAKE:
651                raise RuntimeError('Got multiple handshakes')
652
653            if mtype is _PacketType.KEEPALIVE:
654                if self.debug_print_io:
655                    self.debug_print_call(
656                        f'{self._label}: received keepalive'
657                        f' at {self._tm()}.'
658                    )
659                self._last_keepalive_receive_time = time.monotonic()
660
661            elif mtype is _PacketType.MESSAGE:
662                await self._handle_message_packet(big=False)
663
664            elif mtype is _PacketType.MESSAGE_BIG:
665                await self._handle_message_packet(big=True)
666
667            elif mtype is _PacketType.RESPONSE:
668                await self._handle_response_packet(big=False)
669
670            elif mtype is _PacketType.RESPONSE_BIG:
671                await self._handle_response_packet(big=True)
672
673            else:
674                assert_never(mtype)
675
676    async def _handle_message_packet(self, big: bool) -> None:
677        assert self._peer_info is not None
678        msgid = await self._read_int_16()
679        if big:
680            msglen = await self._read_int_32()
681        else:
682            msglen = await self._read_int_16()
683        msg = await self._reader.readexactly(msglen)
684        self._total_bytes_read += msglen
685        if self.debug_print_io:
686            self.debug_print_call(
687                f'{self._label}: received message {msgid}'
688                f' of size {msglen} at {self._tm()}.'
689            )
690
691        # Create a message-task to handle this message and return
692        # a response (we don't want to block while that happens).
693        assert not self._closing
694        self._prune_tasks()  # Keep from filling with dead tasks.
695        self._tasks.append(
696            asyncio.create_task(
697                self._handle_raw_message(message_id=msgid, message=msg),
698                name='efro rpc message handle',
699            )
700        )
701        if self.debug_print:
702            self.debug_print_call(
703                f'{self._label}: done handling message at {self._tm()}.'
704            )
705
706    async def _handle_response_packet(self, big: bool) -> None:
707        assert self._peer_info is not None
708        msgid = await self._read_int_16()
709        # Protocol 2 gained 32 bit data lengths.
710        if big:
711            rsplen = await self._read_int_32()
712        else:
713            rsplen = await self._read_int_16()
714        if self.debug_print_io:
715            self.debug_print_call(
716                f'{self._label}: received response {msgid}'
717                f' of size {rsplen} at {self._tm()}.'
718            )
719        rsp = await self._reader.readexactly(rsplen)
720        self._total_bytes_read += rsplen
721        msgobj = self._in_flight_messages.get(msgid)
722        if msgobj is None:
723            # It's possible for us to get a response to a message
724            # that has timed out. In this case we will have no local
725            # record of it.
726            if self.debug_print:
727                self.debug_print_call(
728                    f'{self._label}: got response for nonexistent'
729                    f' message id {msgid}; perhaps it timed out?'
730                )
731        else:
732            msgobj.set_response(rsp)
733
734    async def _run_write_task(self) -> None:
735        """Write to the peer."""
736
737        self._check_env()
738
739        # Introduce ourself so our peer knows how it can talk to us.
740        data = dataclass_to_json(
741            _PeerInfo(
742                protocol=OUR_PROTOCOL,
743                keepalive_interval=self._keepalive_interval,
744            )
745        ).encode()
746        self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data)
747
748        # Now just write out-messages as they come in.
749        while True:
750            # Wait until some data comes in.
751            await self._have_out_packets.wait()
752
753            assert self._out_packets
754            data = self._out_packets.popleft()
755
756            # Important: only clear this once all packets are sent.
757            if not self._out_packets:
758                self._have_out_packets.clear()
759
760            self._writer.write(data)
761
762            # This should keep our writer from buffering huge amounts
763            # of outgoing data. We must remember though that we also
764            # need to prevent _out_packets from growing too large and
765            # that part's on us.
766            await self._writer.drain()
767
768            # For now we're not applying backpressure, but let's make
769            # noise if this gets out of hand.
770            if len(self._out_packets) > 200:
771                if not self._did_out_packets_buildup_warning:
772                    logging.warning(
773                        '_out_packets building up too'
774                        ' much on RPCEndpoint %s.',
775                        id(self),
776                    )
777                    self._did_out_packets_buildup_warning = True
778
779    async def _run_keepalive_task(self) -> None:
780        """Send periodic keepalive packets."""
781        self._check_env()
782
783        # We explicitly send our own keepalive packets so we can stay
784        # more on top of the connection state and possibly decide to
785        # kill it when contact is lost more quickly than the OS would
786        # do itself (or at least keep the user informed that the
787        # connection is lagging). It sounds like we could have the TCP
788        # layer do this sort of thing itself but that might be
789        # OS-specific so gonna go this way for now.
790        while True:
791            assert not self._closing
792            await asyncio.sleep(self._keepalive_interval)
793            if not self.test_suppress_keepalives:
794                self._enqueue_outgoing_packet(
795                    _PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER)
796                )
797
798            # Also go ahead and handle dropping the connection if we
799            # haven't heard from the peer in a while.
800            # NOTE: perhaps we want to do something more exact than
801            # this which only checks once per keepalive-interval?..
802            now = time.monotonic()
803            if (
804                self._last_keepalive_receive_time is not None
805                and now - self._last_keepalive_receive_time
806                > self._keepalive_timeout
807            ):
808                if self.debug_print:
809                    since = now - self._last_keepalive_receive_time
810                    self.debug_print_call(
811                        f'{self._label}: reached keepalive time-out'
812                        f' ({since:.1f}s).'
813                    )
814                raise _KeepaliveTimeoutError()
815
816    async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None:
817        try:
818            await call
819        except Exception as exc:
820            # We expect connection errors to put us here, but make noise
821            # if something else does.
822            if not self._is_expected_connection_error(exc):
823                logging.exception(
824                    'Unexpected error in rpc %s %s task'
825                    ' (age=%.1f, total_bytes_read=%d).',
826                    self._label,
827                    tasklabel,
828                    time.monotonic() - self._create_time,
829                    self._total_bytes_read,
830                )
831            else:
832                if self.debug_print:
833                    self.debug_print_call(
834                        f'{self._label}: {tasklabel} task will exit cleanly'
835                        f' due to {exc!r}.'
836                    )
837        finally:
838            # Any core task exiting triggers shutdown.
839            if self.debug_print:
840                self.debug_print_call(
841                    f'{self._label}: {tasklabel} task exiting...'
842                )
843            self.close()
844
845    async def _handle_raw_message(
846        self, message_id: int, message: bytes
847    ) -> None:
848        try:
849            response = await self._handle_raw_message_call(message)
850        except Exception:
851            # We expect local message handler to always succeed.
852            # If that doesn't happen, make a fuss so we know to fix it.
853            # The other end will simply never get a response to this
854            # message.
855            logging.exception('Error handling raw rpc message')
856            return
857
858        assert self._peer_info is not None
859
860        if self._peer_info.protocol == 1:
861            if len(response) > 65535:
862                raise RuntimeError('Response cannot be larger than 65535 bytes')
863
864        # Now send back our response.
865        # Payload consists of type (1b), msgid (2b), len (2b), and data.
866        if len(response) > 65535:
867            self._enqueue_outgoing_packet(
868                _PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER)
869                + message_id.to_bytes(2, _BYTE_ORDER)
870                + len(response).to_bytes(4, _BYTE_ORDER)
871                + response
872            )
873        else:
874            self._enqueue_outgoing_packet(
875                _PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER)
876                + message_id.to_bytes(2, _BYTE_ORDER)
877                + len(response).to_bytes(2, _BYTE_ORDER)
878                + response
879            )
880
881    async def _read_int_8(self) -> int:
882        out = int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER)
883        self._total_bytes_read += 1
884        return out
885
886    async def _read_int_16(self) -> int:
887        out = int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER)
888        self._total_bytes_read += 2
889        return out
890
891    async def _read_int_32(self) -> int:
892        out = int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER)
893        self._total_bytes_read += 4
894        return out
895
896    @classmethod
897    def _is_expected_connection_error(cls, exc: Exception) -> bool:
898        """Stuff we expect to end our connection in normal circumstances."""
899
900        if isinstance(exc, _KeepaliveTimeoutError):
901            return True
902
903        return is_asyncio_streams_communication_error(exc)
904
905    def _check_env(self) -> None:
906        # I was seeing that asyncio stuff wasn't working as expected if
907        # created in one thread and used in another (and have verified
908        # that this is part of the design), so let's enforce a single
909        # thread for all use of an instance.
910        if current_thread() is not self._thread:
911            raise RuntimeError(
912                'This must be called from the same thread'
913                ' that the endpoint was created in.'
914            )
915
916        # This should always be the case if thread is the same.
917        assert asyncio.get_running_loop() is self._event_loop
918
919    def _enqueue_outgoing_packet(self, data: bytes) -> None:
920        """Enqueue a raw packet to be sent. Must be called from our loop."""
921        self._check_env()
922
923        if self.debug_print_io:
924            self.debug_print_call(
925                f'{self._label}: enqueueing outgoing packet'
926                f' {data[:50]!r} at {self._tm()}.'
927            )
928
929        # Add the data and let our write task know about it.
930        self._out_packets.append(data)
931        self._have_out_packets.set()
932
933    def _prune_tasks(self) -> None:
934        self._tasks = self._get_live_tasks()
935
936    def _get_live_tasks(self) -> list[asyncio.Task]:
937        return [t for t in self._tasks if not t.done()]
OUR_PROTOCOL = 2
def ssl_stream_writer_underlying_transport_info(writer: asyncio.streams.StreamWriter) -> str:
71def ssl_stream_writer_underlying_transport_info(
72    writer: asyncio.StreamWriter,
73) -> str:
74    """For debugging SSL Stream connections; returns raw transport info."""
75    # Note: accessing internals here so just returning info and not
76    # actual objs to reduce potential for breakage.
77    transport = getattr(writer, '_transport', None)
78    if transport is not None:
79        sslproto = getattr(transport, '_ssl_protocol', None)
80        if sslproto is not None:
81            raw_transport = getattr(sslproto, '_transport', None)
82            if raw_transport is not None:
83                return str(raw_transport)
84    return '(not found)'

For debugging SSL Stream connections; returns raw transport info.

def ssl_stream_writer_force_close_check(writer: asyncio.streams.StreamWriter) -> None:
 87def ssl_stream_writer_force_close_check(writer: asyncio.StreamWriter) -> None:
 88    """Ensure a writer is closed; hacky workaround for odd hang."""
 89    from threading import Thread
 90
 91    # Disabling for now..
 92    if bool(True):
 93        return
 94
 95    # Hopefully can remove this in Python 3.11?...
 96    # see issue with is_closing() below for more details.
 97    transport = getattr(writer, '_transport', None)
 98    if transport is not None:
 99        sslproto = getattr(transport, '_ssl_protocol', None)
100        if sslproto is not None:
101            raw_transport = getattr(sslproto, '_transport', None)
102            if raw_transport is not None:
103                Thread(
104                    target=partial(
105                        _do_writer_force_close_check, weakref.ref(raw_transport)
106                    ),
107                    daemon=True,
108                ).start()

Ensure a writer is closed; hacky workaround for odd hang.

class RPCEndpoint:
154class RPCEndpoint:
155    """Facilitates asynchronous multiplexed remote procedure calls.
156
157    Be aware that, while multiple calls can be in flight in either direction
158    simultaneously, packets are still sent serially in a single
159    stream. So excessively long messages/responses will delay all other
160    communication. If/when this becomes an issue we can look into breaking up
161    long messages into multiple packets.
162    """
163
164    # Set to True on an instance to test keepalive failures.
165    test_suppress_keepalives: bool = False
166
167    # How long we should wait before giving up on a message by default.
168    # Note this includes processing time on the other end.
169    DEFAULT_MESSAGE_TIMEOUT = 60.0
170
171    # How often we send out keepalive packets by default.
172    DEFAULT_KEEPALIVE_INTERVAL = 10.73  # (avoid too regular of values)
173
174    # How long we can go without receiving a keepalive packet before we
175    # disconnect.
176    DEFAULT_KEEPALIVE_TIMEOUT = 30.0
177
178    def __init__(
179        self,
180        handle_raw_message_call: Callable[[bytes], Awaitable[bytes]],
181        reader: asyncio.StreamReader,
182        writer: asyncio.StreamWriter,
183        label: str,
184        debug_print: bool = False,
185        debug_print_io: bool = False,
186        debug_print_call: Callable[[str], None] | None = None,
187        keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL,
188        keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT,
189    ) -> None:
190        self._handle_raw_message_call = handle_raw_message_call
191        self._reader = reader
192        self._writer = writer
193        self.debug_print = debug_print
194        self.debug_print_io = debug_print_io
195        if debug_print_call is None:
196            debug_print_call = print
197        self.debug_print_call: Callable[[str], None] = debug_print_call
198        self._label = label
199        self._thread = current_thread()
200        self._closing = False
201        self._did_wait_closed = False
202        self._event_loop = asyncio.get_running_loop()
203        self._out_packets = deque[bytes]()
204        self._have_out_packets = asyncio.Event()
205        self._run_called = False
206        self._peer_info: _PeerInfo | None = None
207        self._keepalive_interval = keepalive_interval
208        self._keepalive_timeout = keepalive_timeout
209        self._did_close_writer = False
210        self._did_wait_closed_writer = False
211        self._did_out_packets_buildup_warning = False
212        self._total_bytes_read = 0
213        self._create_time = time.monotonic()
214
215        # Need to hold weak-refs to these otherwise it creates dep-loops
216        # which keeps us alive.
217        self._tasks: list[asyncio.Task] = []
218
219        # When we last got a keepalive or equivalent (time.monotonic value)
220        self._last_keepalive_receive_time: float | None = None
221
222        # (Start near the end to make sure our looping logic is sound).
223        self._next_message_id = 65530
224
225        self._in_flight_messages: dict[int, _InFlightMessage] = {}
226
227        if self.debug_print:
228            peername = self._writer.get_extra_info('peername')
229            self.debug_print_call(
230                f'{self._label}: connected to {peername} at {self._tm()}.'
231            )
232
233    def __del__(self) -> None:
234        if self._run_called:
235            if not self._did_close_writer:
236                logging.warning(
237                    'RPCEndpoint %d dying with run'
238                    ' called but writer not closed (transport=%s).',
239                    id(self),
240                    ssl_stream_writer_underlying_transport_info(self._writer),
241                )
242            elif not self._did_wait_closed_writer:
243                logging.warning(
244                    'RPCEndpoint %d dying with run called'
245                    ' but writer not wait-closed (transport=%s).',
246                    id(self),
247                    ssl_stream_writer_underlying_transport_info(self._writer),
248                )
249
250        # Currently seeing rare issue where sockets don't go down;
251        # let's add a timer to force the issue until we can figure it out.
252        ssl_stream_writer_force_close_check(self._writer)
253
254    async def run(self) -> None:
255        """Run the endpoint until the connection is lost or closed.
256
257        Handles closing the provided reader/writer on close.
258        """
259        try:
260            await self._do_run()
261        except asyncio.CancelledError:
262            # We aren't really designed to be cancelled so let's warn
263            # if it happens.
264            logging.warning(
265                'RPCEndpoint.run got CancelledError;'
266                ' want to try and avoid this.'
267            )
268            raise
269
270    async def _do_run(self) -> None:
271        self._check_env()
272
273        if self._run_called:
274            raise RuntimeError('Run can be called only once per endpoint.')
275        self._run_called = True
276
277        core_tasks = [
278            asyncio.create_task(
279                self._run_core_task('keepalive', self._run_keepalive_task()),
280                name='rpc keepalive',
281            ),
282            asyncio.create_task(
283                self._run_core_task('read', self._run_read_task()),
284                name='rpc read',
285            ),
286            asyncio.create_task(
287                self._run_core_task('write', self._run_write_task()),
288                name='rpc write',
289            ),
290        ]
291        self._tasks += core_tasks
292
293        # Run our core tasks until they all complete.
294        results = await asyncio.gather(*core_tasks, return_exceptions=True)
295
296        # Core tasks should handle their own errors; the only ones
297        # we expect to bubble up are CancelledError.
298        for result in results:
299            # We want to know if any errors happened aside from CancelledError
300            # (which are BaseExceptions, not Exception).
301            if isinstance(result, Exception):
302                logging.warning(
303                    'Got unexpected error from %s core task: %s',
304                    self._label,
305                    result,
306                )
307
308        if not all(task.done() for task in core_tasks):
309            logging.warning(
310                'RPCEndpoint %d: not all core tasks marked done after gather.',
311                id(self),
312            )
313
314        # Shut ourself down.
315        try:
316            self.close()
317            await self.wait_closed()
318        except Exception:
319            logging.exception('Error closing %s.', self._label)
320
321        if self.debug_print:
322            self.debug_print_call(f'{self._label}: finished.')
323
324    def send_message(
325        self,
326        message: bytes,
327        timeout: float | None = None,
328        close_on_error: bool = True,
329    ) -> Awaitable[bytes]:
330        """Send a message to the peer and return a response.
331
332        If timeout is not provided, the default will be used.
333        Raises a CommunicationError if the round trip is not completed
334        for any reason.
335
336        By default, the entire endpoint will go down in the case of
337        errors. This allows messages to be treated as 'reliable' with
338        respect to a given endpoint. Pass close_on_error=False to
339        override this for a particular message.
340        """
341        # Note: This call is synchronous so that the first part of it
342        # (enqueueing outgoing messages) happens synchronously. If it were
343        # a pure async call it could be possible for send order to vary
344        # based on how the async tasks get processed.
345
346        if self.debug_print_io:
347            self.debug_print_call(
348                f'{self._label}: sending message of size {len(message)}'
349                f' at {self._tm()}.'
350            )
351
352        self._check_env()
353
354        if self._closing:
355            raise CommunicationError('Endpoint is closed.')
356
357        if self.debug_print_io:
358            self.debug_print_call(
359                f'{self._label}: have peerinfo? {self._peer_info is not None}.'
360            )
361
362        # message_id is a 16 bit looping value.
363        message_id = self._next_message_id
364        self._next_message_id = (self._next_message_id + 1) % 65536
365
366        if self.debug_print_io:
367            self.debug_print_call(
368                f'{self._label}: will enqueue at {self._tm()}.'
369            )
370
371        # FIXME - should handle backpressure (waiting here if there are
372        # enough packets already enqueued).
373
374        if len(message) > 65535:
375            # Payload consists of type (1b), message_id (2b),
376            # len (4b), and data.
377            self._enqueue_outgoing_packet(
378                _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER)
379                + message_id.to_bytes(2, _BYTE_ORDER)
380                + len(message).to_bytes(4, _BYTE_ORDER)
381                + message
382            )
383        else:
384            # Payload consists of type (1b), message_id (2b),
385            # len (2b), and data.
386            self._enqueue_outgoing_packet(
387                _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER)
388                + message_id.to_bytes(2, _BYTE_ORDER)
389                + len(message).to_bytes(2, _BYTE_ORDER)
390                + message
391            )
392
393        if self.debug_print_io:
394            self.debug_print_call(
395                f'{self._label}: enqueued message of size {len(message)}'
396                f' at {self._tm()}.'
397            )
398
399        # Make an entry so we know this message is out there.
400        assert message_id not in self._in_flight_messages
401        msgobj = self._in_flight_messages[message_id] = _InFlightMessage()
402
403        # Also add its task to our list so we properly cancel it if we die.
404        self._prune_tasks()  # Keep our list from filling with dead tasks.
405        self._tasks.append(msgobj.wait_task)
406
407        # Note: we always want to incorporate a timeout. Individual
408        # messages may hang or error on the other end and this ensures
409        # we won't build up lots of zombie tasks waiting around for
410        # responses that will never arrive.
411        if timeout is None:
412            timeout = self.DEFAULT_MESSAGE_TIMEOUT
413        assert timeout is not None
414
415        bytes_awaitable = msgobj.wait_task
416
417        # Now complete the send asynchronously.
418        return self._send_message(
419            message, timeout, close_on_error, bytes_awaitable, message_id
420        )
421
422    async def _send_message(
423        self,
424        message: bytes,
425        timeout: float | None,
426        close_on_error: bool,
427        bytes_awaitable: asyncio.Task[bytes],
428        message_id: int,
429    ) -> bytes:
430        # We need to know their protocol, so if we haven't gotten a handshake
431        # from them yet, just wait.
432        while self._peer_info is None:
433            await asyncio.sleep(0.01)
434        assert self._peer_info is not None
435
436        if self._peer_info.protocol == 1:
437            if len(message) > 65535:
438                raise RuntimeError('Message cannot be larger than 65535 bytes')
439
440        try:
441            return await asyncio.wait_for(bytes_awaitable, timeout=timeout)
442        except asyncio.CancelledError as exc:
443            # Question: we assume this means the above wait_for() was
444            # cancelled; how do we distinguish between this and *us* being
445            # cancelled though?
446            if self.debug_print:
447                self.debug_print_call(
448                    f'{self._label}: message {message_id} was cancelled.'
449                )
450            if close_on_error:
451                self.close()
452
453            raise CommunicationError() from exc
454        except Exception as exc:
455            # If our timer timed-out or anything else went wrong with
456            # the stream, lump it in as a communication error.
457            if isinstance(
458                exc, asyncio.TimeoutError
459            ) or is_asyncio_streams_communication_error(exc):
460                if self.debug_print:
461                    self.debug_print_call(
462                        f'{self._label}: got {type(exc)} sending message'
463                        f' {message_id}; raising CommunicationError.'
464                    )
465
466                # Stop waiting on the response.
467                bytes_awaitable.cancel()
468
469                # Remove the record of this message.
470                del self._in_flight_messages[message_id]
471
472                if close_on_error:
473                    self.close()
474
475                # Let the user know something went wrong.
476                raise CommunicationError() from exc
477
478            # Some unexpected error; let it bubble up.
479            raise
480
481    def close(self) -> None:
482        """I said seagulls; mmmm; stop it now."""
483        self._check_env()
484
485        if self._closing:
486            return
487
488        if self.debug_print:
489            self.debug_print_call(f'{self._label}: closing...')
490
491        self._closing = True
492
493        # Kill all of our in-flight tasks.
494        if self.debug_print:
495            self.debug_print_call(f'{self._label}: cancelling tasks...')
496        for task in self._get_live_tasks():
497            task.cancel()
498
499        # Close our writer.
500        assert not self._did_close_writer
501        if self.debug_print:
502            self.debug_print_call(f'{self._label}: closing writer...')
503        self._writer.close()
504        self._did_close_writer = True
505
506        # We don't need this anymore and it is likely to be creating a
507        # dependency loop.
508        del self._handle_raw_message_call
509
510    def is_closing(self) -> bool:
511        """Have we begun the process of closing?"""
512        return self._closing
513
514    async def wait_closed(self) -> None:
515        """I said seagulls; mmmm; stop it now.
516
517        Wait for the endpoint to finish closing. This is called by run()
518        so generally does not need to be explicitly called.
519        """
520        # pylint: disable=too-many-branches
521        self._check_env()
522
523        # Make sure we only *enter* this call once.
524        if self._did_wait_closed:
525            return
526        self._did_wait_closed = True
527
528        if not self._closing:
529            raise RuntimeError('Must be called after close()')
530
531        if not self._did_close_writer:
532            logging.warning(
533                'RPCEndpoint wait_closed() called but never'
534                ' explicitly closed writer.'
535            )
536
537        live_tasks = self._get_live_tasks()
538
539        # Don't need our task list anymore; this should
540        # break any cyclical refs from tasks referring to us.
541        self._tasks = []
542
543        if self.debug_print:
544            self.debug_print_call(
545                f'{self._label}: waiting for tasks to finish: '
546                f' ({live_tasks=})...'
547            )
548
549        # Wait for all of our in-flight tasks to wrap up.
550        results = await asyncio.gather(*live_tasks, return_exceptions=True)
551        for result in results:
552            # We want to know if any errors happened aside from CancelledError
553            # (which are BaseExceptions, not Exception).
554            if isinstance(result, Exception):
555                logging.warning(
556                    'Got unexpected error cleaning up %s task: %s',
557                    self._label,
558                    result,
559                )
560
561        if not all(task.done() for task in live_tasks):
562            logging.warning(
563                'RPCEndpoint %d: not all live tasks marked done after gather.',
564                id(self),
565            )
566
567        if self.debug_print:
568            self.debug_print_call(
569                f'{self._label}: tasks finished; waiting for writer close...'
570            )
571
572        # Now wait for our writer to finish going down.
573        # When we close our writer it generally triggers errors
574        # in our current blocked read/writes. However that same
575        # error is also sometimes returned from _writer.wait_closed().
576        # See connection_lost() in asyncio/streams.py to see why.
577        # So let's silently ignore it when that happens.
578        assert self._writer.is_closing()
579        try:
580            # It seems that as of Python 3.9.x it is possible for this to hang
581            # indefinitely. See https://github.com/python/cpython/issues/83939
582            # It sounds like this should be fixed in 3.11 but for now just
583            # forcing the issue with a timeout here.
584            await asyncio.wait_for(
585                self._writer.wait_closed(),
586                # timeout=60.0 * 6.0,
587                timeout=30.0,
588            )
589        except asyncio.TimeoutError:
590            logging.info(
591                'Timeout on _writer.wait_closed() for %s rpc (transport=%s).',
592                self._label,
593                ssl_stream_writer_underlying_transport_info(self._writer),
594            )
595            if self.debug_print:
596                self.debug_print_call(
597                    f'{self._label}: got timeout in _writer.wait_closed();'
598                    ' This should be fixed in future Python versions.'
599                )
600        except Exception as exc:
601            if not self._is_expected_connection_error(exc):
602                logging.exception('Error closing _writer for %s.', self._label)
603            else:
604                if self.debug_print:
605                    self.debug_print_call(
606                        f'{self._label}: silently ignoring error in'
607                        f' _writer.wait_closed(): {exc}.'
608                    )
609        except asyncio.CancelledError:
610            logging.warning(
611                'RPCEndpoint.wait_closed()'
612                ' got asyncio.CancelledError; not expected.'
613            )
614            raise
615        assert not self._did_wait_closed_writer
616        self._did_wait_closed_writer = True
617
618    def _tm(self) -> str:
619        """Simple readable time value for debugging."""
620        tval = time.monotonic() % 100.0
621        return f'{tval:.2f}'
622
623    async def _run_read_task(self) -> None:
624        """Read from the peer."""
625        self._check_env()
626        assert self._peer_info is None
627
628        # Bug fix: if we don't have this set we will never time out
629        # if we never receive any data from the other end.
630        self._last_keepalive_receive_time = time.monotonic()
631
632        # The first thing they should send us is their handshake; then
633        # we'll know if/how we can talk to them.
634        mlen = await self._read_int_32()
635        message = await self._reader.readexactly(mlen)
636        self._total_bytes_read += mlen
637        self._peer_info = dataclass_from_json(_PeerInfo, message.decode())
638        self._last_keepalive_receive_time = time.monotonic()
639        if self.debug_print:
640            self.debug_print_call(
641                f'{self._label}: received handshake at {self._tm()}.'
642            )
643
644        # Now just sit and handle stuff as it comes in.
645        while True:
646            if self._closing:
647                return
648
649            # Read message type.
650            mtype = _PacketType(await self._read_int_8())
651            if mtype is _PacketType.HANDSHAKE:
652                raise RuntimeError('Got multiple handshakes')
653
654            if mtype is _PacketType.KEEPALIVE:
655                if self.debug_print_io:
656                    self.debug_print_call(
657                        f'{self._label}: received keepalive'
658                        f' at {self._tm()}.'
659                    )
660                self._last_keepalive_receive_time = time.monotonic()
661
662            elif mtype is _PacketType.MESSAGE:
663                await self._handle_message_packet(big=False)
664
665            elif mtype is _PacketType.MESSAGE_BIG:
666                await self._handle_message_packet(big=True)
667
668            elif mtype is _PacketType.RESPONSE:
669                await self._handle_response_packet(big=False)
670
671            elif mtype is _PacketType.RESPONSE_BIG:
672                await self._handle_response_packet(big=True)
673
674            else:
675                assert_never(mtype)
676
677    async def _handle_message_packet(self, big: bool) -> None:
678        assert self._peer_info is not None
679        msgid = await self._read_int_16()
680        if big:
681            msglen = await self._read_int_32()
682        else:
683            msglen = await self._read_int_16()
684        msg = await self._reader.readexactly(msglen)
685        self._total_bytes_read += msglen
686        if self.debug_print_io:
687            self.debug_print_call(
688                f'{self._label}: received message {msgid}'
689                f' of size {msglen} at {self._tm()}.'
690            )
691
692        # Create a message-task to handle this message and return
693        # a response (we don't want to block while that happens).
694        assert not self._closing
695        self._prune_tasks()  # Keep from filling with dead tasks.
696        self._tasks.append(
697            asyncio.create_task(
698                self._handle_raw_message(message_id=msgid, message=msg),
699                name='efro rpc message handle',
700            )
701        )
702        if self.debug_print:
703            self.debug_print_call(
704                f'{self._label}: done handling message at {self._tm()}.'
705            )
706
707    async def _handle_response_packet(self, big: bool) -> None:
708        assert self._peer_info is not None
709        msgid = await self._read_int_16()
710        # Protocol 2 gained 32 bit data lengths.
711        if big:
712            rsplen = await self._read_int_32()
713        else:
714            rsplen = await self._read_int_16()
715        if self.debug_print_io:
716            self.debug_print_call(
717                f'{self._label}: received response {msgid}'
718                f' of size {rsplen} at {self._tm()}.'
719            )
720        rsp = await self._reader.readexactly(rsplen)
721        self._total_bytes_read += rsplen
722        msgobj = self._in_flight_messages.get(msgid)
723        if msgobj is None:
724            # It's possible for us to get a response to a message
725            # that has timed out. In this case we will have no local
726            # record of it.
727            if self.debug_print:
728                self.debug_print_call(
729                    f'{self._label}: got response for nonexistent'
730                    f' message id {msgid}; perhaps it timed out?'
731                )
732        else:
733            msgobj.set_response(rsp)
734
735    async def _run_write_task(self) -> None:
736        """Write to the peer."""
737
738        self._check_env()
739
740        # Introduce ourself so our peer knows how it can talk to us.
741        data = dataclass_to_json(
742            _PeerInfo(
743                protocol=OUR_PROTOCOL,
744                keepalive_interval=self._keepalive_interval,
745            )
746        ).encode()
747        self._writer.write(len(data).to_bytes(4, _BYTE_ORDER) + data)
748
749        # Now just write out-messages as they come in.
750        while True:
751            # Wait until some data comes in.
752            await self._have_out_packets.wait()
753
754            assert self._out_packets
755            data = self._out_packets.popleft()
756
757            # Important: only clear this once all packets are sent.
758            if not self._out_packets:
759                self._have_out_packets.clear()
760
761            self._writer.write(data)
762
763            # This should keep our writer from buffering huge amounts
764            # of outgoing data. We must remember though that we also
765            # need to prevent _out_packets from growing too large and
766            # that part's on us.
767            await self._writer.drain()
768
769            # For now we're not applying backpressure, but let's make
770            # noise if this gets out of hand.
771            if len(self._out_packets) > 200:
772                if not self._did_out_packets_buildup_warning:
773                    logging.warning(
774                        '_out_packets building up too'
775                        ' much on RPCEndpoint %s.',
776                        id(self),
777                    )
778                    self._did_out_packets_buildup_warning = True
779
780    async def _run_keepalive_task(self) -> None:
781        """Send periodic keepalive packets."""
782        self._check_env()
783
784        # We explicitly send our own keepalive packets so we can stay
785        # more on top of the connection state and possibly decide to
786        # kill it when contact is lost more quickly than the OS would
787        # do itself (or at least keep the user informed that the
788        # connection is lagging). It sounds like we could have the TCP
789        # layer do this sort of thing itself but that might be
790        # OS-specific so gonna go this way for now.
791        while True:
792            assert not self._closing
793            await asyncio.sleep(self._keepalive_interval)
794            if not self.test_suppress_keepalives:
795                self._enqueue_outgoing_packet(
796                    _PacketType.KEEPALIVE.value.to_bytes(1, _BYTE_ORDER)
797                )
798
799            # Also go ahead and handle dropping the connection if we
800            # haven't heard from the peer in a while.
801            # NOTE: perhaps we want to do something more exact than
802            # this which only checks once per keepalive-interval?..
803            now = time.monotonic()
804            if (
805                self._last_keepalive_receive_time is not None
806                and now - self._last_keepalive_receive_time
807                > self._keepalive_timeout
808            ):
809                if self.debug_print:
810                    since = now - self._last_keepalive_receive_time
811                    self.debug_print_call(
812                        f'{self._label}: reached keepalive time-out'
813                        f' ({since:.1f}s).'
814                    )
815                raise _KeepaliveTimeoutError()
816
817    async def _run_core_task(self, tasklabel: str, call: Awaitable) -> None:
818        try:
819            await call
820        except Exception as exc:
821            # We expect connection errors to put us here, but make noise
822            # if something else does.
823            if not self._is_expected_connection_error(exc):
824                logging.exception(
825                    'Unexpected error in rpc %s %s task'
826                    ' (age=%.1f, total_bytes_read=%d).',
827                    self._label,
828                    tasklabel,
829                    time.monotonic() - self._create_time,
830                    self._total_bytes_read,
831                )
832            else:
833                if self.debug_print:
834                    self.debug_print_call(
835                        f'{self._label}: {tasklabel} task will exit cleanly'
836                        f' due to {exc!r}.'
837                    )
838        finally:
839            # Any core task exiting triggers shutdown.
840            if self.debug_print:
841                self.debug_print_call(
842                    f'{self._label}: {tasklabel} task exiting...'
843                )
844            self.close()
845
846    async def _handle_raw_message(
847        self, message_id: int, message: bytes
848    ) -> None:
849        try:
850            response = await self._handle_raw_message_call(message)
851        except Exception:
852            # We expect local message handler to always succeed.
853            # If that doesn't happen, make a fuss so we know to fix it.
854            # The other end will simply never get a response to this
855            # message.
856            logging.exception('Error handling raw rpc message')
857            return
858
859        assert self._peer_info is not None
860
861        if self._peer_info.protocol == 1:
862            if len(response) > 65535:
863                raise RuntimeError('Response cannot be larger than 65535 bytes')
864
865        # Now send back our response.
866        # Payload consists of type (1b), msgid (2b), len (2b), and data.
867        if len(response) > 65535:
868            self._enqueue_outgoing_packet(
869                _PacketType.RESPONSE_BIG.value.to_bytes(1, _BYTE_ORDER)
870                + message_id.to_bytes(2, _BYTE_ORDER)
871                + len(response).to_bytes(4, _BYTE_ORDER)
872                + response
873            )
874        else:
875            self._enqueue_outgoing_packet(
876                _PacketType.RESPONSE.value.to_bytes(1, _BYTE_ORDER)
877                + message_id.to_bytes(2, _BYTE_ORDER)
878                + len(response).to_bytes(2, _BYTE_ORDER)
879                + response
880            )
881
882    async def _read_int_8(self) -> int:
883        out = int.from_bytes(await self._reader.readexactly(1), _BYTE_ORDER)
884        self._total_bytes_read += 1
885        return out
886
887    async def _read_int_16(self) -> int:
888        out = int.from_bytes(await self._reader.readexactly(2), _BYTE_ORDER)
889        self._total_bytes_read += 2
890        return out
891
892    async def _read_int_32(self) -> int:
893        out = int.from_bytes(await self._reader.readexactly(4), _BYTE_ORDER)
894        self._total_bytes_read += 4
895        return out
896
897    @classmethod
898    def _is_expected_connection_error(cls, exc: Exception) -> bool:
899        """Stuff we expect to end our connection in normal circumstances."""
900
901        if isinstance(exc, _KeepaliveTimeoutError):
902            return True
903
904        return is_asyncio_streams_communication_error(exc)
905
906    def _check_env(self) -> None:
907        # I was seeing that asyncio stuff wasn't working as expected if
908        # created in one thread and used in another (and have verified
909        # that this is part of the design), so let's enforce a single
910        # thread for all use of an instance.
911        if current_thread() is not self._thread:
912            raise RuntimeError(
913                'This must be called from the same thread'
914                ' that the endpoint was created in.'
915            )
916
917        # This should always be the case if thread is the same.
918        assert asyncio.get_running_loop() is self._event_loop
919
920    def _enqueue_outgoing_packet(self, data: bytes) -> None:
921        """Enqueue a raw packet to be sent. Must be called from our loop."""
922        self._check_env()
923
924        if self.debug_print_io:
925            self.debug_print_call(
926                f'{self._label}: enqueueing outgoing packet'
927                f' {data[:50]!r} at {self._tm()}.'
928            )
929
930        # Add the data and let our write task know about it.
931        self._out_packets.append(data)
932        self._have_out_packets.set()
933
934    def _prune_tasks(self) -> None:
935        self._tasks = self._get_live_tasks()
936
937    def _get_live_tasks(self) -> list[asyncio.Task]:
938        return [t for t in self._tasks if not t.done()]

Facilitates asynchronous multiplexed remote procedure calls.

Be aware that, while multiple calls can be in flight in either direction simultaneously, packets are still sent serially in a single stream. So excessively long messages/responses will delay all other communication. If/when this becomes an issue we can look into breaking up long messages into multiple packets.

RPCEndpoint( handle_raw_message_call: Callable[[bytes], Awaitable[bytes]], reader: asyncio.streams.StreamReader, writer: asyncio.streams.StreamWriter, label: str, debug_print: bool = False, debug_print_io: bool = False, debug_print_call: Optional[Callable[[str], NoneType]] = None, keepalive_interval: float = 10.73, keepalive_timeout: float = 30.0)
178    def __init__(
179        self,
180        handle_raw_message_call: Callable[[bytes], Awaitable[bytes]],
181        reader: asyncio.StreamReader,
182        writer: asyncio.StreamWriter,
183        label: str,
184        debug_print: bool = False,
185        debug_print_io: bool = False,
186        debug_print_call: Callable[[str], None] | None = None,
187        keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL,
188        keepalive_timeout: float = DEFAULT_KEEPALIVE_TIMEOUT,
189    ) -> None:
190        self._handle_raw_message_call = handle_raw_message_call
191        self._reader = reader
192        self._writer = writer
193        self.debug_print = debug_print
194        self.debug_print_io = debug_print_io
195        if debug_print_call is None:
196            debug_print_call = print
197        self.debug_print_call: Callable[[str], None] = debug_print_call
198        self._label = label
199        self._thread = current_thread()
200        self._closing = False
201        self._did_wait_closed = False
202        self._event_loop = asyncio.get_running_loop()
203        self._out_packets = deque[bytes]()
204        self._have_out_packets = asyncio.Event()
205        self._run_called = False
206        self._peer_info: _PeerInfo | None = None
207        self._keepalive_interval = keepalive_interval
208        self._keepalive_timeout = keepalive_timeout
209        self._did_close_writer = False
210        self._did_wait_closed_writer = False
211        self._did_out_packets_buildup_warning = False
212        self._total_bytes_read = 0
213        self._create_time = time.monotonic()
214
215        # Need to hold weak-refs to these otherwise it creates dep-loops
216        # which keeps us alive.
217        self._tasks: list[asyncio.Task] = []
218
219        # When we last got a keepalive or equivalent (time.monotonic value)
220        self._last_keepalive_receive_time: float | None = None
221
222        # (Start near the end to make sure our looping logic is sound).
223        self._next_message_id = 65530
224
225        self._in_flight_messages: dict[int, _InFlightMessage] = {}
226
227        if self.debug_print:
228            peername = self._writer.get_extra_info('peername')
229            self.debug_print_call(
230                f'{self._label}: connected to {peername} at {self._tm()}.'
231            )
test_suppress_keepalives: bool = False
DEFAULT_MESSAGE_TIMEOUT = 60.0
DEFAULT_KEEPALIVE_INTERVAL = 10.73
DEFAULT_KEEPALIVE_TIMEOUT = 30.0
debug_print
debug_print_io
debug_print_call: Callable[[str], NoneType]
async def run(self) -> None:
254    async def run(self) -> None:
255        """Run the endpoint until the connection is lost or closed.
256
257        Handles closing the provided reader/writer on close.
258        """
259        try:
260            await self._do_run()
261        except asyncio.CancelledError:
262            # We aren't really designed to be cancelled so let's warn
263            # if it happens.
264            logging.warning(
265                'RPCEndpoint.run got CancelledError;'
266                ' want to try and avoid this.'
267            )
268            raise

Run the endpoint until the connection is lost or closed.

Handles closing the provided reader/writer on close.

def send_message( self, message: bytes, timeout: float | None = None, close_on_error: bool = True) -> Awaitable[bytes]:
324    def send_message(
325        self,
326        message: bytes,
327        timeout: float | None = None,
328        close_on_error: bool = True,
329    ) -> Awaitable[bytes]:
330        """Send a message to the peer and return a response.
331
332        If timeout is not provided, the default will be used.
333        Raises a CommunicationError if the round trip is not completed
334        for any reason.
335
336        By default, the entire endpoint will go down in the case of
337        errors. This allows messages to be treated as 'reliable' with
338        respect to a given endpoint. Pass close_on_error=False to
339        override this for a particular message.
340        """
341        # Note: This call is synchronous so that the first part of it
342        # (enqueueing outgoing messages) happens synchronously. If it were
343        # a pure async call it could be possible for send order to vary
344        # based on how the async tasks get processed.
345
346        if self.debug_print_io:
347            self.debug_print_call(
348                f'{self._label}: sending message of size {len(message)}'
349                f' at {self._tm()}.'
350            )
351
352        self._check_env()
353
354        if self._closing:
355            raise CommunicationError('Endpoint is closed.')
356
357        if self.debug_print_io:
358            self.debug_print_call(
359                f'{self._label}: have peerinfo? {self._peer_info is not None}.'
360            )
361
362        # message_id is a 16 bit looping value.
363        message_id = self._next_message_id
364        self._next_message_id = (self._next_message_id + 1) % 65536
365
366        if self.debug_print_io:
367            self.debug_print_call(
368                f'{self._label}: will enqueue at {self._tm()}.'
369            )
370
371        # FIXME - should handle backpressure (waiting here if there are
372        # enough packets already enqueued).
373
374        if len(message) > 65535:
375            # Payload consists of type (1b), message_id (2b),
376            # len (4b), and data.
377            self._enqueue_outgoing_packet(
378                _PacketType.MESSAGE_BIG.value.to_bytes(1, _BYTE_ORDER)
379                + message_id.to_bytes(2, _BYTE_ORDER)
380                + len(message).to_bytes(4, _BYTE_ORDER)
381                + message
382            )
383        else:
384            # Payload consists of type (1b), message_id (2b),
385            # len (2b), and data.
386            self._enqueue_outgoing_packet(
387                _PacketType.MESSAGE.value.to_bytes(1, _BYTE_ORDER)
388                + message_id.to_bytes(2, _BYTE_ORDER)
389                + len(message).to_bytes(2, _BYTE_ORDER)
390                + message
391            )
392
393        if self.debug_print_io:
394            self.debug_print_call(
395                f'{self._label}: enqueued message of size {len(message)}'
396                f' at {self._tm()}.'
397            )
398
399        # Make an entry so we know this message is out there.
400        assert message_id not in self._in_flight_messages
401        msgobj = self._in_flight_messages[message_id] = _InFlightMessage()
402
403        # Also add its task to our list so we properly cancel it if we die.
404        self._prune_tasks()  # Keep our list from filling with dead tasks.
405        self._tasks.append(msgobj.wait_task)
406
407        # Note: we always want to incorporate a timeout. Individual
408        # messages may hang or error on the other end and this ensures
409        # we won't build up lots of zombie tasks waiting around for
410        # responses that will never arrive.
411        if timeout is None:
412            timeout = self.DEFAULT_MESSAGE_TIMEOUT
413        assert timeout is not None
414
415        bytes_awaitable = msgobj.wait_task
416
417        # Now complete the send asynchronously.
418        return self._send_message(
419            message, timeout, close_on_error, bytes_awaitable, message_id
420        )

Send a message to the peer and return a response.

If timeout is not provided, the default will be used. Raises a CommunicationError if the round trip is not completed for any reason.

By default, the entire endpoint will go down in the case of errors. This allows messages to be treated as 'reliable' with respect to a given endpoint. Pass close_on_error=False to override this for a particular message.

def close(self) -> None:
481    def close(self) -> None:
482        """I said seagulls; mmmm; stop it now."""
483        self._check_env()
484
485        if self._closing:
486            return
487
488        if self.debug_print:
489            self.debug_print_call(f'{self._label}: closing...')
490
491        self._closing = True
492
493        # Kill all of our in-flight tasks.
494        if self.debug_print:
495            self.debug_print_call(f'{self._label}: cancelling tasks...')
496        for task in self._get_live_tasks():
497            task.cancel()
498
499        # Close our writer.
500        assert not self._did_close_writer
501        if self.debug_print:
502            self.debug_print_call(f'{self._label}: closing writer...')
503        self._writer.close()
504        self._did_close_writer = True
505
506        # We don't need this anymore and it is likely to be creating a
507        # dependency loop.
508        del self._handle_raw_message_call

I said seagulls; mmmm; stop it now.

def is_closing(self) -> bool:
510    def is_closing(self) -> bool:
511        """Have we begun the process of closing?"""
512        return self._closing

Have we begun the process of closing?

async def wait_closed(self) -> None:
514    async def wait_closed(self) -> None:
515        """I said seagulls; mmmm; stop it now.
516
517        Wait for the endpoint to finish closing. This is called by run()
518        so generally does not need to be explicitly called.
519        """
520        # pylint: disable=too-many-branches
521        self._check_env()
522
523        # Make sure we only *enter* this call once.
524        if self._did_wait_closed:
525            return
526        self._did_wait_closed = True
527
528        if not self._closing:
529            raise RuntimeError('Must be called after close()')
530
531        if not self._did_close_writer:
532            logging.warning(
533                'RPCEndpoint wait_closed() called but never'
534                ' explicitly closed writer.'
535            )
536
537        live_tasks = self._get_live_tasks()
538
539        # Don't need our task list anymore; this should
540        # break any cyclical refs from tasks referring to us.
541        self._tasks = []
542
543        if self.debug_print:
544            self.debug_print_call(
545                f'{self._label}: waiting for tasks to finish: '
546                f' ({live_tasks=})...'
547            )
548
549        # Wait for all of our in-flight tasks to wrap up.
550        results = await asyncio.gather(*live_tasks, return_exceptions=True)
551        for result in results:
552            # We want to know if any errors happened aside from CancelledError
553            # (which are BaseExceptions, not Exception).
554            if isinstance(result, Exception):
555                logging.warning(
556                    'Got unexpected error cleaning up %s task: %s',
557                    self._label,
558                    result,
559                )
560
561        if not all(task.done() for task in live_tasks):
562            logging.warning(
563                'RPCEndpoint %d: not all live tasks marked done after gather.',
564                id(self),
565            )
566
567        if self.debug_print:
568            self.debug_print_call(
569                f'{self._label}: tasks finished; waiting for writer close...'
570            )
571
572        # Now wait for our writer to finish going down.
573        # When we close our writer it generally triggers errors
574        # in our current blocked read/writes. However that same
575        # error is also sometimes returned from _writer.wait_closed().
576        # See connection_lost() in asyncio/streams.py to see why.
577        # So let's silently ignore it when that happens.
578        assert self._writer.is_closing()
579        try:
580            # It seems that as of Python 3.9.x it is possible for this to hang
581            # indefinitely. See https://github.com/python/cpython/issues/83939
582            # It sounds like this should be fixed in 3.11 but for now just
583            # forcing the issue with a timeout here.
584            await asyncio.wait_for(
585                self._writer.wait_closed(),
586                # timeout=60.0 * 6.0,
587                timeout=30.0,
588            )
589        except asyncio.TimeoutError:
590            logging.info(
591                'Timeout on _writer.wait_closed() for %s rpc (transport=%s).',
592                self._label,
593                ssl_stream_writer_underlying_transport_info(self._writer),
594            )
595            if self.debug_print:
596                self.debug_print_call(
597                    f'{self._label}: got timeout in _writer.wait_closed();'
598                    ' This should be fixed in future Python versions.'
599                )
600        except Exception as exc:
601            if not self._is_expected_connection_error(exc):
602                logging.exception('Error closing _writer for %s.', self._label)
603            else:
604                if self.debug_print:
605                    self.debug_print_call(
606                        f'{self._label}: silently ignoring error in'
607                        f' _writer.wait_closed(): {exc}.'
608                    )
609        except asyncio.CancelledError:
610            logging.warning(
611                'RPCEndpoint.wait_closed()'
612                ' got asyncio.CancelledError; not expected.'
613            )
614            raise
615        assert not self._did_wait_closed_writer
616        self._did_wait_closed_writer = True

I said seagulls; mmmm; stop it now.

Wait for the endpoint to finish closing. This is called by run() so generally does not need to be explicitly called.