452 lines
18 KiB
Python
452 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import hmac
|
|
import logging
|
|
import random
|
|
import socket
|
|
import ssl
|
|
import struct
|
|
import typing
|
|
import uuid
|
|
|
|
import trio
|
|
|
|
import capport.comm.message
|
|
import capport.database
|
|
|
|
if typing.TYPE_CHECKING:
|
|
from ..config import Config
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HubConnectionReadError(ConnectionError):
|
|
pass
|
|
|
|
|
|
class HubConnectionClosedError(ConnectionError):
|
|
pass
|
|
|
|
|
|
class LoopbackConnectionError(Exception):
|
|
pass
|
|
|
|
|
|
class Channel:
|
|
def __init__(self, hub: Hub, transport_stream, server_side: bool):
|
|
self._hub = hub
|
|
self._serverside = server_side
|
|
self._ssl = trio.SSLStream(transport_stream, self._hub._anon_context, server_side=server_side)
|
|
_logger.debug(f"{self}: created (server_side={server_side})")
|
|
|
|
def __repr__(self) -> str:
|
|
return f"Channel[0x{id(self):x}]"
|
|
|
|
async def do_handshake(self) -> capport.comm.message.Hello:
|
|
try:
|
|
await self._ssl.do_handshake()
|
|
ssl_binding = self._ssl.get_channel_binding()
|
|
if not ssl_binding:
|
|
# binding mustn't be None after successful handshake
|
|
raise ConnectionError("Missing SSL channel binding")
|
|
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
|
raise ConnectionError(e) from None
|
|
msg = self._hub._make_hello(ssl_binding, server_side=self._serverside).to_message()
|
|
await self.send_msg(msg)
|
|
peer_hello = (await self.recv_msg()).to_variant()
|
|
if not isinstance(peer_hello, capport.comm.message.Hello):
|
|
raise HubConnectionReadError("Expected Hello as first message")
|
|
expected_auth = self._hub._calc_authentication(ssl_binding, server_side=not self._serverside)
|
|
auth_succ = peer_hello.authentication == expected_auth
|
|
await self.send_msg(capport.comm.message.AuthenticationResult(success=auth_succ).to_message())
|
|
peer_auth = (await self.recv_msg()).to_variant()
|
|
if not isinstance(peer_auth, capport.comm.message.AuthenticationResult):
|
|
raise HubConnectionReadError("Expected AuthenticationResult as second message")
|
|
if not auth_succ or not peer_auth.success:
|
|
raise HubConnectionReadError("Authentication failed")
|
|
return peer_hello
|
|
|
|
async def _read(self, num: int) -> bytes:
|
|
assert num > 0
|
|
buf = b""
|
|
# _logger.debug(f"{self}:_read({num})")
|
|
while num > 0:
|
|
try:
|
|
part = await self._ssl.receive_some(num)
|
|
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
|
raise ConnectionError(e) from None
|
|
# _logger.debug(f"{self}:_read({num}) got part {part!r}")
|
|
if len(part) == 0:
|
|
if len(buf) == 0:
|
|
raise HubConnectionClosedError()
|
|
raise HubConnectionReadError("Unexpected end of TLS stream")
|
|
buf += part
|
|
num -= len(part)
|
|
if num < 0:
|
|
raise HubConnectionReadError("TLS receive_some returned too much")
|
|
return buf
|
|
|
|
async def _recv_raw_msg(self) -> bytes:
|
|
len_bytes = await self._read(4)
|
|
(chunk_size,) = struct.unpack("!I", len_bytes)
|
|
chunk = await self._read(chunk_size)
|
|
if chunk is None:
|
|
raise HubConnectionReadError("Unexpected end of TLS stream after chunk length")
|
|
return chunk
|
|
|
|
async def recv_msg(self) -> capport.comm.message.Message:
|
|
try:
|
|
chunk = await self._recv_raw_msg()
|
|
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
|
raise ConnectionError(e) from None
|
|
msg = capport.comm.message.Message()
|
|
msg.ParseFromString(chunk)
|
|
return msg
|
|
|
|
async def _send_raw(self, chunk: bytes) -> None:
|
|
try:
|
|
await self._ssl.send_all(chunk)
|
|
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
|
raise ConnectionError(e) from None
|
|
|
|
async def send_msg(self, msg: capport.comm.message.Message):
|
|
chunk = msg.SerializeToString(deterministic=True)
|
|
chunk_size = len(chunk)
|
|
len_bytes = struct.pack("!I", chunk_size)
|
|
chunk = len_bytes + chunk
|
|
await self._send_raw(chunk)
|
|
|
|
async def aclose(self):
|
|
try:
|
|
await self._ssl.aclose()
|
|
except (ssl.SSLSyscallError, trio.BrokenResourceError) as e:
|
|
raise ConnectionError(e) from None
|
|
|
|
|
|
class Connection:
|
|
PING_INTERVAL = 10
|
|
RECEIVE_TIMEOUT = 15
|
|
SEND_TIMEOUT = 5
|
|
|
|
def __init__(self, hub: Hub, channel: Channel, peer: capport.comm.message.Hello):
|
|
self._channel = channel
|
|
self._hub = hub
|
|
tx: trio.MemorySendChannel
|
|
rx: trio.MemoryReceiveChannel
|
|
(tx, rx) = trio.open_memory_channel(64)
|
|
self._pending_tx = tx
|
|
self._pending_rx = rx
|
|
self.peer: capport.comm.message.Hello = peer
|
|
self.peer_id: uuid.UUID = uuid.UUID(bytes=peer.instance_id)
|
|
self.closed = trio.Event() # set by Hub._lost_peer
|
|
_logger.debug(f"{self._channel}: authenticated -> {self.peer_id}")
|
|
|
|
async def _sender(self, cancel_scope: trio.CancelScope) -> None:
|
|
try:
|
|
msg: capport.comm.message.Message | None
|
|
while True:
|
|
msg = None
|
|
# make sure we send something every PING_INTERVAL
|
|
with trio.move_on_after(self.PING_INTERVAL):
|
|
msg = await self._pending_rx.receive()
|
|
# if send blocks too long we're in trouble
|
|
with trio.fail_after(self.SEND_TIMEOUT):
|
|
if msg:
|
|
await self._channel.send_msg(msg)
|
|
else:
|
|
await self._channel.send_msg(capport.comm.message.Ping(payload=b"ping").to_message())
|
|
except trio.TooSlowError:
|
|
_logger.warning(f"{self._channel}: send timed out")
|
|
except ConnectionError as e:
|
|
_logger.warning(f"{self._channel}: failed sending: {e!r}")
|
|
except Exception:
|
|
_logger.exception(f"{self._channel}: failed sending")
|
|
finally:
|
|
cancel_scope.cancel()
|
|
|
|
async def _receive(self, cancel_scope: trio.CancelScope) -> None:
|
|
try:
|
|
while True:
|
|
try:
|
|
with trio.fail_after(self.RECEIVE_TIMEOUT):
|
|
msg = await self._channel.recv_msg()
|
|
except (HubConnectionClosedError, ConnectionResetError):
|
|
return
|
|
except trio.TooSlowError:
|
|
_logger.warning(f"{self._channel}: receive timed out")
|
|
return
|
|
await self._hub._received_msg(self.peer_id, msg)
|
|
except ConnectionError as e:
|
|
_logger.warning(f"{self._channel}: failed receiving: {e!r}")
|
|
except Exception:
|
|
_logger.exception(f"{self._channel}: failed receiving")
|
|
finally:
|
|
cancel_scope.cancel()
|
|
|
|
async def _inner_run(self) -> None:
|
|
if self.peer_id == self._hub._instance_id:
|
|
# connected to ourself, don't need that
|
|
raise LoopbackConnectionError()
|
|
async with trio.open_nursery() as nursery:
|
|
nursery.start_soon(self._sender, nursery.cancel_scope)
|
|
# be nice and wait for new_peer beforce receiving messages
|
|
# (won't work on failover to a second connection)
|
|
await nursery.start(self._hub._new_peer, self.peer_id, self)
|
|
nursery.start_soon(self._receive, nursery.cancel_scope)
|
|
|
|
async def send_msg(self, *msgs: capport.comm.message.Message):
|
|
try:
|
|
for msg in msgs:
|
|
await self._pending_tx.send(msg)
|
|
except trio.ClosedResourceError:
|
|
pass
|
|
|
|
async def _run(self) -> None:
|
|
try:
|
|
await self._inner_run()
|
|
finally:
|
|
_logger.debug(f"{self._channel}: finished message handling")
|
|
# basic (non-async) cleanup
|
|
self._hub._lost_peer(self.peer_id, self)
|
|
self._pending_tx.close()
|
|
self._pending_rx.close()
|
|
# allow 3 seconds for proper cleanup
|
|
with trio.CancelScope(shield=True, deadline=trio.current_time() + 3):
|
|
try:
|
|
await self._channel.aclose()
|
|
except OSError:
|
|
pass
|
|
|
|
@staticmethod
|
|
async def run(hub: Hub, transport_stream, server_side: bool) -> None:
|
|
channel = Channel(hub, transport_stream, server_side)
|
|
try:
|
|
with trio.fail_after(5):
|
|
peer = await channel.do_handshake()
|
|
except trio.TooSlowError:
|
|
_logger.warning("Handshake timed out")
|
|
return
|
|
conn = Connection(hub, channel, peer)
|
|
await conn._run()
|
|
|
|
|
|
class ControllerConn:
|
|
def __init__(self, hub: Hub, hostname: str):
|
|
self._hub = hub
|
|
self.hostname = hostname
|
|
self.loopback = False
|
|
|
|
async def _connect(self):
|
|
_logger.info(f"Connecting to controller at {self.hostname}")
|
|
with trio.fail_after(5):
|
|
try:
|
|
stream = await trio.open_tcp_stream(self.hostname, self._hub._config.controller_port)
|
|
except OSError as e:
|
|
_logger.warning(f"Failed to connect to controller at {self.hostname}: {e}")
|
|
return
|
|
try:
|
|
await Connection.run(self._hub, stream, server_side=False)
|
|
finally:
|
|
_logger.info(f"Connection to {self.hostname} closed")
|
|
|
|
async def run(self):
|
|
while True:
|
|
try:
|
|
await self._connect()
|
|
except LoopbackConnectionError:
|
|
_logger.debug(f"Connection to {self.hostname} reached ourself")
|
|
self.loopback = True
|
|
return
|
|
except trio.TooSlowError:
|
|
pass
|
|
# try again later
|
|
retry_splay = random.random() * 5
|
|
await trio.sleep(10 + retry_splay)
|
|
|
|
|
|
class HubApplication:
|
|
async def new_peer(self, *, peer_id: uuid.UUID) -> None:
|
|
_logger.info(f"New peer {peer_id}")
|
|
|
|
def lost_peer(self, *, peer_id: uuid.UUID) -> None:
|
|
_logger.warning(f"Lost peer {peer_id}")
|
|
|
|
async def received_unknown_message(self, *, from_peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None:
|
|
_logger.warning(f"Received from {from_peer_id}: {str(msg).strip()}")
|
|
|
|
async def received_mac_state(self, *, from_peer_id: uuid.UUID, states: capport.comm.message.MacStates) -> None:
|
|
if _logger.isEnabledFor(logging.DEBUG):
|
|
_logger.debug(f"Received states from {from_peer_id}: {str(states).strip()}")
|
|
|
|
async def mac_states_changed(
|
|
self,
|
|
*,
|
|
from_peer_id: uuid.UUID,
|
|
pending_updates: capport.database.PendingUpdates,
|
|
) -> None:
|
|
if _logger.isEnabledFor(logging.DEBUG):
|
|
_logger.debug(f"Received new states from {from_peer_id}: {pending_updates}")
|
|
|
|
|
|
class Hub:
|
|
def __init__(self, config: Config, app: HubApplication, *, is_controller: bool) -> None:
|
|
self._config = config
|
|
self._instance_id = uuid.uuid4()
|
|
self._hostname = socket.getfqdn()
|
|
self._app = app
|
|
self._is_controller = is_controller
|
|
state_filename: str
|
|
if is_controller:
|
|
state_filename = config.database_file
|
|
else:
|
|
state_filename = ""
|
|
self.database = capport.database.Database(state_filename=state_filename)
|
|
self._anon_context = ssl.SSLContext()
|
|
# python ssl doesn't support setting tls1.3 ciphers yet, so make sure we stay on 1.2 for now to enable anon
|
|
self._anon_context.minimum_version = ssl.TLSVersion.TLSv1_2
|
|
self._anon_context.maximum_version = ssl.TLSVersion.TLSv1_2
|
|
# -> AECDH-AES256-SHA
|
|
# sadly SECLEVEL=0 seems to be needed for aNULL, but we might accept bad curves too that way?
|
|
self._anon_context.set_ciphers("HIGH+aNULL+AES256+kECDHE:@SECLEVEL=0")
|
|
self._controllers: dict[str, ControllerConn] = {}
|
|
self._established: dict[uuid.UUID, Connection] = {}
|
|
|
|
async def _accept(self, stream):
|
|
remotename = stream.socket.getpeername()
|
|
if isinstance(remotename, tuple) and len(remotename) == 2:
|
|
remote = f"[{remotename[0]}]:{remotename[1]}"
|
|
else:
|
|
remote = str(remotename)
|
|
try:
|
|
await Connection.run(self, stream, server_side=True)
|
|
except LoopbackConnectionError:
|
|
pass
|
|
except trio.TooSlowError:
|
|
pass
|
|
finally:
|
|
_logger.debug(f"Connection from {remote} closed")
|
|
|
|
async def _listen(self, task_status=trio.TASK_STATUS_IGNORED):
|
|
await trio.serve_tcp(self._accept, self._config.controller_port, task_status=task_status)
|
|
|
|
async def run(self, *, task_status=trio.TASK_STATUS_IGNORED):
|
|
async with trio.open_nursery() as nursery:
|
|
await nursery.start(self.database.run)
|
|
if self._is_controller:
|
|
await nursery.start(self._listen)
|
|
|
|
for name in self._config.controllers:
|
|
conn = ControllerConn(self, name)
|
|
self._controllers[name] = conn
|
|
|
|
task_status.started()
|
|
|
|
for conn in self._controllers.values():
|
|
nursery.start_soon(conn.run)
|
|
|
|
await trio.sleep_forever()
|
|
|
|
def _calc_authentication(self, ssl_binding: bytes, server_side: bool) -> bytes:
|
|
m = hmac.new(self._config.comm_secret.encode("utf8"), digestmod=hashlib.sha256)
|
|
if server_side:
|
|
m.update(b"server$")
|
|
else:
|
|
m.update(b"client$")
|
|
m.update(ssl_binding)
|
|
return m.digest()
|
|
|
|
def _make_hello(self, ssl_binding: bytes, server_side: bool) -> capport.comm.message.Hello:
|
|
return capport.comm.message.Hello(
|
|
instance_id=self._instance_id.bytes,
|
|
hostname=self._hostname,
|
|
is_controller=self._is_controller,
|
|
authentication=self._calc_authentication(ssl_binding, server_side),
|
|
)
|
|
|
|
async def _sync_new_connection(self, peer_id: uuid.UUID, conn: Connection) -> None:
|
|
# send database (and all changes) to peers
|
|
await self.send(*self.database.serialize(), to=peer_id)
|
|
|
|
async def _new_peer(self, peer_id: uuid.UUID, conn: Connection, task_status=trio.TASK_STATUS_IGNORED) -> None:
|
|
have = self._established.get(peer_id, None)
|
|
if not have:
|
|
# peer unknown, "normal start"
|
|
# no "await" between get above and set here!!!
|
|
self._established[peer_id] = conn
|
|
# first wait for app to handle new peer
|
|
await self._app.new_peer(peer_id=peer_id)
|
|
task_status.started()
|
|
await self._sync_new_connection(peer_id, conn)
|
|
return
|
|
|
|
# peer already known - immediately allow receiving messages, then sync connection
|
|
task_status.started()
|
|
await self._sync_new_connection(peer_id, conn)
|
|
# now try to register connection for outgoing messages
|
|
while True:
|
|
# recheck whether peer is currently known (due to awaits since last get)
|
|
have = self._established.get(peer_id, None)
|
|
if have:
|
|
# already got a connection, nothing to do as long as it lives
|
|
await have.closed.wait()
|
|
else:
|
|
# make `conn` new outgoing connection for peer
|
|
# no "await" between get above and set here!!!
|
|
self._established[peer_id] = conn
|
|
await self._app.new_peer(peer_id=peer_id)
|
|
return
|
|
|
|
def _lost_peer(self, peer_id: uuid.UUID, conn: Connection):
|
|
have = self._established.get(peer_id, None)
|
|
lost = False
|
|
if have is conn:
|
|
lost = True
|
|
self._established.pop(peer_id)
|
|
conn.closed.set()
|
|
# only notify if this was the active connection
|
|
if lost:
|
|
# even when we failover to another connection we still need to resync
|
|
# as we don't know which messages might have got lost
|
|
# -> always trigger lost_peer
|
|
self._app.lost_peer(peer_id=peer_id)
|
|
|
|
async def _received_msg(self, peer_id: uuid.UUID, msg: capport.comm.message.Message) -> None:
|
|
variant = msg.to_variant()
|
|
if isinstance(variant, capport.comm.message.Hello):
|
|
pass
|
|
elif isinstance(variant, capport.comm.message.AuthenticationResult):
|
|
pass
|
|
elif isinstance(variant, capport.comm.message.Ping):
|
|
pass
|
|
elif isinstance(variant, capport.comm.message.MacStates):
|
|
await self._app.received_mac_state(from_peer_id=peer_id, states=variant)
|
|
async with self.database.make_changes() as pu:
|
|
for state in variant.states:
|
|
pu.received_mac_state(state)
|
|
if pu:
|
|
# re-broadcast all received updates to all peers
|
|
await self.broadcast(*pu.serialized, exclude=peer_id)
|
|
await self._app.mac_states_changed(from_peer_id=peer_id, pending_updates=pu)
|
|
else:
|
|
await self._app.received_unknown_message(from_peer_id=peer_id, msg=msg)
|
|
|
|
def peer_is_controller(self, peer_id: uuid.UUID) -> bool:
|
|
conn = self._established.get(peer_id)
|
|
if conn:
|
|
return conn.peer.is_controller
|
|
return False
|
|
|
|
async def send(self, *msgs: capport.comm.message.Message, to: uuid.UUID):
|
|
conn = self._established.get(to)
|
|
if conn:
|
|
await conn.send_msg(*msgs)
|
|
|
|
async def broadcast(self, *msgs: capport.comm.message.Message, exclude: uuid.UUID | None = None):
|
|
async with trio.open_nursery() as nursery:
|
|
for peer_id, conn in self._established.items():
|
|
if peer_id == exclude:
|
|
continue
|
|
nursery.start_soon(conn.send_msg, *msgs)
|