from __future__ import annotations import hashlib import hmac import logging import random import socket import ssl import struct import typing import uuid import trio import capport.comm.message import capport.database if typing.TYPE_CHECKING: from ..config import Config _logger = logging.getLogger(__name__) class HubConnectionReadError(ConnectionError): pass class HubConnectionClosedError(ConnectionError): pass class LoopbackConnectionError(Exception): pass class Channel: def __init__(self, hub: Hub, transport_stream, server_side: bool): self._hub = hub self._serverside = server_side self._ssl = trio.SSLStream(transport_stream, self._hub._anon_context, server_side=server_side) _logger.debug(f"{self}: created (server_side={server_side})") def __repr__(self) -> str: return f"Channel[0x{id(self):x}]" async def do_handshake(self) -> capport.comm.message.Hello: try: await self._ssl.do_handshake() ssl_binding = self._ssl.get_channel_binding() if not ssl_binding: # binding mustn't be None after successful handshake raise ConnectionError("Missing SSL channel binding") except (ssl.SSLSyscallError, trio.BrokenResourceError) as e: raise ConnectionError(e) from None msg = self._hub._make_hello(ssl_binding, server_side=self._serverside).to_message() await self.send_msg(msg) peer_hello = (await self.recv_msg()).to_variant() if not isinstance(peer_hello, capport.comm.message.Hello): raise HubConnectionReadError("Expected Hello as first message") expected_auth = self._hub._calc_authentication(ssl_binding, server_side=not self._serverside) auth_succ = peer_hello.authentication == expected_auth await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message()) peer_auth = (await self.recv_msg()).to_variant() if not isinstance(peer_auth, capport.comm.message.AuthenticationResult): raise HubConnectionReadError("Expected AuthenticationResult as second message") if not auth_succ or not peer_auth.success: raise HubConnectionReadError("Authentication failed") return peer_hello async def _read(self, num: int) -> bytes: assert num > 0 buf = b"" # _logger.debug(f"{self}:_read({num})") while num > 0: try: part = await self._ssl.receive_some(num) except (ssl.SSLSyscallError, trio.BrokenResourceError) as e: raise ConnectionError(e) from None # _logger.debug(f"{self}:_read({num}) got part {part!r}") if len(part) == 0: if len(buf) == 0: raise HubConnectionClosedError() raise HubConnectionReadError("Unexpected end of TLS stream") buf += part num -= len(part) if num < 0: raise HubConnectionReadError("TLS receive_some returned too much") return buf async def _recv_raw_msg(self) -> bytes: len_bytes = await self._read(4) (chunk_size,) = struct.unpack("!I", len_bytes) chunk = await self._read(chunk_size) if chunk is None: raise HubConnectionReadError("Unexpected end of TLS stream after chunk length") return chunk async def recv_msg(self) -> capport.comm.message.Message: try: chunk = await self._recv_raw_msg() except (ssl.SSLSyscallError, trio.BrokenResourceError) as e: raise ConnectionError(e) from None msg = capport.comm.message.Message() msg.ParseFromString(chunk) return msg async def _send_raw(self, chunk: bytes) -> None: try: await self._ssl.send_all(chunk) except (ssl.SSLSyscallError, trio.BrokenResourceError) as e: raise ConnectionError(e) from None async def send_msg(self, msg: capport.comm.message.Message): chunk = msg.SerializeToString(deterministic=True) chunk_size = len(chunk) len_bytes = struct.pack("!I", chunk_size) chunk = len_bytes + chunk await self._send_raw(chunk) async def aclose(self): try: await self._ssl.aclose() except (ssl.SSLSyscallError, trio.BrokenResourceError) as e: raise ConnectionError(e) from None class Connection: PING_INTERVAL = 10 RECEIVE_TIMEOUT = 15 SEND_TIMEOUT = 5 def __init__(self, hub: Hub, channel: Channel, peer: capport.comm.message.Hello): self._channel = channel self._hub = hub tx: trio.MemorySendChannel rx: trio.MemoryReceiveChannel (tx, rx) = trio.open_memory_channel(64) self._pending_tx = tx self._pending_rx = rx self.peer: capport.comm.message.Hello = peer self.peer_id: uuid.UUID = uuid.UUID(bytes=peer.instance_id) self.closed = trio.Event() # set by Hub._lost_peer _logger.debug(f"{self._channel}: authenticated -> {self.peer_id}") async def _sender(self, cancel_scope: trio.CancelScope) -> None: try: msg: capport.comm.message.Message | None while True: msg = None # make sure we send something every PING_INTERVAL with trio.move_on_after(self.PING_INTERVAL): msg = await self._pending_rx.receive() # if send blocks too long we're in trouble with trio.fail_after(self.SEND_TIMEOUT): if msg: await self._channel.send_msg(msg) else: await self._channel.send_msg(capport.comm.message.Ping(payload=b"ping").to_message()) except trio.TooSlowError: _logger.warning(f"{self._channel}: send timed out") except ConnectionError as e: _logger.warning(f"{self._channel}: failed sending: {e!r}") except Exception: _logger.exception(f"{self._channel}: failed sending") finally: cancel_scope.cancel() async def _receive(self, cancel_scope: trio.CancelScope) -> None: try: while True: try: with trio.fail_after(self.RECEIVE_TIMEOUT): msg = await self._channel.recv_msg() except (HubConnectionClosedError, ConnectionResetError): return except trio.TooSlowError: _logger.warning(f"{self._channel}: receive timed out") return await self._hub._received_msg(self.peer_id, msg) except ConnectionError as e: _logger.warning(f"{self._channel}: failed receiving: {e!r}") except Exception: _logger.exception(f"{self._channel}: failed receiving") finally: cancel_scope.cancel() async def _inner_run(self) -> None: if self.peer_id == self._hub._instance_id: # connected to ourself, don't need that raise LoopbackConnectionError() async with trio.open_nursery() as nursery: nursery.start_soon(self._sender, nursery.cancel_scope) # be nice and wait for new_peer beforce receiving messages # (won't work on failover to a second connection) await nursery.start(self._hub._new_peer, self.peer_id, self) nursery.start_soon(self._receive, nursery.cancel_scope) async def send_msg(self, *msgs: capport.comm.message.Message): try: for msg in msgs: await self._pending_tx.send(msg) except trio.ClosedResourceError: pass async def _run(self) -> None: try: await self._inner_run() finally: _logger.debug(f"{self._channel}: finished message handling") # basic (non-async) cleanup self._hub._lost_peer(self.peer_id, self) self._pending_tx.close() self._pending_rx.close() # allow 3 seconds for proper cleanup with trio.CancelScope(shield=True, deadline=trio.current_time() + 3): try: await self._channel.aclose() except OSError: pass @staticmethod async def run(hub: Hub, transport_stream, server_side: bool) -> None: channel = Channel(hub, transport_stream, server_side) try: with trio.fail_after(5): peer = await channel.do_handshake() except trio.TooSlowError: _logger.warning("Handshake timed out") return conn = Connection(hub, channel, peer) await conn._run() class ControllerConn: def __init__(self, hub: Hub, hostname: str): self._hub = hub self.hostname = hostname self.loopback = False async def _connect(self): _logger.info(f"Connecting to controller at {self.hostname}") with trio.fail_after(5): try: stream = await trio.open_tcp_stream(self.hostname, self._hub._config.controller_port) except OSError as e: _logger.warning(f"Failed to connect to controller at {self.hostname}: {e}") return try: await Connection.run(self._hub, stream, server_side=False) finally: _logger.info(f"Connection to {self.hostname} closed") async def run(self): while True: try: await self._connect() except LoopbackConnectionError: _logger.debug(f"Connection to {self.hostname} reached ourself") self.loopback = True return except trio.TooSlowError: pass # try again later retry_splay = random.random() * 5 await trio.sleep(10 + retry_splay) class HubApplication: async def new_peer(self, *, peer_id: uuid.UUID) -> None: _logger.info(f"New peer {peer_id}") def lost_peer(self, *, peer_id: uuid.UUID) -> None: _logger.warning(f"Lost peer {peer_id}") async def received_unknown_message(self, *, from_peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None: _logger.warning(f"Received from {from_peer_id}: {str(msg).strip()}") async def received_mac_state(self, *, from_peer_id: uuid.UUID, states: capport.comm.message.MacStates) -> None: if _logger.isEnabledFor(logging.DEBUG): _logger.debug(f"Received states from {from_peer_id}: {str(states).strip()}") async def mac_states_changed( self, *, from_peer_id: uuid.UUID, pending_updates: capport.database.PendingUpdates, ) -> None: if _logger.isEnabledFor(logging.DEBUG): _logger.debug(f"Received new states from {from_peer_id}: {pending_updates}") class Hub: def __init__(self, config: Config, app: HubApplication, *, is_controller: bool) -> None: self._config = config self._instance_id = uuid.uuid4() self._hostname = socket.getfqdn() self._app = app self._is_controller = is_controller state_filename: str if is_controller: state_filename = config.database_file else: state_filename = "" self.database = capport.database.Database(state_filename=state_filename) self._anon_context = ssl.SSLContext() # python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon self._anon_context.minimum_version = ssl.TLSVersion.TLSv1_2 self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2 # -> AECDH-AES256-SHA # sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way? self._anon_context.set_ciphers("HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0") self._controllers: dict[str, ControllerConn] = {} self._established: dict[uuid.UUID, Connection] = {} async def _accept(self, stream): remotename = stream.socket.getpeername() if isinstance(remotename, tuple) and len(remotename) == 2: remote = f"[{remotename[0]}]:{remotename[1]}" else: remote = str(remotename) try: await Connection.run(self, stream, server_side=True) except LoopbackConnectionError: pass except trio.TooSlowError: pass finally: _logger.debug(f"Connection from {remote} closed") async def _listen(self, task_status=trio.TASK_STATUS_IGNORED): await trio.serve_tcp(self._accept, self._config.controller_port, task_status=task_status) async def run(self, *, task_status=trio.TASK_STATUS_IGNORED): async with trio.open_nursery() as nursery: await nursery.start(self.database.run) if self._is_controller: await nursery.start(self._listen) for name in self._config.controllers: conn = ControllerConn(self, name) self._controllers[name] = conn task_status.started() for conn in self._controllers.values(): nursery.start_soon(conn.run) await trio.sleep_forever() def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes: m = hmac.new(self._config.comm_secret.encode("utf8"), digestmod=hashlib.sha256) if server_side: m.update(b"server$") else: m.update(b"client$") m.update(ssl_binding) return m.digest() def _make_hello(self, ssl_binding: bytes, server_side: bool) -> capport.comm.message.Hello: return capport.comm.message.Hello( instance_id=self._instance_id.bytes, hostname=self._hostname, is_controller=self._is_controller, authentication=self._calc_authentication(ssl_binding, server_side), ) async def _sync_new_connection(self, peer_id: uuid.UUID, conn: Connection) -> None: # send database (and all changes) to peers await self.send(*self.database.serialize(), to=peer_id) async def _new_peer(self, peer_id: uuid.UUID, conn: Connection, task_status=trio.TASK_STATUS_IGNORED) -> None: have = self._established.get(peer_id, None) if not have: # peer unknown, "normal start" # no "await" between get above and set here!!! self._established[peer_id] = conn # first wait for app to handle new peer await self._app.new_peer(peer_id=peer_id) task_status.started() await self._sync_new_connection(peer_id, conn) return # peer already known - immediately allow receiving messages, then sync connection task_status.started() await self._sync_new_connection(peer_id, conn) # now try to register connection for outgoing messages while True: # recheck whether peer is currently known (due to awaits since last get) have = self._established.get(peer_id, None) if have: # already got a connection, nothing to do as long as it lives await have.closed.wait() else: # make `conn` new outgoing connection for peer # no "await" between get above and set here!!! self._established[peer_id] = conn await self._app.new_peer(peer_id=peer_id) return def _lost_peer(self, peer_id: uuid.UUID, conn: Connection): have = self._established.get(peer_id, None) lost = False if have is conn: lost = True self._established.pop(peer_id) conn.closed.set() # only notify if this was the active connection if lost: # even when we failover to another connection we still need to resync # as we don't know which messages might have got lost # -> always trigger lost_peer self._app.lost_peer(peer_id=peer_id) async def _received_msg(self, peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None: variant = msg.to_variant() if isinstance(variant, capport.comm.message.Hello): pass elif isinstance(variant, capport.comm.message.AuthenticationResult): pass elif isinstance(variant, capport.comm.message.Ping): pass elif isinstance(variant, capport.comm.message.MacStates): await self._app.received_mac_state(from_peer_id=peer_id, states=variant) async with self.database.make_changes() as pu: for state in variant.states: pu.received_mac_state(state) if pu: # re-broadcast all received updates to all peers await self.broadcast(*pu.serialized, exclude=peer_id) await self._app.mac_states_changed(from_peer_id=peer_id, pending_updates=pu) else: await self._app.received_unknown_message(from_peer_id=peer_id, msg=msg) def peer_is_controller(self, peer_id: uuid.UUID) -> bool: conn = self._established.get(peer_id) if conn: return conn.peer.is_controller return False async def send(self, *msgs: capport.comm.message.Message, to: uuid.UUID): conn = self._established.get(to) if conn: await conn.send_msg(*msgs) async def broadcast(self, *msgs: capport.comm.message.Message, exclude: uuid.UUID | None = None): async with trio.open_nursery() as nursery: for peer_id, conn in self._established.items(): if peer_id == exclude: continue nursery.start_soon(conn.send_msg, *msgs)